123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- from typing import List, Optional
- import numpy as np
- from networking import Discovery, PeerHandle, Server
- from inference.inference_engine import InferenceEngine, Shard
- from .node import Node
- class StandardNode(Node):
- def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery):
- self.id = id
- self.inference_engine = inference_engine
- self.server = server
- self.discovery = discovery
- self.peers: List[PeerHandle] = {}
- self.ring_order: List[str] = []
- async def start(self, wait_for_peers: int = 0) -> None:
- await self.server.start()
- await self.discovery.start()
- self.peers = await self.discovery.discover_peers(wait_for_peers)
- print(f"Starting with the following peers: {self.peers}")
- print("Connecting to peers...")
- for peer in self.peers:
- await peer.connect()
- print(f"Connected to {peer.id()}")
- async def stop(self) -> None:
- await self.discovery.stop()
- await self.server.stop()
- async def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> Optional[np.array]:
- print("Process prompt", shard, prompt, target)
- result = await self.inference_engine.infer_prompt(shard, prompt)
- # Implement prompt processing logic
- print(f"Got result from prompt: {prompt}. Result: {result}")
- # You might want to initiate inference here
- if target:
- target_peer = next((p for p in self.peers if p.id() == target), None)
- if not target_peer:
- raise ValueError(f"Peer {target} not found")
- await target_peer.send_tensor(result)
- return result
- async def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
- print("Process tensor", shard, tensor)
- result = await self.inference_engine.infer_shard(shard, tensor)
- # Implement prompt processing logic
- print(f"Got result from prompt: {len(tensor)}. Result: {result}")
- if target:
- target_peer = next((p for p in self.peers if p.id() == target), None)
- if not target_peer:
- raise ValueError(f"Peer {target} not found")
- await target_peer.send_tensor(result)
- return result
- async def reset_shard(self, shard: Shard) -> None:
- # Implement shard reset logic
- print(f"Resetting shard: {shard}")
- await self.inference_engine.reset_shard(shard)
|