Browse Source

Merge pull request #549 from exo-explore/fixtokenencode

Fix token encode endpoint
Alex Cheema 6 months ago
parent
commit
db9de97fa6
1 changed files with 15 additions and 2 deletions
  1. 15 2
      exo/api/chatgpt_api.py

+ 15 - 2
exo/api/chatgpt_api.py

@@ -278,10 +278,23 @@ class ChatGPTAPI:
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
-    shard = build_base_shard(self.default_model, self.inference_engine_classname)
+    model = data.get("model", self.default_model)
+    if model and model.startswith("gpt-"):  # Handle gpt- model requests
+      model = self.default_model
+    if not model or model not in model_cards:
+      if DEBUG >= 1: print(f"Invalid model: {model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
+      model = self.default_model
+    shard = build_base_shard(model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
-    return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
+    prompt = build_prompt(tokenizer, messages)
+    tokens = tokenizer.encode(prompt)
+    return web.json_response({
+      "length": len(prompt),
+      "num_tokens": len(tokens),
+      "encoded_tokens": tokens,
+      "encoded_prompt": prompt,
+    })
 
   async def handle_get_download_progress(self, request):
     progress_data = {}