Переглянути джерело

preemptively start downloads when any node starts processing a prompt. this fixes #104

Alex Cheema 9 місяців тому
батько
коміт
f29963f41e
3 змінених файлів з 36 додано та 14 видалено
  1. 19 14
      exo/download/hf/hf_shard_download.py
  2. 3 0
      exo/inference/shard.py
  3. 14 0
      main.py

+ 19 - 14
exo/download/hf/hf_shard_download.py

@@ -1,11 +1,12 @@
 import asyncio
+import traceback
 from pathlib import Path
 from typing import Dict, List, Tuple
 from exo.inference.shard import Shard
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_repo_root, get_weight_map, extract_layer_num
-from exo.helpers import AsyncCallbackSystem
+from exo.helpers import AsyncCallbackSystem, DEBUG
 
 class HFShardDownloader(ShardDownloader):
     def __init__(self):
@@ -13,25 +14,29 @@ class HFShardDownloader(ShardDownloader):
         self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
     async def ensure_shard(self, shard: Shard) -> Path:
-        # Cancel any overlapping downloads
-        to_remove = []
+        # If a download on this shard is already in progress, keep that one
         for active_shard, task in self.active_downloads:
-            if shard.overlaps(active_shard):
-                task.cancel()
-                try:
-                    await task
-                except asyncio.CancelledError:
-                    pass  # This is expected when cancelling a task
-                to_remove.append((active_shard, task))
+            if active_shard == shard:
+                return await task
 
-        # Remove cancelled downloads from the list
-        for item in to_remove:
-            self.active_downloads.remove(item)
+        # Cancel any downloads for this model_id on a different shard
+        to_remove = [(active_shard, task) for active_shard, task in self.active_downloads if active_shard.model_id == shard.model_id]
+        for active_shard, task in to_remove:
+            if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
+            task.cancel()
+            try:
+                await task
+            except asyncio.CancelledError:
+                pass  # This is expected when cancelling a task
+            except Exception as e:
+                if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
+                traceback.print_exc()
+        if DEBUG >= 2: print(f"Removing cancelled downloads: {to_remove}")
+        self.active_downloads = [(active_shard, task) for active_shard, task in self.active_downloads if active_shard.model_id != shard.model_id]
 
         # Start new download
         download_task = asyncio.create_task(self._download_shard(shard))
         self.active_downloads.append((shard, download_task))
-
         try:
             return await download_task
         finally:

+ 3 - 0
exo/inference/shard.py

@@ -25,6 +25,9 @@ class Shard:
       "n_layers": self.n_layers,
     }
 
+  def from_dict(data: dict) -> 'Shard':
+    return Shard(**data)
+
   def overlaps(self, other: 'Shard') -> bool:
     return shards_overlap(self, other)
 

+ 14 - 0
main.py

@@ -3,6 +3,7 @@ import asyncio
 import signal
 import json
 import time
+import traceback
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
@@ -11,6 +12,7 @@ from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
+from exo.inference.shard import Shard
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -60,6 +62,18 @@ server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
 node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+def preemptively_start_download(request_id: str, opaque_status: str):
+    try:
+        status = json.loads(opaque_status)
+        if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
+            current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
+            if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
+            asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+    except Exception as e:
+        if DEBUG >= 2:
+            print(f"Failed to preemptively start download: {e}")
+            traceback.print_exc()
+node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)