node.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. import numpy as np
  2. import json
  3. import asyncio
  4. import uuid
  5. import time
  6. import traceback
  7. from typing import List, Dict, Optional, Tuple, Union, Set
  8. from exo.networking import Discovery, PeerHandle, Server
  9. from exo.inference.inference_engine import InferenceEngine, Shard
  10. from exo.topology.topology import Topology
  11. from exo.topology.device_capabilities import device_capabilities
  12. from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
  13. from exo import DEBUG
  14. from exo.helpers import AsyncCallbackSystem
  15. from exo.viz.topology_viz import TopologyViz
  16. from exo.download.hf.hf_helpers import RepoProgressEvent
  17. from exo.inference.inference_engine import get_inference_engine, InferenceEngine
  18. from exo.download.hf.hf_shard_download import HFShardDownloader
  19. class Node:
  20. def __init__(
  21. self,
  22. _id: str,
  23. server: Server,
  24. inference_engine: InferenceEngine,
  25. discovery: Discovery,
  26. partitioning_strategy: PartitioningStrategy = None,
  27. max_generate_tokens: int = 1024,
  28. default_sample_temperature: float = 0.0,
  29. topology_viz: Optional[TopologyViz] = None,
  30. shard_downloader: Optional[HFShardDownloader] = None,
  31. ):
  32. self.id = _id
  33. self.inference_engine = inference_engine
  34. self.server = server
  35. self.discovery = discovery
  36. self.partitioning_strategy = partitioning_strategy
  37. self.peers: List[PeerHandle] = {}
  38. self.topology: Topology = Topology()
  39. self.device_capabilities = device_capabilities()
  40. self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
  41. self.buffered_logits: Dict[str, List[np.ndarray]] = {}
  42. self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
  43. self.buffered_partials: Dict[str, List[np.ndarray]] = {}
  44. self.checkpoints: Dict[str, Dict[str, int]] = {}
  45. self.max_generate_tokens = max_generate_tokens
  46. self.topology_viz = topology_viz
  47. self.default_sample_temperature = default_sample_temperature
  48. self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
  49. self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
  50. self._on_opaque_status.register("node_status").on_next(self.on_node_status)
  51. self.node_download_progress: Dict[str, RepoProgressEvent] = {}
  52. self.topology_inference_engines_pool: List[List[str]] = []
  53. self.shard_downloader = shard_downloader
  54. self.outstanding_requests = {}
  55. async def start(self, wait_for_peers: int = 0) -> None:
  56. await self.server.start()
  57. await self.discovery.start()
  58. await self.update_peers(wait_for_peers)
  59. await self.collect_topology(set())
  60. if DEBUG >= 2: print(f"Collected topology: {self.topology}")
  61. asyncio.create_task(self.periodic_topology_collection(2.0))
  62. async def stop(self) -> None:
  63. await self.discovery.stop()
  64. await self.server.stop()
  65. def on_node_status(self, request_id, opaque_status):
  66. try:
  67. status_data = json.loads(opaque_status)
  68. if status_data.get("type", "") == "supported_inference_engines":
  69. node_id = status_data.get("node_id")
  70. engines = status_data.get("engines", [])
  71. self.topology_inference_engines_pool.append(engines)
  72. if status_data.get("type", "") == "node_status":
  73. if status_data.get("status", "").startswith("start_"):
  74. self.current_topology.active_node_id = status_data.get("node_id")
  75. elif status_data.get("status", "").startswith("end_"):
  76. if status_data.get("node_id") == self.current_topology.active_node_id:
  77. self.current_topology.active_node_id = None
  78. download_progress = None
  79. if status_data.get("type", "") == "download_progress":
  80. if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
  81. download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
  82. self.node_download_progress[status_data.get('node_id')] = download_progress
  83. if self.topology_viz:
  84. self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
  85. except Exception as e:
  86. if DEBUG >= 1: print(f"Error updating visualization: {e}")
  87. if DEBUG >= 1: traceback.print_exc()
  88. def get_supported_inference_engines(self):
  89. supported_engine_names = []
  90. if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
  91. supported_engine_names.append('mlx')
  92. supported_engine_names.append('tinygrad')
  93. else:
  94. supported_engine_names.append('tinygrad')
  95. return supported_engine_names
  96. async def broadcast_supported_engines(self, supported_engines_names: List[str]):
  97. status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
  98. await self.broadcast_opaque_status("", status_message)
  99. def get_topology_inference_engines(self) -> List[List[str]]:
  100. return self.topology_inference_engines_pool
  101. async def process_inference_result(
  102. self,
  103. shard,
  104. result: np.ndarray,
  105. request_id: Optional[str] = None,
  106. inference_state: Optional[dict] = None,
  107. ):
  108. if shard.model_id != 'stable-diffusion-2-1-base':
  109. if request_id not in self.buffered_token_output:
  110. self.buffered_token_output[request_id] = ([], False)
  111. is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
  112. if shard.is_last_layer() and not is_finished:
  113. token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
  114. await self.inference_engine.ensure_shard(shard)
  115. self.buffered_token_output[request_id][0].append(token.item())
  116. is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
  117. if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
  118. asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
  119. forward = token.reshape(1, -1)
  120. intermediate_result = self.buffered_token_output[request_id][0]
  121. else:
  122. forward = result
  123. else:
  124. await self.inference_engine.ensure_shard(shard)
  125. is_finished = inference_state.get("is_finished", False)
  126. intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
  127. forward = result
  128. if shard.is_last_layer():
  129. self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
  130. asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
  131. if is_finished:
  132. if shard.model_id != 'stable-diffusion-2-1-base':
  133. self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
  134. self.outstanding_requests.pop(request_id)
  135. else:
  136. self.outstanding_requests[request_id] = "waiting"
  137. asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
  138. return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
  139. async def process_prompt(
  140. self,
  141. base_shard: Shard,
  142. prompt: str,
  143. request_id: Optional[str] = None,
  144. inference_state: Optional[dict] = {},
  145. ) -> Optional[np.ndarray]:
  146. shard = self.get_current_shard(base_shard)
  147. asyncio.create_task(
  148. self.broadcast_opaque_status(
  149. request_id,
  150. json.dumps({
  151. "type": "node_status",
  152. "node_id": self.id,
  153. "status": "start_process_prompt",
  154. "base_shard": base_shard.to_dict(),
  155. "shard": shard.to_dict(),
  156. "prompt": prompt,
  157. "request_id": request_id,
  158. }),
  159. )
  160. )
  161. start_time = time.perf_counter_ns()
  162. resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
  163. end_time = time.perf_counter_ns()
  164. elapsed_time_ns = end_time - start_time
  165. asyncio.create_task(
  166. self.broadcast_opaque_status(
  167. request_id,
  168. json.dumps({
  169. "type": "node_status",
  170. "node_id": self.id,
  171. "status": "end_process_prompt",
  172. "base_shard": base_shard.to_dict(),
  173. "shard": shard.to_dict(),
  174. "prompt": prompt,
  175. "request_id": request_id,
  176. "elapsed_time_ns": elapsed_time_ns,
  177. "result_size": resp.size if resp is not None else 0,
  178. }),
  179. )
  180. )
  181. return resp
  182. async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
  183. if request_id is None:
  184. request_id = str(uuid.uuid4())
  185. shard = self.get_current_shard(base_shard)
  186. if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
  187. if not shard.is_first_layer():
  188. if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
  189. self.outstanding_requests[request_id] = "waiting"
  190. resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
  191. return None
  192. else:
  193. self.outstanding_requests[request_id] = "processing"
  194. result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
  195. ret = await self.process_inference_result(shard, result, request_id, inference_state)
  196. return result
  197. async def enqueue_example(
  198. self,
  199. base_shard: Shard,
  200. example: np.ndarray,
  201. target: np.ndarray,
  202. length: np.ndarray,
  203. request_id: Optional[str] = None,
  204. train: bool = False,
  205. ):
  206. shard = self.get_current_shard(base_shard)
  207. if shard.is_first_layer():
  208. loss = await self.process_example(shard, example, target, length, train, request_id)
  209. return loss
  210. else:
  211. if request_id is None:
  212. request_id = str(uuid.uuid4())
  213. self.outstanding_requests[request_id] = "waiting"
  214. loss = await self.forward_example(shard, example, target, length, train, request_id, 0)
  215. return loss
  216. async def coordinate_save(
  217. self,
  218. base_shard: Shard,
  219. iteration: int,
  220. destination: str,
  221. ):
  222. shard = self.get_current_shard(base_shard)
  223. model = shard.model_id
  224. sid = shard.__hash__()
  225. path = f"{destination}/{model}/{sid}-{iteration}.safetensors"
  226. self.outstanding_requests[f"{sid}::{iteration}"] = "Checking"
  227. if model not in self.checkpoints:
  228. self.checkpoints[model] = {}
  229. if sid not in self.checkpoints[model]:
  230. self.checkpoints[model][sid] = []
  231. if len(self.checkpoints[model][sid]) < 1 or self.checkpoints[model][sid][-1] < iteration:
  232. print(f"Saving checkpoint to {path}")
  233. self.outstanding_requests[f"{sid}::{iteration}"] = "Saving"
  234. import os
  235. os.makedirs("/".join(path.split("/")[:-1]), exist_ok=True)
  236. await self.inference_engine.save_checkpoint(shard, path)
  237. self.checkpoints[model][sid] = sorted(self.checkpoints[model][sid] + [iteration])
  238. self.outstanding_requests.pop(f"{sid}::{iteration}")
  239. async def process_example(
  240. self,
  241. base_shard: Shard,
  242. example: np.ndarray,
  243. target: np.ndarray,
  244. length: np.ndarray,
  245. train: bool = False,
  246. request_id: Optional[str] = None,
  247. ):
  248. shard = self.get_current_shard(base_shard)
  249. asyncio.create_task(
  250. self.broadcast_opaque_status(
  251. request_id,
  252. json.dumps({
  253. "type": "node_status",
  254. "node_id": self.id,
  255. "status": f"start_{'train' if train else 'eval'}_example",
  256. "base_shard": base_shard.to_dict(),
  257. "shard": shard.to_dict(),
  258. "example_size": example.size,
  259. "example_shape": example.shape,
  260. "request_id": request_id,
  261. }),
  262. )
  263. )
  264. start_time = time.perf_counter_ns()
  265. resp = await self._process_example(shard, example, target, length, train, request_id)
  266. end_time = time.perf_counter_ns()
  267. elapsed_time_ns = end_time - start_time
  268. asyncio.create_task(
  269. self.broadcast_opaque_status(
  270. request_id,
  271. json.dumps({
  272. "type": "node_status",
  273. "node_id": self.id,
  274. "status": f"end_{'train' if train else 'eval'}_example",
  275. "base_shard": base_shard.to_dict(),
  276. "shard": shard.to_dict(),
  277. "request_id": request_id,
  278. "elapsed_time_ns": elapsed_time_ns,
  279. }),
  280. )
  281. )
  282. return resp
  283. async def _process_example(
  284. self,
  285. base_shard: Shard,
  286. example: np.ndarray,
  287. target: np.ndarray,
  288. length: np.ndarray,
  289. train: bool = False,
  290. request_id: Optional[str] = None,
  291. ) -> Optional[np.ndarray]:
  292. if request_id is None:
  293. request_id = str(uuid.uuid4())
  294. shard = self.get_current_shard(base_shard)
  295. if DEBUG >= 1: print(f"[{request_id}] process_example: {example.shape=}")
  296. try:
  297. target = target.astype(int)
  298. if train:
  299. if shard.is_last_layer():
  300. self.outstanding_requests[request_id] = "training"
  301. loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
  302. else:
  303. self.outstanding_requests[request_id] = "preprocessing"
  304. step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
  305. self.outstanding_requests[request_id] = "waiting"
  306. loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
  307. self.outstanding_requests[request_id] = "training"
  308. partial_loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
  309. self.outstanding_requests.pop(request_id)
  310. if shard.is_first_layer():
  311. return loss
  312. else:
  313. return loss, grad
  314. else:
  315. if shard.is_last_layer():
  316. self.outstanding_requests[request_id] = "evaluating"
  317. loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
  318. else:
  319. self.outstanding_requests[request_id] = "preprocessing"
  320. step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
  321. self.outstanding_requests[request_id] = "waiting"
  322. loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
  323. self.outstanding_requests.pop(request_id)
  324. return loss
  325. except Exception as e:
  326. self.outstanding_requests.pop(request_id)
  327. print(f"Error processing example for shard {shard}: {e}")
  328. traceback.print_exc()
  329. return None
  330. async def process_tensor(
  331. self,
  332. base_shard: Shard,
  333. tensor: np.ndarray,
  334. request_id: Optional[str] = None,
  335. inference_state: Optional[dict] = None,
  336. ) -> Optional[np.ndarray]:
  337. shard = self.get_current_shard(base_shard)
  338. asyncio.create_task(
  339. self.broadcast_opaque_status(
  340. request_id,
  341. json.dumps({
  342. "type": "node_status",
  343. "node_id": self.id,
  344. "status": "start_process_tensor",
  345. "base_shard": base_shard.to_dict(),
  346. "shard": shard.to_dict(),
  347. "tensor_size": tensor.size,
  348. "tensor_shape": tensor.shape,
  349. "request_id": request_id,
  350. }),
  351. )
  352. )
  353. start_time = time.perf_counter_ns()
  354. resp = await self._process_tensor(shard, tensor, request_id, inference_state)
  355. end_time = time.perf_counter_ns()
  356. elapsed_time_ns = end_time - start_time
  357. asyncio.create_task(
  358. self.broadcast_opaque_status(
  359. request_id,
  360. json.dumps({
  361. "type": "node_status",
  362. "node_id": self.id,
  363. "status": "end_process_tensor",
  364. "base_shard": base_shard.to_dict(),
  365. "shard": shard.to_dict(),
  366. "request_id": request_id,
  367. "elapsed_time_ns": elapsed_time_ns,
  368. "result_size": resp.size if resp is not None else 0,
  369. }),
  370. )
  371. )
  372. return resp
  373. async def _process_tensor(
  374. self,
  375. base_shard: Shard,
  376. tensor: np.ndarray,
  377. request_id: Optional[str] = None,
  378. inference_state: Optional[dict] = None,
  379. ) -> Optional[np.ndarray]:
  380. if request_id is None:
  381. request_id = str(uuid.uuid4())
  382. shard = self.get_current_shard(base_shard)
  383. if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
  384. try:
  385. self.outstanding_requests[request_id] = "processing"
  386. result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
  387. ret = await self.process_inference_result(shard, result, request_id, inference_state)
  388. return ret
  389. except Exception as e:
  390. self.outstanding_requests.pop(request_id)
  391. print(f"Error processing tensor for shard {shard}: {e}")
  392. traceback.print_exc()
  393. return None
  394. async def forward_example(
  395. self,
  396. base_shard: Shard,
  397. step: np.ndarray,
  398. target: np.ndarray,
  399. length: np.ndarray,
  400. train: bool,
  401. request_id: str,
  402. target_index: int,
  403. ) -> None:
  404. if DEBUG >= 1: print(f"target partition index: {target_index}")
  405. target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
  406. target_shard = self.get_current_shard(base_shard, target_index)
  407. if DEBUG >= 2: print(f"computed target from: {base_shard} {target_index}, {self.topology}. target shard: {target_shard}")
  408. target_peer = next((p for p in self.peers if p.id() == target_id), None)
  409. if not target_peer:
  410. raise ValueError(f"peer for {target_index} not found")
  411. if DEBUG >= 1: print(f"sending example to {target_peer.id()}: {step} => {target} ({length})")
  412. resp = await target_peer.send_example(target_shard, step, target, length, request_id=request_id, train=train)
  413. return resp
  414. async def forward_prompt(
  415. self,
  416. base_shard: Shard,
  417. prompt: str,
  418. request_id: str,
  419. target_index: int,
  420. inference_state: Optional[dict] = None,
  421. ) -> None:
  422. if DEBUG >= 1: print(f"target partition index: {target_index}")
  423. target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
  424. next_shard = self.get_current_shard(base_shard, target_index)
  425. if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
  426. if target_id == self.id:
  427. await self.process_prompt(next_shard, prompt, request_id, inference_state)
  428. else:
  429. target_peer = next((p for p in self.peers if p.id() == target_id), None)
  430. if not target_peer:
  431. raise ValueError(f"Peer for {target_index} not found")
  432. if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
  433. await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
  434. async def forward_tensor(
  435. self,
  436. base_shard: Shard,
  437. tensor: np.ndarray,
  438. request_id: str,
  439. target_index: int,
  440. inference_state: Optional[dict] = None,
  441. ) -> None:
  442. if DEBUG >= 1: print(f"target partition index: {target_index}")
  443. target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
  444. next_shard = self.get_current_shard(base_shard, target_index)
  445. if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
  446. if target_id == self.id:
  447. await self.process_tensor(next_shard, tensor, request_id, inference_state)
  448. else:
  449. target_peer = next((p for p in self.peers if p.id() == target_id), None)
  450. if not target_peer:
  451. raise ValueError(f"Peer for {target_index} not found")
  452. if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
  453. await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
  454. def get_partition_index(self, offset: int = 0):
  455. if not self.partitioning_strategy:
  456. if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
  457. return None
  458. partitions = self.partitioning_strategy.partition(self.topology)
  459. current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
  460. if current_partition_index is None:
  461. raise ValueError(f"No current partition found for node: {self.id}")
  462. return (current_partition_index + offset) % len(partitions)
  463. def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
  464. if index is None:
  465. index = self.get_partition_index()
  466. partitions = self.partitioning_strategy.partition(self.topology)
  467. shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
  468. return shards[index]
  469. async def update_peers(self, wait_for_peers: int = 0) -> bool:
  470. next_peers = await self.discovery.discover_peers(wait_for_peers)
  471. current_peer_ids = {peer.id() for peer in self.peers}
  472. next_peer_ids = {peer.id() for peer in next_peers}
  473. peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
  474. peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
  475. peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())]
  476. peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())]
  477. peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
  478. peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
  479. def _pretty(peers: List[PeerHandle]) -> List[str]:
  480. return [f"{peer.id()}@{peer.addr()}" for peer in peers]
  481. if DEBUG >= 2:
  482. print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
  483. async def disconnect_with_timeout(peer, timeout=5):
  484. try:
  485. await asyncio.wait_for(peer.disconnect(), timeout)
  486. return True
  487. except Exception as e:
  488. print(f"Error disconnecting peer {peer.id()}@{peer.addr()}: {e}")
  489. traceback.print_exc()
  490. return False
  491. async def connect_with_timeout(peer, timeout=5):
  492. try:
  493. await asyncio.wait_for(peer.connect(), timeout)
  494. return True
  495. except Exception as e:
  496. print(f"Error connecting peer {peer.id()}@{peer.addr()}: {e}")
  497. traceback.print_exc()
  498. return False
  499. disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True)
  500. connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True)
  501. successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
  502. failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
  503. successful_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is True]
  504. failed_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is False]
  505. if DEBUG >= 1:
  506. if successful_disconnects: print(f"Successfully disconnected peers: {_pretty(successful_disconnects)}")
  507. if failed_disconnects: print(f"Failed to disconnect peers: {_pretty(failed_disconnects)}")
  508. if successful_connects: print(f"Successfully connected peers: {_pretty(successful_connects)}")
  509. if failed_connects: print(f"Failed to connect peers: {_pretty(failed_connects)}")
  510. self.peers = next_peers
  511. return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
  512. async def select_best_inference_engine(self):
  513. if self.inference_engine.__class__.__name__ == 'DummyInferenceEngine': return
  514. supported_engines = self.get_supported_inference_engines()
  515. await self.broadcast_supported_engines(supported_engines)
  516. if len(self.get_topology_inference_engines()):
  517. self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader)
  518. async def periodic_topology_collection(self, interval: int):
  519. while True:
  520. await asyncio.sleep(interval)
  521. try:
  522. did_peers_change = await self.update_peers()
  523. if DEBUG >= 2: print(f"{did_peers_change=}")
  524. if did_peers_change:
  525. await self.collect_topology(set())
  526. await self.select_best_inference_engine()
  527. except Exception as e:
  528. print(f"Error collecting topology: {e}")
  529. traceback.print_exc()
  530. async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
  531. if request_id not in self.buffered_token_output:
  532. return None, False
  533. return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
  534. async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology:
  535. next_topology = Topology()
  536. next_topology.update_node(self.id, self.device_capabilities)
  537. if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
  538. prev_visited = visited.copy()
  539. visited.add(self.id)
  540. visited.update(p.id() for p in self.peers)
  541. for peer in self.peers:
  542. next_topology.update_node(peer.id(), peer.device_capabilities())
  543. next_topology.add_edge(self.id, peer.id(), peer.description())
  544. if peer.id() in prev_visited:
  545. continue
  546. if max_depth <= 0:
  547. if DEBUG >= 2: print("Max depth reached. Skipping...")
  548. continue
  549. try:
  550. other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
  551. if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
  552. next_topology.merge(peer.id(), other_topology)
  553. except Exception as e:
  554. print(f"Error collecting topology from {peer.id()}: {e}")
  555. traceback.print_exc()
  556. next_topology.active_node_id = self.topology.active_node_id
  557. self.topology = next_topology
  558. if self.topology_viz:
  559. self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id)
  560. return self.topology
  561. @property
  562. def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
  563. return self._on_token
  564. @property
  565. def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
  566. return self._on_opaque_status
  567. def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
  568. if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
  569. self.on_token.trigger_all(request_id, tokens, is_finished)
  570. async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
  571. async def send_result_to_peer(peer):
  572. try:
  573. await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
  574. except asyncio.TimeoutError:
  575. print(f"Timeout broadcasting result to {peer.id()}")
  576. except Exception as e:
  577. print(f"Error broadcasting result to {peer.id()}: {e}")
  578. traceback.print_exc()
  579. await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
  580. async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
  581. if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}")
  582. async def send_status_to_peer(peer):
  583. try:
  584. await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
  585. except asyncio.TimeoutError:
  586. print(f"Timeout sending opaque status to {peer.id()}")
  587. except Exception as e:
  588. print(f"Error sending opaque status to {peer.id()}: {e}")
  589. traceback.print_exc()
  590. await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
  591. # in the case of opaque status, we also want to receive our own opaque statuses
  592. self.on_opaque_status.trigger_all(request_id, status)
  593. @property
  594. def current_topology(self) -> Topology:
  595. return self.topology
  596. def handle_stable_diffusion(self, inference_state, result):
  597. if inference_state['is_step_finished']:
  598. inference_state['step']+=1
  599. progress = [inference_state['step'],inference_state['total_steps']]
  600. intermediate_result = result
  601. if progress[0] == progress[1]:
  602. intermediate_result = result
  603. return intermediate_result, inference_state