Browse Source

fix token encode to use the right model

Alex Cheema 6 months ago
parent
commit
9f86737a94
1 changed files with 7 additions and 1 deletions
  1. 7 1
      exo/api/chatgpt_api.py

+ 7 - 1
exo/api/chatgpt_api.py

@@ -278,7 +278,13 @@ class ChatGPTAPI:
 
 
   async def handle_post_chat_token_encode(self, request):
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
     data = await request.json()
-    shard = build_base_shard(self.default_model, self.inference_engine_classname)
+    model = data.get("model", self.default_model)
+    if model and model.startswith("gpt-"):  # Handle gpt- model requests
+      model = self.default_model
+    if not model or model not in model_cards:
+      if DEBUG >= 1: print(f"Invalid model: {model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
+      model = self.default_model
+    shard = build_base_shard(model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})