|
@@ -278,10 +278,23 @@ class ChatGPTAPI:
|
|
|
|
|
|
async def handle_post_chat_token_encode(self, request):
|
|
|
data = await request.json()
|
|
|
- shard = build_base_shard(self.default_model, self.inference_engine_classname)
|
|
|
+ model = data.get("model", self.default_model)
|
|
|
+ if model and model.startswith("gpt-"): # Handle gpt- model requests
|
|
|
+ model = self.default_model
|
|
|
+ if not model or model not in model_cards:
|
|
|
+ if DEBUG >= 1: print(f"Invalid model: {model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
|
|
+ model = self.default_model
|
|
|
+ shard = build_base_shard(model, self.inference_engine_classname)
|
|
|
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
|
- return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
|
|
|
+ prompt = build_prompt(tokenizer, messages)
|
|
|
+ tokens = tokenizer.encode(prompt)
|
|
|
+ return web.json_response({
|
|
|
+ "length": len(prompt),
|
|
|
+ "num_tokens": len(tokens),
|
|
|
+ "encoded_tokens": tokens,
|
|
|
+ "encoded_prompt": prompt,
|
|
|
+ })
|
|
|
|
|
|
async def handle_get_download_progress(self, request):
|
|
|
progress_data = {}
|