|
@@ -3,55 +3,17 @@ import time
|
|
import asyncio
|
|
import asyncio
|
|
import json
|
|
import json
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from transformers import AutoTokenizer, AutoProcessor
|
|
|
|
|
|
+from transformers import AutoTokenizer
|
|
from typing import List, Literal, Union, Dict
|
|
from typing import List, Literal, Union, Dict
|
|
from aiohttp import web
|
|
from aiohttp import web
|
|
import aiohttp_cors
|
|
import aiohttp_cors
|
|
import traceback
|
|
import traceback
|
|
from exo import DEBUG, VERSION
|
|
from exo import DEBUG, VERSION
|
|
-from exo.helpers import terminal_link, PrefixDict
|
|
|
|
|
|
+from exo.helpers import PrefixDict
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.shard import Shard
|
|
|
|
+from exo.inference.tokenizers import resolve_tokenizer
|
|
from exo.orchestration import Node
|
|
from exo.orchestration import Node
|
|
-
|
|
|
|
-shard_mappings = {
|
|
|
|
- ### llama
|
|
|
|
- "llama-3.1-8b": {
|
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
- },
|
|
|
|
- "llama-3.1-70b": {
|
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
|
|
|
|
- },
|
|
|
|
- "llama-3.1-405b": {
|
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
|
|
|
|
- },
|
|
|
|
- "llama-3-8b": {
|
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
- },
|
|
|
|
- "llama-3-70b": {
|
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
|
|
|
|
- },
|
|
|
|
- ### mistral
|
|
|
|
- "mistral-nemo": {
|
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
|
|
|
|
- },
|
|
|
|
- "mistral-large": {
|
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
|
|
|
|
- },
|
|
|
|
- ### deepseek v2
|
|
|
|
- "deepseek-coder-v2-lite": {
|
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
|
|
|
|
- },
|
|
|
|
- ### llava
|
|
|
|
- "llava-1.5-7b-hf": {
|
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
- },
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-
|
|
|
|
|
|
+from exo.models import model_base_shards
|
|
|
|
|
|
class Message:
|
|
class Message:
|
|
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
|
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
|
@@ -64,7 +26,6 @@ class Message:
|
|
"content": self.content
|
|
"content": self.content
|
|
}
|
|
}
|
|
|
|
|
|
-
|
|
|
|
class ChatCompletionRequest:
|
|
class ChatCompletionRequest:
|
|
def __init__(self, model: str, messages: List[Message], temperature: float):
|
|
def __init__(self, model: str, messages: List[Message], temperature: float):
|
|
self.model = model
|
|
self.model = model
|
|
@@ -78,33 +39,6 @@ class ChatCompletionRequest:
|
|
"temperature": self.temperature
|
|
"temperature": self.temperature
|
|
}
|
|
}
|
|
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-async def resolve_tokenizer(model_id: str):
|
|
|
|
- try:
|
|
|
|
- if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
|
|
|
|
- processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
|
|
|
|
- if not hasattr(processor, 'eos_token_id'):
|
|
|
|
- processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
|
|
|
|
- if not hasattr(processor, 'encode'):
|
|
|
|
- processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
|
|
|
|
- if not hasattr(processor, 'decode'):
|
|
|
|
- processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
|
|
|
|
- return processor
|
|
|
|
- except Exception as e:
|
|
|
|
- if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
|
|
|
|
- if DEBUG >= 4: print(traceback.format_exc())
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
|
|
|
|
- return AutoTokenizer.from_pretrained(model_id)
|
|
|
|
- except Exception as e:
|
|
|
|
- if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
|
|
|
|
- if DEBUG >= 4: print(traceback.format_exc())
|
|
|
|
-
|
|
|
|
- raise ValueError(f"[TODO] Unsupported model: {model_id}")
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def generate_completion(
|
|
def generate_completion(
|
|
chat_request: ChatCompletionRequest,
|
|
chat_request: ChatCompletionRequest,
|
|
tokenizer,
|
|
tokenizer,
|
|
@@ -257,7 +191,7 @@ class ChatGPTAPI:
|
|
|
|
|
|
async def handle_post_chat_token_encode(self, request):
|
|
async def handle_post_chat_token_encode(self, request):
|
|
data = await request.json()
|
|
data = await request.json()
|
|
- shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
|
|
|
|
|
|
+ shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
|
|
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
|
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
|
|
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
|
|
@@ -269,12 +203,12 @@ class ChatGPTAPI:
|
|
chat_request = parse_chat_request(data)
|
|
chat_request = parse_chat_request(data)
|
|
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
|
|
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
|
|
chat_request.model = "llama-3.1-8b"
|
|
chat_request.model = "llama-3.1-8b"
|
|
- if not chat_request.model or chat_request.model not in shard_mappings:
|
|
|
|
- if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
|
|
|
|
|
|
+ if not chat_request.model or chat_request.model not in model_base_shards:
|
|
|
|
+ if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b")
|
|
chat_request.model = "llama-3.1-8b"
|
|
chat_request.model = "llama-3.1-8b"
|
|
- shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
|
|
|
|
|
|
+ shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
|
|
if not shard:
|
|
if not shard:
|
|
- supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
|
|
|
|
|
|
+ supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
|
|
return web.json_response(
|
|
return web.json_response(
|
|
{"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
|
|
{"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
|
|
status=400,
|
|
status=400,
|