|
@@ -245,7 +245,7 @@ class ChatGPTAPI:
|
|
|
)
|
|
|
await response.prepare(request)
|
|
|
|
|
|
- for model_name, pretty in pretty_name.items():
|
|
|
+ async def process_model(model_name, pretty):
|
|
|
if model_name in model_cards:
|
|
|
model_info = model_cards[model_name]
|
|
|
|
|
@@ -273,6 +273,12 @@ class ChatGPTAPI:
|
|
|
|
|
|
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
|
|
|
|
|
+ # Process all models in parallel
|
|
|
+ await asyncio.gather(*[
|
|
|
+ process_model(model_name, pretty)
|
|
|
+ for model_name, pretty in pretty_name.items()
|
|
|
+ ])
|
|
|
+
|
|
|
await response.write(b"data: [DONE]\n\n")
|
|
|
return response
|
|
|
|
|
@@ -562,7 +568,7 @@ class ChatGPTAPI:
|
|
|
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
|
|
shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
|
|
- asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
|
|
|
+ asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
|
|
|
|
|
return web.json_response({
|
|
|
"status": "success",
|