|
@@ -3,20 +3,19 @@ import json
|
|
|
import asyncio
|
|
|
import uuid
|
|
|
import time
|
|
|
-from typing import List, Dict, Optional, Callable, Tuple, Union
|
|
|
+from typing import List, Dict, Optional, Tuple, Union
|
|
|
from exo.networking import Discovery, PeerHandle, Server
|
|
|
from exo.inference.inference_engine import InferenceEngine, Shard
|
|
|
from .node import Node
|
|
|
from exo.topology.topology import Topology
|
|
|
from exo.topology.device_capabilities import device_capabilities
|
|
|
-from exo.topology.partitioning_strategy import PartitioningStrategy
|
|
|
-from exo.topology.partitioning_strategy import Partition
|
|
|
+from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
|
|
|
from exo import DEBUG
|
|
|
-from exo.helpers import AsyncCallback, AsyncCallbackSystem
|
|
|
+from exo.helpers import AsyncCallbackSystem
|
|
|
from exo.viz.topology_viz import TopologyViz
|
|
|
|
|
|
class StandardNode(Node):
|
|
|
- def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256):
|
|
|
+ def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256, chatgpt_api_endpoint: Optional[str] = None, web_chat_url: Optional[str] = None):
|
|
|
self.id = id
|
|
|
self.inference_engine = inference_engine
|
|
|
self.server = server
|
|
@@ -26,7 +25,7 @@ class StandardNode(Node):
|
|
|
self.topology: Topology = Topology()
|
|
|
self.device_capabilities = device_capabilities()
|
|
|
self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
|
|
|
- self.topology_viz = TopologyViz()
|
|
|
+ self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url)
|
|
|
self.max_generate_tokens = max_generate_tokens
|
|
|
self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
|
|
|
self._on_opaque_status = AsyncCallbackSystem[str, str]()
|
|
@@ -57,28 +56,29 @@ class StandardNode(Node):
|
|
|
await self.discovery.stop()
|
|
|
await self.server.stop()
|
|
|
|
|
|
- async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
+ async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
|
|
|
start_time = time.perf_counter_ns()
|
|
|
- resp = await self._process_prompt(shard, prompt, request_id, inference_state)
|
|
|
+ resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
|
|
|
end_time = time.perf_counter_ns()
|
|
|
elapsed_time_ns = end_time - start_time
|
|
|
asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
|
|
|
return resp
|
|
|
|
|
|
- async def _process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
+ async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
if request_id is None:
|
|
|
request_id = str(uuid.uuid4())
|
|
|
if request_id not in self.buffered_token_output:
|
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
|
+ shard = self.get_current_shard(base_shard)
|
|
|
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] process prompt: {shard=} {prompt=}")
|
|
|
- if self.get_current_shard(shard).start_layer != 0:
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {shard=} {prompt=}")
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
|
|
+ if shard.start_layer != 0:
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
|
|
await self.forward_to_next_shard(shard, prompt, request_id)
|
|
|
return
|
|
|
|
|
|
- result, inference_state, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt, inference_state=inference_state)
|
|
|
+ result, inference_state, is_finished = await self.inference_engine.infer_prompt(shard, prompt, inference_state=inference_state)
|
|
|
is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
if is_finished:
|
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
@@ -104,15 +104,16 @@ class StandardNode(Node):
|
|
|
asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
|
|
|
return resp
|
|
|
|
|
|
- async def _process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
+ async def _process_tensor(self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
if request_id is None:
|
|
|
request_id = str(uuid.uuid4())
|
|
|
if request_id not in self.buffered_token_output:
|
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
|
+ shard = self.get_current_shard(base_shard)
|
|
|
|
|
|
try:
|
|
|
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
|
|
- result, inference_state, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor, inference_state=inference_state)
|
|
|
+ result, inference_state, is_finished = await self.inference_engine.infer_tensor(shard, tensor, inference_state=inference_state)
|
|
|
is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
if is_finished:
|
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
@@ -169,23 +170,19 @@ class StandardNode(Node):
|
|
|
else:
|
|
|
await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
|
|
|
|
|
|
- def get_current_shard(self, shard: Shard) -> Shard:
|
|
|
+ def get_current_shard(self, base_shard: Shard) -> Shard:
|
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|
|
|
+ shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
|
|
|
current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
|
|
|
if current_partition_index is None:
|
|
|
raise ValueError(f"No current partition found for node: {self.id}")
|
|
|
+ return shards[current_partition_index]
|
|
|
|
|
|
- current_partition = partitions[current_partition_index]
|
|
|
- start_layer = int(current_partition.start * shard.n_layers)
|
|
|
- end_layer = int(current_partition.end * shard.n_layers) - 1
|
|
|
- return Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
|
|
|
-
|
|
|
-
|
|
|
- async def reset_shard(self, shard: Shard) -> None:
|
|
|
+ async def reset_shard(self, base_shard: Shard) -> None:
|
|
|
# Implement shard reset logic
|
|
|
- if DEBUG >= 2: print(f"Resetting shard: {shard}")
|
|
|
+ if DEBUG >= 2: print(f"Resetting shard: {base_shard}")
|
|
|
self.buffered_token_output = {}
|
|
|
- await self.inference_engine.reset_shard(self.get_current_shard(shard))
|
|
|
+ await self.inference_engine.reset_shard(self.get_current_shard(base_shard))
|
|
|
|
|
|
async def update_peers(self, wait_for_peers: int = 0) -> None:
|
|
|
self.peers = await self.discovery.discover_peers(wait_for_peers)
|
|
@@ -250,7 +247,8 @@ class StandardNode(Node):
|
|
|
|
|
|
# TODO: unify this and collect_topology as global actions
|
|
|
async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
|
|
|
- await self.reset_shard(self.get_current_shard(base_shard))
|
|
|
+ shard = self.get_current_shard(base_shard)
|
|
|
+ await self.reset_shard(shard)
|
|
|
|
|
|
if DEBUG >= 2: print(f"Global reset {base_shard=} {max_depth=} {visited=}")
|
|
|
|