Browse Source

fix issues with chatgpt api where it would generate too long output. avoid nonlocal

Alex Cheema 1 year ago
parent
commit
dd09c59719
1 changed files with 8 additions and 10 deletions
  1. 8 10
      exo/api/chatgpt_api.py

+ 8 - 10
exo/api/chatgpt_api.py

@@ -4,7 +4,7 @@ import asyncio
 import json
 from pathlib import Path
 from transformers import AutoTokenizer
-from typing import List, Literal, Union
+from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 from exo import DEBUG, VERSION
@@ -122,6 +122,8 @@ class ChatGPTAPI:
         self.inference_engine_classname = inference_engine_classname
         self.response_timeout_secs = 90
         self.app = web.Application()
+        self.prev_token_lens: Dict[str, int] = {}
+        self.stream_tasks: Dict[str, asyncio.Task] = {}
         cors = aiohttp_cors.setup(self.app)
         cors_options = aiohttp_cors.ResourceOptions(
             allow_credentials=True,
@@ -191,12 +193,9 @@ class ChatGPTAPI:
                 )
                 await response.prepare(request)
 
-                stream_task = None
-                last_tokens_len = 0
                 async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
-                    nonlocal last_tokens_len
-                    prev_last_tokens_len = last_tokens_len
-                    last_tokens_len = len(tokens)
+                    prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
+                    self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
                     new_tokens = tokens[prev_last_tokens_len:]
                     finish_reason = None
                     eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
@@ -211,15 +210,14 @@ class ChatGPTAPI:
                     if DEBUG >= 2: print(f"Streaming completion: {completion}")
                     await response.write(f"data: {json.dumps(completion)}\n\n".encode())
                 def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-                    nonlocal stream_task
-                    stream_task = asyncio.create_task(stream_result(request_id, tokens, is_finished))
+                    self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
 
                     return _request_id == request_id and is_finished
                 _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
-                if stream_task: # in case there is still a stream task running, wait for it to complete
+                if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
                     if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
                     try:
-                        await asyncio.wait_for(stream_task, timeout=30)
+                        await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
                     except asyncio.TimeoutError:
                         print("WARNING: Stream task timed out. This should not happen.")
                 await response.write_eof()