Browse Source

test clean ups

josh 8 tháng trước cách đây
mục cha
commit
5396f080c5
5 tập tin đã thay đổi với 21 bổ sung18 xóa
  1. 1 1
      exo/api/chatgpt_api.py
  2. 13 10
      exo/download/hf/hf_helpers.py
  3. 2 3
      exo/helpers.py
  4. 3 3
      exo/main.py
  5. 2 1
      setup.py

+ 1 - 1
exo/api/chatgpt_api.py

@@ -189,7 +189,7 @@ class ChatGPTAPI:
     response = web.json_response({"detail": "Quit signal received"}, status=200)
     await response.prepare(request)
     await response.write_eof()
-    await shutdown(signal.SIGINT, asyncio.get_event_loop())
+    await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node)
 
   async def timeout_middleware(self, app, handler):
     async def middleware(request):

+ 13 - 10
exo/download/hf/hf_helpers.py

@@ -2,6 +2,7 @@ import aiofiles.os as aios
 from typing import Union
 import asyncio
 import aiohttp
+from anyio import Path as AsyncPath
 import json
 import os
 import sys
@@ -107,17 +108,19 @@ def get_repo_root(repo_id: str) -> Path:
 
 async def move_models_to_hf(seed_dir: Union[str, Path]):
   """Move model in resources folder of app to .cache/huggingface/hub"""
-  source_dir = Path(seed_dir)
-  dest_dir = get_hf_home()/"hub"
-  await aios.makedirs(dest_dir, exist_ok=True)
-  async for path in async_iterdir(source_dir):  
-    if path.is_dir() and path.name.startswith("models--"):
+  source_dir = AsyncPath(seed_dir)
+  if DEBUG>=1: print("moving files")
+  dest_dir = AsyncPath(get_hf_home()/"hub")
+  await aios.makedirs(dest_dir, exist_ok=True)   
+  async for path in source_dir.iterdir():
+    if await path.is_dir() and path.name.startswith("models--"):
+      if DEBUG>=1: print("moving files")
       dest_path = dest_dir / path.name
-      if await async_exists(dest_path): 
-          if DEBUG >= 1: print(f"skipping moving {dest_path}. File already exists")
-      else:
-          await aios.rename(str(path), str(dest_path))
-      
+      try:
+        await aios.rename(str(path), str(dest_path))
+      except Exception as e:
+        print(e)
+    
 
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"

+ 2 - 3
exo/helpers.py

@@ -237,7 +237,7 @@ def get_all_ip_addresses():
     return ["localhost"]
 
 
-async def shutdown(signal, loop):
+async def shutdown(signal, loop, node):
   """Gracefully shutdown the server and close the asyncio loop."""
   print(f"Received exit signal {signal.name}...")
   print("Thank you for using exo.")
@@ -246,8 +246,7 @@ async def shutdown(signal, loop):
   [task.cancel() for task in server_tasks]
   print(f"Cancelling {len(server_tasks)} outstanding tasks")
   await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-  loop.stop()
+  await node.server.stop()
 
 
 def is_frozen():

+ 3 - 3
exo/main.py

@@ -194,7 +194,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
     node.on_token.deregister(callback_id)
 
 def clean_path(path):
-    """Clean and resolve given path."""
+    """Clean and resolve path"""
     if path.startswith("Optional("):
         path = path.strip('Optional("').rstrip('")')
     return os.path.expanduser(path)
@@ -223,7 +223,7 @@ async def main():
 
   # Use a more direct approach to handle signals
   def handle_exit():
-    asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
+    asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node))
 
   if platform.system() != "Windows":
     for s in [signal.SIGINT, signal.SIGTERM]:
@@ -250,7 +250,7 @@ def run():
   except KeyboardInterrupt:
     print("Received keyboard interrupt. Shutting down...")
   finally:
-    loop.run_until_complete(shutdown(signal.SIGTERM, loop))
+    loop.run_until_complete(shutdown(signal.SIGTERM, loop, node))
     loop.close()
 
 

+ 2 - 1
setup.py

@@ -8,6 +8,7 @@ install_requires = [
   "aiohttp==3.10.11",
   "aiohttp_cors==0.7.0",
   "aiofiles==24.1.0",
+  "anyio==4.6.2",
   "grpcio==1.68.0",
   "grpcio-tools==1.68.0",
   "Jinja2==3.1.4",
@@ -24,7 +25,7 @@ install_requires = [
   "rich==13.7.1",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
-  "transformers==4.46.3",
+  "transformers==4.43.3",
   "uuid==1.30",
   "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
 ]