Преглед изворни кода

Merge pull request #1 from josh1593/package-exo-app

Package exo app
josh пре 5 месеци
родитељ
комит
eca596396d

+ 1 - 1
.gitignore

@@ -4,6 +4,7 @@ test_weights.npz
 .exo_used_ports
 .exo_node_id
 .idea
+.DS_Store
 
 # Byte-compiled / optimized / DLL files
 __pycache__/
@@ -15,7 +16,6 @@ __pycache__/
 
 # Distribution / packaging
 /.Python
-/build/
 /develop-eggs/
 /dist/
 /downloads/

BIN
docs/exo-rounded.png


+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 16 - 6
exo/api/chatgpt_api.py

@@ -8,15 +8,16 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 import traceback
+import os
+import sys
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict
+from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
 from typing import Callable, Optional
 
-
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
     self.role = role
@@ -26,6 +27,7 @@ class Message:
     return {"role": self.role, "content": self.content}
 
 
+
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float):
     self.model = model
@@ -143,7 +145,6 @@ class PromptSession:
     self.timestamp = timestamp
     self.prompt = prompt
 
-
 class ChatGPTAPI:
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
     self.node = node
@@ -172,13 +173,22 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
+    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
 
-    self.static_dir = Path(__file__).parent.parent/"tinychat"
-    self.app.router.add_get("/", self.handle_root)
-    self.app.router.add_static("/", self.static_dir, name="static")
+    if "__compiled__" not in globals():
+      self.static_dir = Path(__file__).parent.parent/"tinychat"
+      self.app.router.add_get("/", self.handle_root)
+      self.app.router.add_static("/", self.static_dir, name="static")
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
+  
+  async def handle_quit(self, request):
+    if DEBUG>=1: print("Received quit signal")
+    response = web.json_response({"detail": "Quit signal received"}, status=200)
+    await response.prepare(request)
+    await response.write_eof()
+    await shutdown(signal.SIGINT, asyncio.get_event_loop())
 
   async def timeout_middleware(self, app, handler):
     async def middleware(request):

+ 20 - 4
exo/download/hf/hf_helpers.py

@@ -1,7 +1,11 @@
+import aiofiles.os as aios
+from typing import Union
 import asyncio
 import aiohttp
 import json
 import os
+import sys
+import shutil
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
@@ -9,7 +13,7 @@ from fnmatch import fnmatch
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG
+from exo.helpers import DEBUG, is_frozen
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.inference.shard import Shard
 import aiofiles
@@ -17,7 +21,6 @@ from aiofiles import os as aios
 
 T = TypeVar("T")
 
-
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_file = refs_dir/revision
@@ -99,9 +102,22 @@ async def get_auth_headers():
 
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = repo_id.replace("/", "--")
+  sanitized_repo_id = str(repo_id).replace("/", "--")
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
+async def move_models_to_hf(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = get_hf_home()/"hub"
+  await aios.makedirs(dest_dir, exist_ok=True)
+  async for path in source_dir.iterdir():
+    if path.is_dir() and path.startswith("models--"):
+      dest_path = dest_dir / path.name
+      if dest_path.exists():
+        if DEBUG>=1: print(f"skipping moving {dest_path}. File already exists")
+      else:
+        await aios.rename(str(path), str(dest_path))
+        
 
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
@@ -409,7 +425,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     elif shard.is_last_layer():
       shard_specific_patterns.add(sorted_file_names[-1])
   else:
-    shard_specific_patterns = set("*.safetensors")
+    shard_specific_patterns = set(["*.safetensors"])
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
 

+ 20 - 0
exo/helpers.py

@@ -1,4 +1,5 @@
 import os
+import sys
 import asyncio
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return ["localhost"]
+
+
+async def shutdown(signal, loop):
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
+
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

+ 2 - 1
exo/inference/mlx/sharded_utils.py

@@ -21,6 +21,7 @@ from transformers import AutoProcessor
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 
 from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
 from ..shard import Shard
 
 
@@ -183,7 +184,7 @@ async def load_shard(
     processor.encode = processor.tokenizer.encode
     return model, processor
   else:
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
 
 

+ 5 - 9
exo/inference/tinygrad/inference.py

@@ -7,7 +7,6 @@ from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
 from tinygrad import Tensor, nn, Context
 from exo.inference.inference_engine import InferenceEngine
-from typing import Optional, Tuple
 import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
@@ -68,24 +67,21 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
     logits = x[:, -1, :]
     def sample_wrapper():
-      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
-    out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
-    return out.numpy().astype(int)
+      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    return np.array(tokens)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
   
   async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
-    return tokens
+    return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
-    return output_data.numpy()
+    return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:

+ 15 - 19
exo/main.py

@@ -3,6 +3,9 @@ import asyncio
 import signal
 import json
 import logging
+import platform
+import os
+import sys
 import time
 import traceback
 import uuid
@@ -17,14 +20,14 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
-from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home
+from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -34,6 +37,7 @@ parser.add_argument("--default-model", type=str, default=None, help="Default mod
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
+parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
@@ -129,7 +133,6 @@ node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
 
-
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
     status = json.loads(opaque_status)
@@ -162,20 +165,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
-
-async def shutdown(signal, loop):
-  """Gracefully shutdown the server and close the asyncio loop."""
-  print(f"Received exit signal {signal.name}...")
-  print("Thank you for using exo.")
-  print_yellow_exo()
-  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-  [task.cancel() for task in server_tasks]
-  print(f"Cancelling {len(server_tasks)} outstanding tasks")
-  await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-  loop.stop()
-
-
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
   inference_class = inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
@@ -219,13 +208,20 @@ async def main():
           {"❌ No read access" if not has_read else ""}
           {"❌ No write access" if not has_write else ""}
           """)
+    
+  if not args.models_seed_dir is None:
+    try:
+      await move_models_to_hf(args.models_seed_dir)
+    except Exception as e:
+      print(f"Error moving models to .cache/huggingface: {e}")
 
   # Use a more direct approach to handle signals
   def handle_exit():
     asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
 
-  for s in [signal.SIGINT, signal.SIGTERM]:
-    loop.add_signal_handler(s, handle_exit)
+  if platform.system() != "Windows":
+    for s in [signal.SIGINT, signal.SIGTERM]:
+      loop.add_signal_handler(s, handle_exit)
 
   await node.start(wait_for_peers=args.wait_for_peers)
 

+ 60 - 0
scripts/build_exo.py

@@ -0,0 +1,60 @@
+import site
+import subprocess
+import sys
+import os 
+import pkgutil
+
+def run():
+    site_packages = site.getsitepackages()[0]
+    command = [
+        f"{sys.executable}", "-m", "nuitka", "exo/main.py",
+        "--company-name=exolabs",
+        "--product-name=exo",
+        "--output-dir=dist",
+        "--follow-imports",
+        "--standalone",
+        "--output-filename=exo",
+        "--onefile",
+        "--python-flag=no_site"
+    ]
+
+    if sys.platform == "darwin": 
+        command.extend([
+            "--macos-app-name=exo",
+            "--macos-app-mode=gui",
+            "--macos-app-version=0.0.1",
+            "--macos-signed-app-name=com.exolabs.exo",
+            "--macos-sign-identity=auto",
+            "--macos-sign-notarization",
+            "--include-distribution-meta=mlx",
+            "--include-module=mlx._reprlib_fix",
+            "--include-module=mlx._os_warning",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
+            "--include-distribution-meta=pygments",
+            "--nofollow-import-to=tinygrad"
+        ])
+        inference_modules = [
+            name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
+        ]
+        for module in inference_modules:
+            command.append(f"--include-module=exo.inference.mlx.models.{module}")
+    elif sys.platform == "win32":  
+        command.extend([
+            "--windows-icon-from-ico=docs/exo-logo-win.ico",
+            "--file-version=0.0.1",
+            "--product-version=0.0.1"
+        ])
+    elif sys.platform.startswith("linux"):  
+        command.extend([
+            "--include-distribution-metadata=pygments",
+            "--linux-icon=docs/exo-rounded.png"
+        ])
+    try:
+        subprocess.run(command, check=True)
+        print("Build completed!")
+    except subprocess.CalledProcessError as e:
+        print(f"An error occurred: {e}")
+
+if __name__ == "__main__":
+    run()

+ 4 - 4
setup.py

@@ -8,11 +8,12 @@ install_requires = [
   "aiohttp==3.10.11",
   "aiohttp_cors==0.7.0",
   "aiofiles==24.1.0",
-  "grpcio==1.64.1",
-  "grpcio-tools==1.64.1",
+  "grpcio==1.68.0",
+  "grpcio-tools==1.68.0",
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "numpy==2.0.0",
+  "nuitka==2.4.10",
   "nvidia-ml-py==12.560.30",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
@@ -21,10 +22,9 @@ install_requires = [
   "pydantic==2.9.2",
   "requests==2.32.3",
   "rich==13.7.1",
-  "safetensors==0.4.3",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
-  "transformers==4.43.3",
+  "transformers==4.46.3",
   "uuid==1.30",
   "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
 ]