Parcourir la source

pr suggestions fix

josh il y a 5 mois
Parent
commit
e991438e72
5 fichiers modifiés avec 25 ajouts et 34 suppressions
  1. 1 7
      exo/api/chatgpt_api.py
  2. 2 7
      exo/download/hf/hf_helpers.py
  3. 20 0
      exo/helpers.py
  4. 1 15
      exo/main.py
  5. 1 5
      scripts/build_exo.py

+ 1 - 7
exo/api/chatgpt_api.py

@@ -13,7 +13,7 @@ import signal
 import sys
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict
+from exo.helpers import PrefixDict, shutdown
 from exo.inference.inference_engine import inference_engine_classes
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
@@ -148,11 +148,6 @@ class PromptSession:
     self.timestamp = timestamp
     self.prompt = prompt
 
-def is_frozen():
-  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
-    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
-    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
-
 class ChatGPTAPI:
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
@@ -193,7 +188,6 @@ class ChatGPTAPI:
   
   async def handle_quit(self, request):
     print("Received quit signal")
-    from exo.main import shutdown
     response = web.json_response({"detail": "Quit signal received"}, status=200)
     await response.prepare(request)
     await response.write_eof()

+ 2 - 7
exo/download/hf/hf_helpers.py

@@ -10,7 +10,7 @@ from fnmatch import fnmatch
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG
+from exo.helpers import DEBUG, is_frozen
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.inference.shard import Shard
 import aiofiles
@@ -18,11 +18,6 @@ from aiofiles import os as aios
 
 T = TypeVar("T")
 
-def is_frozen():
-  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
-    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
-    or ('__compiled__' in globals())
-
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_file = refs_dir/revision
@@ -105,7 +100,7 @@ async def get_auth_headers():
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   sanitized_repo_id = str(repo_id).replace("/", "--")
-  if "Qwen2.5-0.5B-Instruct-4bit" in str(repo_id) and is_frozen():
+  if is_frozen():
     repo_root = Path(sys.argv[0]).parent/f"models--{sanitized_repo_id}"
     return repo_root
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"

+ 20 - 0
exo/helpers.py

@@ -1,4 +1,5 @@
 import os
+import sys
 import asyncio
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return ["localhost"]
+
+
+async def shutdown(signal, loop):
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
+
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

+ 1 - 15
exo/main.py

@@ -20,7 +20,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
@@ -163,20 +163,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
-
-async def shutdown(signal, loop):
-  """Gracefully shutdown the server and close the asyncio loop."""
-  print(f"Received exit signal {signal.name}...")
-  print("Thank you for using exo.")
-  print_yellow_exo()
-  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-  [task.cancel() for task in server_tasks]
-  print(f"Cancelling {len(server_tasks)} outstanding tasks")
-  await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-  loop.stop()
-
-
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
   inference_class = inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)

+ 1 - 5
scripts/build_exo.py

@@ -21,11 +21,7 @@ def run():
             "--macos-app-name=exo",
             "--macos-app-mode=gui",
             "--macos-app-version=0.0.1",
-            "--include-module=exo.inference.mlx.models.llama",
-            "--include-module=exo.inference.mlx.models.deepseek_v2",
-            "--include-module=exo.inference.mlx.models.base",
-            "--include-module=exo.inference.mlx.models.llava",
-            "--include-module=exo.inference.mlx.models.qwen2",
+            "--include-module=exo.inference.mlx.models.*",
             "--include-distribution-meta=mlx",
             "--include-module=mlx._reprlib_fix",
             "--include-module=mlx._os_warning",