|
@@ -1,4 +1,4 @@
|
|
|
-from typing import List, Dict, Optional, Callable, Tuple
|
|
|
+from typing import List, Dict, Optional, Callable, Tuple, Union
|
|
|
import numpy as np
|
|
|
from exo.networking import Discovery, PeerHandle, Server
|
|
|
from exo.inference.inference_engine import InferenceEngine, Shard
|
|
@@ -43,7 +43,12 @@ class StandardNode(Node):
|
|
|
if request_id not in self.buffered_token_output:
|
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
|
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] process prompt: {shard}, {prompt}")
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] process prompt: {shard=} {prompt=}")
|
|
|
+ if self.get_current_shard(shard).start_layer != 0:
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {shard=} {prompt=}")
|
|
|
+ await self.forward_to_next_shard(shard, prompt, request_id)
|
|
|
+ return
|
|
|
+
|
|
|
result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
|
|
|
is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
|
|
|
if is_finished:
|
|
@@ -56,7 +61,7 @@ class StandardNode(Node):
|
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
|
|
|
|
|
|
if not is_finished:
|
|
|
- asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
|
|
|
+ asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
|
|
|
|
|
|
return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
|
|
|
|
|
@@ -79,7 +84,7 @@ class StandardNode(Node):
|
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
|
|
|
|
|
|
if not is_finished:
|
|
|
- asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
|
|
|
+ asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
|
|
|
|
|
|
return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
|
|
|
except Exception as e:
|
|
@@ -88,7 +93,7 @@ class StandardNode(Node):
|
|
|
traceback.print_exc()
|
|
|
return None
|
|
|
|
|
|
- async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray, request_id: str) -> None:
|
|
|
+ async def forward_to_next_shard(self, shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str) -> None:
|
|
|
if not self.partitioning_strategy:
|
|
|
if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
|
|
|
return
|
|
@@ -103,7 +108,10 @@ class StandardNode(Node):
|
|
|
|
|
|
if next_partition:
|
|
|
if next_partition.node_id == self.id:
|
|
|
- await self.process_tensor(shard, tensor, request_id)
|
|
|
+ if isinstance(tensor_or_prompt, np.ndarray):
|
|
|
+ await self.process_tensor(shard, tensor_or_prompt, request_id)
|
|
|
+ else:
|
|
|
+ await self.process_prompt(shard, tensor_or_prompt, request_id)
|
|
|
return
|
|
|
|
|
|
target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
|
@@ -114,9 +122,12 @@ class StandardNode(Node):
|
|
|
end_layer = int(next_partition.end * shard.n_layers) - 1
|
|
|
next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
|
|
|
|
|
|
- if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor.size=} {tensor.shape=}")
|
|
|
+ if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
|
|
|
|
|
|
- await target_peer.send_tensor(next_shard, tensor, request_id)
|
|
|
+ if isinstance(tensor_or_prompt, np.ndarray):
|
|
|
+ await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id)
|
|
|
+ else:
|
|
|
+ await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
|
|
|
|
|
|
def get_current_shard(self, shard: Shard) -> Shard:
|
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|