|
@@ -17,6 +17,7 @@ from aiofiles import os as aios
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
+
|
|
|
async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
|
|
|
refs_dir = get_repo_root(repo_id)/"refs"
|
|
|
refs_file = refs_dir/revision
|
|
@@ -69,6 +70,8 @@ def _add_wildcard_to_directories(pattern: str) -> str:
|
|
|
return pattern + "*"
|
|
|
return pattern
|
|
|
|
|
|
+def get_hf_endpoint() -> str:
|
|
|
+ return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
|
|
|
|
|
|
def get_hf_home() -> Path:
|
|
|
"""Get the Hugging Face home directory."""
|
|
@@ -99,7 +102,7 @@ def get_repo_root(repo_id: str) -> Path:
|
|
|
|
|
|
|
|
|
async def fetch_file_list(session, repo_id, revision, path=""):
|
|
|
- api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
|
|
|
+ api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
|
|
url = f"{api_url}/{path}" if path else api_url
|
|
|
|
|
|
headers = await get_auth_headers()
|
|
@@ -124,7 +127,7 @@ async def fetch_file_list(session, repo_id, revision, path=""):
|
|
|
async def download_file(
|
|
|
session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
|
|
|
):
|
|
|
- base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
|
|
|
+ base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
|
|
url = urljoin(base_url, file_path)
|
|
|
local_path = os.path.join(save_directory, file_path)
|
|
|
|
|
@@ -214,7 +217,7 @@ async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
|
|
|
|
|
|
# Fetch the commit hash for the given revision
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
- api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
|
|
|
+ api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
|
|
|
headers = await get_auth_headers()
|
|
|
async with session.get(api_url, headers=headers) as response:
|
|
|
if response.status != 200:
|