Bläddra i källkod

pr suggestions fix

josh 5 månader sedan
förälder
incheckning
e991438e72
5 ändrade filer med 25 tillägg och 34 borttagningar
  1. 1 7
      exo/api/chatgpt_api.py
  2. 2 7
      exo/download/hf/hf_helpers.py
  3. 20 0
      exo/helpers.py
  4. 1 15
      exo/main.py
  5. 1 5
      scripts/build_exo.py

+ 1 - 7
exo/api/chatgpt_api.py

@@ -13,7 +13,7 @@ import signal
 import sys
 import sys
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict
+from exo.helpers import PrefixDict, shutdown
 from exo.inference.inference_engine import inference_engine_classes
 from exo.inference.inference_engine import inference_engine_classes
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
@@ -148,11 +148,6 @@ class PromptSession:
     self.timestamp = timestamp
     self.timestamp = timestamp
     self.prompt = prompt
     self.prompt = prompt
 
 
-def is_frozen():
-  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
-    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
-    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
-
 class ChatGPTAPI:
 class ChatGPTAPI:
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
     self.node = node
@@ -193,7 +188,6 @@ class ChatGPTAPI:
   
   
   async def handle_quit(self, request):
   async def handle_quit(self, request):
     print("Received quit signal")
     print("Received quit signal")
-    from exo.main import shutdown
     response = web.json_response({"detail": "Quit signal received"}, status=200)
     response = web.json_response({"detail": "Quit signal received"}, status=200)
     await response.prepare(request)
     await response.prepare(request)
     await response.write_eof()
     await response.write_eof()

+ 2 - 7
exo/download/hf/hf_helpers.py

@@ -10,7 +10,7 @@ from fnmatch import fnmatch
 from pathlib import Path
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
 from typing import Generator, Iterable, TypeVar, TypedDict
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG
+from exo.helpers import DEBUG, is_frozen
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 import aiofiles
 import aiofiles
@@ -18,11 +18,6 @@ from aiofiles import os as aios
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 
 
-def is_frozen():
-  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
-    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
-    or ('__compiled__' in globals())
-
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_file = refs_dir/revision
   refs_file = refs_dir/revision
@@ -105,7 +100,7 @@ async def get_auth_headers():
 def get_repo_root(repo_id: str) -> Path:
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   sanitized_repo_id = str(repo_id).replace("/", "--")
   sanitized_repo_id = str(repo_id).replace("/", "--")
-  if "Qwen2.5-0.5B-Instruct-4bit" in str(repo_id) and is_frozen():
+  if is_frozen():
     repo_root = Path(sys.argv[0]).parent/f"models--{sanitized_repo_id}"
     repo_root = Path(sys.argv[0]).parent/f"models--{sanitized_repo_id}"
     return repo_root
     return repo_root
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"

+ 20 - 0
exo/helpers.py

@@ -1,4 +1,5 @@
 import os
 import os
+import sys
 import asyncio
 import asyncio
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
 import socket
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
   except:
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return ["localhost"]
     return ["localhost"]
+
+
+async def shutdown(signal, loop):
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
+
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

+ 1 - 15
exo/main.py

@@ -20,7 +20,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
@@ -163,20 +163,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 
 
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 
-
-async def shutdown(signal, loop):
-  """Gracefully shutdown the server and close the asyncio loop."""
-  print(f"Received exit signal {signal.name}...")
-  print("Thank you for using exo.")
-  print_yellow_exo()
-  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-  [task.cancel() for task in server_tasks]
-  print(f"Cancelling {len(server_tasks)} outstanding tasks")
-  await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-  loop.stop()
-
-
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
   inference_class = inference_engine.__class__.__name__
   inference_class = inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
   shard = build_base_shard(model_name, inference_class)

+ 1 - 5
scripts/build_exo.py

@@ -21,11 +21,7 @@ def run():
             "--macos-app-name=exo",
             "--macos-app-name=exo",
             "--macos-app-mode=gui",
             "--macos-app-mode=gui",
             "--macos-app-version=0.0.1",
             "--macos-app-version=0.0.1",
-            "--include-module=exo.inference.mlx.models.llama",
-            "--include-module=exo.inference.mlx.models.deepseek_v2",
-            "--include-module=exo.inference.mlx.models.base",
-            "--include-module=exo.inference.mlx.models.llava",
-            "--include-module=exo.inference.mlx.models.qwen2",
+            "--include-module=exo.inference.mlx.models.*",
             "--include-distribution-meta=mlx",
             "--include-distribution-meta=mlx",
             "--include-module=mlx._reprlib_fix",
             "--include-module=mlx._reprlib_fix",
             "--include-module=mlx._os_warning",
             "--include-module=mlx._os_warning",