Răsfoiți Sursa

Merge pull request #526 from exo-explore/downloadendpoint

Download endpoint
Alex Cheema 5 luni în urmă
părinte
comite
46e202ad04
1 a modificat fișierele cu 19 adăugiri și 0 ștergeri
  1. 19 0
      exo/api/chatgpt_api.py

+ 19 - 0
exo/api/chatgpt_api.py

@@ -176,6 +176,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
     cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
+    cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
 
     if "__compiled__" not in globals():
       self.static_dir = Path(__file__).parent.parent/"tinychat"
@@ -410,6 +411,24 @@ class ChatGPTAPI:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"error": str(e)}, status=500)
 
+  async def handle_post_download(self, request):
+    try:
+      data = await request.json()
+      model_name = data.get("model")
+      if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
+      if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
+      shard = build_base_shard(model_name, self.inference_engine_classname)
+      if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
+      asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
+
+      return web.json_response({
+        "status": "success",
+        "message": f"Download started for model: {model_name}"
+      })
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"error": str(e)}, status=500)
+
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     await runner.setup()