Browse Source

moving models

josh 8 months ago
parent
commit
867f348e71
2 changed files with 20 additions and 4 deletions
  1. 12 2
      exo/download/hf/hf_helpers.py
  2. 8 2
      scripts/build_exo.py

+ 12 - 2
exo/download/hf/hf_helpers.py

@@ -3,6 +3,7 @@ import aiohttp
 import json
 import os
 import sys
+import shutil
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
@@ -101,10 +102,19 @@ def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   sanitized_repo_id = str(repo_id).replace("/", "--")
   if is_frozen():
-    repo_root = Path(sys.argv[0]).parent/f"models--{sanitized_repo_id}"
-    return repo_root
+    exec_root = Path(sys.argv[0]).parent
+    asyncio.run(move_models_to_hf)
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
+async def move_models_to_hf():
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(sys.argv[0]).parent
+  dest_dir = get_hf_home()/"hub"
+  await aios.makedirs(dest_dir, exist_ok=True)
+  for path in source_dir.iterdir():
+    if path.is_dir() and path.startswith("models--"):
+      dest_path = dest_dir / path.name
+      shutil.move(str(path), str(dest_path))
 
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"

+ 8 - 2
scripts/build_exo.py

@@ -2,6 +2,8 @@ import site
 import subprocess
 import sys
 import os 
+import pkgutil
+
 def run():
     site_packages = site.getsitepackages()[0]
     command = [
@@ -34,6 +36,11 @@ def run():
             "--include-distribution-meta=pygments",
             "--nofollow-import-to=tinygrad"
         ])
+        inference_modules = [
+            name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
+        ]
+        for module in inference_modules:
+            command.append(f"--include-module=exo.inference.mlx.models.{module}")
     elif sys.platform == "win32":  
         command.extend([
             "--windows-icon-from-ico=docs/exo-logo-win.ico",
@@ -45,9 +52,8 @@ def run():
             "--include-distribution-metadata=pygments",
             "--linux-icon=docs/exo-rounded.png"
         ])
-
     try:
-        subprocess.run(command, check=True)
+        # subprocess.run(command, check=True)
         print("Build completed!")
     except subprocess.CalledProcessError as e:
         print(f"An error occurred: {e}")