|
@@ -4,7 +4,7 @@ import asyncio
|
|
import json
|
|
import json
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
-from typing import List, Literal, Union
|
|
|
|
|
|
+from typing import List, Literal, Union, Dict
|
|
from aiohttp import web
|
|
from aiohttp import web
|
|
import aiohttp_cors
|
|
import aiohttp_cors
|
|
from exo import DEBUG, VERSION
|
|
from exo import DEBUG, VERSION
|
|
@@ -122,6 +122,8 @@ class ChatGPTAPI:
|
|
self.inference_engine_classname = inference_engine_classname
|
|
self.inference_engine_classname = inference_engine_classname
|
|
self.response_timeout_secs = 90
|
|
self.response_timeout_secs = 90
|
|
self.app = web.Application()
|
|
self.app = web.Application()
|
|
|
|
+ self.prev_token_lens: Dict[str, int] = {}
|
|
|
|
+ self.stream_tasks: Dict[str, asyncio.Task] = {}
|
|
cors = aiohttp_cors.setup(self.app)
|
|
cors = aiohttp_cors.setup(self.app)
|
|
cors_options = aiohttp_cors.ResourceOptions(
|
|
cors_options = aiohttp_cors.ResourceOptions(
|
|
allow_credentials=True,
|
|
allow_credentials=True,
|
|
@@ -191,12 +193,9 @@ class ChatGPTAPI:
|
|
)
|
|
)
|
|
await response.prepare(request)
|
|
await response.prepare(request)
|
|
|
|
|
|
- stream_task = None
|
|
|
|
- last_tokens_len = 0
|
|
|
|
async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
|
|
async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
|
|
- nonlocal last_tokens_len
|
|
|
|
- prev_last_tokens_len = last_tokens_len
|
|
|
|
- last_tokens_len = len(tokens)
|
|
|
|
|
|
+ prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
|
|
|
|
+ self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
|
|
new_tokens = tokens[prev_last_tokens_len:]
|
|
new_tokens = tokens[prev_last_tokens_len:]
|
|
finish_reason = None
|
|
finish_reason = None
|
|
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
|
|
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
|
|
@@ -211,15 +210,14 @@ class ChatGPTAPI:
|
|
if DEBUG >= 2: print(f"Streaming completion: {completion}")
|
|
if DEBUG >= 2: print(f"Streaming completion: {completion}")
|
|
await response.write(f"data: {json.dumps(completion)}\n\n".encode())
|
|
await response.write(f"data: {json.dumps(completion)}\n\n".encode())
|
|
def on_result(_request_id: str, tokens: List[int], is_finished: bool):
|
|
def on_result(_request_id: str, tokens: List[int], is_finished: bool):
|
|
- nonlocal stream_task
|
|
|
|
- stream_task = asyncio.create_task(stream_result(request_id, tokens, is_finished))
|
|
|
|
|
|
+ self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
|
|
|
|
|
|
return _request_id == request_id and is_finished
|
|
return _request_id == request_id and is_finished
|
|
_, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
|
|
_, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
|
|
- if stream_task: # in case there is still a stream task running, wait for it to complete
|
|
|
|
|
|
+ if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
|
|
if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
|
|
if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
|
|
try:
|
|
try:
|
|
- await asyncio.wait_for(stream_task, timeout=30)
|
|
|
|
|
|
+ await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
|
|
except asyncio.TimeoutError:
|
|
except asyncio.TimeoutError:
|
|
print("WARNING: Stream task timed out. This should not happen.")
|
|
print("WARNING: Stream task timed out. This should not happen.")
|
|
await response.write_eof()
|
|
await response.write_eof()
|