浏览代码

Merge pull request #217 from jshield/feat/support_hf_endpoint

feat: support HF_ENDPOINT base url ENV VAR
Alex Cheema 10 月之前
父节点
当前提交
57745e4f02
共有 1 个文件被更改,包括 6 次插入3 次删除
  1. 6 3
      exo/download/hf/hf_helpers.py

+ 6 - 3
exo/download/hf/hf_helpers.py

@@ -17,6 +17,7 @@ from aiofiles import os as aios
 
 T = TypeVar("T")
 
+
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_file = refs_dir/revision
@@ -69,6 +70,8 @@ def _add_wildcard_to_directories(pattern: str) -> str:
     return pattern + "*"
   return pattern
 
+def get_hf_endpoint() -> str:
+    return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
 
 def get_hf_home() -> Path:
   """Get the Hugging Face home directory."""
@@ -99,7 +102,7 @@ def get_repo_root(repo_id: str) -> Path:
 
 
 async def fetch_file_list(session, repo_id, revision, path=""):
-  api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
+  api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   url = f"{api_url}/{path}" if path else api_url
 
   headers = await get_auth_headers()
@@ -124,7 +127,7 @@ async def fetch_file_list(session, repo_id, revision, path=""):
 async def download_file(
   session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
 ):
-  base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
+  base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
   url = urljoin(base_url, file_path)
   local_path = os.path.join(save_directory, file_path)
 
@@ -214,7 +217,7 @@ async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
 
   # Fetch the commit hash for the given revision
   async with aiohttp.ClientSession() as session:
-    api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
+    api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
     headers = await get_auth_headers()
     async with session.get(api_url, headers=headers) as response:
       if response.status != 200: