Browse Source

adjust grpc settings, ensure connected before sending any grpc commands

Alex Cheema 3 months ago
parent
commit
4081305e60
2 changed files with 22 additions and 13 deletions
  1. 20 13
      exo/networking/grpc/grpc_peer_handle.py
  2. 2 0
      exo/networking/grpc/grpc_server.py

+ 20 - 13
exo/networking/grpc/grpc_peer_handle.py

@@ -29,15 +29,16 @@ class GRPCPeerHandle(PeerHandle):
     self.channel = None
     self.stub = None
     self.channel_options = [
-      ("grpc.max_metadata_size", 64 * 1024 * 1024),
+      ("grpc.max_metadata_size", 32 * 1024 * 1024),
       ("grpc.max_receive_message_length", 256 * 1024 * 1024),
       ("grpc.max_send_message_length", 256 * 1024 * 1024),
       ("grpc.max_concurrent_streams", 100),
       ("grpc.http2.min_time_between_pings_ms", 10000),
-      ("grpc.keepalive_time_ms", 20000),
-      ("grpc.keepalive_timeout_ms", 10000),
+      ("grpc.keepalive_time_ms", 10000),
+      ("grpc.keepalive_timeout_ms", 5000),
       ("grpc.keepalive_permit_without_calls", 1),
       ("grpc.http2.max_pings_without_data", 0),
+      ("grpc.http2.min_ping_interval_without_data_ms", 5000),
       ("grpc.tcp_nodelay", 1),
       ("grpc.optimization_target", "throughput"),
     ]
@@ -55,14 +56,13 @@ class GRPCPeerHandle(PeerHandle):
     return self._device_capabilities
 
   async def connect(self):
-    if self.channel is None:
-      self.channel = grpc.aio.insecure_channel(
-        self.address,
-        options=self.channel_options,
-        compression=grpc.Compression.Gzip
-      )
-      self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
-    await self.channel.channel_ready()
+    self.channel = grpc.aio.insecure_channel(
+      self.address,
+      options=self.channel_options,
+      compression=grpc.Compression.Gzip
+    )
+    self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
+    await asyncio.wait_for(self.channel.channel_ready(), timeout=10.0)
 
   async def is_connected(self) -> bool:
     return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
@@ -74,7 +74,7 @@ class GRPCPeerHandle(PeerHandle):
     self.stub = None
 
   async def _ensure_connected(self):
-    if not await self.is_connected():
+    if not (await self.is_connected()):
       try:
         await asyncio.wait_for(self.connect(), timeout=10.0)
       except asyncio.TimeoutError:
@@ -98,6 +98,7 @@ class GRPCPeerHandle(PeerHandle):
       return False
 
   async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
+    await self._ensure_connected()
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
       shard=node_service_pb2.Shard(
@@ -112,6 +113,7 @@ class GRPCPeerHandle(PeerHandle):
     await self.stub.SendPrompt(request)
 
   async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
+    await self._ensure_connected()
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
@@ -131,6 +133,7 @@ class GRPCPeerHandle(PeerHandle):
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
   async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
+    await self._ensure_connected()
     request = node_service_pb2.ExampleRequest(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
@@ -153,6 +156,7 @@ class GRPCPeerHandle(PeerHandle):
       return loss
 
   async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
+    await self._ensure_connected()
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
@@ -171,6 +175,7 @@ class GRPCPeerHandle(PeerHandle):
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
   async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+    await self._ensure_connected()
     request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
     response = await self.stub.CollectTopology(request)
     topology = Topology()
@@ -185,6 +190,7 @@ class GRPCPeerHandle(PeerHandle):
     return topology
 
   async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    await self._ensure_connected()
     tensor = None
     if isinstance(result, np.ndarray):
       tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
@@ -193,8 +199,9 @@ class GRPCPeerHandle(PeerHandle):
     await self.stub.SendResult(request)
 
   async def send_opaque_status(self, request_id: str, status: str) -> None:
+    await self._ensure_connected()
     request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
-    await self.stub.SendOpaqueStatus(request)
+    await asyncio.wait_for(self.stub.SendOpaqueStatus(request), timeout=10.0)
 
   def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
     proto_inference_state = node_service_pb2.InferenceState()

+ 2 - 0
exo/networking/grpc/grpc_server.py

@@ -40,6 +40,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         ("grpc.max_concurrent_streams", 100),
         ("grpc.tcp_nodelay", 1),
         ("grpc.optimization_target", "throughput"),
+        ("grpc.keepalive_permit_without_calls", 1),
+        ("grpc.http2.max_concurrent_streams", 0),  # Unlimited concurrent streams
       ],
     )
     node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)