Browse Source

add DEBUG flag for controlling debug logs

Alex Cheema 1 year ago
parent
commit
a933352ac3
4 changed files with 20 additions and 15 deletions
  1. 1 0
      exo/__init__.py
  2. 3 0
      exo/helpers.py
  3. 9 8
      exo/networking/grpc/grpc_discovery.py
  4. 7 7
      exo/networking/grpc/grpc_server.py

+ 1 - 0
exo/__init__.py

@@ -0,0 +1 @@
+from tinygrad.helpers import DEBUG

+ 3 - 0
exo/helpers.py

@@ -0,0 +1,3 @@
+import os
+
+DEBUG = int(os.getenv("DEBUG", default="0"))

+ 9 - 8
exo/networking/grpc/grpc_discovery.py

@@ -7,6 +7,7 @@ from ..discovery import Discovery
 from ..peer_handle import PeerHandle
 from .grpc_peer_handle import GRPCPeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities
+from exo import DEBUG
 
 class GRPCDiscovery(Discovery):
     def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):
@@ -38,26 +39,26 @@ class GRPCDiscovery(Discovery):
             await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
 
     async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-        print("Starting peer discovery process...")
+        if DEBUG >= 2: print("Starting peer discovery process...")
 
         if wait_for_peers > 0:
             while not self.known_peers:
-                print("No peers discovered yet, retrying in 1 second...")
+                if DEBUG >= 2: print("No peers discovered yet, retrying in 1 second...")
                 await asyncio.sleep(1)  # Keep trying to find peers
-            print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
+            if DEBUG >= 2: print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
 
         grace_period = 5  # seconds
         while True:
             initial_peer_count = len(self.known_peers)
-            print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
+            if DEBUG >= 2: print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
             await asyncio.sleep(grace_period)
             if len(self.known_peers) == initial_peer_count:
                 if wait_for_peers > 0:
-                    print(f"Waiting additional {wait_for_peers} seconds for more peers.")
+                    if DEBUG >= 2: print(f"Waiting additional {wait_for_peers} seconds for more peers.")
                     await asyncio.sleep(wait_for_peers)
                     wait_for_peers = 0
                 else:
-                    print("No new peers discovered in the last grace period. Ending discovery process.")
+                    if DEBUG >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
                     break  # No new peers found in the grace period, we are done
 
         return list(self.known_peers.values())
@@ -93,7 +94,7 @@ class GRPCDiscovery(Discovery):
             try:
                 data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024)
                 message = json.loads(data.decode('utf-8'))
-                print(f"received from peer {addr}: {message}")
+                if DEBUG >= 2: print(f"received from peer {addr}: {message}")
                 if message['type'] == 'discovery' and message['node_id'] != self.node_id:
                     peer_id = message['node_id']
                     peer_host = addr[0]
@@ -114,5 +115,5 @@ class GRPCDiscovery(Discovery):
             for peer_id in peers_to_remove:
                 del self.known_peers[peer_id]
                 del self.peer_last_seen[peer_id]
-                print(f"Removed peer {peer_id} due to inactivity.")
+                if DEBUG >= 2: print(f"Removed peer {peer_id} due to inactivity.")
             await asyncio.sleep(self.broadcast_interval)

+ 7 - 7
exo/networking/grpc/grpc_server.py

@@ -4,12 +4,10 @@ import numpy as np
 
 from . import node_service_pb2
 from . import node_service_pb2_grpc
+from exo import DEBUG
 from exo.inference.shard import Shard
-
 from exo.orchestration import Node
 
-import uuid
-
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     def __init__(self, node: Node, host: str, port: int):
         self.node = node
@@ -25,19 +23,20 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         listen_addr = f'{self.host}:{self.port}'
         self.server.add_insecure_port(listen_addr)
         await self.server.start()
-        print(f"Server started, listening on {listen_addr}")
+        if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
 
     async def stop(self) -> None:
         if self.server:
             await self.server.stop(grace=5)
             await self.server.wait_for_termination()
-            print("Server stopped and all connections are closed")
+            if DEBUG >= 1: print("Server stopped and all connections are closed")
 
     async def SendPrompt(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         prompt = request.prompt
         request_id = request.request_id
         result = await self.node.process_prompt(shard, prompt, request_id)
+        if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
         tensor_data = result.tobytes() if result is not None else None
         return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
@@ -47,19 +46,20 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         request_id = request.request_id
 
         result = await self.node.process_tensor(shard, tensor, request_id)
-        print("SendTensor tensor result", result)
+        if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
         tensor_data = result.tobytes() if result is not None else None
         return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
     async def GetInferenceResult(self, request, context):
         request_id = request.request_id
         result = await self.node.get_inference_result(request_id)
+        if DEBUG >= 2: print(f"GetInferenceResult {request_id=}: {result}")
         tensor_data = result[0].tobytes() if result[0] is not None else None
         return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)), is_finished=result[1]) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
 
     async def ResetShard(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
-        print(f"Received ResetShard request: {shard}")
+        if DEBUG >= 2: print(f"Received ResetShard request: {shard}")
         await self.node.reset_shard(shard)
         return node_service_pb2.Empty()