Jelajahi Sumber

download edge cases

Alex Cheema 9 bulan lalu
induk
melakukan
6bddb2a9dc

+ 10 - 9
exo/download/hf/hf_shard_download.py

@@ -10,18 +10,19 @@ from exo.helpers import AsyncCallbackSystem, DEBUG
 
 class HFShardDownloader(ShardDownloader):
     def __init__(self):
-        self.active_downloads: List[Tuple[Shard, asyncio.Task]] = []
+        self.active_downloads: Dict[Shard, asyncio.Task] = []
         self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
     async def ensure_shard(self, shard: Shard) -> Path:
         # If a download on this shard is already in progress, keep that one
-        for active_shard, task in self.active_downloads:
+        for active_shard, task in self.active_downloads.values():
             if active_shard == shard:
+                if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
                 return await task
 
         # Cancel any downloads for this model_id on a different shard
-        to_remove = [(active_shard, task) for active_shard, task in self.active_downloads if active_shard.model_id == shard.model_id]
-        for active_shard, task in to_remove:
+        existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
+        for active_shard in existing_active_shards:
             if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
             task.cancel()
             try:
@@ -31,18 +32,18 @@ class HFShardDownloader(ShardDownloader):
             except Exception as e:
                 if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
                 traceback.print_exc()
-        if DEBUG >= 2: print(f"Removing cancelled downloads: {to_remove}")
-        self.active_downloads = [(active_shard, task) for active_shard, task in self.active_downloads if active_shard.model_id != shard.model_id]
+        self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
 
         # Start new download
         download_task = asyncio.create_task(self._download_shard(shard))
-        self.active_downloads.append((shard, download_task))
+        self.active_downloads[shard] = download_task
         try:
             return await download_task
         finally:
             # Ensure the task is removed even if an exception occurs
-            if (shard, download_task) in self.active_downloads:
-                self.active_downloads.remove((shard, download_task))
+            print(f"Removing download task for {shard}: {shard in self.active_downloads}")
+            if shard in self.active_downloads:
+                self.active_downloads.pop(shard)
 
     async def _download_shard(self, shard: Shard) -> Path:
         async def wrapped_progress_callback(event: RepoProgressEvent):

+ 6 - 3
exo/inference/shard.py

@@ -1,13 +1,16 @@
-from dataclasses import dataclass
+from dataclasses import dataclass, field
 
 
-@dataclass
+@dataclass(frozen=True)
 class Shard:
   model_id: str
   start_layer: int
   end_layer: int
   n_layers: int
 
+  def __hash__(self):
+    return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers))
+
   def is_first_layer(self) -> bool:
     return self.start_layer == 0
 
@@ -35,4 +38,4 @@ def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
   return (
       shard1.model_id == shard2.model_id
       and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)
-  )
+  )

+ 6 - 6
exo/networking/grpc/grpc_server.py

@@ -48,7 +48,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     image_str = request.image_str
     request_id = request.request_id
     result = await self.node.process_prompt(shard, prompt, image_str, request_id)
-    if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
@@ -64,14 +64,14 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     inference_state = request.inference_state
 
     result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
-    if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
+    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
   async def GetInferenceResult(self, request, context):
     request_id = request.request_id
     result = await self.node.get_inference_result(request_id)
-    if DEBUG >= 2: print(f"GetInferenceResult {request_id=}: {result}")
+    if DEBUG >= 5: print(f"GetInferenceResult {request_id=}: {result}")
     tensor_data = result[0].tobytes() if result[0] is not None else None
     return (
       node_service_pb2.InferenceResult(
@@ -96,20 +96,20 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       for node_id, cap in topology.nodes.items()
     }
     peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
-    if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
+    if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
     return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 
   async def SendResult(self, request, context):
     request_id = request.request_id
     result = request.result
     is_finished = request.is_finished
-    if DEBUG >= 2: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
+    if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
     self.node.on_token.trigger_all(request_id, result, is_finished)
     return node_service_pb2.Empty()
 
   async def SendOpaqueStatus(self, request, context):
     request_id = request.request_id
     status = request.status
-    if DEBUG >= 2: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
+    if DEBUG >= 5: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
     self.node.on_opaque_status.trigger_all(request_id, status)
     return node_service_pb2.Empty()

+ 3 - 3
main.py

@@ -9,7 +9,7 @@ from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
-from exo.download.shard_download import ShardDownloader
+from exo.download.shard_download import ShardDownloader, RepoProgressEvent
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
 from exo.inference.shard import Shard
@@ -79,10 +79,10 @@ if args.prometheus_client_port:
     start_metrics_server(node, args.prometheus_client_port)
 
 last_broadcast_time = 0
-def throttled_broadcast(shard, event):
+def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
     global last_broadcast_time
     current_time = time.time()
-    if current_time - last_broadcast_time >= 0.1:
+    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
         last_broadcast_time = current_time
         asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)