grpc_peer_handle.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import grpc
  2. import numpy as np
  3. from typing import Optional
  4. # These would be generated from the .proto file
  5. from . import node_service_pb2
  6. from . import node_service_pb2_grpc
  7. from ..peer_handle import PeerHandle
  8. from inference.shard import Shard
  9. class GRPCPeerHandle(PeerHandle):
  10. def __init__(self, id: str, address: str):
  11. self._id = id
  12. self.address = address
  13. def id(self) -> str:
  14. return self._id
  15. async def connect(self):
  16. self.channel = grpc.aio.insecure_channel(self.address)
  17. self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
  18. async def disconnect(self):
  19. await self.channel.close()
  20. async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
  21. request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
  22. response = await self.stub.SendPrompt(request)
  23. print(f"Sent prompt to {self.address}: {prompt}")
  24. if not response.tensor_data or not response.shape or not response.dtype:
  25. return None
  26. return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
  27. async def send_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.array]:
  28. request = node_service_pb2.TensorRequest(
  29. shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
  30. tensor = node_service_pb2.Tensor(
  31. tensor_data=tensor.tobytes(),
  32. shape=tensor.shape,
  33. dtype=str(tensor.dtype)
  34. ),
  35. )
  36. response = await self.stub.SendTensor(request)
  37. if not response.tensor_data or not response.shape or not response.dtype:
  38. return None
  39. return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
  40. async def reset_shard(self, shard: Shard) -> None:
  41. request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
  42. await self.stub.ResetShard(request)
  43. print(f"Reset shard {shard} on {self.address}")