Ver código fonte

fix weight_map resolution. previously we were always defaulting to allow pattern *.safetensors

Alex Cheema 8 meses atrás
pai
commit
8f65e1e697
3 arquivos alterados com 34 adições e 22 exclusões
  1. 0 0
      exo/download/__init__.py
  2. 0 0
      exo/download/hf/__init__.py
  3. 34 22
      exo/download/hf/hf_helpers.py

+ 0 - 0
exo/download/__init__.py


+ 0 - 0
exo/download/hf/__init__.py


+ 34 - 22
exo/download/hf/hf_helpers.py

@@ -200,6 +200,36 @@ async def download_file(
     if DEBUG >= 2: print(f"Downloaded: {file_path}")
 
 
+async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
+  repo_root = get_repo_root(repo_id)
+  refs_dir = repo_root/"refs"
+  refs_file = refs_dir/revision
+
+  # Check if we have a cached commit hash
+  if await aios.path.exists(refs_file):
+    async with aiofiles.open(refs_file, 'r') as f:
+      commit_hash = (await f.read()).strip()
+      if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
+      return commit_hash
+
+  # Fetch the commit hash for the given revision
+  async with aiohttp.ClientSession() as session:
+    api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
+    headers = await get_auth_headers()
+    async with session.get(api_url, headers=headers) as response:
+      if response.status != 200:
+        raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
+      revision_info = await response.json()
+      commit_hash = revision_info['sha']
+
+  # Cache the commit hash
+  await aios.makedirs(refs_dir, exist_ok=True)
+  async with aiofiles.open(refs_file, 'w') as f:
+    await f.write(commit_hash)
+
+  return commit_hash
+
+
 async def download_repo_files(
   repo_id: str,
   revision: str = "main",
@@ -209,34 +239,15 @@ async def download_repo_files(
   max_parallel_downloads: int = 4
 ) -> Path:
   repo_root = get_repo_root(repo_id)
-  refs_dir = repo_root/"refs"
   snapshots_dir = repo_root/"snapshots"
   cachedreqs_dir = repo_root/"cachedreqs"
 
   # Ensure directories exist
-  await aios.makedirs(refs_dir, exist_ok=True)
   await aios.makedirs(snapshots_dir, exist_ok=True)
   await aios.makedirs(cachedreqs_dir, exist_ok=True)
 
-  # Check if we have a cached commit hash
-  refs_file = refs_dir/revision
-  if await aios.path.exists(refs_file):
-    async with aiofiles.open(refs_file, 'r') as f:
-      commit_hash = (await f.read()).strip()
-      if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
-  else:
-    async with aiohttp.ClientSession() as session:
-      # Fetch the commit hash for the given revision
-      api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
-      headers = await get_auth_headers()
-      async with session.get(api_url, headers=headers) as response:
-        if response.status != 200:
-          raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
-        revision_info = await response.json()
-        commit_hash = revision_info['sha']
-        # Cache the commit hash
-        async with aiofiles.open(refs_file, 'w') as f:
-          await f.write(commit_hash)
+  # Resolve revision to commit hash
+  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
 
   # Set up the snapshot directory
   snapshot_dir = snapshots_dir/commit_hash
@@ -356,7 +367,8 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
 
   # Check if the file exists
   repo_root = get_repo_root(repo_id)
-  snapshot_dir = repo_root/"snapshots"
+  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
+  snapshot_dir = repo_root/"snapshots"/commit_hash
   index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
 
   if index_file: