|
@@ -65,6 +65,7 @@ class ChatGPTAPI:
|
|
|
self.app = web.Application()
|
|
|
self.app.router.add_post('/v1/chat/completions', self.handle_post)
|
|
|
self.inference_engine_classname = inference_engine_classname
|
|
|
+ self.response_timeout_secs = 90
|
|
|
|
|
|
async def handle_post(self, request):
|
|
|
data = await request.json()
|
|
@@ -84,49 +85,52 @@ class ChatGPTAPI:
|
|
|
|
|
|
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
|
|
|
try:
|
|
|
- result = await self.node.process_prompt(shard, prompt, request_id=request_id)
|
|
|
+ await self.node.process_prompt(shard, prompt, request_id=request_id)
|
|
|
except Exception as e:
|
|
|
- pass # TODO
|
|
|
- # return web.json_response({'detail': str(e)}, status=500)
|
|
|
-
|
|
|
- # poll for the response. TODO: implement callback for specific request id
|
|
|
- timeout = 90
|
|
|
- start_time = time.time()
|
|
|
- while time.time() - start_time < timeout:
|
|
|
- try:
|
|
|
- result, is_finished = await self.node.get_inference_result(request_id)
|
|
|
- except Exception as e:
|
|
|
- continue
|
|
|
- await asyncio.sleep(0.1)
|
|
|
- if is_finished:
|
|
|
- eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
|
|
|
- if DEBUG >= 2: print(f"Checking if end of result {result[-1]=} is {eos_token_id=}")
|
|
|
- if result[-1] == eos_token_id:
|
|
|
- result = result[:-1]
|
|
|
- return web.json_response({
|
|
|
- "id": f"chatcmpl-{request_id}",
|
|
|
- "object": "chat.completion",
|
|
|
- "created": int(time.time()),
|
|
|
- "model": chat_request.model,
|
|
|
- "usage": {
|
|
|
- "prompt_tokens": len(tokenizer.encode(prompt)),
|
|
|
- "completion_tokens": len(result),
|
|
|
- "total_tokens": len(tokenizer.encode(prompt)) + len(result)
|
|
|
- },
|
|
|
- "choices": [
|
|
|
- {
|
|
|
- "message": {
|
|
|
- "role": "assistant",
|
|
|
- "content": tokenizer.decode(result)
|
|
|
- },
|
|
|
- "logprobs": None,
|
|
|
- "finish_reason": "stop",
|
|
|
- "index": 0
|
|
|
- }
|
|
|
- ]
|
|
|
- })
|
|
|
-
|
|
|
- return web.json_response({'detail': "Response generation timed out"}, status=408)
|
|
|
+ if DEBUG >= 2:
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+ return web.json_response({'detail': f"Error processing prompt (see logs): {str(e)}"}, status=500)
|
|
|
+
|
|
|
+ callback_id = f"chatgpt-api-wait-response-{request_id}"
|
|
|
+ callback = self.node.on_token.register(callback_id)
|
|
|
+
|
|
|
+ try:
|
|
|
+ if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
|
|
|
+ _, result, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=self.response_timeout_secs)
|
|
|
+
|
|
|
+ eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
|
|
|
+ if DEBUG >= 2: print(f"Checking if end of result {result[-1]=} is {eos_token_id=}")
|
|
|
+ if result[-1] == eos_token_id:
|
|
|
+ result = result[:-1]
|
|
|
+
|
|
|
+ return web.json_response({
|
|
|
+ "id": f"chatcmpl-{request_id}",
|
|
|
+ "object": "chat.completion",
|
|
|
+ "created": int(time.time()),
|
|
|
+ "model": chat_request.model,
|
|
|
+ "usage": {
|
|
|
+ "prompt_tokens": len(tokenizer.encode(prompt)),
|
|
|
+ "completion_tokens": len(result),
|
|
|
+ "total_tokens": len(tokenizer.encode(prompt)) + len(result)
|
|
|
+ },
|
|
|
+ "choices": [
|
|
|
+ {
|
|
|
+ "message": {
|
|
|
+ "role": "assistant",
|
|
|
+ "content": tokenizer.decode(result)
|
|
|
+ },
|
|
|
+ "logprobs": None,
|
|
|
+ "finish_reason": "stop",
|
|
|
+ "index": 0
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ })
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ return web.json_response({'detail': "Response generation timed out"}, status=408)
|
|
|
+ finally:
|
|
|
+ deregistered_callback = self.node.on_token.deregister(callback_id)
|
|
|
+ if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
|
|
|
|
|
|
async def run(self, host: str = "0.0.0.0", port: int = 8000):
|
|
|
runner = web.AppRunner(self.app)
|