浏览代码

preemptively start downloads when any node starts processing a prompt. this fixes #104

Alex Cheema 9 月之前
父节点
当前提交
f29963f41e
共有 3 个文件被更改,包括 36 次插入14 次删除
  1. 19 14
      exo/download/hf/hf_shard_download.py
  2. 3 0
      exo/inference/shard.py
  3. 14 0
      main.py

+ 19 - 14
exo/download/hf/hf_shard_download.py

@@ -1,11 +1,12 @@
 import asyncio
 import asyncio
+import traceback
 from pathlib import Path
 from pathlib import Path
 from typing import Dict, List, Tuple
 from typing import Dict, List, Tuple
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_repo_root, get_weight_map, extract_layer_num
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_repo_root, get_weight_map, extract_layer_num
-from exo.helpers import AsyncCallbackSystem
+from exo.helpers import AsyncCallbackSystem, DEBUG
 
 
 class HFShardDownloader(ShardDownloader):
 class HFShardDownloader(ShardDownloader):
     def __init__(self):
     def __init__(self):
@@ -13,25 +14,29 @@ class HFShardDownloader(ShardDownloader):
         self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
         self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
 
     async def ensure_shard(self, shard: Shard) -> Path:
     async def ensure_shard(self, shard: Shard) -> Path:
-        # Cancel any overlapping downloads
-        to_remove = []
+        # If a download on this shard is already in progress, keep that one
         for active_shard, task in self.active_downloads:
         for active_shard, task in self.active_downloads:
-            if shard.overlaps(active_shard):
-                task.cancel()
-                try:
-                    await task
-                except asyncio.CancelledError:
-                    pass  # This is expected when cancelling a task
-                to_remove.append((active_shard, task))
+            if active_shard == shard:
+                return await task
 
 
-        # Remove cancelled downloads from the list
-        for item in to_remove:
-            self.active_downloads.remove(item)
+        # Cancel any downloads for this model_id on a different shard
+        to_remove = [(active_shard, task) for active_shard, task in self.active_downloads if active_shard.model_id == shard.model_id]
+        for active_shard, task in to_remove:
+            if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
+            task.cancel()
+            try:
+                await task
+            except asyncio.CancelledError:
+                pass  # This is expected when cancelling a task
+            except Exception as e:
+                if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
+                traceback.print_exc()
+        if DEBUG >= 2: print(f"Removing cancelled downloads: {to_remove}")
+        self.active_downloads = [(active_shard, task) for active_shard, task in self.active_downloads if active_shard.model_id != shard.model_id]
 
 
         # Start new download
         # Start new download
         download_task = asyncio.create_task(self._download_shard(shard))
         download_task = asyncio.create_task(self._download_shard(shard))
         self.active_downloads.append((shard, download_task))
         self.active_downloads.append((shard, download_task))
-
         try:
         try:
             return await download_task
             return await download_task
         finally:
         finally:

+ 3 - 0
exo/inference/shard.py

@@ -25,6 +25,9 @@ class Shard:
       "n_layers": self.n_layers,
       "n_layers": self.n_layers,
     }
     }
 
 
+  def from_dict(data: dict) -> 'Shard':
+    return Shard(**data)
+
   def overlaps(self, other: 'Shard') -> bool:
   def overlaps(self, other: 'Shard') -> bool:
     return shards_overlap(self, other)
     return shards_overlap(self, other)
 
 

+ 14 - 0
main.py

@@ -3,6 +3,7 @@ import asyncio
 import signal
 import signal
 import json
 import json
 import time
 import time
+import traceback
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
@@ -11,6 +12,7 @@ from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
 from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
+from exo.inference.shard import Shard
 
 
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -60,6 +62,18 @@ server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
 node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
 node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+def preemptively_start_download(request_id: str, opaque_status: str):
+    try:
+        status = json.loads(opaque_status)
+        if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
+            current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
+            if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
+            asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+    except Exception as e:
+        if DEBUG >= 2:
+            print(f"Failed to preemptively start download: {e}")
+            traceback.print_exc()
+node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 if args.prometheus_client_port:
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
     start_metrics_server(node, args.prometheus_client_port)