standard_node.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from typing import List, Optional
  2. import numpy as np
  3. from networking import Discovery, PeerHandle, Server
  4. from inference.inference_engine import InferenceEngine, Shard
  5. from .node import Node
  6. from topology.topology import Topology
  7. class StandardNode(Node):
  8. def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery):
  9. self.id = id
  10. self.inference_engine = inference_engine
  11. self.server = server
  12. self.discovery = discovery
  13. self.peers: List[PeerHandle] = {}
  14. self.topology: Topology = Topology()
  15. self.successor: Optional[PeerHandle] = None
  16. async def start(self, wait_for_peers: int = 0) -> None:
  17. await self.server.start()
  18. await self.discovery.start()
  19. self.peers = await self.discovery.discover_peers(wait_for_peers)
  20. print(f"Starting with the following peers: {self.peers}")
  21. print("Connecting to peers...")
  22. for peer in self.peers:
  23. await peer.connect()
  24. print(f"Connected to {peer.id()}")
  25. async def stop(self) -> None:
  26. await self.discovery.stop()
  27. await self.server.stop()
  28. async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
  29. print("Process prompt", shard, prompt)
  30. result = await self.inference_engine.infer_prompt(shard, prompt)
  31. # Implement prompt processing logic
  32. print(f"Got result from prompt: {prompt}. Result: {result}")
  33. # You might want to initiate inference here
  34. if self.successor:
  35. await self.succesor.send_tensor()
  36. return result
  37. async def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
  38. print("Process tensor", shard, tensor)
  39. result = await self.inference_engine.infer_shard(shard, tensor)
  40. # Implement prompt processing logic
  41. print(f"Got result from prompt: {len(tensor)}. Result: {result}")
  42. if target:
  43. target_peer = next((p for p in self.peers if p.id() == target), None)
  44. if not target_peer:
  45. raise ValueError(f"Peer {target} not found")
  46. await target_peer.send_tensor(result)
  47. return result
  48. async def reset_shard(self, shard: Shard) -> None:
  49. # Implement shard reset logic
  50. print(f"Resetting shard: {shard}")
  51. await self.inference_engine.reset_shard(shard)