Просмотр исходного кода

only stream results for the same request id. this allows multiple concurrent requests on the same LLM without overlapping interference in the streamed outputs

Alex Cheema 10 месяцев назад
Родитель
Сommit
9b9f40d470
1 измененных файлов с 4 добавлено и 4 удалено
  1. 4 4
      exo/api/chatgpt_api.py

+ 4 - 4
exo/api/chatgpt_api.py

@@ -285,9 +285,9 @@ class ChatGPTAPI:
         )
         await response.prepare(request)
 
-        async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
-          prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
-          self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
+        async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
+          prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
+          self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
           eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
@@ -317,7 +317,7 @@ class ChatGPTAPI:
             if DEBUG >= 2: traceback.print_exc()
 
         def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-          self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
+          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
 
           return _request_id == request_id and is_finished