浏览代码

Merge pull request #470 from josh1593/package-exo-app

package exo as installable
Alex Cheema 8 月之前
父节点
当前提交
0501efa6cd
共有 10 个文件被更改,包括 135 次插入31 次删除
  1. 1 1
      .gitignore
  2. 二进制
      docs/exo-rounded.png
  3. 1 1
      exo/__init__.py
  4. 16 6
      exo/api/chatgpt_api.py
  5. 19 3
      exo/download/hf/hf_helpers.py
  6. 20 0
      exo/helpers.py
  7. 2 1
      exo/inference/mlx/sharded_utils.py
  8. 15 19
      exo/main.py
  9. 60 0
      scripts/build_exo.py
  10. 1 0
      setup.py

+ 1 - 1
.gitignore

@@ -4,6 +4,7 @@ test_weights.npz
 .exo_used_ports
 .exo_node_id
 .idea
+.DS_Store
 
 # Byte-compiled / optimized / DLL files
 __pycache__/
@@ -15,7 +16,6 @@ __pycache__/
 
 # Distribution / packaging
 /.Python
-/build/
 /develop-eggs/
 /dist/
 /downloads/

二进制
docs/exo-rounded.png


+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 16 - 6
exo/api/chatgpt_api.py

@@ -8,15 +8,16 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 import traceback
+import os
+import sys
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict
+from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
 from typing import Callable, Optional
 
-
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
     self.role = role
@@ -26,6 +27,7 @@ class Message:
     return {"role": self.role, "content": self.content}
 
 
+
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float):
     self.model = model
@@ -143,7 +145,6 @@ class PromptSession:
     self.timestamp = timestamp
     self.prompt = prompt
 
-
 class ChatGPTAPI:
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
     self.node = node
@@ -172,13 +173,22 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
+    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
 
-    self.static_dir = Path(__file__).parent.parent/"tinychat"
-    self.app.router.add_get("/", self.handle_root)
-    self.app.router.add_static("/", self.static_dir, name="static")
+    if "__compiled__" not in globals():
+      self.static_dir = Path(__file__).parent.parent/"tinychat"
+      self.app.router.add_get("/", self.handle_root)
+      self.app.router.add_static("/", self.static_dir, name="static")
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
+  
+  async def handle_quit(self, request):
+    if DEBUG>=1: print("Received quit signal")
+    response = web.json_response({"detail": "Quit signal received"}, status=200)
+    await response.prepare(request)
+    await response.write_eof()
+    await shutdown(signal.SIGINT, asyncio.get_event_loop())
 
   async def timeout_middleware(self, app, handler):
     async def middleware(request):

+ 19 - 3
exo/download/hf/hf_helpers.py

@@ -1,7 +1,11 @@
+import aiofiles.os as aios
+from typing import Union
 import asyncio
 import aiohttp
 import json
 import os
+import sys
+import shutil
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
@@ -9,7 +13,7 @@ from fnmatch import fnmatch
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG
+from exo.helpers import DEBUG, is_frozen
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.inference.shard import Shard
 import aiofiles
@@ -17,7 +21,6 @@ from aiofiles import os as aios
 
 T = TypeVar("T")
 
-
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_file = refs_dir/revision
@@ -99,9 +102,22 @@ async def get_auth_headers():
 
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = repo_id.replace("/", "--")
+  sanitized_repo_id = str(repo_id).replace("/", "--")
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
+async def move_models_to_hf(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = get_hf_home()/"hub"
+  await aios.makedirs(dest_dir, exist_ok=True)
+  async for path in source_dir.iterdir():
+    if path.is_dir() and path.startswith("models--"):
+      dest_path = dest_dir / path.name
+      if dest_path.exists():
+        if DEBUG>=1: print(f"skipping moving {dest_path}. File already exists")
+      else:
+        await aios.rename(str(path), str(dest_path))
+        
 
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"

+ 20 - 0
exo/helpers.py

@@ -1,4 +1,5 @@
 import os
+import sys
 import asyncio
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return ["localhost"]
+
+
+async def shutdown(signal, loop):
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
+
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

+ 2 - 1
exo/inference/mlx/sharded_utils.py

@@ -21,6 +21,7 @@ from transformers import AutoProcessor
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 
 from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
 from ..shard import Shard
 
 
@@ -183,7 +184,7 @@ async def load_shard(
     processor.encode = processor.tokenizer.encode
     return model, processor
   else:
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
 
 

+ 15 - 19
exo/main.py

@@ -3,6 +3,9 @@ import asyncio
 import signal
 import json
 import logging
+import platform
+import os
+import sys
 import time
 import traceback
 import uuid
@@ -17,14 +20,14 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
-from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home
+from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -34,6 +37,7 @@ parser.add_argument("--default-model", type=str, default=None, help="Default mod
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
+parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
@@ -129,7 +133,6 @@ node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
 
-
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
     status = json.loads(opaque_status)
@@ -162,20 +165,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
-
-async def shutdown(signal, loop):
-  """Gracefully shutdown the server and close the asyncio loop."""
-  print(f"Received exit signal {signal.name}...")
-  print("Thank you for using exo.")
-  print_yellow_exo()
-  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-  [task.cancel() for task in server_tasks]
-  print(f"Cancelling {len(server_tasks)} outstanding tasks")
-  await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-  loop.stop()
-
-
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
   inference_class = inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
@@ -219,13 +208,20 @@ async def main():
           {"❌ No read access" if not has_read else ""}
           {"❌ No write access" if not has_write else ""}
           """)
+    
+  if not args.models_seed_dir is None:
+    try:
+      await move_models_to_hf(args.models_seed_dir)
+    except Exception as e:
+      print(f"Error moving models to .cache/huggingface: {e}")
 
   # Use a more direct approach to handle signals
   def handle_exit():
     asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
 
-  for s in [signal.SIGINT, signal.SIGTERM]:
-    loop.add_signal_handler(s, handle_exit)
+  if platform.system() != "Windows":
+    for s in [signal.SIGINT, signal.SIGTERM]:
+      loop.add_signal_handler(s, handle_exit)
 
   await node.start(wait_for_peers=args.wait_for_peers)
 

+ 60 - 0
scripts/build_exo.py

@@ -0,0 +1,60 @@
+import site
+import subprocess
+import sys
+import os 
+import pkgutil
+
+def run():
+    site_packages = site.getsitepackages()[0]
+    command = [
+        f"{sys.executable}", "-m", "nuitka", "exo/main.py",
+        "--company-name=exolabs",
+        "--product-name=exo",
+        "--output-dir=dist",
+        "--follow-imports",
+        "--standalone",
+        "--output-filename=exo",
+        "--onefile",
+        "--python-flag=no_site"
+    ]
+
+    if sys.platform == "darwin": 
+        command.extend([
+            "--macos-app-name=exo",
+            "--macos-app-mode=gui",
+            "--macos-app-version=0.0.1",
+            "-macos-signed-app-name=com.exolabs.exo",
+            "--macos-sign-identity=auto",
+            "--macos-sign-notarization",
+            "--include-distribution-meta=mlx",
+            "--include-module=mlx._reprlib_fix",
+            "--include-module=mlx._os_warning",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
+            "--include-distribution-meta=pygments",
+            "--nofollow-import-to=tinygrad"
+        ])
+        inference_modules = [
+            name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
+        ]
+        for module in inference_modules:
+            command.append(f"--include-module=exo.inference.mlx.models.{module}")
+    elif sys.platform == "win32":  
+        command.extend([
+            "--windows-icon-from-ico=docs/exo-logo-win.ico",
+            "--file-version=0.0.1",
+            "--product-version=0.0.1"
+        ])
+    elif sys.platform.startswith("linux"):  
+        command.extend([
+            "--include-distribution-metadata=pygments",
+            "--linux-icon=docs/exo-rounded.png"
+        ])
+    try:
+        subprocess.run(command, check=True)
+        print("Build completed!")
+    except subprocess.CalledProcessError as e:
+        print(f"An error occurred: {e}")
+
+if __name__ == "__main__":
+    run()

+ 1 - 0
setup.py

@@ -13,6 +13,7 @@ install_requires = [
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "numpy==2.0.0",
+  "nuitka==2.4.10",
   "nvidia-ml-py==12.560.30",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",