Jelajahi Sumber

only stream results for the same request id. this allows multiple concurrent requests on the same LLM without overlapping interference in the streamed outputs

Alex Cheema 10 bulan lalu
induk
melakukan
9b9f40d470
1 mengubah file dengan 4 tambahan dan 4 penghapusan
  1. 4 4
      exo/api/chatgpt_api.py

+ 4 - 4
exo/api/chatgpt_api.py

@@ -285,9 +285,9 @@ class ChatGPTAPI:
         )
         await response.prepare(request)
 
-        async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
-          prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
-          self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
+        async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
+          prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
+          self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
           eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
@@ -317,7 +317,7 @@ class ChatGPTAPI:
             if DEBUG >= 2: traceback.print_exc()
 
         def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-          self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
+          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
 
           return _request_id == request_id and is_finished