standard_node.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from typing import List, Optional
  2. import numpy as np
  3. from networking import Discovery, PeerHandle, Server
  4. from inference.inference_engine import InferenceEngine, Shard
  5. from .node import Node
  6. class StandardNode(Node):
  7. def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery):
  8. self.id = id
  9. self.inference_engine = inference_engine
  10. self.server = server
  11. self.discovery = discovery
  12. self.peers: List[PeerHandle] = {}
  13. self.ring_order: List[str] = []
  14. async def start(self, wait_for_peers: int = 0) -> None:
  15. await self.server.start()
  16. await self.discovery.start()
  17. self.peers = await self.discovery.discover_peers(wait_for_peers)
  18. print(f"Starting with the following peers: {self.peers}")
  19. print("Connecting to peers...")
  20. for peer in self.peers:
  21. await peer.connect()
  22. print(f"Connected to {peer.id()}")
  23. async def stop(self) -> None:
  24. await self.discovery.stop()
  25. await self.server.stop()
  26. async def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> Optional[np.array]:
  27. print("Process prompt", shard, prompt, target)
  28. result = await self.inference_engine.infer_prompt(shard, prompt)
  29. # Implement prompt processing logic
  30. print(f"Got result from prompt: {prompt}. Result: {result}")
  31. # You might want to initiate inference here
  32. if target:
  33. target_peer = next((p for p in self.peers if p.id() == target), None)
  34. if not target_peer:
  35. raise ValueError(f"Peer {target} not found")
  36. await target_peer.send_tensor(result)
  37. return result
  38. async def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
  39. print("Process tensor", shard, tensor)
  40. result = await self.inference_engine.infer_shard(shard, tensor)
  41. # Implement prompt processing logic
  42. print(f"Got result from prompt: {len(tensor)}. Result: {result}")
  43. if target:
  44. target_peer = next((p for p in self.peers if p.id() == target), None)
  45. if not target_peer:
  46. raise ValueError(f"Peer {target} not found")
  47. await target_peer.send_tensor(result)
  48. return result
  49. async def reset_shard(self, shard: Shard) -> None:
  50. # Implement shard reset logic
  51. print(f"Resetting shard: {shard}")
  52. await self.inference_engine.reset_shard(shard)