peer_handle.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from abc import ABC, abstractmethod
  2. from typing import Optional, Tuple
  3. import numpy as np
  4. from inference.shard import Shard
  5. from topology.device_capabilities import DeviceCapabilities
  6. from topology.topology import Topology
  7. class PeerHandle(ABC):
  8. @abstractmethod
  9. def id(self) -> str:
  10. pass
  11. @abstractmethod
  12. def device_capabilities(self) -> DeviceCapabilities:
  13. pass
  14. @abstractmethod
  15. async def connect(self) -> None:
  16. pass
  17. @abstractmethod
  18. async def is_connected(self) -> bool:
  19. pass
  20. @abstractmethod
  21. async def disconnect(self) -> None:
  22. pass
  23. @abstractmethod
  24. async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
  25. pass
  26. @abstractmethod
  27. async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
  28. pass
  29. @abstractmethod
  30. async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
  31. pass
  32. @abstractmethod
  33. async def reset_shard(self, shard: Shard) -> None:
  34. pass
  35. async def collect_topology(self, max_depth: int) -> Topology:
  36. pass