download_hf.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import argparse
  2. import asyncio
  3. from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
  4. DEFAULT_ALLOW_PATTERNS = [
  5. "*.json",
  6. "*.py",
  7. "tokenizer.model",
  8. "*.tiktoken",
  9. "*.txt",
  10. "*.safetensors",
  11. ]
  12. # Always ignore `.git` and `.cache/huggingface` folders in commits
  13. DEFAULT_IGNORE_PATTERNS = [
  14. ".git",
  15. ".git/*",
  16. "*/.git",
  17. "**/.git/**",
  18. ".cache/huggingface",
  19. ".cache/huggingface/*",
  20. "*/.cache/huggingface",
  21. "**/.cache/huggingface/**",
  22. ]
  23. async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
  24. async def progress_callback(event: RepoProgressEvent):
  25. print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
  26. print(f"Estimated time remaining: {event.overall_eta}")
  27. print("File Progress:")
  28. for file_path, progress in event.file_progress.items():
  29. status_icon = {'not_started': '⚪', 'in_progress': '🔵', 'complete': '✅'}[progress.status]
  30. eta_str = str(progress.eta)
  31. print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
  32. f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
  33. print("\n")
  34. await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
  35. if __name__ == "__main__":
  36. parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
  37. parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
  38. parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
  39. parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
  40. parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
  41. args = parser.parse_args()
  42. asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))