Browse Source

add a test for hf get_weight_map

Alex Cheema 9 months ago
parent
commit
84187113de
1 changed files with 26 additions and 0 deletions
  1. 26 0
      test/test_hf.py

+ 26 - 0
test/test_hf.py

@@ -0,0 +1,26 @@
+import os
+import sys
+
+# Add the project root to the Python path
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.insert(0, project_root)
+
+import asyncio
+from exo.download.hf.hf_helpers import get_weight_map
+
+async def test_get_weight_map():
+  repo_ids = [
+    "mlx-community/quantized-gemma-2b",
+    "mlx-community/Meta-Llama-3.1-8B-4bit",
+    "mlx-community/Meta-Llama-3.1-70B-4bit",
+    "mlx-community/Meta-Llama-3.1-405B-4bit",
+  ]
+  for repo_id in repo_ids:
+    weight_map = await get_weight_map(repo_id)
+    assert weight_map is not None, "Weight map should not be None"
+    assert isinstance(weight_map, dict), "Weight map should be a dictionary"
+    assert len(weight_map) > 0, "Weight map should not be empty"
+    print(f"OK: {repo_id}")
+
+if __name__ == "__main__":
+  asyncio.run(test_get_weight_map())