|
@@ -278,7 +278,13 @@ class ChatGPTAPI:
|
|
|
|
|
|
async def handle_post_chat_token_encode(self, request):
|
|
|
data = await request.json()
|
|
|
- shard = build_base_shard(self.default_model, self.inference_engine_classname)
|
|
|
+ model = data.get("model", self.default_model)
|
|
|
+ if model and model.startswith("gpt-"): # Handle gpt- model requests
|
|
|
+ model = self.default_model
|
|
|
+ if not model or model not in model_cards:
|
|
|
+ if DEBUG >= 1: print(f"Invalid model: {model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
|
|
+ model = self.default_model
|
|
|
+ shard = build_base_shard(model, self.inference_engine_classname)
|
|
|
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
|
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
|