|
@@ -1,4 +1,4 @@
|
|
-from typing import List, Optional, Callable
|
|
|
|
|
|
+from typing import List, Dict, Optional, Callable, Tuple
|
|
import numpy as np
|
|
import numpy as np
|
|
from networking import Discovery, PeerHandle, Server
|
|
from networking import Discovery, PeerHandle, Server
|
|
from inference.inference_engine import InferenceEngine, Shard
|
|
from inference.inference_engine import InferenceEngine, Shard
|
|
@@ -7,6 +7,8 @@ from topology.topology import Topology
|
|
from topology.device_capabilities import device_capabilities
|
|
from topology.device_capabilities import device_capabilities
|
|
from topology.partitioning_strategy import PartitioningStrategy
|
|
from topology.partitioning_strategy import PartitioningStrategy
|
|
from topology.partitioning_strategy import Partition
|
|
from topology.partitioning_strategy import Partition
|
|
|
|
+import asyncio
|
|
|
|
+import uuid
|
|
|
|
|
|
class StandardNode(Node):
|
|
class StandardNode(Node):
|
|
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
|
|
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
|
|
@@ -18,54 +20,70 @@ class StandardNode(Node):
|
|
self.peers: List[PeerHandle] = {}
|
|
self.peers: List[PeerHandle] = {}
|
|
self.topology: Topology = Topology()
|
|
self.topology: Topology = Topology()
|
|
self.device_capabilities = device_capabilities()
|
|
self.device_capabilities = device_capabilities()
|
|
- self.buffered_token_output: List[int] = []
|
|
|
|
|
|
+ self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
|
|
self.on_token = on_token
|
|
self.on_token = on_token
|
|
self.max_generate_tokens = max_generate_tokens
|
|
self.max_generate_tokens = max_generate_tokens
|
|
|
|
|
|
async def start(self, wait_for_peers: int = 0) -> None:
|
|
async def start(self, wait_for_peers: int = 0) -> None:
|
|
await self.server.start()
|
|
await self.server.start()
|
|
await self.discovery.start()
|
|
await self.discovery.start()
|
|
- self.peers = await self.discovery.discover_peers(wait_for_peers)
|
|
|
|
- print(f"Starting with the following peers: {self.peers}")
|
|
|
|
- print("Connecting to peers...")
|
|
|
|
- for peer in self.peers:
|
|
|
|
- await peer.connect()
|
|
|
|
- print(f"Connected to {peer.id()}")
|
|
|
|
|
|
+ await self.update_peers(wait_for_peers)
|
|
await self.collect_topology()
|
|
await self.collect_topology()
|
|
print(f"Collected topology: {self.topology}")
|
|
print(f"Collected topology: {self.topology}")
|
|
|
|
+ asyncio.create_task(self.periodic_topology_collection(5))
|
|
|
|
|
|
async def stop(self) -> None:
|
|
async def stop(self) -> None:
|
|
await self.discovery.stop()
|
|
await self.discovery.stop()
|
|
await self.server.stop()
|
|
await self.server.stop()
|
|
|
|
|
|
- async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
|
|
|
|
- print("process prompt", shard, prompt)
|
|
|
|
|
|
+ async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
|
+ if request_id is None:
|
|
|
|
+ request_id = str(uuid.uuid4())
|
|
|
|
+ if request_id not in self.buffered_token_output:
|
|
|
|
+ self.buffered_token_output[request_id] = ([], False)
|
|
|
|
+
|
|
|
|
+ print(f"[{request_id}] process prompt: {shard}, {prompt}")
|
|
result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
|
|
result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
|
|
|
|
+ self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], is_finished)
|
|
|
|
|
|
- print(f"result size: {result.size}, is finished: {is_finished}")
|
|
|
|
if result.size == 1:
|
|
if result.size == 1:
|
|
- self.buffered_token_output.append(result.item())
|
|
|
|
- self.on_token(self.buffered_token_output)
|
|
|
|
|
|
+ self.buffered_token_output[request_id][0].append(result.item())
|
|
|
|
+ self.on_token(self.buffered_token_output[request_id][0])
|
|
|
|
|
|
- if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
|
|
|
|
- await self.forward_tensor_to_next_shard(shard, result)
|
|
|
|
|
|
+ print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
|
|
|
|
|
|
- return np.array(self.buffered_token_output) if self.buffered_token_output else None
|
|
|
|
|
|
+ if not is_finished and len(self.buffered_token_output[request_id]) < self.max_generate_tokens:
|
|
|
|
+ asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
|
|
|
|
|
|
- async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
|
|
|
|
- result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
|
|
|
|
|
|
+ return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
|
|
|
|
|
|
- print(f"result size: {result.size}, is finished: {is_finished}")
|
|
|
|
- if result.size == 1:
|
|
|
|
- self.buffered_token_output.append(result.item())
|
|
|
|
- self.on_token(self.buffered_token_output)
|
|
|
|
|
|
+ async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
|
+ if request_id is None:
|
|
|
|
+ request_id = str(uuid.uuid4())
|
|
|
|
+ if request_id not in self.buffered_token_output:
|
|
|
|
+ self.buffered_token_output[request_id] = ([], False)
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ print(f"[{request_id}] process_tensor: {shard}, {tensor}")
|
|
|
|
+ result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
|
|
|
|
+ self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], is_finished)
|
|
|
|
|
|
- if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
|
|
|
|
- await self.forward_tensor_to_next_shard(shard, result)
|
|
|
|
|
|
+ if result.size == 1: # we got a new token out
|
|
|
|
+ self.buffered_token_output[request_id][0].append(result.item())
|
|
|
|
+ self.on_token(self.buffered_token_output[request_id][0])
|
|
|
|
+ print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
|
|
|
|
|
|
- return np.array(self.buffered_token_output) if self.buffered_token_output else None
|
|
|
|
|
|
+ if not is_finished and len(self.buffered_token_output[request_id]) < self.max_generate_tokens:
|
|
|
|
+ asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
|
|
|
|
|
|
- async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray) -> None:
|
|
|
|
|
|
+ return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
|
|
|
|
+ except Exception as e:
|
|
|
|
+ import traceback
|
|
|
|
+ print(f"Error processing tensor for shard {shard}: {e}")
|
|
|
|
+ traceback.print_exc()
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray, request_id: str) -> None:
|
|
if not self.partitioning_strategy:
|
|
if not self.partitioning_strategy:
|
|
print("No partitioning strategy found. Skipping forward.")
|
|
print("No partitioning strategy found. Skipping forward.")
|
|
return
|
|
return
|
|
@@ -80,7 +98,7 @@ class StandardNode(Node):
|
|
|
|
|
|
if next_partition:
|
|
if next_partition:
|
|
if next_partition.node_id == self.id:
|
|
if next_partition.node_id == self.id:
|
|
- await self.process_tensor(shard, tensor)
|
|
|
|
|
|
+ await self.process_tensor(shard, tensor, request_id)
|
|
return
|
|
return
|
|
|
|
|
|
target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
|
target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
|
@@ -91,9 +109,9 @@ class StandardNode(Node):
|
|
end_layer = int(next_partition.end * shard.n_layers) - 1
|
|
end_layer = int(next_partition.end * shard.n_layers) - 1
|
|
next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
|
|
next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
|
|
|
|
|
|
- print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}")
|
|
|
|
|
|
+ print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}: {tensor}")
|
|
|
|
|
|
- await target_peer.send_tensor(next_shard, tensor)
|
|
|
|
|
|
+ await target_peer.send_tensor(next_shard, tensor, request_id)
|
|
|
|
|
|
def get_current_shard(self, shard: Shard) -> Shard:
|
|
def get_current_shard(self, shard: Shard) -> Shard:
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|
|
@@ -110,9 +128,20 @@ class StandardNode(Node):
|
|
async def reset_shard(self, shard: Shard) -> None:
|
|
async def reset_shard(self, shard: Shard) -> None:
|
|
# Implement shard reset logic
|
|
# Implement shard reset logic
|
|
print(f"Resetting shard: {shard}")
|
|
print(f"Resetting shard: {shard}")
|
|
- self.buffered_token_output = []
|
|
|
|
|
|
+ self.buffered_token_output = {}
|
|
await self.inference_engine.reset_shard(self.get_current_shard(shard))
|
|
await self.inference_engine.reset_shard(self.get_current_shard(shard))
|
|
|
|
|
|
|
|
+ async def update_peers(self, wait_for_peers: int = 0) -> None:
|
|
|
|
+ self.peers = await self.discovery.discover_peers(wait_for_peers)
|
|
|
|
+ print(f"Starting with the following peers: {self.peers}")
|
|
|
|
+ print("Connecting to new peers...")
|
|
|
|
+ for peer in self.peers:
|
|
|
|
+ is_connected = await peer.is_connected()
|
|
|
|
+ print(f"Connected to {peer.id()}: {is_connected}")
|
|
|
|
+ if not is_connected:
|
|
|
|
+ await peer.connect()
|
|
|
|
+ print(f"Connected to peer {peer.id()}")
|
|
|
|
+
|
|
async def collect_topology(self, max_depth: int = 4) -> Topology:
|
|
async def collect_topology(self, max_depth: int = 4) -> Topology:
|
|
self.topology.update_node(self.id, self.device_capabilities)
|
|
self.topology.update_node(self.id, self.device_capabilities)
|
|
|
|
|
|
@@ -121,8 +150,28 @@ class StandardNode(Node):
|
|
self.topology.add_edge(self.id, peer.id())
|
|
self.topology.add_edge(self.id, peer.id())
|
|
|
|
|
|
if max_depth > 0:
|
|
if max_depth > 0:
|
|
- other_topology = await peer.collect_topology(max_depth = max_depth - 1)
|
|
|
|
- print(f"Collected topology from: {peer.id()}: {other_topology}")
|
|
|
|
- self.topology.merge(other_topology)
|
|
|
|
|
|
+ try:
|
|
|
|
+ other_topology = await peer.collect_topology(max_depth = max_depth - 1)
|
|
|
|
+ print(f"Collected topology from: {peer.id()}: {other_topology}")
|
|
|
|
+ self.topology.merge(other_topology)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error collecting topology from {peer.id()}: {e}")
|
|
|
|
|
|
return self.topology
|
|
return self.topology
|
|
|
|
+
|
|
|
|
+ async def periodic_topology_collection(self, interval: int):
|
|
|
|
+ while True:
|
|
|
|
+ await asyncio.sleep(interval)
|
|
|
|
+ try:
|
|
|
|
+ await self.update_peers()
|
|
|
|
+ await self.collect_topology()
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error collecting topology: {e}")
|
|
|
|
+
|
|
|
|
+ print("Topology collection task executed.")
|
|
|
|
+ print(f"Current topology: {self.topology}")
|
|
|
|
+
|
|
|
|
+ async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
|
|
|
|
+ if request_id not in self.buffered_token_output:
|
|
|
|
+ return None, False
|
|
|
|
+ return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
|