فهرست منبع

add a cli that can be triggered with --run-model <model> --prompt <prompt>

Alex Cheema 11 ماه پیش
والد
کامیت
e84304317c
6فایلهای تغییر یافته به همراه137 افزوده شده و 94 حذف شده
  1. 9 75
      exo/api/chatgpt_api.py
  2. 0 15
      exo/helpers.py
  3. 16 0
      exo/inference/inference_engine.py
  4. 27 0
      exo/inference/tokenizers.py
  5. 40 0
      exo/models.py
  6. 45 4
      main.py

+ 9 - 75
exo/api/chatgpt_api.py

@@ -3,55 +3,17 @@ import time
 import asyncio
 import json
 from pathlib import Path
-from transformers import AutoTokenizer, AutoProcessor
+from transformers import AutoTokenizer
 from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 import traceback
 from exo import DEBUG, VERSION
-from exo.helpers import terminal_link, PrefixDict
+from exo.helpers import PrefixDict
 from exo.inference.shard import Shard
+from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-
-shard_mappings = {
-  ### llama
-  "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "llama-3.1-405b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
-  },
-  "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
-  },
-  ### mistral
-  "mistral-nemo": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
-  },
-  "mistral-large": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
-  },
-  ### deepseek v2
-  "deepseek-coder-v2-lite": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
-  },
-  ### llava
-  "llava-1.5-7b-hf": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
-  },
-}
-
-
+from exo.models import model_base_shards
 
 class Message:
     def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -64,7 +26,6 @@ class Message:
             "content": self.content
         }
 
-
 class ChatCompletionRequest:
     def __init__(self, model: str, messages: List[Message], temperature: float):
         self.model = model
@@ -78,33 +39,6 @@ class ChatCompletionRequest:
             "temperature": self.temperature
         }
 
-
-
-async def resolve_tokenizer(model_id: str):
-  try:
-    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
-    processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
-    if not hasattr(processor, 'eos_token_id'):
-      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
-    if not hasattr(processor, 'encode'):
-      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
-    if not hasattr(processor, 'decode'):
-      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
-    return processor
-  except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
-    if DEBUG >= 4: print(traceback.format_exc())
-
-  try:
-    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
-    return AutoTokenizer.from_pretrained(model_id)
-  except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
-    if DEBUG >= 4: print(traceback.format_exc())
-
-  raise ValueError(f"[TODO] Unsupported model: {model_id}")
-
-
 def generate_completion(
   chat_request: ChatCompletionRequest,
   tokenizer,
@@ -257,7 +191,7 @@ class ChatGPTAPI:
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
-    shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
+    shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(shard.model_id)
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
@@ -269,12 +203,12 @@ class ChatGPTAPI:
     chat_request = parse_chat_request(data)
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
       chat_request.model = "llama-3.1-8b"
-    if not chat_request.model or chat_request.model not in shard_mappings:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
+    if not chat_request.model or chat_request.model not in model_base_shards:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b")
       chat_request.model = "llama-3.1-8b"
-    shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
+    shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
     if not shard:
-      supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
+      supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
       return web.json_response(
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,

+ 0 - 15
exo/helpers.py

@@ -32,21 +32,6 @@ def get_system_info():
   return "Non-Mac, non-Linux system"
 
 
-def get_inference_engine(inference_engine_name, shard_downloader: 'ShardDownloader'):
-  if inference_engine_name == "mlx":
-    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-
-    return MLXDynamicShardInferenceEngine(shard_downloader)
-  elif inference_engine_name == "tinygrad":
-    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-    import tinygrad.helpers
-    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-
-    return TinygradDynamicShardInferenceEngine(shard_downloader)
-  else:
-    raise ValueError(f"Inference engine {inference_engine_name} not supported")
-
-
 def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
   used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports")
 

+ 16 - 0
exo/inference/inference_engine.py

@@ -1,4 +1,5 @@
 import numpy as np
+import os
 
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
@@ -12,3 +13,18 @@ class InferenceEngine(ABC):
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
+
+
+def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+  if inference_engine_name == "mlx":
+    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+
+    return MLXDynamicShardInferenceEngine(shard_downloader)
+  elif inference_engine_name == "tinygrad":
+    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+    import tinygrad.helpers
+    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
+
+    return TinygradDynamicShardInferenceEngine(shard_downloader)
+  else:
+    raise ValueError(f"Inference engine {inference_engine_name} not supported")

+ 27 - 0
exo/inference/tokenizers.py

@@ -0,0 +1,27 @@
+import traceback
+from transformers import AutoTokenizer, AutoProcessor
+from exo.helpers import DEBUG
+
+async def resolve_tokenizer(model_id: str):
+  try:
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
+    processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
+    if not hasattr(processor, 'eos_token_id'):
+      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
+    if not hasattr(processor, 'encode'):
+      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
+    if not hasattr(processor, 'decode'):
+      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
+    return processor
+  except Exception as e:
+    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
+
+  try:
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
+    return AutoTokenizer.from_pretrained(model_id)
+  except Exception as e:
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
+
+  raise ValueError(f"[TODO] Unsupported model: {model_id}")

+ 40 - 0
exo/models.py

@@ -0,0 +1,40 @@
+from exo.inference.shard import Shard
+
+model_base_shards = {
+  ### llama
+  "llama-3.1-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3.1-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
+  },
+  "llama-3.1-405b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
+  },
+  "llama-3-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+  },
+  ### mistral
+  "mistral-nemo": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
+  },
+  "mistral-large": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
+  },
+  ### deepseek v2
+  "deepseek-coder-v2-lite": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
+  },
+  ### llava
+  "llava-1.5-7b-hf": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
+  },
+}
+

+ 45 - 4
main.py

@@ -11,8 +11,13 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.inference.shard import Shard
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.inference.tokenizers import resolve_tokenizer
+from exo.orchestration.node import Node
+from exo.models import model_base_shards
+import uuid
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -31,6 +36,8 @@ parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90,
 parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
+parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
+parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model")
 args = parser.parse_args()
 
 print_yellow_exo()
@@ -110,6 +117,34 @@ async def shutdown(signal, loop):
     await server.stop()
     loop.stop()
 
+async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
+    shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+    if not shard:
+        print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+        return
+    tokenizer = await resolve_tokenizer(shard.model_id)
+    request_id = str(uuid.uuid4())
+    callback_id = f"cli-wait-response-{request_id}"
+    callback = node.on_token.register(callback_id)
+    prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
+
+    try:
+        print(f"Processing prompt: {prompt}")
+        await node.process_prompt(shard, prompt, None, request_id=request_id)
+
+        _, tokens, _ = await callback.wait(
+            lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
+            timeout=90
+        )
+
+        print("\nGenerated response:")
+        print(tokenizer.decode(tokens))
+    except Exception as e:
+        print(f"Error processing prompt: {str(e)}")
+        traceback.print_exc()
+    finally:
+        node.on_token.deregister(callback_id)
+
 async def main():
     loop = asyncio.get_running_loop()
 
@@ -121,9 +156,15 @@ async def main():
         loop.add_signal_handler(s, handle_exit)
 
     await node.start(wait_for_peers=args.wait_for_peers)
-    asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
 
-    await asyncio.Event().wait()
+    if args.run_model:
+        if not args.prompt:
+            print("Error: --prompt is required when using --run-model")
+            return
+        await run_model_cli(node, inference_engine, args.run_model, args.prompt)
+    else:
+        asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
+        await asyncio.Event().wait()
 
 if __name__ == "__main__":
     loop = asyncio.new_event_loop()
@@ -134,4 +175,4 @@ if __name__ == "__main__":
         print("Received keyboard interrupt. Shutting down...")
     finally:
         loop.run_until_complete(shutdown(signal.SIGTERM, loop))
-        loop.close()
+        loop.close()