|
@@ -10,6 +10,7 @@ from ..peer_handle import PeerHandle
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.topology.topology import Topology
|
|
|
from exo.topology.device_capabilities import DeviceCapabilities
|
|
|
+from exo.helpers import DEBUG
|
|
|
|
|
|
|
|
|
class GRPCPeerHandle(PeerHandle):
|
|
@@ -30,8 +31,9 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
return self._device_capabilities
|
|
|
|
|
|
async def connect(self):
|
|
|
- self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
|
|
|
- self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
|
|
+ if self.channel is None:
|
|
|
+ self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
|
|
|
+ self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
|
|
await self.channel.channel_ready()
|
|
|
|
|
|
async def is_connected(self) -> bool:
|
|
@@ -43,14 +45,23 @@ class GRPCPeerHandle(PeerHandle):
|
|
|
self.channel = None
|
|
|
self.stub = None
|
|
|
|
|
|
+ async def _ensure_connected(self):
|
|
|
+ if not await self.is_connected():
|
|
|
+ await self.connect()
|
|
|
+
|
|
|
async def health_check(self) -> bool:
|
|
|
try:
|
|
|
+ await self._ensure_connected()
|
|
|
request = node_service_pb2.HealthCheckRequest()
|
|
|
response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
|
|
|
return response.is_healthy
|
|
|
except asyncio.TimeoutError:
|
|
|
return False
|
|
|
except:
|
|
|
+ if DEBUG >= 4:
|
|
|
+ print(f"Health check failed for {self._id}@{self.address}.")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
return False
|
|
|
|
|
|
async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
|