Browse Source

Merge branch 'main' into stable-stable-diffusion-mlx

Alex Cheema 5 months ago
parent
commit
5f3d000a7b

+ 2 - 0
README.md

@@ -18,6 +18,8 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
 [![Tests](https://dl.circleci.com/status-badge/img/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
 [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
 
+<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
+
 </div>
 
 ---

+ 16 - 4
exo/api/chatgpt_api.py

@@ -166,7 +166,7 @@ class PromptSession:
     self.prompt = prompt
 
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
@@ -176,6 +176,7 @@ class ChatGPTAPI:
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     self.default_model = default_model or "llama-3.2-1b"
+    self.system_prompt = system_prompt
 
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
@@ -253,7 +254,7 @@ class ChatGPTAPI:
         )
         await response.prepare(request)
 
-        for model_name, pretty in pretty_name.items():
+        async def process_model(model_name, pretty):
             if model_name in model_cards:
                 model_info = model_cards[model_name]
 
@@ -281,6 +282,12 @@ class ChatGPTAPI:
 
                         await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
 
+        # Process all models in parallel
+        await asyncio.gather(*[
+            process_model(model_name, pretty)
+            for model_name, pretty in pretty_name.items()
+        ])
+
         await response.write(b"data: [DONE]\n\n")
         return response
 
@@ -293,7 +300,8 @@ class ChatGPTAPI:
         )
 
   async def handle_get_models(self, request):
-    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
+    models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
+    return web.json_response({"object": "list", "data": models_list})
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
@@ -345,6 +353,10 @@ class ChatGPTAPI:
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
+    # Add system prompt if set
+    if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
+      chat_request.messages.insert(0, Message("system", self.system_prompt))
+
     prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:
@@ -645,7 +657,7 @@ class ChatGPTAPI:
       if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
       shard = build_base_shard(model_name, self.inference_engine_classname)
       if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
-      asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
+      asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
 
       return web.json_response({
         "status": "success",

+ 117 - 0
exo/inference/mlx/models/phi3.py

@@ -0,0 +1,117 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()
+
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+class Phi3Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+        
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+    else:
+      h = inputs
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = Phi3Model(args)
+    if self.args.shard.is_last_layer():
+      self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rope.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.hidden_size // self.args.num_attention_heads
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 4 - 3
exo/inference/mlx/models/qwen2.py

@@ -9,13 +9,12 @@ from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
 from ...shard import Shard
 from .base import IdentityBlock
 
-
 @dataclass
 class ModelArgs(ModelArgs):
   shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
 
   def __post_init__(self):
-    super().__post_init__()  # Ensure parent initializations are respected
+    super().__post_init__()
 
     if isinstance(self.shard, Shard):
       return
@@ -24,7 +23,6 @@ class ModelArgs(ModelArgs):
 
     self.shard = Shard(**self.shard)
 
-
 class Qwen2Model(nn.Module):
   def __init__(self, args: ModelArgs):
     super().__init__()
@@ -32,14 +30,17 @@ class Qwen2Model(nn.Module):
     self.vocab_size = args.vocab_size
     self.num_hidden_layers = args.num_hidden_layers
     assert self.vocab_size > 0
+
     if self.args.shard.is_first_layer():
       self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+
     self.layers = []
     for i in range(self.num_hidden_layers):
       if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
         self.layers.append(TransformerBlock(args=args))
       else:
         self.layers.append(IdentityBlock())
+
     if self.args.shard.is_last_layer():
       self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 

+ 3 - 1
exo/main.py

@@ -69,6 +69,7 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
+parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 
@@ -146,7 +147,8 @@ api = ChatGPTAPI(
   inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
   on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
-  default_model=args.default_model
+  default_model=args.default_model,
+  system_prompt=args.system_prompt
 )
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None

+ 5 - 0
exo/models.py

@@ -113,6 +113,9 @@ model_cards = {
   "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
   # stable diffusion
   "stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
+  # phi
+  "phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
+  "phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
   # dummy
   "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
 }
@@ -151,6 +154,8 @@ pretty_name = {
   "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
   "qwen-2.5-72b": "Qwen 2.5 72B",
   "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
+  "phi-3.5-mini": "Phi-3.5 Mini",
+  "phi-4": "Phi-4",
   "llama-3-8b": "Llama 3 8B",
   "llama-3-70b": "Llama 3 70B",
   "stable-diffusion-2-1-base": "Stable Diffusion 2.1",

+ 58 - 13
exo/viz/topology_viz.py

@@ -91,25 +91,70 @@ class TopologyViz:
     content = []
     requests = list(self.requests.values())[-3:]  # Get the 3 most recent requests
     max_width = self.console.width - 6  # Full width minus padding and icon
-    max_lines = 13  # Maximum number of lines for the entire panel content
+
+    # Calculate available height for content
+    panel_height = 15  # Fixed panel height
+    available_lines = panel_height - 2  # Subtract 2 for panel borders
+    lines_per_entry = available_lines // len(requests) if requests else 0
 
     for (prompt, output) in reversed(requests):
       prompt_icon, output_icon = "💬️", "🤖"
 
+      # Calculate max lines for prompt and output
+      max_prompt_lines = lines_per_entry // 3  # Allocate 1/3 for prompt
+      max_output_lines = lines_per_entry - max_prompt_lines - 1  # Remaining space minus spacing
+
       # Process prompt
-      prompt_lines = prompt.split('\n')
-      if len(prompt_lines) > max_lines // 2:
-        prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
+      prompt_lines = []
+      for line in prompt.split('\n'):
+        words = line.split()
+        current_line = []
+        current_length = 0
+
+        for word in words:
+          if current_length + len(word) + 1 <= max_width:
+            current_line.append(word)
+            current_length += len(word) + 1
+          else:
+            if current_line:
+              prompt_lines.append(' '.join(current_line))
+            current_line = [word]
+            current_length = len(word)
+
+        if current_line:
+          prompt_lines.append(' '.join(current_line))
+
+      if len(prompt_lines) > max_prompt_lines:
+        prompt_lines = prompt_lines[:max_prompt_lines - 1] + ['...']
+
       prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
-      prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
+      prompt_text.append('\n'.join(prompt_lines), style="white")
+
+      # Process output - same word-aware wrapping
+      output_lines = []
+      for line in output.split('\n'):
+        words = line.split()
+        current_line = []
+        current_length = 0
+
+        for word in words:
+          if current_length + len(word) + 1 <= max_width:
+            current_line.append(word)
+            current_length += len(word) + 1
+          else:
+            if current_line:
+              output_lines.append(' '.join(current_line))
+            current_line = [word]
+            current_length = len(word)
+
+        if current_line:
+          output_lines.append(' '.join(current_line))
+
+      if len(output_lines) > max_output_lines:
+        output_lines = output_lines[:max_output_lines - 1] + ['...']
 
-      # Process output
-      output_lines = output.split('\n')
-      remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
-      if len(output_lines) > remaining_lines:
-        output_lines = output_lines[:remaining_lines - 1] + ['...']
       output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
-      output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
+      output_text.append('\n'.join(output_lines), style="white")
 
       content.append(prompt_text)
       content.append(output_text)
@@ -119,8 +164,8 @@ class TopologyViz:
       Group(*content),
       title="",
       border_style="cyan",
-      height=15,  # Increased height to accommodate multiple lines
-      expand=True  # Allow the panel to expand to full width
+      height=panel_height,
+      expand=True
     )
 
   def _generate_main_layout(self) -> str:

+ 1 - 1
test/test_tokenizers.py

@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
 
-ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit"]
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
 models = []
 for model_id in model_cards: