Explorar o código

(partially) restore exo node equality by forwarding prompts to the dynamically selected head

Alex Cheema hai 9 meses
pai
achega
1d5c28aed4
Modificáronse 2 ficheiros con 29 adicións e 20 borrados
  1. 10 12
      examples/llama3_distributed.py
  2. 19 8
      exo/orchestration/standard_node.py

+ 10 - 12
examples/llama3_distributed.py

@@ -22,14 +22,15 @@ model_path = get_model_path(path_or_hf_repo)
 tokenizer_config = {}
 tokenizer = load_tokenizer(model_path, tokenizer_config)
 
-peer1 = GRPCPeerHandle(
-    "node1",
-    "localhost:8080",
-    DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
-)
+# we intentionally leave out peer1 to demonstrate equality of nodes in exo.
+# there is no "master" node in exo, all nodes are equal and can take on any role.
+# peer1 = GRPCPeerHandle(
+#     "node1",
+#     "localhost:8080",
+#     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
+# )
 peer2 = GRPCPeerHandle(
     "node2",
-    # "10.0.0.161:8080",
     "localhost:8081",
     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
 )
@@ -48,17 +49,14 @@ async def run_prompt(prompt: str):
             messages, tokenize=False, add_generation_prompt=True
         )
 
-    for peer in [peer1, peer2]:
-        await peer.connect()
-
-    await peer.global_reset(shard, set(), 2)
+    await peer2.connect()
+    await peer2.global_reset(shard, set(), 2)
 
     try:
-        await peer1.send_prompt(shard, prompt, request_id)
+        await peer2.send_prompt(shard, prompt, request_id)
     except Exception as e:
         print(e)
 
-    import sys
     import time
     # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
     previous_length = 0

+ 19 - 8
exo/orchestration/standard_node.py

@@ -1,4 +1,4 @@
-from typing import List, Dict, Optional, Callable, Tuple
+from typing import List, Dict, Optional, Callable, Tuple, Union
 import numpy as np
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
@@ -43,7 +43,12 @@ class StandardNode(Node):
         if request_id not in self.buffered_token_output:
             self.buffered_token_output[request_id] = ([], False)
 
-        if DEBUG >= 2: print(f"[{request_id}] process prompt: {shard}, {prompt}")
+        if DEBUG >= 2: print(f"[{request_id}] process prompt: {shard=} {prompt=}")
+        if self.get_current_shard(shard).start_layer != 0:
+            if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {shard=} {prompt=}")
+            await self.forward_to_next_shard(shard, prompt, request_id)
+            return
+
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
         is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
         if is_finished:
@@ -56,7 +61,7 @@ class StandardNode(Node):
         if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
         if not is_finished:
-            asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
+            asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
 
         return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
 
@@ -79,7 +84,7 @@ class StandardNode(Node):
             if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
             if not is_finished:
-                asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
+                asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
 
             return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
         except Exception as e:
@@ -88,7 +93,7 @@ class StandardNode(Node):
             traceback.print_exc()
             return None
 
-    async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray, request_id: str) -> None:
+    async def forward_to_next_shard(self, shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str) -> None:
         if not self.partitioning_strategy:
             if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
             return
@@ -103,7 +108,10 @@ class StandardNode(Node):
 
             if next_partition:
                 if next_partition.node_id == self.id:
-                    await self.process_tensor(shard, tensor, request_id)
+                    if isinstance(tensor_or_prompt, np.ndarray):
+                        await self.process_tensor(shard, tensor_or_prompt, request_id)
+                    else:
+                        await self.process_prompt(shard, tensor_or_prompt, request_id)
                     return
 
                 target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
@@ -114,9 +122,12 @@ class StandardNode(Node):
                 end_layer = int(next_partition.end * shard.n_layers) - 1
                 next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
 
-                if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor.size=} {tensor.shape=}")
+                if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
 
-                await target_peer.send_tensor(next_shard, tensor, request_id)
+                if isinstance(tensor_or_prompt, np.ndarray):
+                    await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id)
+                else:
+                    await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
 
     def get_current_shard(self, shard: Shard) -> Shard:
         partitions = self.partitioning_strategy.partition(self.topology)