|
@@ -200,6 +200,36 @@ async def download_file(
|
|
|
if DEBUG >= 2: print(f"Downloaded: {file_path}")
|
|
|
|
|
|
|
|
|
+async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
|
|
|
+ repo_root = get_repo_root(repo_id)
|
|
|
+ refs_dir = repo_root/"refs"
|
|
|
+ refs_file = refs_dir/revision
|
|
|
+
|
|
|
+ # Check if we have a cached commit hash
|
|
|
+ if await aios.path.exists(refs_file):
|
|
|
+ async with aiofiles.open(refs_file, 'r') as f:
|
|
|
+ commit_hash = (await f.read()).strip()
|
|
|
+ if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
|
|
|
+ return commit_hash
|
|
|
+
|
|
|
+ # Fetch the commit hash for the given revision
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
|
|
|
+ headers = await get_auth_headers()
|
|
|
+ async with session.get(api_url, headers=headers) as response:
|
|
|
+ if response.status != 200:
|
|
|
+ raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
|
|
|
+ revision_info = await response.json()
|
|
|
+ commit_hash = revision_info['sha']
|
|
|
+
|
|
|
+ # Cache the commit hash
|
|
|
+ await aios.makedirs(refs_dir, exist_ok=True)
|
|
|
+ async with aiofiles.open(refs_file, 'w') as f:
|
|
|
+ await f.write(commit_hash)
|
|
|
+
|
|
|
+ return commit_hash
|
|
|
+
|
|
|
+
|
|
|
async def download_repo_files(
|
|
|
repo_id: str,
|
|
|
revision: str = "main",
|
|
@@ -209,34 +239,15 @@ async def download_repo_files(
|
|
|
max_parallel_downloads: int = 4
|
|
|
) -> Path:
|
|
|
repo_root = get_repo_root(repo_id)
|
|
|
- refs_dir = repo_root/"refs"
|
|
|
snapshots_dir = repo_root/"snapshots"
|
|
|
cachedreqs_dir = repo_root/"cachedreqs"
|
|
|
|
|
|
# Ensure directories exist
|
|
|
- await aios.makedirs(refs_dir, exist_ok=True)
|
|
|
await aios.makedirs(snapshots_dir, exist_ok=True)
|
|
|
await aios.makedirs(cachedreqs_dir, exist_ok=True)
|
|
|
|
|
|
- # Check if we have a cached commit hash
|
|
|
- refs_file = refs_dir/revision
|
|
|
- if await aios.path.exists(refs_file):
|
|
|
- async with aiofiles.open(refs_file, 'r') as f:
|
|
|
- commit_hash = (await f.read()).strip()
|
|
|
- if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
|
|
|
- else:
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
- # Fetch the commit hash for the given revision
|
|
|
- api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
|
|
|
- headers = await get_auth_headers()
|
|
|
- async with session.get(api_url, headers=headers) as response:
|
|
|
- if response.status != 200:
|
|
|
- raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
|
|
|
- revision_info = await response.json()
|
|
|
- commit_hash = revision_info['sha']
|
|
|
- # Cache the commit hash
|
|
|
- async with aiofiles.open(refs_file, 'w') as f:
|
|
|
- await f.write(commit_hash)
|
|
|
+ # Resolve revision to commit hash
|
|
|
+ commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
|
|
|
|
|
|
# Set up the snapshot directory
|
|
|
snapshot_dir = snapshots_dir/commit_hash
|
|
@@ -356,7 +367,8 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
|
|
|
|
|
|
# Check if the file exists
|
|
|
repo_root = get_repo_root(repo_id)
|
|
|
- snapshot_dir = repo_root/"snapshots"
|
|
|
+ commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
|
|
|
+ snapshot_dir = repo_root/"snapshots"/commit_hash
|
|
|
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
|
|
|
|
|
|
if index_file:
|