Browse Source

fix issue with eos_token_id

Alex Cheema 5 months ago
parent
commit
3a4bae0dab
1 changed files with 9 additions and 8 deletions
  1. 9 8
      exo/api/chatgpt_api.py

+ 9 - 8
exo/api/chatgpt_api.py

@@ -414,14 +414,13 @@ class ChatGPTAPI:
             )
             if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")
 
-            finish_reason = None
-            eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
+            eos_token_id = None
+            if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
+            if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
 
-            if tokens[-1] == eos_token_id:
-              if is_finished:
-                finish_reason = "stop"
-            if is_finished and not finish_reason:
-              finish_reason = "length"
+            finish_reason = None
+            if is_finished: finish_reason = "stop" if tokens[-1] == eos_token_id else "length"
+            if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=} {finish_reason=}")
 
             completion = generate_completion(
               chat_request,
@@ -468,7 +467,9 @@ class ChatGPTAPI:
           if is_finished:
             break
         finish_reason = "length"
-        eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
+        eos_token_id = None
+        if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
+        if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
         if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
         if tokens[-1] == eos_token_id:
           finish_reason = "stop"