standard_node.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. import numpy as np
  2. import json
  3. import asyncio
  4. import uuid
  5. import time
  6. import traceback
  7. from typing import List, Dict, Optional, Tuple, Union
  8. from exo.networking import Discovery, PeerHandle, Server
  9. from exo.inference.inference_engine import InferenceEngine, Shard
  10. from .node import Node
  11. from exo.topology.topology import Topology
  12. from exo.topology.device_capabilities import device_capabilities
  13. from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
  14. from exo import DEBUG
  15. from exo.helpers import AsyncCallbackSystem
  16. from exo.viz.topology_viz import TopologyViz
  17. from exo.download.hf.hf_helpers import RepoProgressEvent
  18. class StandardNode(Node):
  19. def __init__(
  20. self,
  21. _id: str,
  22. server: Server,
  23. inference_engine: InferenceEngine,
  24. discovery: Discovery,
  25. partitioning_strategy: PartitioningStrategy = None,
  26. max_generate_tokens: int = 1024,
  27. topology_viz: Optional[TopologyViz] = None,
  28. ):
  29. self.id = _id
  30. self.inference_engine = inference_engine
  31. self.server = server
  32. self.discovery = discovery
  33. self.partitioning_strategy = partitioning_strategy
  34. self.peers: List[PeerHandle] = {}
  35. self.topology: Topology = Topology()
  36. self.device_capabilities = device_capabilities()
  37. self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
  38. self.max_generate_tokens = max_generate_tokens
  39. self.topology_viz = topology_viz
  40. self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
  41. self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
  42. self._on_opaque_status.register("node_status").on_next(self.on_node_status)
  43. self.node_download_progress: Dict[str, RepoProgressEvent] = {}
  44. async def start(self, wait_for_peers: int = 0) -> None:
  45. await self.server.start()
  46. await self.discovery.start()
  47. await self.update_peers(wait_for_peers)
  48. await self.collect_topology()
  49. if DEBUG >= 2: print(f"Collected topology: {self.topology}")
  50. asyncio.create_task(self.periodic_topology_collection(1.0))
  51. async def stop(self) -> None:
  52. await self.discovery.stop()
  53. await self.server.stop()
  54. def on_node_status(self, request_id, opaque_status):
  55. try:
  56. status_data = json.loads(opaque_status)
  57. if status_data.get("type", "") == "node_status":
  58. if status_data.get("status", "").startswith("start_"):
  59. self.current_topology.active_node_id = status_data.get("node_id")
  60. elif status_data.get("status", "").startswith("end_"):
  61. if status_data.get("node_id") == self.current_topology.active_node_id:
  62. self.current_topology.active_node_id = None
  63. download_progress = None
  64. if status_data.get("type", "") == "download_progress":
  65. if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
  66. download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
  67. self.node_download_progress[status_data.get('node_id')] = download_progress
  68. if self.topology_viz:
  69. self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress)
  70. except Exception as e:
  71. if DEBUG >= 1: print(f"Error updating visualization: {e}")
  72. if DEBUG >= 1: traceback.print_exc()
  73. def get_supported_inference_engines(self):
  74. supported_engines = []
  75. if self.inferenceEngine == 'mlx':
  76. supported_engines.extend('mlx', 'tinygrad')
  77. else:
  78. supported_engines.append('tinygrad')
  79. return supported_engines
  80. async def broadcast_supported_engines(self, supported_engines: List):
  81. await self.broadcast_opaque_status("", json.dumps({
  82. "type": "supported_inference_engines",
  83. "node_id": self.id,
  84. "engines": supported_engines
  85. }))
  86. async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
  87. shard = self.get_current_shard(base_shard)
  88. asyncio.create_task(
  89. self.broadcast_opaque_status(
  90. request_id,
  91. json.dumps({
  92. "type": "node_status",
  93. "node_id": self.id,
  94. "status": "start_process_prompt",
  95. "base_shard": base_shard.to_dict(),
  96. "shard": shard.to_dict(),
  97. "prompt": prompt,
  98. "image_str": image_str,
  99. "inference_state": inference_state,
  100. "request_id": request_id,
  101. }),
  102. )
  103. )
  104. start_time = time.perf_counter_ns()
  105. resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
  106. end_time = time.perf_counter_ns()
  107. elapsed_time_ns = end_time - start_time
  108. asyncio.create_task(
  109. self.broadcast_opaque_status(
  110. request_id,
  111. json.dumps({
  112. "type": "node_status",
  113. "node_id": self.id,
  114. "status": "end_process_prompt",
  115. "base_shard": base_shard.to_dict(),
  116. "shard": shard.to_dict(),
  117. "prompt": prompt,
  118. "image_str": image_str,
  119. "inference_state": inference_state,
  120. "request_id": request_id,
  121. "elapsed_time_ns": elapsed_time_ns,
  122. "result_size": resp.size if resp is not None else 0,
  123. }),
  124. )
  125. )
  126. return resp
  127. async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
  128. if request_id is None:
  129. request_id = str(uuid.uuid4())
  130. if request_id not in self.buffered_token_output:
  131. self.buffered_token_output[request_id] = ([], False)
  132. shard = self.get_current_shard(base_shard)
  133. if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
  134. if shard.start_layer != 0:
  135. if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
  136. await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state)
  137. return
  138. result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
  139. is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
  140. if is_finished:
  141. self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
  142. asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity
  143. if result.size == 1:
  144. self.buffered_token_output[request_id][0].append(result.item())
  145. self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
  146. if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
  147. if not is_finished:
  148. asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state))
  149. return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
  150. async def process_tensor(
  151. self,
  152. base_shard: Shard,
  153. tensor: np.ndarray,
  154. request_id: Optional[str] = None,
  155. inference_state: Optional[str] = None,
  156. ) -> Optional[np.ndarray]:
  157. shard = self.get_current_shard(base_shard)
  158. asyncio.create_task(
  159. self.broadcast_opaque_status(
  160. request_id,
  161. json.dumps({
  162. "type": "node_status",
  163. "node_id": self.id,
  164. "status": "start_process_tensor",
  165. "base_shard": base_shard.to_dict(),
  166. "shard": shard.to_dict(),
  167. "tensor_size": tensor.size,
  168. "tensor_shape": tensor.shape,
  169. "request_id": request_id,
  170. "inference_state": inference_state,
  171. }),
  172. )
  173. )
  174. start_time = time.perf_counter_ns()
  175. resp = await self._process_tensor(shard, tensor, request_id, inference_state)
  176. end_time = time.perf_counter_ns()
  177. elapsed_time_ns = end_time - start_time
  178. asyncio.create_task(
  179. self.broadcast_opaque_status(
  180. request_id,
  181. json.dumps({
  182. "type": "node_status",
  183. "node_id": self.id,
  184. "status": "end_process_tensor",
  185. "base_shard": base_shard.to_dict(),
  186. "shard": shard.to_dict(),
  187. "request_id": request_id,
  188. "elapsed_time_ns": elapsed_time_ns,
  189. "result_size": resp.size if resp is not None else 0,
  190. }),
  191. )
  192. )
  193. return resp
  194. async def _process_tensor(
  195. self,
  196. base_shard: Shard,
  197. tensor: np.ndarray,
  198. request_id: Optional[str] = None,
  199. inference_state: Optional[str] = None,
  200. ) -> Optional[np.ndarray]:
  201. if request_id is None:
  202. request_id = str(uuid.uuid4())
  203. if request_id not in self.buffered_token_output:
  204. self.buffered_token_output[request_id] = ([], False)
  205. shard = self.get_current_shard(base_shard)
  206. try:
  207. if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
  208. result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
  209. is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
  210. if is_finished:
  211. self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
  212. asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity
  213. if result.size == 1: # we got a new token out
  214. self.buffered_token_output[request_id][0].append(result.item())
  215. self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
  216. if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
  217. if not is_finished:
  218. asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
  219. return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
  220. except Exception as e:
  221. print(f"Error processing tensor for shard {shard}: {e}")
  222. traceback.print_exc()
  223. return None
  224. async def forward_to_next_shard(
  225. self,
  226. base_shard: Shard,
  227. tensor_or_prompt: Union[np.ndarray, str],
  228. request_id: str,
  229. image_str: Optional[str] = None,
  230. inference_state: Optional[str] = None,
  231. ) -> None:
  232. if not self.partitioning_strategy:
  233. if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
  234. return
  235. shard = self.get_current_shard(base_shard)
  236. partitions = self.partitioning_strategy.partition(self.topology)
  237. shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
  238. current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
  239. if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
  240. if current_partition_index is not None:
  241. next_partition_index = (current_partition_index+1) % len(partitions)
  242. next_partition: Partition = partitions[next_partition_index]
  243. next_shard = shards[next_partition_index]
  244. if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
  245. if next_partition.node_id == self.id:
  246. if isinstance(tensor_or_prompt, np.ndarray):
  247. await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
  248. else:
  249. await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
  250. return
  251. target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
  252. if not target_peer:
  253. raise ValueError(f"Peer for {next_partition} not found")
  254. if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
  255. if isinstance(tensor_or_prompt, np.ndarray):
  256. await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
  257. else:
  258. await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
  259. def get_current_shard(self, base_shard: Shard) -> Shard:
  260. partitions = self.partitioning_strategy.partition(self.topology)
  261. shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
  262. current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
  263. if current_partition_index is None:
  264. raise ValueError(f"No current partition found for node: {self.id}")
  265. return shards[current_partition_index]
  266. async def update_peers(self, wait_for_peers: int = 0) -> bool:
  267. next_peers = await self.discovery.discover_peers(wait_for_peers)
  268. current_peer_ids = {peer.id() for peer in self.peers}
  269. next_peer_ids = {peer.id() for peer in next_peers}
  270. peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
  271. peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
  272. peers_updated = [
  273. peer for peer in next_peers
  274. if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())
  275. ]
  276. peers_unchanged = [
  277. peer for peer in next_peers
  278. if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())
  279. ]
  280. peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
  281. peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
  282. def _pretty(peers: List[PeerHandle]) -> List[str]:
  283. return [f"{peer.id()}@{peer.addr()}" for peer in peers]
  284. if DEBUG >= 2: print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
  285. async def disconnect_with_timeout(peer, timeout=5):
  286. try:
  287. await asyncio.wait_for(peer.disconnect(), timeout)
  288. return True
  289. except Exception as e:
  290. print(f"Error disconnecting peer {peer.id()}@{peer.addr()}: {e}")
  291. traceback.print_exc()
  292. return False
  293. async def connect_with_timeout(peer, timeout=5):
  294. try:
  295. await asyncio.wait_for(peer.connect(), timeout)
  296. return True
  297. except Exception as e:
  298. print(f"Error connecting peer {peer.id()}@{peer.addr()}: {e}")
  299. traceback.print_exc()
  300. return False
  301. disconnect_results = await asyncio.gather(
  302. *(disconnect_with_timeout(peer) for peer in peers_to_disconnect),
  303. return_exceptions=True
  304. )
  305. connect_results = await asyncio.gather(
  306. *(connect_with_timeout(peer) for peer in peers_to_connect),
  307. return_exceptions=True
  308. )
  309. successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
  310. failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
  311. successful_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is True]
  312. failed_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is False]
  313. if DEBUG >= 1:
  314. if successful_disconnects: print(f"Successfully disconnected peers: {_pretty(successful_disconnects)}")
  315. if failed_disconnects: print(f"Failed to disconnect peers: {_pretty(failed_disconnects)}")
  316. if successful_connects: print(f"Successfully connected peers: {_pretty(successful_connects)}")
  317. if failed_connects: print(f"Failed to connect peers: {_pretty(failed_connects)}")
  318. self.peers = next_peers
  319. return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
  320. async def periodic_topology_collection(self, interval: int):
  321. while True:
  322. await asyncio.sleep(interval)
  323. try:
  324. did_peers_change = await self.update_peers()
  325. if DEBUG >= 2: print(f"{did_peers_change=}")
  326. if did_peers_change:
  327. await self.collect_topology()
  328. except Exception as e:
  329. print(f"Error collecting topology: {e}")
  330. traceback.print_exc()
  331. async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
  332. if request_id not in self.buffered_token_output:
  333. return None, False
  334. return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
  335. async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
  336. next_topology = Topology()
  337. next_topology.update_node(self.id, self.device_capabilities)
  338. if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
  339. prev_visited = visited.copy()
  340. visited.add(self.id)
  341. visited.update(p.id() for p in self.peers)
  342. for peer in self.peers:
  343. next_topology.update_node(peer.id(), peer.device_capabilities())
  344. next_topology.add_edge(self.id, peer.id())
  345. if peer.id() in prev_visited:
  346. continue
  347. if max_depth <= 0:
  348. if DEBUG >= 2: print("Max depth reached. Skipping...")
  349. continue
  350. try:
  351. other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
  352. if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
  353. self.topology.merge(other_topology)
  354. except Exception as e:
  355. print(f"Error collecting topology from {peer.id()}: {e}")
  356. next_topology.active_node_id = self.topology.active_node_id # this is not so clean.
  357. self.topology = next_topology
  358. if self.topology_viz:
  359. self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)
  360. return next_topology
  361. @property
  362. def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
  363. return self._on_token
  364. @property
  365. def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
  366. return self._on_opaque_status
  367. def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
  368. if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
  369. self.on_token.trigger_all(request_id, tokens, is_finished)
  370. async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
  371. async def send_result_to_peer(peer):
  372. try:
  373. await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
  374. except asyncio.TimeoutError:
  375. print(f"Timeout broadcasting result to {peer.id()}")
  376. except Exception as e:
  377. print(f"Error broadcasting result to {peer.id()}: {e}")
  378. traceback.print_exc()
  379. await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
  380. async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
  381. if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}")
  382. async def send_status_to_peer(peer):
  383. try:
  384. await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
  385. except asyncio.TimeoutError:
  386. print(f"Timeout sending opaque status to {peer.id()}")
  387. except Exception as e:
  388. print(f"Error sending opaque status to {peer.id()}: {e}")
  389. traceback.print_exc()
  390. await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
  391. # in the case of opaque status, we also want to receive our own opaque statuses
  392. self.on_opaque_status.trigger_all(request_id, status)
  393. @property
  394. def current_topology(self) -> Topology:
  395. return self.topology