standard_node.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from typing import List, Optional, Callable
  2. import numpy as np
  3. from networking import Discovery, PeerHandle, Server
  4. from inference.inference_engine import InferenceEngine, Shard
  5. from .node import Node
  6. from topology.topology import Topology
  7. from topology.device_capabilities import device_capabilities
  8. from topology.partitioning_strategy import PartitioningStrategy
  9. from topology.partitioning_strategy import Partition
  10. class StandardNode(Node):
  11. def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
  12. self.id = id
  13. self.inference_engine = inference_engine
  14. self.server = server
  15. self.discovery = discovery
  16. self.partitioning_strategy = partitioning_strategy
  17. self.peers: List[PeerHandle] = {}
  18. self.topology: Topology = Topology()
  19. self.device_capabilities = device_capabilities()
  20. self.buffered_token_output: List[int] = []
  21. self.on_token = on_token
  22. self.max_generate_tokens = max_generate_tokens
  23. async def start(self, wait_for_peers: int = 0) -> None:
  24. await self.server.start()
  25. await self.discovery.start()
  26. self.peers = await self.discovery.discover_peers(wait_for_peers)
  27. print(f"Starting with the following peers: {self.peers}")
  28. print("Connecting to peers...")
  29. for peer in self.peers:
  30. await peer.connect()
  31. print(f"Connected to {peer.id()}")
  32. await self.collect_topology()
  33. print(f"Collected topology: {self.topology}")
  34. async def stop(self) -> None:
  35. await self.discovery.stop()
  36. await self.server.stop()
  37. async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
  38. print("process prompt", shard, prompt)
  39. result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
  40. print(f"result size: {result.size}, is finished: {is_finished}")
  41. if result.size == 1:
  42. self.buffered_token_output.append(result.item())
  43. self.on_token(self.buffered_token_output)
  44. if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
  45. await self.forward_tensor_to_next_shard(shard, result)
  46. return np.array(self.buffered_token_output) if self.buffered_token_output else None
  47. async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
  48. result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
  49. print(f"result size: {result.size}, is finished: {is_finished}")
  50. if result.size == 1:
  51. self.buffered_token_output.append(result.item())
  52. self.on_token(self.buffered_token_output)
  53. if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
  54. await self.forward_tensor_to_next_shard(shard, result)
  55. return np.array(self.buffered_token_output) if self.buffered_token_output else None
  56. async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray) -> None:
  57. if not self.partitioning_strategy:
  58. print("No partitioning strategy found. Skipping forward.")
  59. return
  60. partitions = self.partitioning_strategy.partition(self.topology)
  61. current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
  62. print(f"Current partition index: {current_partition_index}")
  63. if current_partition_index is not None:
  64. next_partition_index = (current_partition_index + 1) % len(partitions)
  65. next_partition: Partition = partitions[next_partition_index]
  66. print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
  67. if next_partition:
  68. if next_partition.node_id == self.id:
  69. await self.process_tensor(shard, tensor)
  70. return
  71. target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
  72. if not target_peer:
  73. raise ValueError(f"Peer for {next_partition} not found")
  74. start_layer = int(next_partition.start * shard.n_layers)
  75. end_layer = int(next_partition.end * shard.n_layers) - 1
  76. next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
  77. print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}")
  78. await target_peer.send_tensor(next_shard, tensor)
  79. def get_current_shard(self, shard: Shard) -> Shard:
  80. partitions = self.partitioning_strategy.partition(self.topology)
  81. current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
  82. if current_partition_index is None:
  83. raise ValueError(f"No current partition found for node: {self.id}")
  84. current_partition = partitions[current_partition_index]
  85. start_layer = int(current_partition.start * shard.n_layers)
  86. end_layer = int(current_partition.end * shard.n_layers) - 1
  87. return Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
  88. async def reset_shard(self, shard: Shard) -> None:
  89. # Implement shard reset logic
  90. print(f"Resetting shard: {shard}")
  91. self.buffered_token_output = []
  92. await self.inference_engine.reset_shard(self.get_current_shard(shard))
  93. async def collect_topology(self, max_depth: int = 4) -> Topology:
  94. self.topology.update_node(self.id, self.device_capabilities)
  95. for peer in self.peers:
  96. self.topology.update_node(peer.id(), peer.device_capabilities())
  97. self.topology.add_edge(self.id, peer.id())
  98. if max_depth > 0:
  99. other_topology = await peer.collect_topology(max_depth = max_depth - 1)
  100. print(f"Collected topology from: {peer.id()}: {other_topology}")
  101. self.topology.merge(other_topology)
  102. return self.topology