Alex Cheema 6 месяцев назад
Родитель
Сommit
87d1271d33
1 измененных файлов с 8 добавлено и 8 удалено
  1. 8 8
      exo/api/chatgpt_api.py

+ 8 - 8
exo/api/chatgpt_api.py

@@ -204,7 +204,7 @@ class ChatGPTAPI:
 
     # Get the callback system and register our handler
     self.token_callback = node.on_token.register("chatgpt-api-token-handler")
-    self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
+    self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
     self.system_prompt = system_prompt
 
     cors = aiohttp_cors.setup(self.app)
@@ -463,17 +463,17 @@ class ChatGPTAPI:
       else:
         tokens = []
         while True:
-          token, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
-          tokens.append(token)
+          _tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
+          tokens.extend(_tokens)
           if is_finished:
             break
         finish_reason = "length"
         eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
-        if DEBUG >= 2: print(f"Checking if end of tokens result {token=} is {eos_token_id=}")
-        if token == eos_token_id:
+        if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
+        if tokens[-1] == eos_token_id:
           finish_reason = "stop"
 
-        return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, [token], stream, finish_reason, "chat.completion"))
+        return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
     except asyncio.TimeoutError:
       return web.json_response({"detail": "Response generation timed out"}, status=408)
     except Exception as e:
@@ -678,8 +678,8 @@ class ChatGPTAPI:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
 
-  async def handle_token(self, request_id: str, token: int, is_finished: bool):
-    await self.token_queues[request_id].put((token, is_finished))
+  async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
+    await self.token_queues[request_id].put((tokens, is_finished))
 
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)