Browse Source

Merge pull request #282 from exo-explore/unsilence_errors

Unsilence errors
Alex Cheema 8 months ago
parent
commit
2b9dec20eb
1 changed files with 9 additions and 5 deletions
  1. 9 5
      exo/api/chatgpt_api.py

+ 9 - 5
exo/api/chatgpt_api.py

@@ -268,9 +268,10 @@ class ChatGPTAPI:
     callback = self.node.on_token.register(callback_id)
 
     if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
-    asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))
 
     try:
+      await asyncio.wait_for(self.node.process_prompt(shard, prompt, image_str, request_id=request_id), timeout=self.response_timeout)
+
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
 
       if stream:
@@ -284,9 +285,9 @@ class ChatGPTAPI:
         )
         await response.prepare(request)
 
-        async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
-          prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
-          self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
+        async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
+          prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
+          self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
           eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
@@ -316,7 +317,7 @@ class ChatGPTAPI:
             if DEBUG >= 2: traceback.print_exc()
 
         def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-          self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
+          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
 
           return _request_id == request_id and is_finished
 
@@ -345,6 +346,9 @@ class ChatGPTAPI:
         return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
     except asyncio.TimeoutError:
       return web.json_response({"detail": "Response generation timed out"}, status=408)
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
     finally:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")