|
@@ -176,6 +176,7 @@ class ChatGPTAPI:
|
|
|
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
|
|
|
+ cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
|
|
|
|
|
|
if "__compiled__" not in globals():
|
|
|
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
@@ -410,6 +411,24 @@ class ChatGPTAPI:
|
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
|
return web.json_response({"error": str(e)}, status=500)
|
|
|
|
|
|
+ async def handle_post_download(self, request):
|
|
|
+ try:
|
|
|
+ data = await request.json()
|
|
|
+ model_name = data.get("model")
|
|
|
+ if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
|
|
|
+ if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
|
|
+ shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
+ if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
|
|
+ asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
|
|
|
+
|
|
|
+ return web.json_response({
|
|
|
+ "status": "success",
|
|
|
+ "message": f"Download started for model: {model_name}"
|
|
|
+ })
|
|
|
+ except Exception as e:
|
|
|
+ if DEBUG >= 2: traceback.print_exc()
|
|
|
+ return web.json_response({"error": str(e)}, status=500)
|
|
|
+
|
|
|
async def run(self, host: str = "0.0.0.0", port: int = 52415):
|
|
|
runner = web.AppRunner(self.app)
|
|
|
await runner.setup()
|