|
@@ -1,11 +1,12 @@
|
|
import asyncio
|
|
import asyncio
|
|
|
|
+import traceback
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
from typing import Dict, List, Tuple
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.shard import Shard
|
|
from exo.download.shard_download import ShardDownloader
|
|
from exo.download.shard_download import ShardDownloader
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_repo_root, get_weight_map, extract_layer_num
|
|
from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_repo_root, get_weight_map, extract_layer_num
|
|
-from exo.helpers import AsyncCallbackSystem
|
|
|
|
|
|
+from exo.helpers import AsyncCallbackSystem, DEBUG
|
|
|
|
|
|
class HFShardDownloader(ShardDownloader):
|
|
class HFShardDownloader(ShardDownloader):
|
|
def __init__(self):
|
|
def __init__(self):
|
|
@@ -13,25 +14,29 @@ class HFShardDownloader(ShardDownloader):
|
|
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
|
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard) -> Path:
|
|
async def ensure_shard(self, shard: Shard) -> Path:
|
|
- # Cancel any overlapping downloads
|
|
|
|
- to_remove = []
|
|
|
|
|
|
+ # If a download on this shard is already in progress, keep that one
|
|
for active_shard, task in self.active_downloads:
|
|
for active_shard, task in self.active_downloads:
|
|
- if shard.overlaps(active_shard):
|
|
|
|
- task.cancel()
|
|
|
|
- try:
|
|
|
|
- await task
|
|
|
|
- except asyncio.CancelledError:
|
|
|
|
- pass # This is expected when cancelling a task
|
|
|
|
- to_remove.append((active_shard, task))
|
|
|
|
|
|
+ if active_shard == shard:
|
|
|
|
+ return await task
|
|
|
|
|
|
- # Remove cancelled downloads from the list
|
|
|
|
- for item in to_remove:
|
|
|
|
- self.active_downloads.remove(item)
|
|
|
|
|
|
+ # Cancel any downloads for this model_id on a different shard
|
|
|
|
+ to_remove = [(active_shard, task) for active_shard, task in self.active_downloads if active_shard.model_id == shard.model_id]
|
|
|
|
+ for active_shard, task in to_remove:
|
|
|
|
+ if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
|
|
|
|
+ task.cancel()
|
|
|
|
+ try:
|
|
|
|
+ await task
|
|
|
|
+ except asyncio.CancelledError:
|
|
|
|
+ pass # This is expected when cancelling a task
|
|
|
|
+ except Exception as e:
|
|
|
|
+ if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
|
|
|
|
+ traceback.print_exc()
|
|
|
|
+ if DEBUG >= 2: print(f"Removing cancelled downloads: {to_remove}")
|
|
|
|
+ self.active_downloads = [(active_shard, task) for active_shard, task in self.active_downloads if active_shard.model_id != shard.model_id]
|
|
|
|
|
|
# Start new download
|
|
# Start new download
|
|
download_task = asyncio.create_task(self._download_shard(shard))
|
|
download_task = asyncio.create_task(self._download_shard(shard))
|
|
self.active_downloads.append((shard, download_task))
|
|
self.active_downloads.append((shard, download_task))
|
|
-
|
|
|
|
try:
|
|
try:
|
|
return await download_task
|
|
return await download_task
|
|
finally:
|
|
finally:
|