Bladeren bron

add global reset

Alex Cheema 10 maanden geleden
bovenliggende
commit
e17905e295

+ 7 - 5
examples/llama3_distributed.py

@@ -17,19 +17,20 @@ models = {
     "mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
 }
 
-path_or_hf_repo = "mlx-community/Meta-Llama-3-70B-Instruct-4bit"
+path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
 model_path = get_model_path(path_or_hf_repo)
 tokenizer_config = {}
 tokenizer = load_tokenizer(model_path, tokenizer_config)
 
-peer2 = GRPCPeerHandle(
+peer1 = GRPCPeerHandle(
     "node1",
     "localhost:8080",
     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
 )
-peer1 = GRPCPeerHandle(
+peer2 = GRPCPeerHandle(
     "node2",
-    "10.0.0.161:8080",
+    # "10.0.0.161:8080",
+    "localhost:8081",
     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
 )
 shard = models[path_or_hf_repo]
@@ -49,7 +50,8 @@ async def run_prompt(prompt: str):
 
     for peer in [peer1, peer2]:
         await peer.connect()
-        await peer.reset_shard(shard)
+
+    await peer.global_reset(shard, set(), 2)
 
     try:
         await peer1.send_prompt(shard, prompt, request_id)

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -159,7 +159,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
         toks = [self.tokenizer.bos_id] + encode_message("user", prompt) + encode_role("assistant")
         start_pos = prefill(self.model, toks[:-1])
         last_tok = toks[-1]
-        
+
         output_data = np.array(self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
         start_pos += 1
 

+ 2 - 0
exo/networking/grpc/grpc_discovery.py

@@ -106,6 +106,8 @@ class GRPCDiscovery(Discovery):
                     self.peer_last_seen[peer_id] = time.time()
             except Exception as e:
                 print(f"Error in peer discovery: {e}")
+                import traceback
+                print(traceback.format_exc())
                 await asyncio.sleep(self.broadcast_interval / 2)
 
     async def _cleanup_peers(self):

+ 4 - 0
exo/networking/grpc/grpc_peer_handle.py

@@ -88,3 +88,7 @@ class GRPCPeerHandle(PeerHandle):
             for peer_id in peers.peer_ids:
                 topology.add_edge(node_id, peer_id)
         return topology
+
+    async def global_reset(self, base_shard: Shard, visited: set[str], max_depth: int) -> None:
+        request = node_service_pb2.GlobalResetRequest(base_shard=node_service_pb2.Shard(model_id=base_shard.model_id, start_layer=base_shard.start_layer, end_layer=base_shard.end_layer, n_layers=base_shard.n_layers), visited=visited, max_depth=max_depth)
+        await self.stub.GlobalReset(request)

+ 7 - 0
exo/networking/grpc/grpc_server.py

@@ -70,3 +70,10 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory) for node_id, cap in topology.nodes.items()}
         peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
         return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
+
+    async def GlobalReset(self, request, context):
+        base_shard = Shard(model_id=request.base_shard.model_id, start_layer=request.base_shard.start_layer, end_layer=request.base_shard.end_layer, n_layers=request.base_shard.n_layers)
+        visited = set(request.visited)
+        max_depth = request.max_depth
+        await self.node.global_reset(base_shard, visited, max_depth)
+        return node_service_pb2.Empty()

+ 7 - 0
exo/networking/grpc/node_service.proto

@@ -8,6 +8,7 @@ service NodeService {
   rpc ResetShard (ResetShardRequest) returns (Empty) {}
   rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
+  rpc GlobalReset (GlobalResetRequest) returns (Empty) {}
 }
 
 message Shard {
@@ -53,6 +54,12 @@ message CollectTopologyRequest {
   int32 max_depth = 2;
 }
 
+message GlobalResetRequest {
+  Shard base_shard = 1;
+  repeated string visited = 2;
+  int32 max_depth = 3;
+}
+
 message Topology {
   map<string, DeviceCapabilities> nodes = 1;
   map<string, Peers> peer_graph = 2;

File diff suppressed because it is too large
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 43 - 0
exo/networking/grpc/node_service_pb2_grpc.py

@@ -64,6 +64,11 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Topology.FromString,
                 _registered_method=True)
+        self.GlobalReset = channel.unary_unary(
+                '/node_service.NodeService/GlobalReset',
+                request_serializer=node__service__pb2.GlobalResetRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
 
 
 class NodeServiceServicer(object):
@@ -99,6 +104,12 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
+    def GlobalReset(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
 
 def add_NodeServiceServicer_to_server(servicer, server):
     rpc_method_handlers = {
@@ -127,6 +138,11 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
                     response_serializer=node__service__pb2.Topology.SerializeToString,
             ),
+            'GlobalReset': grpc.unary_unary_rpc_method_handler(
+                    servicer.GlobalReset,
+                    request_deserializer=node__service__pb2.GlobalResetRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
     }
     generic_handler = grpc.method_handlers_generic_handler(
             'node_service.NodeService', rpc_method_handlers)
@@ -272,3 +288,30 @@ class NodeService(object):
             timeout,
             metadata,
             _registered_method=True)
+
+    @staticmethod
+    def GlobalReset(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/GlobalReset',
+            node__service__pb2.GlobalResetRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)

+ 5 - 0
exo/networking/peer_handle.py

@@ -42,5 +42,10 @@ class PeerHandle(ABC):
     async def reset_shard(self, shard: Shard) -> None:
         pass
 
+    @abstractmethod
     async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
         pass
+
+    @abstractmethod
+    async def global_reset(self, base_shard: Shard, visited: set[str], max_depth: int) -> None:
+        pass

+ 3 - 3
exo/orchestration/node.py

@@ -26,13 +26,13 @@ class Node(ABC):
         pass
 
     @abstractmethod
-    async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
+    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
         pass
 
     @abstractmethod
-    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
         pass
 
     @abstractmethod
-    async def global_reset(self, visited: set[str], max_depth: int = 2) -> None:
+    async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
         pass

+ 42 - 18
exo/orchestration/standard_node.py

@@ -147,18 +147,38 @@ class StandardNode(Node):
                 await peer.connect()
                 if DEBUG >= 2: print(f"Connected to peer {peer.id()}")
 
+    async def periodic_topology_collection(self, interval: int):
+        while True:
+            await asyncio.sleep(interval)
+            try:
+                await self.update_peers()
+                await self.collect_topology()
+            except Exception as e:
+                print(f"Error collecting topology: {e}")
+
+            if DEBUG >= 2: print("Topology collection task executed.")
+            if DEBUG >= 2: print(f"Current topology: {self.topology}")
+
+    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+        if request_id not in self.buffered_token_output:
+            return None, False
+        return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
+
     async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
         self.topology.update_node(self.id, self.device_capabilities)
 
         if DEBUG >= 2: print(f"Collecting topoloy {max_depth=} {visited=}")
+
+        prev_visited = visited.copy()
+        visited.update(p.id() for p in self.peers)
+
         for peer in self.peers:
             self.topology.update_node(peer.id(), peer.device_capabilities())
             self.topology.add_edge(self.id, peer.id())
 
-            if peer.id() in visited:
+            if peer.id() in prev_visited:
                 if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
                 continue
-            visited.add(peer.id())
 
             if max_depth <= 0:
                 if DEBUG >= 2: print(f"Max depth reached. Skipping...")
@@ -173,22 +193,26 @@ class StandardNode(Node):
 
         return self.topology
 
-    async def periodic_topology_collection(self, interval: int):
-        while True:
-            await asyncio.sleep(interval)
-            try:
-                await self.update_peers()
-                await self.collect_topology()
-            except Exception as e:
-                print(f"Error collecting topology: {e}")
+    # TODO: unify this and collect_topology as global actions
+    async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
+        await self.reset_shard(self.get_current_shard(base_shard))
 
-            if DEBUG >= 2: print("Topology collection task executed.")
-            if DEBUG >= 2: print(f"Current topology: {self.topology}")
+        if DEBUG >= 2: print(f"Global reset {base_shard=} {max_depth=} {visited=}")
 
-    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-        if request_id not in self.buffered_token_output:
-            return None, False
-        return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
+        prev_visited = visited.copy()
+        visited.update(p.id() for p in self.peers)
 
-    async def global_reset(self, max_depth: int = 2) -> None:
-        pass
+        for peer in self.peers:
+            if peer.id() in prev_visited:
+                if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
+                continue
+
+            if max_depth <= 0:
+                if DEBUG >= 2: print(f"Max depth reached. Skipping...")
+                continue
+
+            try:
+                print(f"Forwarding global reset to peer {peer.id()}")
+                await peer.global_reset(base_shard, visited, max_depth = max_depth - 1)
+            except Exception as e:
+                print(f"Error collecting topology from {peer.id()}: {e}")

Some files were not shown because too many files changed in this diff