Browse Source

(partially) restore exo node equality by forwarding prompts to the dynamically selected head

Alex Cheema 9 months ago
parent
commit
1d5c28aed4
2 changed files with 29 additions and 20 deletions
  1. 10 12
      examples/llama3_distributed.py
  2. 19 8
      exo/orchestration/standard_node.py

+ 10 - 12
examples/llama3_distributed.py

@@ -22,14 +22,15 @@ model_path = get_model_path(path_or_hf_repo)
 tokenizer_config = {}
 tokenizer_config = {}
 tokenizer = load_tokenizer(model_path, tokenizer_config)
 tokenizer = load_tokenizer(model_path, tokenizer_config)
 
 
-peer1 = GRPCPeerHandle(
-    "node1",
-    "localhost:8080",
-    DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
-)
+# we intentionally leave out peer1 to demonstrate equality of nodes in exo.
+# there is no "master" node in exo, all nodes are equal and can take on any role.
+# peer1 = GRPCPeerHandle(
+#     "node1",
+#     "localhost:8080",
+#     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
+# )
 peer2 = GRPCPeerHandle(
 peer2 = GRPCPeerHandle(
     "node2",
     "node2",
-    # "10.0.0.161:8080",
     "localhost:8081",
     "localhost:8081",
     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
 )
 )
@@ -48,17 +49,14 @@ async def run_prompt(prompt: str):
             messages, tokenize=False, add_generation_prompt=True
             messages, tokenize=False, add_generation_prompt=True
         )
         )
 
 
-    for peer in [peer1, peer2]:
-        await peer.connect()
-
-    await peer.global_reset(shard, set(), 2)
+    await peer2.connect()
+    await peer2.global_reset(shard, set(), 2)
 
 
     try:
     try:
-        await peer1.send_prompt(shard, prompt, request_id)
+        await peer2.send_prompt(shard, prompt, request_id)
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)
 
 
-    import sys
     import time
     import time
     # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
     # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
     previous_length = 0
     previous_length = 0

+ 19 - 8
exo/orchestration/standard_node.py

@@ -1,4 +1,4 @@
-from typing import List, Dict, Optional, Callable, Tuple
+from typing import List, Dict, Optional, Callable, Tuple, Union
 import numpy as np
 import numpy as np
 from exo.networking import Discovery, PeerHandle, Server
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
 from exo.inference.inference_engine import InferenceEngine, Shard
@@ -43,7 +43,12 @@ class StandardNode(Node):
         if request_id not in self.buffered_token_output:
         if request_id not in self.buffered_token_output:
             self.buffered_token_output[request_id] = ([], False)
             self.buffered_token_output[request_id] = ([], False)
 
 
-        if DEBUG >= 2: print(f"[{request_id}] process prompt: {shard}, {prompt}")
+        if DEBUG >= 2: print(f"[{request_id}] process prompt: {shard=} {prompt=}")
+        if self.get_current_shard(shard).start_layer != 0:
+            if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {shard=} {prompt=}")
+            await self.forward_to_next_shard(shard, prompt, request_id)
+            return
+
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
         is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
         is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
         if is_finished:
         if is_finished:
@@ -56,7 +61,7 @@ class StandardNode(Node):
         if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
         if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
 
         if not is_finished:
         if not is_finished:
-            asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
+            asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
 
 
         return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
         return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
 
 
@@ -79,7 +84,7 @@ class StandardNode(Node):
             if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
             if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
 
             if not is_finished:
             if not is_finished:
-                asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
+                asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
 
 
             return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
             return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
         except Exception as e:
         except Exception as e:
@@ -88,7 +93,7 @@ class StandardNode(Node):
             traceback.print_exc()
             traceback.print_exc()
             return None
             return None
 
 
-    async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray, request_id: str) -> None:
+    async def forward_to_next_shard(self, shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str) -> None:
         if not self.partitioning_strategy:
         if not self.partitioning_strategy:
             if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
             if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
             return
             return
@@ -103,7 +108,10 @@ class StandardNode(Node):
 
 
             if next_partition:
             if next_partition:
                 if next_partition.node_id == self.id:
                 if next_partition.node_id == self.id:
-                    await self.process_tensor(shard, tensor, request_id)
+                    if isinstance(tensor_or_prompt, np.ndarray):
+                        await self.process_tensor(shard, tensor_or_prompt, request_id)
+                    else:
+                        await self.process_prompt(shard, tensor_or_prompt, request_id)
                     return
                     return
 
 
                 target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
                 target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
@@ -114,9 +122,12 @@ class StandardNode(Node):
                 end_layer = int(next_partition.end * shard.n_layers) - 1
                 end_layer = int(next_partition.end * shard.n_layers) - 1
                 next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
                 next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
 
 
-                if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor.size=} {tensor.shape=}")
+                if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
 
 
-                await target_peer.send_tensor(next_shard, tensor, request_id)
+                if isinstance(tensor_or_prompt, np.ndarray):
+                    await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id)
+                else:
+                    await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
 
 
     def get_current_shard(self, shard: Shard) -> Shard:
     def get_current_shard(self, shard: Shard) -> Shard:
         partitions = self.partitioning_strategy.partition(self.topology)
         partitions = self.partitioning_strategy.partition(self.topology)