Browse Source

Merge pull request #588 from exo-explore/betterdl

better download
Alex Cheema 5 months ago
parent
commit
d9a836f152
1 changed files with 8 additions and 2 deletions
  1. 8 2
      exo/api/chatgpt_api.py

+ 8 - 2
exo/api/chatgpt_api.py

@@ -245,7 +245,7 @@ class ChatGPTAPI:
         )
         await response.prepare(request)
 
-        for model_name, pretty in pretty_name.items():
+        async def process_model(model_name, pretty):
             if model_name in model_cards:
                 model_info = model_cards[model_name]
 
@@ -273,6 +273,12 @@ class ChatGPTAPI:
 
                         await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
 
+        # Process all models in parallel
+        await asyncio.gather(*[
+            process_model(model_name, pretty)
+            for model_name, pretty in pretty_name.items()
+        ])
+
         await response.write(b"data: [DONE]\n\n")
         return response
 
@@ -562,7 +568,7 @@ class ChatGPTAPI:
       if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
       shard = build_base_shard(model_name, self.inference_engine_classname)
       if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
-      asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
+      asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
 
       return web.json_response({
         "status": "success",