|
@@ -1,6 +1,7 @@
|
|
import uuid
|
|
import uuid
|
|
import time
|
|
import time
|
|
import asyncio
|
|
import asyncio
|
|
|
|
+from transformers import AutoTokenizer
|
|
from typing import List
|
|
from typing import List
|
|
from aiohttp import web
|
|
from aiohttp import web
|
|
from exo import DEBUG
|
|
from exo import DEBUG
|
|
@@ -8,8 +9,14 @@ from exo.inference.shard import Shard
|
|
from exo.orchestration import Node
|
|
from exo.orchestration import Node
|
|
|
|
|
|
shard_mappings = {
|
|
shard_mappings = {
|
|
- "llama-3-8b": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
- "llama-3-70b": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
|
|
|
+ "llama-3-8b": {
|
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
+ "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
|
|
|
|
+ },
|
|
|
|
+ "llama-3-70b": {
|
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
|
+ "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
|
|
|
|
+ },
|
|
}
|
|
}
|
|
|
|
|
|
class Message:
|
|
class Message:
|
|
@@ -23,25 +30,54 @@ class ChatCompletionRequest:
|
|
self.messages = messages
|
|
self.messages = messages
|
|
self.temperature = temperature
|
|
self.temperature = temperature
|
|
|
|
|
|
|
|
+def resolve_tinygrad_tokenizer(model_id: str):
|
|
|
|
+ if model_id == "llama3-8b-sfr":
|
|
|
|
+ return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
|
|
|
|
+ elif model_id == "llama3-70b-sfr":
|
|
|
|
+ return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
|
|
|
|
+
|
|
|
|
+def resolve_tokenizer(model_id: str):
|
|
|
|
+ try:
|
|
|
|
+ if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
|
|
|
|
+ return AutoTokenizer.from_pretrained(model_id)
|
|
|
|
+ except:
|
|
|
|
+ import traceback
|
|
|
|
+ if DEBUG >= 2: print(traceback.format_exc())
|
|
|
|
+ if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer")
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
|
|
|
|
+ return resolve_tinygrad_tokenizer(model_id)
|
|
|
|
+ except:
|
|
|
|
+ import traceback
|
|
|
|
+ if DEBUG >= 2: print(traceback.format_exc())
|
|
|
|
+ if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer")
|
|
|
|
+
|
|
|
|
+ if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
|
|
|
|
+ from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
|
|
|
+ return load_tokenizer(get_model_path(model_id))
|
|
|
|
+
|
|
class ChatGPTAPI:
|
|
class ChatGPTAPI:
|
|
- def __init__(self, node: Node):
|
|
|
|
|
|
+ def __init__(self, node: Node, inference_engine_classname: str):
|
|
self.node = node
|
|
self.node = node
|
|
self.app = web.Application()
|
|
self.app = web.Application()
|
|
self.app.router.add_post('/v1/chat/completions', self.handle_post)
|
|
self.app.router.add_post('/v1/chat/completions', self.handle_post)
|
|
|
|
+ self.inference_engine_classname = inference_engine_classname
|
|
|
|
|
|
async def handle_post(self, request):
|
|
async def handle_post(self, request):
|
|
data = await request.json()
|
|
data = await request.json()
|
|
messages = [Message(**msg) for msg in data['messages']]
|
|
messages = [Message(**msg) for msg in data['messages']]
|
|
chat_request = ChatCompletionRequest(data['model'], messages, data['temperature'])
|
|
chat_request = ChatCompletionRequest(data['model'], messages, data['temperature'])
|
|
prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
|
|
prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
|
|
- shard = shard_mappings.get(chat_request.model)
|
|
|
|
|
|
+ shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
|
|
if not shard:
|
|
if not shard:
|
|
return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
|
|
return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
|
|
request_id = str(uuid.uuid4())
|
|
request_id = str(uuid.uuid4())
|
|
|
|
|
|
- # TODO equivalent for non-mlx since the user can't install this on non-macs even though the tokenizer itself
|
|
|
|
- from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
|
|
|
- tokenizer = load_tokenizer(get_model_path(shard.model_id))
|
|
|
|
|
|
+ tokenizer = resolve_tokenizer(shard.model_id)
|
|
|
|
+ if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
|
prompt = tokenizer.apply_chat_template(
|
|
prompt = tokenizer.apply_chat_template(
|
|
chat_request.messages, tokenize=False, add_generation_prompt=True
|
|
chat_request.messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
)
|
|
@@ -63,7 +99,9 @@ class ChatGPTAPI:
|
|
continue
|
|
continue
|
|
await asyncio.sleep(0.1)
|
|
await asyncio.sleep(0.1)
|
|
if is_finished:
|
|
if is_finished:
|
|
- if result[-1] == tokenizer._tokenizer.eos_token_id:
|
|
|
|
|
|
+ eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
|
|
|
|
+ if DEBUG >= 2: print(f"Checking if end of result {result[-1]=} is {eos_token_id=}")
|
|
|
|
+ if result[-1] == eos_token_id:
|
|
result = result[:-1]
|
|
result = result[:-1]
|
|
return web.json_response({
|
|
return web.json_response({
|
|
"id": f"chatcmpl-{request_id}",
|
|
"id": f"chatcmpl-{request_id}",
|