|
@@ -138,9 +138,9 @@ def parse_message(data: dict):
|
|
|
return Message(data["role"], data["content"])
|
|
|
|
|
|
|
|
|
-def parse_chat_request(data: dict):
|
|
|
+def parse_chat_request(data: dict, default_model: str):
|
|
|
return ChatCompletionRequest(
|
|
|
- data.get("model", "llama-3.1-8b"),
|
|
|
+ data.get("model", default_model),
|
|
|
[parse_message(msg) for msg in data["messages"]],
|
|
|
data.get("temperature", 0.0),
|
|
|
)
|
|
@@ -163,6 +163,8 @@ class ChatGPTAPI:
|
|
|
self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
|
|
|
self.prev_token_lens: Dict[str, int] = {}
|
|
|
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
|
|
+ self.default_model = "llama-3.1-8b"
|
|
|
+
|
|
|
cors = aiohttp_cors.setup(self.app)
|
|
|
cors_options = aiohttp_cors.ResourceOptions(
|
|
|
allow_credentials=True,
|
|
@@ -209,7 +211,7 @@ class ChatGPTAPI:
|
|
|
|
|
|
async def handle_post_chat_token_encode(self, request):
|
|
|
data = await request.json()
|
|
|
- shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
|
|
|
+ shard = model_base_shards.get(data.get("model", self.default_model), {}).get(self.inference_engine_classname)
|
|
|
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
|
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
|
|
@@ -227,12 +229,12 @@ class ChatGPTAPI:
|
|
|
data = await request.json()
|
|
|
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
|
|
stream = data.get("stream", False)
|
|
|
- chat_request = parse_chat_request(data)
|
|
|
+ chat_request = parse_chat_request(data, self.default_model)
|
|
|
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
|
|
|
- chat_request.model = "llama-3.1-8b"
|
|
|
+ chat_request.model = self.default_model if self.default_model.startswith("llama") else "llama-3.1-8b"
|
|
|
if not chat_request.model or chat_request.model not in model_base_shards:
|
|
|
- if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b")
|
|
|
- chat_request.model = "llama-3.1-8b"
|
|
|
+ if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to {self.default_model}")
|
|
|
+ chat_request.model = self.default_model
|
|
|
shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
|
|
|
if not shard:
|
|
|
supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
|