josh il y a 9 mois
Parent
commit
fea1c0fc29

+ 1 - 1
.gitignore

@@ -4,6 +4,7 @@ test_weights.npz
 .exo_used_ports
 .exo_node_id
 .idea
+.DS_Store
 
 # Byte-compiled / optimized / DLL files
 __pycache__/
@@ -15,7 +16,6 @@ __pycache__/
 
 # Distribution / packaging
 /.Python
-/build/
 /develop-eggs/
 /dist/
 /downloads/

BIN
docs/exo-rounded.png


+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 23 - 4
exo/api/chatgpt_api.py

@@ -8,6 +8,9 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 import traceback
+import os
+import signal
+import sys
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict
@@ -18,7 +21,6 @@ from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable
 
-
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
     self.role = role
@@ -28,6 +30,7 @@ class Message:
     return {"role": self.role, "content": self.content}
 
 
+
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float):
     self.model = model
@@ -145,6 +148,10 @@ class PromptSession:
     self.timestamp = timestamp
     self.prompt = prompt
 
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
 
 class ChatGPTAPI:
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
@@ -174,13 +181,25 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
+    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
 
-    self.static_dir = Path(__file__).parent.parent/"tinychat"
-    self.app.router.add_get("/", self.handle_root)
-    self.app.router.add_static("/", self.static_dir, name="static")
+    if "__compiled__" not in globals():
+      self.static_dir = Path(__file__).parent.parent/"tinychat"
+      self.app.router.add_get("/", self.handle_root)
+      self.app.router.add_static("/", self.static_dir, name="static")
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
+  
+  async def handle_quit(self, request):
+    print("Received quit signal")
+    from exo.main import shutdown
+    response = web.json_response({"detail": "Quit signal received"}, status=200)
+    await response.prepare(request)
+    await response.write_eof()
+
+    await shutdown(signal.SIGINT, asyncio.get_event_loop())
+    return response
 
   async def timeout_middleware(self, app, handler):
     async def middleware(request):

+ 9 - 1
exo/download/hf/hf_helpers.py

@@ -2,6 +2,7 @@ import asyncio
 import aiohttp
 import json
 import os
+import sys
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
@@ -17,6 +18,10 @@ from aiofiles import os as aios
 
 T = TypeVar("T")
 
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or ('__compiled__' in globals())
 
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
@@ -99,7 +104,10 @@ async def get_auth_headers():
 
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = repo_id.replace("/", "--")
+  sanitized_repo_id = str(repo_id).replace("/", "--")
+  if "Qwen2.5-0.5B-Instruct-4bit" in str(repo_id) and is_frozen():
+    repo_root = Path(sys.argv[0]).parent/f"models--{sanitized_repo_id}"
+    return repo_root
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
 

+ 2 - 1
exo/inference/mlx/sharded_utils.py

@@ -21,6 +21,7 @@ from transformers import AutoProcessor
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 
 from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
 from ..shard import Shard
 
 
@@ -183,7 +184,7 @@ async def load_shard(
     processor.encode = processor.tokenizer.encode
     return model, processor
   else:
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
 
 

+ 7 - 2
exo/main.py

@@ -3,7 +3,11 @@ import asyncio
 import signal
 import json
 import logging
+import platform
+import os
+import sys
 import time
+import subprocess
 import traceback
 import uuid
 from exo.networking.manual.manual_discovery import ManualDiscovery
@@ -210,8 +214,9 @@ async def main():
   def handle_exit():
     asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
 
-  for s in [signal.SIGINT, signal.SIGTERM]:
-    loop.add_signal_handler(s, handle_exit)
+  if platform.system() != "Windows":
+    for s in [signal.SIGINT, signal.SIGTERM]:
+      loop.add_signal_handler(s, handle_exit)
 
   await node.start(wait_for_peers=args.wait_for_peers)
 

+ 56 - 0
scripts/build_exo.py

@@ -0,0 +1,56 @@
+import site
+import subprocess
+import sys
+import os 
+def run():
+    site_packages = site.getsitepackages()[0]
+    command = [
+        f"{sys.executable}", "-m", "nuitka", "exo/main.py",
+        "--company-name=exolabs",
+        "--product-name=exo",
+        "--output-dir=dist",
+        "--follow-imports",
+        "--standalone",
+        "--output-filename=exo",
+        "--onefile",
+        "--python-flag=no_site"
+    ]
+
+    if sys.platform == "darwin": 
+        command.extend([
+            "--macos-app-name=exo",
+            "--macos-app-mode=gui",
+            "--macos-app-version=0.0.1",
+            "--include-module=exo.inference.mlx.models.llama",
+            "--include-module=exo.inference.mlx.models.deepseek_v2",
+            "--include-module=exo.inference.mlx.models.base",
+            "--include-module=exo.inference.mlx.models.llava",
+            "--include-module=exo.inference.mlx.models.qwen2",
+            "--include-distribution-meta=mlx",
+            "--include-module=mlx._reprlib_fix",
+            "--include-module=mlx._os_warning",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
+            "--include-distribution-meta=pygments",
+            "--nofollow-import-to=tinygrad"
+        ])
+    elif sys.platform == "win32":  
+        command.extend([
+            "--windows-icon-from-ico=docs/exo-logo-win.ico",
+            "--file-version=0.0.1",
+            "--product-version=0.0.1"
+        ])
+    elif sys.platform.startswith("linux"):  
+        command.extend([
+            "--include-distribution-metadata=pygments",
+            "--linux-icon=docs/exo-rounded.png"
+        ])
+
+    try:
+        subprocess.run(command, check=True)
+        print("Build completed!")
+    except subprocess.CalledProcessError as e:
+        print(f"An error occurred: {e}")
+
+if __name__ == "__main__":
+    run()

+ 3 - 2
setup.py

@@ -13,6 +13,7 @@ install_requires = [
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "numpy==2.0.0",
+  "nuitka==2.4.10",
   "nvidia-ml-py==12.560.30",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
@@ -34,8 +35,8 @@ extras_require = {
     "yapf==0.40.2",
   ],
   "apple_silicon": [
-    "mlx==0.20.0",
-    "mlx-lm==0.19.3",
+    "mlx==0.18.0",
+    "mlx-lm==0.18.2",
   ],
 }