1
0

test_hf.py 846 B

1234567891011121314151617181920212223242526
  1. import os
  2. import sys
  3. # Add the project root to the Python path
  4. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  5. sys.path.insert(0, project_root)
  6. import asyncio
  7. from exo.download.hf.hf_helpers import get_weight_map
  8. async def test_get_weight_map():
  9. repo_ids = [
  10. "mlx-community/quantized-gemma-2b",
  11. "mlx-community/Meta-Llama-3.1-8B-4bit",
  12. "mlx-community/Meta-Llama-3.1-70B-4bit",
  13. "mlx-community/Meta-Llama-3.1-405B-4bit",
  14. ]
  15. for repo_id in repo_ids:
  16. weight_map = await get_weight_map(repo_id)
  17. assert weight_map is not None, "Weight map should not be None"
  18. assert isinstance(weight_map, dict), "Weight map should be a dictionary"
  19. assert len(weight_map) > 0, "Weight map should not be empty"
  20. print(f"OK: {repo_id}")
  21. if __name__ == "__main__":
  22. asyncio.run(test_get_weight_map())