Browse Source

add support for multiple concurrent requests with request ids

Alex Cheema 1 year ago
parent
commit
b01f69bb6b

+ 29 - 5
example_user_2.py

@@ -16,11 +16,16 @@ model_path = get_model_path(path_or_hf_repo)
 tokenizer_config = {}
 tokenizer_config = {}
 tokenizer = load_tokenizer(model_path, tokenizer_config)
 tokenizer = load_tokenizer(model_path, tokenizer_config)
 
 
-peer = GRPCPeerHandle(
+peer1 = GRPCPeerHandle(
     "node1",
     "node1",
     "localhost:8080",
     "localhost:8080",
     DeviceCapabilities(model="test1", chip="test1", memory=10000)
     DeviceCapabilities(model="test1", chip="test1", memory=10000)
 )
 )
+peer2 = GRPCPeerHandle(
+    "node2",
+    "localhost:8081",
+    DeviceCapabilities(model="test1", chip="test1", memory=10000)
+)
 shard = Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=0, n_layers=32)
 shard = Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=0, n_layers=32)
 
 
 async def run_prompt(prompt: str):
 async def run_prompt(prompt: str):
@@ -35,11 +40,30 @@ async def run_prompt(prompt: str):
             messages, tokenize=False, add_generation_prompt=True
             messages, tokenize=False, add_generation_prompt=True
         )
         )
 
 
-    await peer.connect()
-    await peer.reset_shard(shard)
+    for peer in [peer1, peer2]:
+        await peer.connect()
+        await peer.reset_shard(shard)
+
+    try:
+        await peer1.send_prompt(shard, prompt, "request-id-1")
+    except Exception as e:
+        print(e)
+
+    import sys
+    # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
+    previous_length = 0
+    while True:
+        result, is_finished = await peer2.get_inference_result("request-id-1")
+        await asyncio.sleep(0.1)
+
+        # Print the updated string in place
+        updated_string = tokenizer.decode(result)
+        print(updated_string[previous_length:], end='', flush=True)
+        previous_length = len(updated_string)
 
 
-    result = await peer.send_prompt(shard, prompt)
-    print(tokenizer.decode(result))
+        if is_finished:
+            print("\nDone")
+            break
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="Run prompt")
     parser = argparse.ArgumentParser(description="Run prompt")

+ 3 - 2
inference/inference_engine.py

@@ -1,16 +1,17 @@
 import numpy as np
 import numpy as np
 import mlx.nn as nn
 import mlx.nn as nn
 
 
+from typing import Tuple
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from .shard import Shard
 from .shard import Shard
 
 
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
     @abstractmethod
     @abstractmethod
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
+    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> Tuple[np.ndarray, bool]:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
+    async def infer_prompt(self, shard: Shard, prompt: str) -> Tuple[np.ndarray, bool]:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod

+ 2 - 1
networking/grpc/grpc_discovery.py

@@ -99,7 +99,8 @@ class GRPCDiscovery(Discovery):
                     peer_host = addr[0]
                     peer_host = addr[0]
                     peer_port = message['grpc_port']
                     peer_port = message['grpc_port']
                     device_capabilities = DeviceCapabilities(**message['device_capabilities'])
                     device_capabilities = DeviceCapabilities(**message['device_capabilities'])
-                    self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+                    if peer_id not in self.known_peers:
+                        self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
                     self.peer_last_seen[peer_id] = time.time()
                     self.peer_last_seen[peer_id] = time.time()
             except Exception as e:
             except Exception as e:
                 print(f"Error in peer discovery: {e}")
                 print(f"Error in peer discovery: {e}")

+ 24 - 8
networking/grpc/grpc_peer_handle.py

@@ -1,6 +1,6 @@
 import grpc
 import grpc
 import numpy as np
 import numpy as np
-from typing import Optional
+from typing import Optional, Tuple
 
 
 # These would be generated from the .proto file
 # These would be generated from the .proto file
 from . import node_service_pb2
 from . import node_service_pb2
@@ -16,6 +16,8 @@ class GRPCPeerHandle(PeerHandle):
         self._id = id
         self._id = id
         self.address = address
         self.address = address
         self._device_capabilities = device_capabilities
         self._device_capabilities = device_capabilities
+        self.channel = None
+        self.stub = None
 
 
     def id(self) -> str:
     def id(self) -> str:
         return self._id
         return self._id
@@ -24,23 +26,30 @@ class GRPCPeerHandle(PeerHandle):
         return self._device_capabilities
         return self._device_capabilities
 
 
     async def connect(self):
     async def connect(self):
-        self.channel = grpc.aio.insecure_channel(self.address)
+        self.channel = grpc.aio.insecure_channel(self.address, options=[
+            ('grpc.max_metadata_size', 32*1024*1024)
+        ])
         self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
         self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
 
 
+    async def is_connected(self) -> bool:
+        return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
+
     async def disconnect(self):
     async def disconnect(self):
-        await self.channel.close()
+        if self.channel:
+            await self.channel.close()
+        self.channel = None
+        self.stub = None
 
 
-    async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
-        request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
+    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
+        request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers), request_id=request_id)
         response = await self.stub.SendPrompt(request)
         response = await self.stub.SendPrompt(request)
-        print(f"Sent prompt to {self.address}: {prompt}")
 
 
         if not response.tensor_data or not response.shape or not response.dtype:
         if not response.tensor_data or not response.shape or not response.dtype:
             return None
             return None
 
 
         return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
         return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
 
-    async def send_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.array]:
+    async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
         request = node_service_pb2.TensorRequest(
         request = node_service_pb2.TensorRequest(
             shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
             shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
             tensor = node_service_pb2.Tensor(
             tensor = node_service_pb2.Tensor(
@@ -48,6 +57,7 @@ class GRPCPeerHandle(PeerHandle):
                 shape=tensor.shape,
                 shape=tensor.shape,
                 dtype=str(tensor.dtype)
                 dtype=str(tensor.dtype)
             ),
             ),
+            request_id=request_id
         )
         )
         response = await self.stub.SendTensor(request)
         response = await self.stub.SendTensor(request)
 
 
@@ -56,10 +66,16 @@ class GRPCPeerHandle(PeerHandle):
 
 
         return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
         return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
 
+    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+        request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
+        response = await self.stub.GetInferenceResult(request)
+        if response.tensor is None:
+            return None, response.is_finished
+        return np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape), response.is_finished
+
     async def reset_shard(self, shard: Shard) -> None:
     async def reset_shard(self, shard: Shard) -> None:
         request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
         request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
         await self.stub.ResetShard(request)
         await self.stub.ResetShard(request)
-        print(f"Reset shard {shard} on {self.address}")
 
 
     async def collect_topology(self, max_depth: int) -> Topology:
     async def collect_topology(self, max_depth: int) -> Topology:
         request = node_service_pb2.CollectTopologyRequest(max_depth=max_depth)
         request = node_service_pb2.CollectTopologyRequest(max_depth=max_depth)

+ 19 - 7
networking/grpc/grpc_server.py

@@ -8,6 +8,8 @@ from inference.shard import Shard
 
 
 from orchestration import Node
 from orchestration import Node
 
 
+import uuid
+
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     def __init__(self, node: Node, host: str, port: int):
     def __init__(self, node: Node, host: str, port: int):
         self.node = node
         self.node = node
@@ -17,7 +19,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 
 
     async def start(self) -> None:
     async def start(self) -> None:
         self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
         self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
-            ('grpc.max_metadata_size', 128*1024)
+            ('grpc.max_metadata_size', 32*1024*1024)
         ])
         ])
         node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
         node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
         listen_addr = f'{self.host}:{self.port}'
         listen_addr = f'{self.host}:{self.port}'
@@ -27,23 +29,33 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 
 
     async def stop(self) -> None:
     async def stop(self) -> None:
         if self.server:
         if self.server:
-            await self.server.stop(5)  # 5 seconds grace period
-            print("Server stopped")
+            await self.server.stop(grace=5)
+            await self.server.wait_for_termination()
+            print("Server stopped and all connections are closed")
 
 
     async def SendPrompt(self, request, context):
     async def SendPrompt(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         prompt = request.prompt
         prompt = request.prompt
-        result = await self.node.process_prompt(shard, prompt)
+        request_id = request.request_id
+        result = await self.node.process_prompt(shard, prompt, request_id)
         tensor_data = result.tobytes() if result is not None else None
         tensor_data = result.tobytes() if result is not None else None
-        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
+        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
 
     async def SendTensor(self, request, context):
     async def SendTensor(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
         tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
-        result = await self.node.process_tensor(shard, tensor)
+        request_id = request.request_id
+
+        result = await self.node.process_tensor(shard, tensor, request_id)
         print("SendTensor tensor result", result)
         print("SendTensor tensor result", result)
         tensor_data = result.tobytes() if result is not None else None
         tensor_data = result.tobytes() if result is not None else None
-        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
+        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
+
+    async def GetInferenceResult(self, request, context):
+        request_id = request.request_id
+        result = await self.node.get_inference_result(request_id)
+        tensor_data = result[0].tobytes() if result[0] is not None else None
+        return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype))) if result[0] is not None else node_service_pb2.InferenceResult()
 
 
     async def ResetShard(self, request, context):
     async def ResetShard(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)

+ 12 - 0
networking/grpc/node_service.proto

@@ -6,6 +6,7 @@ service NodeService {
   rpc SendPrompt (PromptRequest) returns (Tensor) {}
   rpc SendPrompt (PromptRequest) returns (Tensor) {}
   rpc SendTensor (TensorRequest) returns (Tensor) {}
   rpc SendTensor (TensorRequest) returns (Tensor) {}
   rpc ResetShard (ResetShardRequest) returns (Empty) {}
   rpc ResetShard (ResetShardRequest) returns (Empty) {}
+  rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
 }
 }
 
 
@@ -19,11 +20,22 @@ message Shard {
 message PromptRequest {
 message PromptRequest {
   Shard shard = 1;
   Shard shard = 1;
   string prompt = 2;
   string prompt = 2;
+  optional string request_id = 3;
 }
 }
 
 
 message TensorRequest {
 message TensorRequest {
   Shard shard = 1;
   Shard shard = 1;
   Tensor tensor = 2;
   Tensor tensor = 2;
+  optional string request_id = 3;
+}
+
+message GetInferenceResultRequest {
+  string request_id = 1;
+}
+
+message InferenceResult {
+  optional Tensor tensor = 1;
+  bool is_finished = 2;
 }
 }
 
 
 message Tensor {
 message Tensor {

File diff suppressed because it is too large
+ 0 - 1
networking/grpc/node_service_pb2.py


+ 43 - 0
networking/grpc/node_service_pb2_grpc.py

@@ -54,6 +54,11 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.ResetShardRequest.SerializeToString,
                 request_serializer=node__service__pb2.ResetShardRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
                 _registered_method=True)
+        self.GetInferenceResult = channel.unary_unary(
+                '/node_service.NodeService/GetInferenceResult',
+                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.InferenceResult.FromString,
+                _registered_method=True)
         self.CollectTopology = channel.unary_unary(
         self.CollectTopology = channel.unary_unary(
                 '/node_service.NodeService/CollectTopology',
                 '/node_service.NodeService/CollectTopology',
                 request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
                 request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
@@ -82,6 +87,12 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
 
+    def GetInferenceResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
     def CollectTopology(self, request, context):
     def CollectTopology(self, request, context):
         """Missing associated documentation comment in .proto file."""
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -106,6 +117,11 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.ResetShardRequest.FromString,
                     request_deserializer=node__service__pb2.ResetShardRequest.FromString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             ),
+            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.GetInferenceResult,
+                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+            ),
             'CollectTopology': grpc.unary_unary_rpc_method_handler(
             'CollectTopology': grpc.unary_unary_rpc_method_handler(
                     servicer.CollectTopology,
                     servicer.CollectTopology,
                     request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
                     request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
@@ -203,6 +219,33 @@ class NodeService(object):
             metadata,
             metadata,
             _registered_method=True)
             _registered_method=True)
 
 
+    @staticmethod
+    def GetInferenceResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/GetInferenceResult',
+            node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            node__service__pb2.InferenceResult.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
     @staticmethod
     @staticmethod
     def CollectTopology(request,
     def CollectTopology(request,
             target,
             target,

+ 11 - 3
networking/peer_handle.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Optional
+from typing import Optional, Tuple
 import numpy as np
 import numpy as np
 from inference.shard import Shard
 from inference.shard import Shard
 from topology.device_capabilities import DeviceCapabilities
 from topology.device_capabilities import DeviceCapabilities
@@ -18,16 +18,24 @@ class PeerHandle(ABC):
     async def connect(self) -> None:
     async def connect(self) -> None:
         pass
         pass
 
 
+    @abstractmethod
+    async def is_connected(self) -> bool:
+        pass
+
     @abstractmethod
     @abstractmethod
     async def disconnect(self) -> None:
     async def disconnect(self) -> None:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
+    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
+        pass
+
+    @abstractmethod
+    async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    async def send_tensor(self, shard: Shard, tensor: np.array) -> Optional[np.array]:
+    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod

+ 6 - 1
orchestration/node.py

@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Tuple
 import numpy as np
 import numpy as np
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from inference.shard import Shard
 from inference.shard import Shard
@@ -25,5 +25,10 @@ class Node(ABC):
     async def reset_shard(self, shard: Shard) -> None:
     async def reset_shard(self, shard: Shard) -> None:
         pass
         pass
 
 
+    @abstractmethod
     async def collect_topology(self, max_depth: int = 2) -> Topology:
     async def collect_topology(self, max_depth: int = 2) -> Topology:
         pass
         pass
+
+    @abstractmethod
+    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+        pass

+ 82 - 33
orchestration/standard_node.py

@@ -1,4 +1,4 @@
-from typing import List, Optional, Callable
+from typing import List, Dict, Optional, Callable, Tuple
 import numpy as np
 import numpy as np
 from networking import Discovery, PeerHandle, Server
 from networking import Discovery, PeerHandle, Server
 from inference.inference_engine import InferenceEngine, Shard
 from inference.inference_engine import InferenceEngine, Shard
@@ -7,6 +7,8 @@ from topology.topology import Topology
 from topology.device_capabilities import device_capabilities
 from topology.device_capabilities import device_capabilities
 from topology.partitioning_strategy import PartitioningStrategy
 from topology.partitioning_strategy import PartitioningStrategy
 from topology.partitioning_strategy import Partition
 from topology.partitioning_strategy import Partition
+import asyncio
+import uuid
 
 
 class StandardNode(Node):
 class StandardNode(Node):
     def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
     def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
@@ -18,54 +20,70 @@ class StandardNode(Node):
         self.peers: List[PeerHandle] = {}
         self.peers: List[PeerHandle] = {}
         self.topology: Topology = Topology()
         self.topology: Topology = Topology()
         self.device_capabilities = device_capabilities()
         self.device_capabilities = device_capabilities()
-        self.buffered_token_output: List[int] = []
+        self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
         self.on_token = on_token
         self.on_token = on_token
         self.max_generate_tokens = max_generate_tokens
         self.max_generate_tokens = max_generate_tokens
 
 
     async def start(self, wait_for_peers: int = 0) -> None:
     async def start(self, wait_for_peers: int = 0) -> None:
         await self.server.start()
         await self.server.start()
         await self.discovery.start()
         await self.discovery.start()
-        self.peers = await self.discovery.discover_peers(wait_for_peers)
-        print(f"Starting with the following peers: {self.peers}")
-        print("Connecting to peers...")
-        for peer in self.peers:
-            await peer.connect()
-            print(f"Connected to {peer.id()}")
+        await self.update_peers(wait_for_peers)
         await self.collect_topology()
         await self.collect_topology()
         print(f"Collected topology: {self.topology}")
         print(f"Collected topology: {self.topology}")
+        asyncio.create_task(self.periodic_topology_collection(5))
 
 
     async def stop(self) -> None:
     async def stop(self) -> None:
         await self.discovery.stop()
         await self.discovery.stop()
         await self.server.stop()
         await self.server.stop()
 
 
-    async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
-        print("process prompt", shard, prompt)
+    async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+        if request_id is None:
+            request_id = str(uuid.uuid4())
+        if request_id not in self.buffered_token_output:
+            self.buffered_token_output[request_id] = ([], False)
+
+        print(f"[{request_id}] process prompt: {shard}, {prompt}")
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
+        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], is_finished)
 
 
-        print(f"result size: {result.size}, is finished: {is_finished}")
         if result.size == 1:
         if result.size == 1:
-            self.buffered_token_output.append(result.item())
-            self.on_token(self.buffered_token_output)
+            self.buffered_token_output[request_id][0].append(result.item())
+            self.on_token(self.buffered_token_output[request_id][0])
 
 
-        if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
-            await self.forward_tensor_to_next_shard(shard, result)
+        print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
 
-        return np.array(self.buffered_token_output) if self.buffered_token_output else None
+        if not is_finished and len(self.buffered_token_output[request_id]) < self.max_generate_tokens:
+            asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
 
 
-    async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
-        result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
+        return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
 
 
-        print(f"result size: {result.size}, is finished: {is_finished}")
-        if result.size == 1:
-            self.buffered_token_output.append(result.item())
-            self.on_token(self.buffered_token_output)
+    async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+        if request_id is None:
+            request_id = str(uuid.uuid4())
+        if request_id not in self.buffered_token_output:
+            self.buffered_token_output[request_id] = ([], False)
+
+        try:
+            print(f"[{request_id}] process_tensor: {shard}, {tensor}")
+            result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
+            self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], is_finished)
 
 
-        if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
-            await self.forward_tensor_to_next_shard(shard, result)
+            if result.size == 1:  # we got a new token out
+                self.buffered_token_output[request_id][0].append(result.item())
+                self.on_token(self.buffered_token_output[request_id][0])
+            print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
 
-        return np.array(self.buffered_token_output) if self.buffered_token_output else None
+            if not is_finished and len(self.buffered_token_output[request_id]) < self.max_generate_tokens:
+                asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
 
 
-    async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray) -> None:
+            return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+        except Exception as e:
+            import traceback
+            print(f"Error processing tensor for shard {shard}: {e}")
+            traceback.print_exc()
+            return None
+
+    async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray, request_id: str) -> None:
         if not self.partitioning_strategy:
         if not self.partitioning_strategy:
             print("No partitioning strategy found. Skipping forward.")
             print("No partitioning strategy found. Skipping forward.")
             return
             return
@@ -80,7 +98,7 @@ class StandardNode(Node):
 
 
             if next_partition:
             if next_partition:
                 if next_partition.node_id == self.id:
                 if next_partition.node_id == self.id:
-                    await self.process_tensor(shard, tensor)
+                    await self.process_tensor(shard, tensor, request_id)
                     return
                     return
 
 
                 target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
                 target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
@@ -91,9 +109,9 @@ class StandardNode(Node):
                 end_layer = int(next_partition.end * shard.n_layers) - 1
                 end_layer = int(next_partition.end * shard.n_layers) - 1
                 next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
                 next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
 
 
-                print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}")
+                print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}: {tensor}")
 
 
-                await target_peer.send_tensor(next_shard, tensor)
+                await target_peer.send_tensor(next_shard, tensor, request_id)
 
 
     def get_current_shard(self, shard: Shard) -> Shard:
     def get_current_shard(self, shard: Shard) -> Shard:
         partitions = self.partitioning_strategy.partition(self.topology)
         partitions = self.partitioning_strategy.partition(self.topology)
@@ -110,9 +128,20 @@ class StandardNode(Node):
     async def reset_shard(self, shard: Shard) -> None:
     async def reset_shard(self, shard: Shard) -> None:
         # Implement shard reset logic
         # Implement shard reset logic
         print(f"Resetting shard: {shard}")
         print(f"Resetting shard: {shard}")
-        self.buffered_token_output = []
+        self.buffered_token_output = {}
         await self.inference_engine.reset_shard(self.get_current_shard(shard))
         await self.inference_engine.reset_shard(self.get_current_shard(shard))
 
 
+    async def update_peers(self, wait_for_peers: int = 0) -> None:
+        self.peers = await self.discovery.discover_peers(wait_for_peers)
+        print(f"Starting with the following peers: {self.peers}")
+        print("Connecting to new peers...")
+        for peer in self.peers:
+            is_connected = await peer.is_connected()
+            print(f"Connected to {peer.id()}: {is_connected}")
+            if not is_connected:
+                await peer.connect()
+                print(f"Connected to peer {peer.id()}")
+
     async def collect_topology(self, max_depth: int = 4) -> Topology:
     async def collect_topology(self, max_depth: int = 4) -> Topology:
         self.topology.update_node(self.id, self.device_capabilities)
         self.topology.update_node(self.id, self.device_capabilities)
 
 
@@ -121,8 +150,28 @@ class StandardNode(Node):
             self.topology.add_edge(self.id, peer.id())
             self.topology.add_edge(self.id, peer.id())
 
 
             if max_depth > 0:
             if max_depth > 0:
-                other_topology = await peer.collect_topology(max_depth = max_depth - 1)
-                print(f"Collected topology from: {peer.id()}: {other_topology}")
-                self.topology.merge(other_topology)
+                try:
+                    other_topology = await peer.collect_topology(max_depth = max_depth - 1)
+                    print(f"Collected topology from: {peer.id()}: {other_topology}")
+                    self.topology.merge(other_topology)
+                except Exception as e:
+                    print(f"Error collecting topology from {peer.id()}: {e}")
 
 
         return self.topology
         return self.topology
+
+    async def periodic_topology_collection(self, interval: int):
+        while True:
+            await asyncio.sleep(interval)
+            try:
+                await self.update_peers()
+                await self.collect_topology()
+            except Exception as e:
+                print(f"Error collecting topology: {e}")
+
+            print("Topology collection task executed.")
+            print(f"Current topology: {self.topology}")
+
+    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+        if request_id not in self.buffered_token_output:
+            return None, False
+        return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]

Some files were not shown because too many files changed in this diff