|
@@ -1,4 +1,4 @@
|
|
|
-from typing import List, Optional
|
|
|
+from typing import List, Optional, Callable
|
|
|
import numpy as np
|
|
|
from networking import Discovery, PeerHandle, Server
|
|
|
from inference.inference_engine import InferenceEngine, Shard
|
|
@@ -9,7 +9,7 @@ from topology.partitioning_strategy import PartitioningStrategy
|
|
|
from topology.partitioning_strategy import Partition
|
|
|
|
|
|
class StandardNode(Node):
|
|
|
- def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None):
|
|
|
+ def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
|
|
|
self.id = id
|
|
|
self.inference_engine = inference_engine
|
|
|
self.server = server
|
|
@@ -18,6 +18,9 @@ class StandardNode(Node):
|
|
|
self.peers: List[PeerHandle] = {}
|
|
|
self.topology: Topology = Topology()
|
|
|
self.device_capabilities = device_capabilities()
|
|
|
+ self.buffered_token_output: List[int] = []
|
|
|
+ self.on_token = on_token
|
|
|
+ self.max_generate_tokens = max_generate_tokens
|
|
|
|
|
|
async def start(self, wait_for_peers: int = 0) -> None:
|
|
|
await self.server.start()
|
|
@@ -35,23 +38,32 @@ class StandardNode(Node):
|
|
|
await self.discovery.stop()
|
|
|
await self.server.stop()
|
|
|
|
|
|
- async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
|
|
|
- print("Process prompt", shard, prompt)
|
|
|
- result = await self.inference_engine.infer_prompt(shard, prompt)
|
|
|
- print(f"Got result from prompt: {prompt}. Result: {result}")
|
|
|
+ async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
|
|
|
+ print("process prompt", shard, prompt)
|
|
|
+ result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
|
|
|
|
|
|
- await self.forward_tensor_to_next_shard(shard, result)
|
|
|
+ print(f"result size: {result.size}, is finished: {is_finished}")
|
|
|
+ if result.size == 1:
|
|
|
+ self.buffered_token_output.append(result.item())
|
|
|
+ self.on_token(self.buffered_token_output)
|
|
|
|
|
|
- return result
|
|
|
+ if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
|
|
|
+ await self.forward_tensor_to_next_shard(shard, result)
|
|
|
|
|
|
- async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> None:
|
|
|
- print("Process tensor", shard, tensor)
|
|
|
- result = await self.inference_engine.infer_tensor(shard, tensor)
|
|
|
- print(f"Got result from tensor: {len(tensor)}. Result: {result}")
|
|
|
+ return np.array(self.buffered_token_output) if self.buffered_token_output else None
|
|
|
|
|
|
- await self.forward_tensor_to_next_shard(shard, result)
|
|
|
+ async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
|
|
|
+ result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
|
|
|
|
|
|
- return result
|
|
|
+ print(f"result size: {result.size}, is finished: {is_finished}")
|
|
|
+ if result.size == 1:
|
|
|
+ self.buffered_token_output.append(result.item())
|
|
|
+ self.on_token(self.buffered_token_output)
|
|
|
+
|
|
|
+ if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
|
|
|
+ await self.forward_tensor_to_next_shard(shard, result)
|
|
|
+
|
|
|
+ return np.array(self.buffered_token_output) if self.buffered_token_output else None
|
|
|
|
|
|
async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray) -> None:
|
|
|
if not self.partitioning_strategy:
|
|
@@ -67,6 +79,10 @@ class StandardNode(Node):
|
|
|
print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
|
|
|
|
|
|
if next_partition:
|
|
|
+ if next_partition.node_id == self.id:
|
|
|
+ await self.process_tensor(shard, tensor)
|
|
|
+ return
|
|
|
+
|
|
|
target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
|
|
if not target_peer:
|
|
|
raise ValueError(f"Peer for {next_partition} not found")
|
|
@@ -79,10 +95,23 @@ class StandardNode(Node):
|
|
|
|
|
|
await target_peer.send_tensor(next_shard, tensor)
|
|
|
|
|
|
+ def get_current_shard(self, shard: Shard) -> Shard:
|
|
|
+ partitions = self.partitioning_strategy.partition(self.topology)
|
|
|
+ current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
|
|
|
+ if current_partition_index is None:
|
|
|
+ raise ValueError(f"No current partition found for node: {self.id}")
|
|
|
+
|
|
|
+ current_partition = partitions[current_partition_index]
|
|
|
+ start_layer = int(current_partition.start * shard.n_layers)
|
|
|
+ end_layer = int(current_partition.end * shard.n_layers) - 1
|
|
|
+ return Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
|
|
|
+
|
|
|
+
|
|
|
async def reset_shard(self, shard: Shard) -> None:
|
|
|
# Implement shard reset logic
|
|
|
print(f"Resetting shard: {shard}")
|
|
|
- await self.inference_engine.reset_shard(shard)
|
|
|
+ self.buffered_token_output = []
|
|
|
+ await self.inference_engine.reset_shard(self.get_current_shard(shard))
|
|
|
|
|
|
async def collect_topology(self, max_depth: int = 4) -> Topology:
|
|
|
self.topology.update_node(self.id, self.device_capabilities)
|