|
@@ -13,10 +13,7 @@ from exo.inference.shard import Shard
|
|
|
from exo.orchestration import Node
|
|
|
|
|
|
shard_mappings = {
|
|
|
- "llama-3-8b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
|
|
|
- },
|
|
|
+ ### llama
|
|
|
"llama-3.1-8b": {
|
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
},
|
|
@@ -24,12 +21,23 @@ shard_mappings = {
|
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
},
|
|
|
"llama-3.1-405b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
|
|
|
+ },
|
|
|
+ "llama-3-8b": {
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
+ "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
|
|
|
},
|
|
|
"llama-3-70b": {
|
|
|
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
"TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
|
|
|
},
|
|
|
+ ### mistral
|
|
|
+ "mistral-nemo": {
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
|
|
|
+ },
|
|
|
+ "mistral-large": {
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
|
|
|
+ },
|
|
|
}
|
|
|
|
|
|
class Message:
|
|
@@ -124,6 +132,17 @@ def build_prompt(tokenizer, messages: List[Message]):
|
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
|
)
|
|
|
|
|
|
+def parse_message(data: dict):
|
|
|
+ if 'role' not in data or 'content' not in data:
|
|
|
+ raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
|
|
+ return Message(data['role'], data['content'])
|
|
|
+
|
|
|
+def parse_chat_request(data: dict):
|
|
|
+ return ChatCompletionRequest(
|
|
|
+ data.get('model', 'llama-3.1-8b'),
|
|
|
+ [parse_message(msg) for msg in data['messages']],
|
|
|
+ data.get('temperature', 0.0)
|
|
|
+ )
|
|
|
|
|
|
class ChatGPTAPI:
|
|
|
def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
|
|
@@ -150,32 +169,46 @@ class ChatGPTAPI:
|
|
|
self.app.router.add_get('/', self.handle_root)
|
|
|
self.app.router.add_static('/', self.static_dir, name='static')
|
|
|
|
|
|
+ # Add middleware to log every request
|
|
|
+ self.app.middlewares.append(self.log_request)
|
|
|
+
|
|
|
+ async def log_request(self, app, handler):
|
|
|
+ async def middleware(request):
|
|
|
+ if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
|
|
|
+ return await handler(request)
|
|
|
+ return middleware
|
|
|
+
|
|
|
async def handle_root(self, request):
|
|
|
+ print(f"Handling root request from {request.remote}")
|
|
|
return web.FileResponse(self.static_dir / 'index.html')
|
|
|
|
|
|
async def handle_post_chat_token_encode(self, request):
|
|
|
data = await request.json()
|
|
|
shard = shard_mappings.get(data.get('model', 'llama-3.1-8b'), {}).get(self.inference_engine_classname)
|
|
|
- messages = data.get('messages', [])
|
|
|
+ messages = [parse_message(msg) for msg in data.get('messages', [])]
|
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
|
return web.json_response({'length': len(build_prompt(tokenizer, messages))})
|
|
|
|
|
|
async def handle_post_chat_completions(self, request):
|
|
|
data = await request.json()
|
|
|
+ if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
|
|
stream = data.get('stream', False)
|
|
|
- messages = [Message(**msg) for msg in data['messages']]
|
|
|
- chat_request = ChatCompletionRequest(data.get('model', 'llama-3.1-8b'), messages, data.get('temperature', 0.0))
|
|
|
+ chat_request = parse_chat_request(data)
|
|
|
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
|
|
|
chat_request.model = "llama-3.1-8b"
|
|
|
- shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
|
|
|
+ if not chat_request.model or chat_request.model not in shard_mappings:
|
|
|
+ if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
|
|
|
+ chat_request.model = "llama-3.1-8b"
|
|
|
+ shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
|
|
|
if not shard:
|
|
|
- return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
|
|
|
+ supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
|
|
|
+ return web.json_response({'detail': f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"}, status=400)
|
|
|
request_id = str(uuid.uuid4())
|
|
|
|
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
|
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
|
|
|
|
|
- prompt = build_prompt(tokenizer, messages)
|
|
|
+ prompt = build_prompt(tokenizer, chat_request.messages)
|
|
|
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
|
|
callback = self.node.on_token.register(callback_id)
|
|
|
|