|
@@ -29,15 +29,16 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
self.channel = None
|
|
|
self.stub = None
|
|
|
self.channel_options = [
|
|
|
- ("grpc.max_metadata_size", 64 * 1024 * 1024),
|
|
|
+ ("grpc.max_metadata_size", 32 * 1024 * 1024),
|
|
|
("grpc.max_receive_message_length", 256 * 1024 * 1024),
|
|
|
("grpc.max_send_message_length", 256 * 1024 * 1024),
|
|
|
("grpc.max_concurrent_streams", 100),
|
|
|
("grpc.http2.min_time_between_pings_ms", 10000),
|
|
|
- ("grpc.keepalive_time_ms", 20000),
|
|
|
- ("grpc.keepalive_timeout_ms", 10000),
|
|
|
+ ("grpc.keepalive_time_ms", 10000),
|
|
|
+ ("grpc.keepalive_timeout_ms", 5000),
|
|
|
("grpc.keepalive_permit_without_calls", 1),
|
|
|
("grpc.http2.max_pings_without_data", 0),
|
|
|
+ ("grpc.http2.min_ping_interval_without_data_ms", 5000),
|
|
|
("grpc.tcp_nodelay", 1),
|
|
|
("grpc.optimization_target", "throughput"),
|
|
|
]
|
|
@@ -55,14 +56,13 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
return self._device_capabilities
|
|
|
|
|
|
async def connect(self):
|
|
|
- if self.channel is None:
|
|
|
- self.channel = grpc.aio.insecure_channel(
|
|
|
- self.address,
|
|
|
- options=self.channel_options,
|
|
|
- compression=grpc.Compression.Gzip
|
|
|
- )
|
|
|
- self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
|
|
- await self.channel.channel_ready()
|
|
|
+ self.channel = grpc.aio.insecure_channel(
|
|
|
+ self.address,
|
|
|
+ options=self.channel_options,
|
|
|
+ compression=grpc.Compression.Gzip
|
|
|
+ )
|
|
|
+ self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
|
|
+ await asyncio.wait_for(self.channel.channel_ready(), timeout=10.0)
|
|
|
|
|
|
async def is_connected(self) -> bool:
|
|
|
return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
|
|
@@ -74,7 +74,7 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
self.stub = None
|
|
|
|
|
|
async def _ensure_connected(self):
|
|
|
- if not await self.is_connected():
|
|
|
+ if not (await self.is_connected()):
|
|
|
try:
|
|
|
await asyncio.wait_for(self.connect(), timeout=10.0)
|
|
|
except asyncio.TimeoutError:
|
|
@@ -98,6 +98,7 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
return False
|
|
|
|
|
|
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
|
|
+ await self._ensure_connected()
|
|
|
request = node_service_pb2.PromptRequest(
|
|
|
prompt=prompt,
|
|
|
shard=node_service_pb2.Shard(
|
|
@@ -112,6 +113,7 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
await self.stub.SendPrompt(request)
|
|
|
|
|
|
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
|
|
+ await self._ensure_connected()
|
|
|
request = node_service_pb2.TensorRequest(
|
|
|
shard=node_service_pb2.Shard(
|
|
|
model_id=shard.model_id,
|
|
@@ -131,6 +133,7 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
|
|
|
|
|
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
|
|
|
+ await self._ensure_connected()
|
|
|
request = node_service_pb2.ExampleRequest(
|
|
|
shard=node_service_pb2.Shard(
|
|
|
model_id=shard.model_id,
|
|
@@ -153,6 +156,7 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
return loss
|
|
|
|
|
|
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
|
|
|
+ await self._ensure_connected()
|
|
|
request = node_service_pb2.TensorRequest(
|
|
|
shard=node_service_pb2.Shard(
|
|
|
model_id=shard.model_id,
|
|
@@ -171,6 +175,7 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
|
|
|
|
|
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
|
|
|
+ await self._ensure_connected()
|
|
|
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
|
|
|
response = await self.stub.CollectTopology(request)
|
|
|
topology = Topology()
|
|
@@ -185,6 +190,7 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
return topology
|
|
|
|
|
|
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
|
|
+ await self._ensure_connected()
|
|
|
tensor = None
|
|
|
if isinstance(result, np.ndarray):
|
|
|
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
|
|
@@ -193,8 +199,9 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
await self.stub.SendResult(request)
|
|
|
|
|
|
async def send_opaque_status(self, request_id: str, status: str) -> None:
|
|
|
+ await self._ensure_connected()
|
|
|
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
|
|
|
- await self.stub.SendOpaqueStatus(request)
|
|
|
+ await asyncio.wait_for(self.stub.SendOpaqueStatus(request), timeout=10.0)
|
|
|
|
|
|
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
|
|
|
proto_inference_state = node_service_pb2.InferenceState()
|