Selaa lähdekoodia

only stream results for the same request id. this allows multiple concurrent requests on the same LLM without overlapping interference in the streamed outputs

Alex Cheema 10 kuukautta sitten
vanhempi
commit
9b9f40d470
1 muutettua tiedostoa jossa 4 lisäystä ja 4 poistoa
  1. 4 4
      exo/api/chatgpt_api.py

+ 4 - 4
exo/api/chatgpt_api.py

@@ -285,9 +285,9 @@ class ChatGPTAPI:
         )
         await response.prepare(request)
 
-        async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
-          prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
-          self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
+        async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
+          prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
+          self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
           eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
@@ -317,7 +317,7 @@ class ChatGPTAPI:
             if DEBUG >= 2: traceback.print_exc()
 
         def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-          self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
+          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
 
           return _request_id == request_id and is_finished