|
@@ -204,7 +204,7 @@ class ChatGPTAPI:
|
|
|
|
|
|
# Get the callback system and register our handler
|
|
|
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
|
|
|
- self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
|
|
|
+ self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
|
|
|
self.system_prompt = system_prompt
|
|
|
|
|
|
cors = aiohttp_cors.setup(self.app)
|
|
@@ -463,17 +463,17 @@ class ChatGPTAPI:
|
|
|
else:
|
|
|
tokens = []
|
|
|
while True:
|
|
|
- token, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
|
|
|
- tokens.append(token)
|
|
|
+ _tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
|
|
|
+ tokens.extend(_tokens)
|
|
|
if is_finished:
|
|
|
break
|
|
|
finish_reason = "length"
|
|
|
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
|
|
|
- if DEBUG >= 2: print(f"Checking if end of tokens result {token=} is {eos_token_id=}")
|
|
|
- if token == eos_token_id:
|
|
|
+ if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
|
|
|
+ if tokens[-1] == eos_token_id:
|
|
|
finish_reason = "stop"
|
|
|
|
|
|
- return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, [token], stream, finish_reason, "chat.completion"))
|
|
|
+ return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
|
|
|
except asyncio.TimeoutError:
|
|
|
return web.json_response({"detail": "Response generation timed out"}, status=408)
|
|
|
except Exception as e:
|
|
@@ -678,8 +678,8 @@ class ChatGPTAPI:
|
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
|
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
|
|
|
|
|
|
- async def handle_token(self, request_id: str, token: int, is_finished: bool):
|
|
|
- await self.token_queues[request_id].put((token, is_finished))
|
|
|
+ async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
|
|
|
+ await self.token_queues[request_id].put((tokens, is_finished))
|
|
|
|
|
|
async def run(self, host: str = "0.0.0.0", port: int = 52415):
|
|
|
runner = web.AppRunner(self.app)
|