瀏覽代碼

more robust message parsing fixes #81

Alex Cheema 9 月之前
父節點
當前提交
03fe7a058c
共有 1 個文件被更改,包括 15 次插入5 次删除
  1. 15 5
      exo/api/chatgpt_api.py

+ 15 - 5
exo/api/chatgpt_api.py

@@ -24,7 +24,7 @@ shard_mappings = {
         "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
     },
     "llama-3.1-405b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
+        "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
     },
     "llama-3-70b": {
         "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
@@ -124,6 +124,17 @@ def build_prompt(tokenizer, messages: List[Message]):
         messages, tokenize=False, add_generation_prompt=True
     )
 
+def parse_message(data: dict):
+    if 'role' not in data or 'content' not in data:
+        raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
+    return Message(data['role'], data['content'])
+
+def parse_chat_request(data: dict):
+    return ChatCompletionRequest(
+        data.get('model', 'llama-3.1-8b'),
+        [parse_message(msg) for msg in data['messages']],
+        data.get('temperature', 0.0)
+    )
 
 class ChatGPTAPI:
     def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
@@ -156,15 +167,14 @@ class ChatGPTAPI:
     async def handle_post_chat_token_encode(self, request):
         data = await request.json()
         shard = shard_mappings.get(data.get('model', 'llama-3.1-8b'), {}).get(self.inference_engine_classname)
-        messages = data.get('messages', [])
+        messages = [parse_message(msg) for msg in data.get('messages', [])]
         tokenizer = await resolve_tokenizer(shard.model_id)
         return web.json_response({'length': len(build_prompt(tokenizer, messages))})
 
     async def handle_post_chat_completions(self, request):
         data = await request.json()
         stream = data.get('stream', False)
-        messages = [Message(**msg) for msg in data['messages']]
-        chat_request = ChatCompletionRequest(data.get('model', 'llama-3.1-8b'), messages, data.get('temperature', 0.0))
+        chat_request = parse_chat_request(data)
         if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
             chat_request.model = "llama-3.1-8b"
         shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
@@ -175,7 +185,7 @@ class ChatGPTAPI:
         tokenizer = await resolve_tokenizer(shard.model_id)
         if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
-        prompt = build_prompt(tokenizer, messages)
+        prompt = build_prompt(tokenizer, chat_request.messages)
         callback_id = f"chatgpt-api-wait-response-{request_id}"
         callback = self.node.on_token.register(callback_id)