Browse Source

more robust handling of timeouts

Alex Cheema 8 months ago
parent
commit
c3864f5e6f
2 changed files with 19 additions and 14 deletions
  1. 11 8
      exo/api/chatgpt_api.py
  2. 8 6
      exo/tinychat/index.js

+ 11 - 8
exo/api/chatgpt_api.py

@@ -176,16 +176,23 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
-    # Endpoint for download progress tracking
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
 
-    self.static_dir = Path(__file__).parent.parent/"tinychat"
+    self.static_dir = Path(__file__).parent.parent / "tinychat"
     self.app.router.add_get("/", self.handle_root)
     self.app.router.add_static("/", self.static_dir, name="static")
 
-    # Add middleware to log every request
+    self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
 
+  async def timeout_middleware(self, app, handler):
+    async def middleware(request):
+      try:
+        return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
+      except asyncio.TimeoutError:
+        return web.json_response({"detail": "Request timed out"}, status=408)
+    return middleware
+
   async def log_request(self, app, handler):
     async def middleware(request):
       if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
@@ -261,11 +268,7 @@ class ChatGPTAPI:
     callback = self.node.on_token.register(callback_id)
 
     if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
-    try:
-      await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
-    except Exception as e:
-      if DEBUG >= 2: traceback.print_exc()
-      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
+    asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))
 
     try:
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")

+ 8 - 6
exo/tinychat/index.js

@@ -107,15 +107,15 @@ document.addEventListener("alpine:init", () => {
         el.style.height = "auto";
         el.style.height = el.scrollHeight + "px";
 
-        // Proceed to handle the message
+        localStorage.setItem("pendingMessage", value);
         this.processMessage(value);
       } catch (error) {
         console.error('error', error)
-        this.errorMessage = error.message || 'Errore durante l\'invio del messaggio.';
+        this.lastErrorMessage = error.message || 'Unknown error on handleSend';
+        this.errorMessage = error.message || 'Unknown error on handleSend';
         setTimeout(() => {
           this.errorMessage = null;
         }, 5 * 1000)
-        localStorage.setItem("pendingMessage", value);
         this.generating = false;
       }
     },
@@ -233,6 +233,7 @@ document.addEventListener("alpine:init", () => {
         }
       } catch (error) {
         console.error('error', error)
+        this.lastErrorMessage = error;
         this.errorMessage = error;
         setTimeout(() => {
           this.errorMessage = null;
@@ -323,8 +324,11 @@ document.addEventListener("alpine:init", () => {
                   // Clear pendingMessage
                   localStorage.removeItem("pendingMessage");
                   // Call processMessage() with savedMessage
-                  await this.processMessage(savedMessage);
+                  if (this.lastErrorMessage) {
+                    await this.processMessage(savedMessage);
+                  }
                 }
+                this.lastErrorMessage = null;
               }
             } else {
               // Compute human-readable strings
@@ -344,8 +348,6 @@ document.addEventListener("alpine:init", () => {
       } catch (error) {
         console.error("Error fetching download progress:", error);
         this.downloadProgress = null;
-        // Stop polling in case of error
-        this.stopDownloadProgressPolling();
       }
     },