Browse Source

add a default_model in ChatGPTAPI

Mukund Mauji 7 months ago
parent
commit
f523713417
1 changed files with 9 additions and 7 deletions
  1. 9 7
      exo/api/chatgpt_api.py

+ 9 - 7
exo/api/chatgpt_api.py

@@ -138,9 +138,9 @@ def parse_message(data: dict):
   return Message(data["role"], data["content"])
 
 
-def parse_chat_request(data: dict):
+def parse_chat_request(data: dict, default_model: str):
   return ChatCompletionRequest(
-    data.get("model", "llama-3.1-8b"),
+    data.get("model", default_model),
     [parse_message(msg) for msg in data["messages"]],
     data.get("temperature", 0.0),
   )
@@ -163,6 +163,8 @@ class ChatGPTAPI:
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
+    self.default_model = "llama-3.1-8b"
+
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
       allow_credentials=True,
@@ -209,7 +211,7 @@ class ChatGPTAPI:
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
-    shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
+    shard = model_base_shards.get(data.get("model", self.default_model), {}).get(self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(shard.model_id)
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
@@ -227,12 +229,12 @@ class ChatGPTAPI:
     data = await request.json()
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
-    chat_request = parse_chat_request(data)
+    chat_request = parse_chat_request(data, self.default_model)
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
-      chat_request.model = "llama-3.1-8b"
+      chat_request.model = self.default_model if self.default_model.startswith("llama") else "llama-3.1-8b"
     if not chat_request.model or chat_request.model not in model_base_shards:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b")
-      chat_request.model = "llama-3.1-8b"
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to {self.default_model}")
+      chat_request.model = self.default_model
     shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
     if not shard:
       supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]