|
@@ -15,44 +15,28 @@ from exo.orchestration import Node
|
|
|
shard_mappings = {
|
|
|
### llama
|
|
|
"llama-3.1-8b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(
|
|
|
- model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32
|
|
|
- ),
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
},
|
|
|
"llama-3.1-70b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(
|
|
|
- model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80
|
|
|
- ),
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
},
|
|
|
"llama-3.1-405b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(
|
|
|
- model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126
|
|
|
- ),
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
|
|
|
},
|
|
|
"llama-3-8b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(
|
|
|
- model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32
|
|
|
- ),
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
|
|
"TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
|
|
|
},
|
|
|
"llama-3-70b": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(
|
|
|
- model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80
|
|
|
- ),
|
|
|
- "TinygradDynamicShardInferenceEngine": Shard(
|
|
|
- model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80
|
|
|
- ),
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
|
|
+ "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
|
|
|
},
|
|
|
### mistral
|
|
|
"mistral-nemo": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(
|
|
|
- model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40
|
|
|
- ),
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
|
|
|
},
|
|
|
"mistral-large": {
|
|
|
- "MLXDynamicShardInferenceEngine": Shard(
|
|
|
- model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88
|
|
|
- ),
|
|
|
+ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
|
|
|
},
|
|
|
### deepseek v2
|
|
|
"deepseek-coder-v2-lite": {
|
|
@@ -82,9 +66,7 @@ def resolve_tinygrad_tokenizer(model_id: str):
|
|
|
elif model_id == "llama3-70b-sfr":
|
|
|
return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
|
|
|
else:
|
|
|
- raise ValueError(
|
|
|
- f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}"
|
|
|
- )
|
|
|
+ raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
|
|
|
|
|
|
|
|
|
async def resolve_tokenizer(model_id: str):
|
|
@@ -190,12 +172,8 @@ class ChatGPTAPI:
|
|
|
allow_headers="*",
|
|
|
allow_methods="*",
|
|
|
)
|
|
|
- cors.add(
|
|
|
- self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options}
|
|
|
- )
|
|
|
- cors.add(
|
|
|
- self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options}
|
|
|
- )
|
|
|
+ cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
|
+ cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
|
|
self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
|
|
|
self.app.router.add_get("/", self.handle_root)
|
|
|
self.app.router.add_static("/", self.static_dir, name="static")
|
|
@@ -226,22 +204,16 @@ class ChatGPTAPI:
|
|
|
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
|
|
stream = data.get("stream", False)
|
|
|
chat_request = parse_chat_request(data)
|
|
|
- if chat_request.model and chat_request.model.startswith(
|
|
|
- "gpt-"
|
|
|
- ): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
|
|
|
+ if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
|
|
|
chat_request.model = "llama-3.1-8b"
|
|
|
if not chat_request.model or chat_request.model not in shard_mappings:
|
|
|
if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
|
|
|
chat_request.model = "llama-3.1-8b"
|
|
|
shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
|
|
|
if not shard:
|
|
|
- supported_models = [
|
|
|
- model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines
|
|
|
- ]
|
|
|
+ supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
|
|
|
return web.json_response(
|
|
|
- {
|
|
|
- "detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"
|
|
|
- },
|
|
|
+ {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
|
|
|
status=400,
|
|
|
)
|
|
|
request_id = str(uuid.uuid4())
|
|
@@ -261,9 +233,7 @@ class ChatGPTAPI:
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
- return web.json_response(
|
|
|
- {"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500
|
|
|
- )
|
|
|
+ return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
|
|
|
|
|
try:
|
|
|
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
|
|
@@ -284,11 +254,7 @@ class ChatGPTAPI:
|
|
|
self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
|
|
|
new_tokens = tokens[prev_last_tokens_len:]
|
|
|
finish_reason = None
|
|
|
- eos_token_id = (
|
|
|
- tokenizer.special_tokens_map.get("eos_token_id")
|
|
|
- if isinstance(tokenizer._tokenizer, AutoTokenizer)
|
|
|
- else tokenizer.eos_token_id
|
|
|
- )
|
|
|
+ eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
|
|
|
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
|
|
|
new_tokens = new_tokens[:-1]
|
|
|
if is_finished:
|
|
@@ -315,9 +281,7 @@ class ChatGPTAPI:
|
|
|
return _request_id == request_id and is_finished
|
|
|
|
|
|
_, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
|
|
|
- if (
|
|
|
- request_id in self.stream_tasks
|
|
|
- ): # in case there is still a stream task running, wait for it to complete
|
|
|
+ if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
|
|
|
if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
|
|
|
try:
|
|
|
await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
|
|
@@ -332,21 +296,13 @@ class ChatGPTAPI:
|
|
|
)
|
|
|
|
|
|
finish_reason = "length"
|
|
|
- eos_token_id = (
|
|
|
- tokenizer.special_tokens_map.get("eos_token_id")
|
|
|
- if isinstance(tokenizer._tokenizer, AutoTokenizer)
|
|
|
- else tokenizer.eos_token_id
|
|
|
- )
|
|
|
+ eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
|
|
|
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
|
|
|
if tokens[-1] == eos_token_id:
|
|
|
tokens = tokens[:-1]
|
|
|
finish_reason = "stop"
|
|
|
|
|
|
- return web.json_response(
|
|
|
- generate_completion(
|
|
|
- chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"
|
|
|
- )
|
|
|
- )
|
|
|
+ return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
|
|
|
except asyncio.TimeoutError:
|
|
|
return web.json_response({"detail": "Response generation timed out"}, status=408)
|
|
|
finally:
|
|
@@ -359,7 +315,5 @@ class ChatGPTAPI:
|
|
|
site = web.TCPSite(runner, host, port)
|
|
|
await site.start()
|
|
|
if DEBUG >= 0:
|
|
|
- print(
|
|
|
- f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}"
|
|
|
- )
|
|
|
+ print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
|
|
|
print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")
|