|
@@ -176,16 +176,23 @@ class ChatGPTAPI:
|
|
|
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
|
- # Endpoint for download progress tracking
|
|
|
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
|
|
|
|
|
- self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
|
+ self.static_dir = Path(__file__).parent.parent / "tinychat"
|
|
|
self.app.router.add_get("/", self.handle_root)
|
|
|
self.app.router.add_static("/", self.static_dir, name="static")
|
|
|
|
|
|
- # Add middleware to log every request
|
|
|
+ self.app.middlewares.append(self.timeout_middleware)
|
|
|
self.app.middlewares.append(self.log_request)
|
|
|
|
|
|
+ async def timeout_middleware(self, app, handler):
|
|
|
+ async def middleware(request):
|
|
|
+ try:
|
|
|
+ return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ return web.json_response({"detail": "Request timed out"}, status=408)
|
|
|
+ return middleware
|
|
|
+
|
|
|
async def log_request(self, app, handler):
|
|
|
async def middleware(request):
|
|
|
if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
|
|
@@ -261,11 +268,7 @@ class ChatGPTAPI:
|
|
|
callback = self.node.on_token.register(callback_id)
|
|
|
|
|
|
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
|
|
|
- try:
|
|
|
- await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
|
|
|
- except Exception as e:
|
|
|
- if DEBUG >= 2: traceback.print_exc()
|
|
|
- return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
|
|
+ asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))
|
|
|
|
|
|
try:
|
|
|
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
|