|
@@ -24,7 +24,7 @@ shard_mappings = {
|
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
},
|
|
|
"llama-3.1-405b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
|
|
|
},
|
|
|
"llama-3-70b": {
|
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
@@ -124,6 +124,17 @@ def build_prompt(tokenizer, messages: List[Message]):
|
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
|
)
|
|
|
|
|
|
+def parse_message(data: dict):
|
|
|
+ if 'role' not in data or 'content' not in data:
|
|
|
+ raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
|
|
+ return Message(data['role'], data['content'])
|
|
|
+
|
|
|
+def parse_chat_request(data: dict):
|
|
|
+ return ChatCompletionRequest(
|
|
|
+ data.get('model', 'llama-3.1-8b'),
|
|
|
+ [parse_message(msg) for msg in data['messages']],
|
|
|
+ data.get('temperature', 0.0)
|
|
|
+ )
|
|
|
|
|
|
class ChatGPTAPI:
|
|
|
def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
|
|
@@ -156,15 +167,14 @@ class ChatGPTAPI:
|
|
|
async def handle_post_chat_token_encode(self, request):
|
|
|
data = await request.json()
|
|
|
shard = shard_mappings.get(data.get('model', 'llama-3.1-8b'), {}).get(self.inference_engine_classname)
|
|
|
- messages = data.get('messages', [])
|
|
|
+ messages = [parse_message(msg) for msg in data.get('messages', [])]
|
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
|
return web.json_response({'length': len(build_prompt(tokenizer, messages))})
|
|
|
|
|
|
async def handle_post_chat_completions(self, request):
|
|
|
data = await request.json()
|
|
|
stream = data.get('stream', False)
|
|
|
- messages = [Message(**msg) for msg in data['messages']]
|
|
|
- chat_request = ChatCompletionRequest(data.get('model', 'llama-3.1-8b'), messages, data.get('temperature', 0.0))
|
|
|
+ chat_request = parse_chat_request(data)
|
|
|
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
|
|
|
chat_request.model = "llama-3.1-8b"
|
|
|
shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
|
|
@@ -175,7 +185,7 @@ class ChatGPTAPI:
|
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
|
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
|
|
|
|
|
- prompt = build_prompt(tokenizer, messages)
|
|
|
+ prompt = build_prompt(tokenizer, chat_request.messages)
|
|
|
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
|
|
callback = self.node.on_token.register(callback_id)
|
|
|
|