|
@@ -39,6 +39,8 @@ class StandardNode(Node):
|
|
|
self.topology: Topology = Topology()
|
|
|
self.device_capabilities = device_capabilities()
|
|
|
self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
|
|
|
+ self.buffered_logits: Dict[str, List[np.ndarray]] = {}
|
|
|
+ self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
|
|
|
self.max_generate_tokens = max_generate_tokens
|
|
|
self.topology_viz = topology_viz
|
|
|
self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
|
|
@@ -87,24 +89,53 @@ class StandardNode(Node):
|
|
|
def get_supported_inference_engines(self):
|
|
|
supported_engine_names = []
|
|
|
if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
|
|
|
- supported_engine_names.append('mlx')
|
|
|
- supported_engine_names.append('tinygrad')
|
|
|
+ supported_engine_names.append('mlx')
|
|
|
+ supported_engine_names.append('tinygrad')
|
|
|
else:
|
|
|
- supported_engine_names.append('tinygrad')
|
|
|
+ supported_engine_names.append('tinygrad')
|
|
|
return supported_engine_names
|
|
|
|
|
|
async def broadcast_supported_engines(self, supported_engines_names: List[str]):
|
|
|
- status_message = json.dumps({
|
|
|
- "type": "supported_inference_engines",
|
|
|
- "node_id": self.id,
|
|
|
- "engines": supported_engines_names
|
|
|
- })
|
|
|
+ status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
|
|
|
await self.broadcast_opaque_status("", status_message)
|
|
|
|
|
|
def get_topology_inference_engines(self) -> List[List[str]]:
|
|
|
return self.topology_inference_engines_pool
|
|
|
+
|
|
|
+ async def process_inference_result(
|
|
|
+ self,
|
|
|
+ shard,
|
|
|
+ result: np.ndarray,
|
|
|
+ request_id: Optional[str] = None,
|
|
|
+ ):
|
|
|
+ if request_id not in self.buffered_token_output:
|
|
|
+ self.buffered_token_output[request_id] = ([], False)
|
|
|
+ is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
+ if shard.is_last_layer() and not is_finished:
|
|
|
+ token = await self.inference_engine.sample(result)
|
|
|
+ await self.inference_engine.ensure_shard(shard)
|
|
|
+ self.buffered_token_output[request_id][0].append(token.item())
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
|
+ is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
|
|
|
+ forward = token.reshape(1, -1)
|
|
|
+ self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
|
+ asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
|
|
|
+ else:
|
|
|
+ forward = result
|
|
|
|
|
|
- async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
+ if is_finished:
|
|
|
+ self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
|
+ else:
|
|
|
+ asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
|
|
|
+
|
|
|
+ return np.array(self.buffered_token_output[request_id][0])
|
|
|
+
|
|
|
+ async def process_prompt(
|
|
|
+ self,
|
|
|
+ base_shard: Shard,
|
|
|
+ prompt: str,
|
|
|
+ request_id: Optional[str] = None,
|
|
|
+ ) -> Optional[np.ndarray]:
|
|
|
shard = self.get_current_shard(base_shard)
|
|
|
asyncio.create_task(
|
|
|
self.broadcast_opaque_status(
|
|
@@ -116,14 +147,12 @@ class StandardNode(Node):
|
|
|
"base_shard": base_shard.to_dict(),
|
|
|
"shard": shard.to_dict(),
|
|
|
"prompt": prompt,
|
|
|
- "image_str": image_str,
|
|
|
- "inference_state": inference_state,
|
|
|
"request_id": request_id,
|
|
|
}),
|
|
|
)
|
|
|
)
|
|
|
start_time = time.perf_counter_ns()
|
|
|
- resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
|
|
|
+ resp = await self._process_prompt(base_shard, prompt, request_id)
|
|
|
end_time = time.perf_counter_ns()
|
|
|
elapsed_time_ns = end_time - start_time
|
|
|
asyncio.create_task(
|
|
@@ -136,8 +165,6 @@ class StandardNode(Node):
|
|
|
"base_shard": base_shard.to_dict(),
|
|
|
"shard": shard.to_dict(),
|
|
|
"prompt": prompt,
|
|
|
- "image_str": image_str,
|
|
|
- "inference_state": inference_state,
|
|
|
"request_id": request_id,
|
|
|
"elapsed_time_ns": elapsed_time_ns,
|
|
|
"result_size": resp.size if resp is not None else 0,
|
|
@@ -146,42 +173,26 @@ class StandardNode(Node):
|
|
|
)
|
|
|
return resp
|
|
|
|
|
|
- async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
+ async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
if request_id is None:
|
|
|
request_id = str(uuid.uuid4())
|
|
|
- if request_id not in self.buffered_token_output:
|
|
|
- self.buffered_token_output[request_id] = ([], False)
|
|
|
shard = self.get_current_shard(base_shard)
|
|
|
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
|
|
|
- if shard.start_layer != 0:
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
|
|
|
- await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state)
|
|
|
- return
|
|
|
-
|
|
|
- result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
|
|
|
- is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
- if is_finished:
|
|
|
- self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
|
- asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity
|
|
|
-
|
|
|
- if result.size == 1:
|
|
|
- self.buffered_token_output[request_id][0].append(result.item())
|
|
|
- self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
|
-
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
|
-
|
|
|
- if not is_finished:
|
|
|
- asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state))
|
|
|
-
|
|
|
- return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
|
|
+ if not shard.is_first_layer():
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
|
|
+ resp = await self.forward_prompt(shard, prompt, request_id, 0)
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
|
|
|
+ ret = await self.process_inference_result(shard, result, request_id)
|
|
|
+ return result
|
|
|
|
|
|
async def process_tensor(
|
|
|
self,
|
|
|
base_shard: Shard,
|
|
|
tensor: np.ndarray,
|
|
|
request_id: Optional[str] = None,
|
|
|
- inference_state: Optional[str] = None,
|
|
|
) -> Optional[np.ndarray]:
|
|
|
shard = self.get_current_shard(base_shard)
|
|
|
asyncio.create_task(
|
|
@@ -196,12 +207,11 @@ class StandardNode(Node):
|
|
|
"tensor_size": tensor.size,
|
|
|
"tensor_shape": tensor.shape,
|
|
|
"request_id": request_id,
|
|
|
- "inference_state": inference_state,
|
|
|
}),
|
|
|
)
|
|
|
)
|
|
|
start_time = time.perf_counter_ns()
|
|
|
- resp = await self._process_tensor(shard, tensor, request_id, inference_state)
|
|
|
+ resp = await self._process_tensor(shard, tensor, request_id)
|
|
|
end_time = time.perf_counter_ns()
|
|
|
elapsed_time_ns = end_time - start_time
|
|
|
asyncio.create_task(
|
|
@@ -226,84 +236,77 @@ class StandardNode(Node):
|
|
|
base_shard: Shard,
|
|
|
tensor: np.ndarray,
|
|
|
request_id: Optional[str] = None,
|
|
|
- inference_state: Optional[str] = None,
|
|
|
) -> Optional[np.ndarray]:
|
|
|
if request_id is None:
|
|
|
request_id = str(uuid.uuid4())
|
|
|
- if request_id not in self.buffered_token_output:
|
|
|
- self.buffered_token_output[request_id] = ([], False)
|
|
|
shard = self.get_current_shard(base_shard)
|
|
|
|
|
|
+ if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
|
|
try:
|
|
|
- if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
|
|
- result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
|
|
|
- is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
- if is_finished:
|
|
|
- self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
|
- asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity
|
|
|
-
|
|
|
- if result.size == 1: # we got a new token out
|
|
|
- self.buffered_token_output[request_id][0].append(result.item())
|
|
|
- self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
|
-
|
|
|
- if not is_finished:
|
|
|
- asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
|
|
|
-
|
|
|
- return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
|
|
|
+ result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
|
|
|
+ ret = await self.process_inference_result(shard, result, request_id)
|
|
|
+ return ret
|
|
|
except Exception as e:
|
|
|
print(f"Error processing tensor for shard {shard}: {e}")
|
|
|
traceback.print_exc()
|
|
|
return None
|
|
|
|
|
|
- async def forward_to_next_shard(
|
|
|
+ async def forward_prompt(
|
|
|
self,
|
|
|
base_shard: Shard,
|
|
|
- tensor_or_prompt: Union[np.ndarray, str],
|
|
|
+ prompt: str,
|
|
|
request_id: str,
|
|
|
- image_str: Optional[str] = None,
|
|
|
- inference_state: Optional[str] = None,
|
|
|
+ target_index: int,
|
|
|
) -> None:
|
|
|
+ if DEBUG >= 1: print(f"target partition index: {target_index}")
|
|
|
+ target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
|
|
+ next_shard = self.get_current_shard(base_shard, target_index)
|
|
|
+ if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
|
|
|
+ if target_id == self.id:
|
|
|
+ await self.process_prompt(next_shard, prompt, request_id)
|
|
|
+ else:
|
|
|
+ target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
|
+ if not target_peer:
|
|
|
+ raise ValueError(f"Peer for {target_index} not found")
|
|
|
+ if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
|
|
|
+ await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
|
|
|
+
|
|
|
+ async def forward_tensor(
|
|
|
+ self,
|
|
|
+ base_shard: Shard,
|
|
|
+ tensor: np.ndarray,
|
|
|
+ request_id: str,
|
|
|
+ target_index: int,
|
|
|
+ ) -> None:
|
|
|
+ if DEBUG >= 1: print(f"target partition index: {target_index}")
|
|
|
+ target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
|
|
+ next_shard = self.get_current_shard(base_shard, target_index)
|
|
|
+ if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
|
|
|
+ if target_id == self.id:
|
|
|
+ await self.process_tensor(next_shard, tensor, request_id)
|
|
|
+ else:
|
|
|
+ target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
|
+ if not target_peer:
|
|
|
+ raise ValueError(f"Peer for {target_index} not found")
|
|
|
+ if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
|
|
|
+ await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
|
|
|
+
|
|
|
+ def get_partition_index(self, offset: int = 0):
|
|
|
if not self.partitioning_strategy:
|
|
|
if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
|
|
|
- return
|
|
|
- shard = self.get_current_shard(base_shard)
|
|
|
-
|
|
|
+ return None
|
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|
|
|
- shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
|
|
|
current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
|
|
|
- if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
|
|
|
- if current_partition_index is not None:
|
|
|
- next_partition_index = (current_partition_index+1) % len(partitions)
|
|
|
- next_partition: Partition = partitions[next_partition_index]
|
|
|
- next_shard = shards[next_partition_index]
|
|
|
- if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
|
|
|
-
|
|
|
- if next_partition.node_id == self.id:
|
|
|
- if isinstance(tensor_or_prompt, np.ndarray):
|
|
|
- await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
|
- else:
|
|
|
- await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
|
|
|
- return
|
|
|
-
|
|
|
- target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
|
|
- if not target_peer:
|
|
|
- raise ValueError(f"Peer for {next_partition} not found")
|
|
|
-
|
|
|
- if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
|
|
|
-
|
|
|
- if isinstance(tensor_or_prompt, np.ndarray):
|
|
|
- await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
|
|
|
- else:
|
|
|
- await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
|
|
|
+ if current_partition_index is None:
|
|
|
+ raise ValueError(f"No current partition found for node: {self.id}")
|
|
|
+ return (current_partition_index + offset) % len(partitions)
|
|
|
|
|
|
- def get_current_shard(self, base_shard: Shard) -> Shard:
|
|
|
+ def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
|
|
|
+ if index is None:
|
|
|
+ index = self.get_partition_index()
|
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|
|
|
shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
|
|
|
- current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
|
|
|
- if current_partition_index is None:
|
|
|
- raise ValueError(f"No current partition found for node: {self.id}")
|
|
|
- return shards[current_partition_index]
|
|
|
+ return shards[index]
|
|
|
|
|
|
async def update_peers(self, wait_for_peers: int = 0) -> bool:
|
|
|
next_peers = await self.discovery.discover_peers(wait_for_peers)
|
|
@@ -311,20 +314,16 @@ class StandardNode(Node):
|
|
|
next_peer_ids = {peer.id() for peer in next_peers}
|
|
|
peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
|
|
|
peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
|
|
|
- peers_updated = [
|
|
|
- peer for peer in next_peers
|
|
|
- if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())
|
|
|
- ]
|
|
|
- peers_unchanged = [
|
|
|
- peer for peer in next_peers
|
|
|
- if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())
|
|
|
- ]
|
|
|
+ peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())]
|
|
|
+ peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())]
|
|
|
peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
|
|
|
peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
|
|
|
|
|
|
def _pretty(peers: List[PeerHandle]) -> List[str]:
|
|
|
return [f"{peer.id()}@{peer.addr()}" for peer in peers]
|
|
|
- if DEBUG >= 2: print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
|
|
|
+
|
|
|
+ if DEBUG >= 2:
|
|
|
+ print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
|
|
|
|
|
|
async def disconnect_with_timeout(peer, timeout=5):
|
|
|
try:
|
|
@@ -344,14 +343,8 @@ class StandardNode(Node):
|
|
|
traceback.print_exc()
|
|
|
return False
|
|
|
|
|
|
- disconnect_results = await asyncio.gather(
|
|
|
- *(disconnect_with_timeout(peer) for peer in peers_to_disconnect),
|
|
|
- return_exceptions=True
|
|
|
- )
|
|
|
- connect_results = await asyncio.gather(
|
|
|
- *(connect_with_timeout(peer) for peer in peers_to_connect),
|
|
|
- return_exceptions=True
|
|
|
- )
|
|
|
+ disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True)
|
|
|
+ connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True)
|
|
|
|
|
|
successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
|
|
|
failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
|
|
@@ -370,12 +363,7 @@ class StandardNode(Node):
|
|
|
supported_engines = self.get_supported_inference_engines()
|
|
|
await self.broadcast_supported_engines(supported_engines)
|
|
|
if len(self.get_topology_inference_engines()):
|
|
|
- if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
|
|
|
- if DEBUG >= 1: print("Found node with only tinygrad, using tinygrad on all nodes")
|
|
|
- self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
|
|
|
- else:
|
|
|
- if DEBUG >= 1: print("All nodes can use mlx, using mlx for inference")
|
|
|
- self.inference_engine = get_inference_engine("mlx", self.shard_downloader)
|
|
|
+ self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader)
|
|
|
|
|
|
async def periodic_topology_collection(self, interval: int):
|
|
|
while True:
|
|
@@ -422,6 +410,7 @@ class StandardNode(Node):
|
|
|
self.topology.merge(other_topology)
|
|
|
except Exception as e:
|
|
|
print(f"Error collecting topology from {peer.id()}: {e}")
|
|
|
+ traceback.print_exc()
|
|
|
|
|
|
next_topology.active_node_id = self.topology.active_node_id # this is not so clean.
|
|
|
self.topology = next_topology
|
|
@@ -440,7 +429,7 @@ class StandardNode(Node):
|
|
|
def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
|
|
|
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
|
|
|
self.on_token.trigger_all(request_id, tokens, is_finished)
|
|
|
-
|
|
|
+
|
|
|
async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
|
|
async def send_result_to_peer(peer):
|
|
|
try:
|
|
@@ -464,6 +453,7 @@ class StandardNode(Node):
|
|
|
except Exception as e:
|
|
|
print(f"Error sending opaque status to {peer.id()}: {e}")
|
|
|
traceback.print_exc()
|
|
|
+
|
|
|
await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
|
|
|
# in the case of opaque status, we also want to receive our own opaque statuses
|
|
|
self.on_opaque_status.trigger_all(request_id, status)
|