Browse Source

moving models

josh 8 months ago
parent
commit
867f348e71
2 changed files with 20 additions and 4 deletions
  1. 12 2
      exo/download/hf/hf_helpers.py
  2. 8 2
      scripts/build_exo.py

+ 12 - 2
exo/download/hf/hf_helpers.py

@@ -3,6 +3,7 @@ import aiohttp
 import json
 import json
 import os
 import os
 import sys
 import sys
+import shutil
 from urllib.parse import urljoin
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
@@ -101,10 +102,19 @@ def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   sanitized_repo_id = str(repo_id).replace("/", "--")
   sanitized_repo_id = str(repo_id).replace("/", "--")
   if is_frozen():
   if is_frozen():
-    repo_root = Path(sys.argv[0]).parent/f"models--{sanitized_repo_id}"
-    return repo_root
+    exec_root = Path(sys.argv[0]).parent
+    asyncio.run(move_models_to_hf)
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
 
+async def move_models_to_hf():
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(sys.argv[0]).parent
+  dest_dir = get_hf_home()/"hub"
+  await aios.makedirs(dest_dir, exist_ok=True)
+  for path in source_dir.iterdir():
+    if path.is_dir() and path.startswith("models--"):
+      dest_path = dest_dir / path.name
+      shutil.move(str(path), str(dest_path))
 
 
 async def fetch_file_list(session, repo_id, revision, path=""):
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"

+ 8 - 2
scripts/build_exo.py

@@ -2,6 +2,8 @@ import site
 import subprocess
 import subprocess
 import sys
 import sys
 import os 
 import os 
+import pkgutil
+
 def run():
 def run():
     site_packages = site.getsitepackages()[0]
     site_packages = site.getsitepackages()[0]
     command = [
     command = [
@@ -34,6 +36,11 @@ def run():
             "--include-distribution-meta=pygments",
             "--include-distribution-meta=pygments",
             "--nofollow-import-to=tinygrad"
             "--nofollow-import-to=tinygrad"
         ])
         ])
+        inference_modules = [
+            name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
+        ]
+        for module in inference_modules:
+            command.append(f"--include-module=exo.inference.mlx.models.{module}")
     elif sys.platform == "win32":  
     elif sys.platform == "win32":  
         command.extend([
         command.extend([
             "--windows-icon-from-ico=docs/exo-logo-win.ico",
             "--windows-icon-from-ico=docs/exo-logo-win.ico",
@@ -45,9 +52,8 @@ def run():
             "--include-distribution-metadata=pygments",
             "--include-distribution-metadata=pygments",
             "--linux-icon=docs/exo-rounded.png"
             "--linux-icon=docs/exo-rounded.png"
         ])
         ])
-
     try:
     try:
-        subprocess.run(command, check=True)
+        # subprocess.run(command, check=True)
         print("Build completed!")
         print("Build completed!")
     except subprocess.CalledProcessError as e:
     except subprocess.CalledProcessError as e:
         print(f"An error occurred: {e}")
         print(f"An error occurred: {e}")