Browse Source

add support for multiple concurrent requests with request ids

Alex Cheema 1 year ago
parent
commit
b01f69bb6b

+ 29 - 5
example_user_2.py

@@ -16,11 +16,16 @@ model_path = get_model_path(path_or_hf_repo)
 tokenizer_config = {}
 tokenizer = load_tokenizer(model_path, tokenizer_config)
 
-peer = GRPCPeerHandle(
+peer1 = GRPCPeerHandle(
     "node1",
     "localhost:8080",
     DeviceCapabilities(model="test1", chip="test1", memory=10000)
 )
+peer2 = GRPCPeerHandle(
+    "node2",
+    "localhost:8081",
+    DeviceCapabilities(model="test1", chip="test1", memory=10000)
+)
 shard = Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=0, n_layers=32)
 
 async def run_prompt(prompt: str):
@@ -35,11 +40,30 @@ async def run_prompt(prompt: str):
             messages, tokenize=False, add_generation_prompt=True
         )
 
-    await peer.connect()
-    await peer.reset_shard(shard)
+    for peer in [peer1, peer2]:
+        await peer.connect()
+        await peer.reset_shard(shard)
+
+    try:
+        await peer1.send_prompt(shard, prompt, "request-id-1")
+    except Exception as e:
+        print(e)
+
+    import sys
+    # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
+    previous_length = 0
+    while True:
+        result, is_finished = await peer2.get_inference_result("request-id-1")
+        await asyncio.sleep(0.1)
+
+        # Print the updated string in place
+        updated_string = tokenizer.decode(result)
+        print(updated_string[previous_length:], end='', flush=True)
+        previous_length = len(updated_string)
 
-    result = await peer.send_prompt(shard, prompt)
-    print(tokenizer.decode(result))
+        if is_finished:
+            print("\nDone")
+            break
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="Run prompt")

+ 3 - 2
inference/inference_engine.py

@@ -1,16 +1,17 @@
 import numpy as np
 import mlx.nn as nn
 
+from typing import Tuple
 from abc import ABC, abstractmethod
 from .shard import Shard
 
 class InferenceEngine(ABC):
     @abstractmethod
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
+    async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> Tuple[np.ndarray, bool]:
         pass
 
     @abstractmethod
-    async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
+    async def infer_prompt(self, shard: Shard, prompt: str) -> Tuple[np.ndarray, bool]:
         pass
 
     @abstractmethod

+ 2 - 1
networking/grpc/grpc_discovery.py

@@ -99,7 +99,8 @@ class GRPCDiscovery(Discovery):
                     peer_host = addr[0]
                     peer_port = message['grpc_port']
                     device_capabilities = DeviceCapabilities(**message['device_capabilities'])
-                    self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+                    if peer_id not in self.known_peers:
+                        self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
                     self.peer_last_seen[peer_id] = time.time()
             except Exception as e:
                 print(f"Error in peer discovery: {e}")

+ 24 - 8
networking/grpc/grpc_peer_handle.py

@@ -1,6 +1,6 @@
 import grpc
 import numpy as np
-from typing import Optional
+from typing import Optional, Tuple
 
 # These would be generated from the .proto file
 from . import node_service_pb2
@@ -16,6 +16,8 @@ class GRPCPeerHandle(PeerHandle):
         self._id = id
         self.address = address
         self._device_capabilities = device_capabilities
+        self.channel = None
+        self.stub = None
 
     def id(self) -> str:
         return self._id
@@ -24,23 +26,30 @@ class GRPCPeerHandle(PeerHandle):
         return self._device_capabilities
 
     async def connect(self):
-        self.channel = grpc.aio.insecure_channel(self.address)
+        self.channel = grpc.aio.insecure_channel(self.address, options=[
+            ('grpc.max_metadata_size', 32*1024*1024)
+        ])
         self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
 
+    async def is_connected(self) -> bool:
+        return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
+
     async def disconnect(self):
-        await self.channel.close()
+        if self.channel:
+            await self.channel.close()
+        self.channel = None
+        self.stub = None
 
-    async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
-        request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
+    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
+        request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers), request_id=request_id)
         response = await self.stub.SendPrompt(request)
-        print(f"Sent prompt to {self.address}: {prompt}")
 
         if not response.tensor_data or not response.shape or not response.dtype:
             return None
 
         return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
-    async def send_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.array]:
+    async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
         request = node_service_pb2.TensorRequest(
             shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
             tensor = node_service_pb2.Tensor(
@@ -48,6 +57,7 @@ class GRPCPeerHandle(PeerHandle):
                 shape=tensor.shape,
                 dtype=str(tensor.dtype)
             ),
+            request_id=request_id
         )
         response = await self.stub.SendTensor(request)
 
@@ -56,10 +66,16 @@ class GRPCPeerHandle(PeerHandle):
 
         return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
+    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+        request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
+        response = await self.stub.GetInferenceResult(request)
+        if response.tensor is None:
+            return None, response.is_finished
+        return np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape), response.is_finished
+
     async def reset_shard(self, shard: Shard) -> None:
         request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
         await self.stub.ResetShard(request)
-        print(f"Reset shard {shard} on {self.address}")
 
     async def collect_topology(self, max_depth: int) -> Topology:
         request = node_service_pb2.CollectTopologyRequest(max_depth=max_depth)

+ 19 - 7
networking/grpc/grpc_server.py

@@ -8,6 +8,8 @@ from inference.shard import Shard
 
 from orchestration import Node
 
+import uuid
+
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     def __init__(self, node: Node, host: str, port: int):
         self.node = node
@@ -17,7 +19,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 
     async def start(self) -> None:
         self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
-            ('grpc.max_metadata_size', 128*1024)
+            ('grpc.max_metadata_size', 32*1024*1024)
         ])
         node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
         listen_addr = f'{self.host}:{self.port}'
@@ -27,23 +29,33 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 
     async def stop(self) -> None:
         if self.server:
-            await self.server.stop(5)  # 5 seconds grace period
-            print("Server stopped")
+            await self.server.stop(grace=5)
+            await self.server.wait_for_termination()
+            print("Server stopped and all connections are closed")
 
     async def SendPrompt(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         prompt = request.prompt
-        result = await self.node.process_prompt(shard, prompt)
+        request_id = request.request_id
+        result = await self.node.process_prompt(shard, prompt, request_id)
         tensor_data = result.tobytes() if result is not None else None
-        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
+        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
     async def SendTensor(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
-        result = await self.node.process_tensor(shard, tensor)
+        request_id = request.request_id
+
+        result = await self.node.process_tensor(shard, tensor, request_id)
         print("SendTensor tensor result", result)
         tensor_data = result.tobytes() if result is not None else None
-        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
+        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
+
+    async def GetInferenceResult(self, request, context):
+        request_id = request.request_id
+        result = await self.node.get_inference_result(request_id)
+        tensor_data = result[0].tobytes() if result[0] is not None else None
+        return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype))) if result[0] is not None else node_service_pb2.InferenceResult()
 
     async def ResetShard(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)

+ 12 - 0
networking/grpc/node_service.proto

@@ -6,6 +6,7 @@ service NodeService {
   rpc SendPrompt (PromptRequest) returns (Tensor) {}
   rpc SendTensor (TensorRequest) returns (Tensor) {}
   rpc ResetShard (ResetShardRequest) returns (Empty) {}
+  rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
 }
 
@@ -19,11 +20,22 @@ message Shard {
 message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
+  optional string request_id = 3;
 }
 
 message TensorRequest {
   Shard shard = 1;
   Tensor tensor = 2;
+  optional string request_id = 3;
+}
+
+message GetInferenceResultRequest {
+  string request_id = 1;
+}
+
+message InferenceResult {
+  optional Tensor tensor = 1;
+  bool is_finished = 2;
 }
 
 message Tensor {

File diff suppressed because it is too large
+ 0 - 1
networking/grpc/node_service_pb2.py


+ 43 - 0
networking/grpc/node_service_pb2_grpc.py

@@ -54,6 +54,11 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.ResetShardRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
+        self.GetInferenceResult = channel.unary_unary(
+                '/node_service.NodeService/GetInferenceResult',
+                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.InferenceResult.FromString,
+                _registered_method=True)
         self.CollectTopology = channel.unary_unary(
                 '/node_service.NodeService/CollectTopology',
                 request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
@@ -82,6 +87,12 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
+    def GetInferenceResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
     def CollectTopology(self, request, context):
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -106,6 +117,11 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.ResetShardRequest.FromString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
+            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.GetInferenceResult,
+                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+            ),
             'CollectTopology': grpc.unary_unary_rpc_method_handler(
                     servicer.CollectTopology,
                     request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
@@ -203,6 +219,33 @@ class NodeService(object):
             metadata,
             _registered_method=True)
 
+    @staticmethod
+    def GetInferenceResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/GetInferenceResult',
+            node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            node__service__pb2.InferenceResult.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
     @staticmethod
     def CollectTopology(request,
             target,

+ 11 - 3
networking/peer_handle.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import Optional
+from typing import Optional, Tuple
 import numpy as np
 from inference.shard import Shard
 from topology.device_capabilities import DeviceCapabilities
@@ -18,16 +18,24 @@ class PeerHandle(ABC):
     async def connect(self) -> None:
         pass
 
+    @abstractmethod
+    async def is_connected(self) -> bool:
+        pass
+
     @abstractmethod
     async def disconnect(self) -> None:
         pass
 
     @abstractmethod
-    async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
+    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
+        pass
+
+    @abstractmethod
+    async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
         pass
 
     @abstractmethod
-    async def send_tensor(self, shard: Shard, tensor: np.array) -> Optional[np.array]:
+    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
         pass
 
     @abstractmethod

+ 6 - 1
orchestration/node.py

@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Tuple
 import numpy as np
 from abc import ABC, abstractmethod
 from inference.shard import Shard
@@ -25,5 +25,10 @@ class Node(ABC):
     async def reset_shard(self, shard: Shard) -> None:
         pass
 
+    @abstractmethod
     async def collect_topology(self, max_depth: int = 2) -> Topology:
         pass
+
+    @abstractmethod
+    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+        pass

+ 82 - 33
orchestration/standard_node.py

@@ -1,4 +1,4 @@
-from typing import List, Optional, Callable
+from typing import List, Dict, Optional, Callable, Tuple
 import numpy as np
 from networking import Discovery, PeerHandle, Server
 from inference.inference_engine import InferenceEngine, Shard
@@ -7,6 +7,8 @@ from topology.topology import Topology
 from topology.device_capabilities import device_capabilities
 from topology.partitioning_strategy import PartitioningStrategy
 from topology.partitioning_strategy import Partition
+import asyncio
+import uuid
 
 class StandardNode(Node):
     def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
@@ -18,54 +20,70 @@ class StandardNode(Node):
         self.peers: List[PeerHandle] = {}
         self.topology: Topology = Topology()
         self.device_capabilities = device_capabilities()
-        self.buffered_token_output: List[int] = []
+        self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
         self.on_token = on_token
         self.max_generate_tokens = max_generate_tokens
 
     async def start(self, wait_for_peers: int = 0) -> None:
         await self.server.start()
         await self.discovery.start()
-        self.peers = await self.discovery.discover_peers(wait_for_peers)
-        print(f"Starting with the following peers: {self.peers}")
-        print("Connecting to peers...")
-        for peer in self.peers:
-            await peer.connect()
-            print(f"Connected to {peer.id()}")
+        await self.update_peers(wait_for_peers)
         await self.collect_topology()
         print(f"Collected topology: {self.topology}")
+        asyncio.create_task(self.periodic_topology_collection(5))
 
     async def stop(self) -> None:
         await self.discovery.stop()
         await self.server.stop()
 
-    async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
-        print("process prompt", shard, prompt)
+    async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+        if request_id is None:
+            request_id = str(uuid.uuid4())
+        if request_id not in self.buffered_token_output:
+            self.buffered_token_output[request_id] = ([], False)
+
+        print(f"[{request_id}] process prompt: {shard}, {prompt}")
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
+        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], is_finished)
 
-        print(f"result size: {result.size}, is finished: {is_finished}")
         if result.size == 1:
-            self.buffered_token_output.append(result.item())
-            self.on_token(self.buffered_token_output)
+            self.buffered_token_output[request_id][0].append(result.item())
+            self.on_token(self.buffered_token_output[request_id][0])
 
-        if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
-            await self.forward_tensor_to_next_shard(shard, result)
+        print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
-        return np.array(self.buffered_token_output) if self.buffered_token_output else None
+        if not is_finished and len(self.buffered_token_output[request_id]) < self.max_generate_tokens:
+            asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
 
-    async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
-        result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
+        return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
 
-        print(f"result size: {result.size}, is finished: {is_finished}")
-        if result.size == 1:
-            self.buffered_token_output.append(result.item())
-            self.on_token(self.buffered_token_output)
+    async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+        if request_id is None:
+            request_id = str(uuid.uuid4())
+        if request_id not in self.buffered_token_output:
+            self.buffered_token_output[request_id] = ([], False)
+
+        try:
+            print(f"[{request_id}] process_tensor: {shard}, {tensor}")
+            result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
+            self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], is_finished)
 
-        if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
-            await self.forward_tensor_to_next_shard(shard, result)
+            if result.size == 1:  # we got a new token out
+                self.buffered_token_output[request_id][0].append(result.item())
+                self.on_token(self.buffered_token_output[request_id][0])
+            print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
-        return np.array(self.buffered_token_output) if self.buffered_token_output else None
+            if not is_finished and len(self.buffered_token_output[request_id]) < self.max_generate_tokens:
+                asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
 
-    async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray) -> None:
+            return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+        except Exception as e:
+            import traceback
+            print(f"Error processing tensor for shard {shard}: {e}")
+            traceback.print_exc()
+            return None
+
+    async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray, request_id: str) -> None:
         if not self.partitioning_strategy:
             print("No partitioning strategy found. Skipping forward.")
             return
@@ -80,7 +98,7 @@ class StandardNode(Node):
 
             if next_partition:
                 if next_partition.node_id == self.id:
-                    await self.process_tensor(shard, tensor)
+                    await self.process_tensor(shard, tensor, request_id)
                     return
 
                 target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
@@ -91,9 +109,9 @@ class StandardNode(Node):
                 end_layer = int(next_partition.end * shard.n_layers) - 1
                 next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
 
-                print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}")
+                print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}: {tensor}")
 
-                await target_peer.send_tensor(next_shard, tensor)
+                await target_peer.send_tensor(next_shard, tensor, request_id)
 
     def get_current_shard(self, shard: Shard) -> Shard:
         partitions = self.partitioning_strategy.partition(self.topology)
@@ -110,9 +128,20 @@ class StandardNode(Node):
     async def reset_shard(self, shard: Shard) -> None:
         # Implement shard reset logic
         print(f"Resetting shard: {shard}")
-        self.buffered_token_output = []
+        self.buffered_token_output = {}
         await self.inference_engine.reset_shard(self.get_current_shard(shard))
 
+    async def update_peers(self, wait_for_peers: int = 0) -> None:
+        self.peers = await self.discovery.discover_peers(wait_for_peers)
+        print(f"Starting with the following peers: {self.peers}")
+        print("Connecting to new peers...")
+        for peer in self.peers:
+            is_connected = await peer.is_connected()
+            print(f"Connected to {peer.id()}: {is_connected}")
+            if not is_connected:
+                await peer.connect()
+                print(f"Connected to peer {peer.id()}")
+
     async def collect_topology(self, max_depth: int = 4) -> Topology:
         self.topology.update_node(self.id, self.device_capabilities)
 
@@ -121,8 +150,28 @@ class StandardNode(Node):
             self.topology.add_edge(self.id, peer.id())
 
             if max_depth > 0:
-                other_topology = await peer.collect_topology(max_depth = max_depth - 1)
-                print(f"Collected topology from: {peer.id()}: {other_topology}")
-                self.topology.merge(other_topology)
+                try:
+                    other_topology = await peer.collect_topology(max_depth = max_depth - 1)
+                    print(f"Collected topology from: {peer.id()}: {other_topology}")
+                    self.topology.merge(other_topology)
+                except Exception as e:
+                    print(f"Error collecting topology from {peer.id()}: {e}")
 
         return self.topology
+
+    async def periodic_topology_collection(self, interval: int):
+        while True:
+            await asyncio.sleep(interval)
+            try:
+                await self.update_peers()
+                await self.collect_topology()
+            except Exception as e:
+                print(f"Error collecting topology: {e}")
+
+            print("Topology collection task executed.")
+            print(f"Current topology: {self.topology}")
+
+    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+        if request_id not in self.buffered_token_output:
+            return None, False
+        return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]

Some files were not shown because too many files changed in this diff