瀏覽代碼

clean branch

josh 9 月之前
父節點
當前提交
fea1c0fc29
共有 9 個文件被更改,包括 102 次插入12 次删除
  1. 1 1
      .gitignore
  2. 二進制
      docs/exo-rounded.png
  3. 1 1
      exo/__init__.py
  4. 23 4
      exo/api/chatgpt_api.py
  5. 9 1
      exo/download/hf/hf_helpers.py
  6. 2 1
      exo/inference/mlx/sharded_utils.py
  7. 7 2
      exo/main.py
  8. 56 0
      scripts/build_exo.py
  9. 3 2
      setup.py

+ 1 - 1
.gitignore

@@ -4,6 +4,7 @@ test_weights.npz
 .exo_used_ports
 .exo_used_ports
 .exo_node_id
 .exo_node_id
 .idea
 .idea
+.DS_Store
 
 
 # Byte-compiled / optimized / DLL files
 # Byte-compiled / optimized / DLL files
 __pycache__/
 __pycache__/
@@ -15,7 +16,6 @@ __pycache__/
 
 
 # Distribution / packaging
 # Distribution / packaging
 /.Python
 /.Python
-/build/
 /develop-eggs/
 /develop-eggs/
 /dist/
 /dist/
 /downloads/
 /downloads/

二進制
docs/exo-rounded.png


+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 23 - 4
exo/api/chatgpt_api.py

@@ -8,6 +8,9 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 from aiohttp import web
 import aiohttp_cors
 import aiohttp_cors
 import traceback
 import traceback
+import os
+import signal
+import sys
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict
 from exo.helpers import PrefixDict
@@ -18,7 +21,6 @@ from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable
 from typing import Callable
 
 
-
 class Message:
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
     self.role = role
     self.role = role
@@ -28,6 +30,7 @@ class Message:
     return {"role": self.role, "content": self.content}
     return {"role": self.role, "content": self.content}
 
 
 
 
+
 class ChatCompletionRequest:
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float):
   def __init__(self, model: str, messages: List[Message], temperature: float):
     self.model = model
     self.model = model
@@ -145,6 +148,10 @@ class PromptSession:
     self.timestamp = timestamp
     self.timestamp = timestamp
     self.prompt = prompt
     self.prompt = prompt
 
 
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
 
 
 class ChatGPTAPI:
 class ChatGPTAPI:
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
@@ -174,13 +181,25 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
+    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
 
 
-    self.static_dir = Path(__file__).parent.parent/"tinychat"
-    self.app.router.add_get("/", self.handle_root)
-    self.app.router.add_static("/", self.static_dir, name="static")
+    if "__compiled__" not in globals():
+      self.static_dir = Path(__file__).parent.parent/"tinychat"
+      self.app.router.add_get("/", self.handle_root)
+      self.app.router.add_static("/", self.static_dir, name="static")
 
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
     self.app.middlewares.append(self.log_request)
+  
+  async def handle_quit(self, request):
+    print("Received quit signal")
+    from exo.main import shutdown
+    response = web.json_response({"detail": "Quit signal received"}, status=200)
+    await response.prepare(request)
+    await response.write_eof()
+
+    await shutdown(signal.SIGINT, asyncio.get_event_loop())
+    return response
 
 
   async def timeout_middleware(self, app, handler):
   async def timeout_middleware(self, app, handler):
     async def middleware(request):
     async def middleware(request):

+ 9 - 1
exo/download/hf/hf_helpers.py

@@ -2,6 +2,7 @@ import asyncio
 import aiohttp
 import aiohttp
 import json
 import json
 import os
 import os
+import sys
 from urllib.parse import urljoin
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
@@ -17,6 +18,10 @@ from aiofiles import os as aios
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 
 
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or ('__compiled__' in globals())
 
 
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_dir = get_repo_root(repo_id)/"refs"
@@ -99,7 +104,10 @@ async def get_auth_headers():
 
 
 def get_repo_root(repo_id: str) -> Path:
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = repo_id.replace("/", "--")
+  sanitized_repo_id = str(repo_id).replace("/", "--")
+  if "Qwen2.5-0.5B-Instruct-4bit" in str(repo_id) and is_frozen():
+    repo_root = Path(sys.argv[0]).parent/f"models--{sanitized_repo_id}"
+    return repo_root
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
 
 
 

+ 2 - 1
exo/inference/mlx/sharded_utils.py

@@ -21,6 +21,7 @@ from transformers import AutoProcessor
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 
 
 from exo import DEBUG
 from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
 from ..shard import Shard
 from ..shard import Shard
 
 
 
 
@@ -183,7 +184,7 @@ async def load_shard(
     processor.encode = processor.tokenizer.encode
     processor.encode = processor.tokenizer.encode
     return model, processor
     return model, processor
   else:
   else:
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
     return model, tokenizer
 
 
 
 

+ 7 - 2
exo/main.py

@@ -3,7 +3,11 @@ import asyncio
 import signal
 import signal
 import json
 import json
 import logging
 import logging
+import platform
+import os
+import sys
 import time
 import time
+import subprocess
 import traceback
 import traceback
 import uuid
 import uuid
 from exo.networking.manual.manual_discovery import ManualDiscovery
 from exo.networking.manual.manual_discovery import ManualDiscovery
@@ -210,8 +214,9 @@ async def main():
   def handle_exit():
   def handle_exit():
     asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
     asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
 
 
-  for s in [signal.SIGINT, signal.SIGTERM]:
-    loop.add_signal_handler(s, handle_exit)
+  if platform.system() != "Windows":
+    for s in [signal.SIGINT, signal.SIGTERM]:
+      loop.add_signal_handler(s, handle_exit)
 
 
   await node.start(wait_for_peers=args.wait_for_peers)
   await node.start(wait_for_peers=args.wait_for_peers)
 
 

+ 56 - 0
scripts/build_exo.py

@@ -0,0 +1,56 @@
+import site
+import subprocess
+import sys
+import os 
+def run():
+    site_packages = site.getsitepackages()[0]
+    command = [
+        f"{sys.executable}", "-m", "nuitka", "exo/main.py",
+        "--company-name=exolabs",
+        "--product-name=exo",
+        "--output-dir=dist",
+        "--follow-imports",
+        "--standalone",
+        "--output-filename=exo",
+        "--onefile",
+        "--python-flag=no_site"
+    ]
+
+    if sys.platform == "darwin": 
+        command.extend([
+            "--macos-app-name=exo",
+            "--macos-app-mode=gui",
+            "--macos-app-version=0.0.1",
+            "--include-module=exo.inference.mlx.models.llama",
+            "--include-module=exo.inference.mlx.models.deepseek_v2",
+            "--include-module=exo.inference.mlx.models.base",
+            "--include-module=exo.inference.mlx.models.llava",
+            "--include-module=exo.inference.mlx.models.qwen2",
+            "--include-distribution-meta=mlx",
+            "--include-module=mlx._reprlib_fix",
+            "--include-module=mlx._os_warning",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
+            "--include-distribution-meta=pygments",
+            "--nofollow-import-to=tinygrad"
+        ])
+    elif sys.platform == "win32":  
+        command.extend([
+            "--windows-icon-from-ico=docs/exo-logo-win.ico",
+            "--file-version=0.0.1",
+            "--product-version=0.0.1"
+        ])
+    elif sys.platform.startswith("linux"):  
+        command.extend([
+            "--include-distribution-metadata=pygments",
+            "--linux-icon=docs/exo-rounded.png"
+        ])
+
+    try:
+        subprocess.run(command, check=True)
+        print("Build completed!")
+    except subprocess.CalledProcessError as e:
+        print(f"An error occurred: {e}")
+
+if __name__ == "__main__":
+    run()

+ 3 - 2
setup.py

@@ -13,6 +13,7 @@ install_requires = [
   "Jinja2==3.1.4",
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "netifaces==0.11.0",
   "numpy==2.0.0",
   "numpy==2.0.0",
+  "nuitka==2.4.10",
   "nvidia-ml-py==12.560.30",
   "nvidia-ml-py==12.560.30",
   "pillow==10.4.0",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
   "prometheus-client==0.20.0",
@@ -34,8 +35,8 @@ extras_require = {
     "yapf==0.40.2",
     "yapf==0.40.2",
   ],
   ],
   "apple_silicon": [
   "apple_silicon": [
-    "mlx==0.20.0",
-    "mlx-lm==0.19.3",
+    "mlx==0.18.0",
+    "mlx-lm==0.18.2",
   ],
   ],
 }
 }