|
@@ -5,7 +5,7 @@ from typing import Dict, List, Tuple
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.download.shard_download import ShardDownloader
|
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
|
-from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_repo_root, get_weight_map, extract_layer_num
|
|
|
+from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns
|
|
|
from exo.helpers import AsyncCallbackSystem, DEBUG
|
|
|
|
|
|
class HFShardDownloader(ShardDownloader):
|
|
@@ -57,7 +57,7 @@ class HFShardDownloader(ShardDownloader):
|
|
|
self._on_progress.trigger_all(shard, event)
|
|
|
|
|
|
weight_map = await get_weight_map(shard.model_id)
|
|
|
- allow_patterns = self._get_allow_patterns(weight_map, shard.start_layer, shard.end_layer)
|
|
|
+ allow_patterns = get_allow_patterns(weight_map, shard)
|
|
|
|
|
|
return await download_repo_files(
|
|
|
repo_id=shard.model_id,
|
|
@@ -65,25 +65,6 @@ class HFShardDownloader(ShardDownloader):
|
|
|
allow_patterns=allow_patterns
|
|
|
)
|
|
|
|
|
|
- @staticmethod
|
|
|
- def _get_allow_patterns(weight_map: Dict[str, str], start_layer: int, end_layer: int) -> List[str]:
|
|
|
- default_patterns = [
|
|
|
- "*.json",
|
|
|
- "*.py",
|
|
|
- "tokenizer.model",
|
|
|
- "*.tiktoken",
|
|
|
- "*.txt",
|
|
|
- ]
|
|
|
- shard_specific_patterns = []
|
|
|
- if weight_map:
|
|
|
- for tensor_name, filename in weight_map.items():
|
|
|
- layer_num = extract_layer_num(tensor_name)
|
|
|
- if layer_num is not None and start_layer <= layer_num <= end_layer:
|
|
|
- shard_specific_patterns.append(filename)
|
|
|
- else:
|
|
|
- shard_specific_patterns = ["*.safetensors"]
|
|
|
- return list(set(default_patterns + shard_specific_patterns)) # Remove duplicates
|
|
|
-
|
|
|
@property
|
|
|
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
|
|
return self._on_progress
|