Browse Source

Merge pull request #627 from exo-explore/deepseek

Deepseek, tinychat group models, latex formatting, thinking boxes
Alex Cheema 5 months ago
parent
commit
66f73768cc
7 changed files with 534 additions and 110 deletions
  1. 135 0
      exo/inference/mlx/models/deepseek_v3.py
  2. 66 0
      exo/models.py
  3. 87 0
      exo/tinychat/index.css
  4. 184 96
      exo/tinychat/index.html
  5. 57 10
      exo/tinychat/index.js
  6. 2 2
      setup.py
  7. 3 2
      test/test_tokenizers.py

+ 135 - 0
exo/inference/mlx/models/deepseek_v3.py

@@ -0,0 +1,135 @@
+from dataclasses import dataclass, field
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.cache import KVCache
+from mlx_lm.models.deepseek_v3 import (
+  ModelArgs as V3ModelArgs,
+  DeepseekV3DecoderLayer,
+)
+from .base import IdentityBlock
+from exo.inference.shard import Shard
+
+
+@dataclass
+class ModelArgs(V3ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class DeepseekV3Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.args = config
+    self.num_hidden_layers = config.num_hidden_layers
+    self.vocab_size = config.vocab_size
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(DeepseekV3DecoderLayer(config, i))
+      else:
+        self.layers.append(IdentityBlock())
+
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+  def __call__(
+    self,
+    x: mx.array,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(x)
+    else:
+      h = x
+
+    mask = None
+    T = h.shape[1]
+    if T > 1:
+      mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.args = config
+    self.model_type = config.model_type
+    self.model = DeepseekV3Model(config)
+    if self.args.shard.is_last_layer():
+      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache: Optional[KVCache] = None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      return self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
+        shard_state_dict[key] = value
+
+    for l in range(self.args.num_hidden_layers):
+      prefix = f"model.layers.{l}"
+      for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
+        for k in ["weight", "scales", "biases"]:
+          expert_key = f"{prefix}.mlp.experts.0.{m}.{k}"
+          if expert_key in shard_state_dict:
+            to_join = [
+              shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
+              for e in range(self.args.n_routed_experts)
+            ]
+            shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return (
+      self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
+      self.args.v_head_dim,
+    )
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 66 - 0
exo/models.py

@@ -88,6 +88,38 @@ model_cards = {
   ### deepseek
   "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
   "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
+  "deepseek-v3": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V3-4bit", }, },
+  "deepseek-r1": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-4bit", }, },
+  ### deepseek distills
+  "deepseek-r1-distill-qwen-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/deepseek-r1-distill-qwen-1.5b", }, },
+  "deepseek-r1-distill-qwen-1.5b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-3bit", }, },
+  "deepseek-r1-distill-qwen-1.5b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-6bit", }, },
+  "deepseek-r1-distill-qwen-1.5b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit", }, },
+  "deepseek-r1-distill-qwen-1.5b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-bf16", }, },
+  "deepseek-r1-distill-qwen-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit", }, },
+  "deepseek-r1-distill-qwen-7b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-3bit", }, },
+  "deepseek-r1-distill-qwen-7b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-6bit", }, },
+  "deepseek-r1-distill-qwen-7b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-8bit", }, },
+  "deepseek-r1-distill-qwen-7b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-bf16", }, },
+  "deepseek-r1-distill-qwen-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-4bit", }, },
+  "deepseek-r1-distill-qwen-14b-3bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-3bit", }, },
+  "deepseek-r1-distill-qwen-14b-6bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-6bit", }, },
+  "deepseek-r1-distill-qwen-14b-8bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-8bit", }, },
+  "deepseek-r1-distill-qwen-14b-bf16": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-bf16", }, },
+  "deepseek-r1-distill-qwen-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-4bit", }, },
+  "deepseek-r1-distill-qwen-32b-3bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-3bit", }, },
+  "deepseek-r1-distill-qwen-32b-6bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-6bit", }, },
+  "deepseek-r1-distill-qwen-32b-8bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-MLX-8Bit", }, },
+  "deepseek-r1-distill-qwen-32b-bf16": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-bf16", }, },
+  "deepseek-r1-distill-llama-8b": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-4bit", }, },
+  "deepseek-r1-distill-llama-8b-3bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-3bit", }, },
+  "deepseek-r1-distill-llama-8b-6bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-6bit", }, },
+  "deepseek-r1-distill-llama-8b-8bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-8bit", }, },
+  "deepseek-r1-distill-llama-8b-bf16": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-bf16", }, },
+  "deepseek-r1-distill-llama-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-4bit", }, },
+  "deepseek-r1-distill-llama-70b-3bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-3bit", }, },
+  "deepseek-r1-distill-llama-70b-6bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-6bit", }, },
+  "deepseek-r1-distill-llama-70b-8bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit", }, },
   ### llava
   "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   ### qwen
@@ -140,6 +172,8 @@ pretty_name = {
   "mistral-large": "Mistral Large",
   "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
   "deepseek-coder-v2.5": "Deepseek Coder V2.5",
+  "deepseek-v3": "Deepseek V3",
+  "deepseek-r1": "Deepseek R1",
   "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
   "qwen-2.5-1.5b": "Qwen 2.5 1.5B",
   "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
@@ -159,6 +193,38 @@ pretty_name = {
   "llama-3-8b": "Llama 3 8B",
   "llama-3-70b": "Llama 3 70B",
   "stable-diffusion-2-1-base": "Stable Diffusion 2.1",
+  "deepseek-r1-distill-qwen-1.5b": "DeepSeek R1 Distill Qwen 1.5B",
+  "deepseek-r1-distill-qwen-1.5b-3bit": "DeepSeek R1 Distill Qwen 1.5B (3-bit)",
+  "deepseek-r1-distill-qwen-1.5b-6bit": "DeepSeek R1 Distill Qwen 1.5B (6-bit)",
+  "deepseek-r1-distill-qwen-1.5b-8bit": "DeepSeek R1 Distill Qwen 1.5B (8-bit)",
+  "deepseek-r1-distill-qwen-1.5b-bf16": "DeepSeek R1 Distill Qwen 1.5B (BF16)",
+  "deepseek-r1-distill-qwen-7b": "DeepSeek R1 Distill Qwen 7B",
+  "deepseek-r1-distill-qwen-7b-3bit": "DeepSeek R1 Distill Qwen 7B (3-bit)",
+  "deepseek-r1-distill-qwen-7b-6bit": "DeepSeek R1 Distill Qwen 7B (6-bit)",
+  "deepseek-r1-distill-qwen-7b-8bit": "DeepSeek R1 Distill Qwen 7B (8-bit)",
+  "deepseek-r1-distill-qwen-7b-bf16": "DeepSeek R1 Distill Qwen 7B (BF16)",
+  "deepseek-r1-distill-qwen-14b": "DeepSeek R1 Distill Qwen 14B",
+  "deepseek-r1-distill-qwen-14b-3bit": "DeepSeek R1 Distill Qwen 14B (3-bit)",
+  "deepseek-r1-distill-qwen-14b-6bit": "DeepSeek R1 Distill Qwen 14B (6-bit)",
+  "deepseek-r1-distill-qwen-14b-8bit": "DeepSeek R1 Distill Qwen 14B (8-bit)",
+  "deepseek-r1-distill-qwen-14b-bf16": "DeepSeek R1 Distill Qwen 14B (BF16)",
+  "deepseek-r1-distill-qwen-32b": "DeepSeek R1 Distill Qwen 32B",
+  "deepseek-r1-distill-qwen-32b-3bit": "DeepSeek R1 Distill Qwen 32B (3-bit)",
+  "deepseek-r1-distill-qwen-32b-8bit": "DeepSeek R1 Distill Qwen 32B (8-bit)",
+  "deepseek-r1-distill-qwen-32b-bf16": "DeepSeek R1 Distill Qwen 32B (BF16)",
+  "deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
+  "deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
+  "deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
+  "deepseek-r1-distill-llama-8b": "DeepSeek R1 Distill Llama 8B",
+  "deepseek-r1-distill-llama-8b-3bit": "DeepSeek R1 Distill Llama 8B (3-bit)",
+  "deepseek-r1-distill-llama-8b-6bit": "DeepSeek R1 Distill Llama 8B (6-bit)",
+  "deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
+  "deepseek-r1-distill-llama-8b-bf16": "DeepSeek R1 Distill Llama 8B (BF16)",
+  "deepseek-r1-distill-llama-70b": "DeepSeek R1 Distill Llama 70B",
+  "deepseek-r1-distill-llama-70b-3bit": "DeepSeek R1 Distill Llama 70B (3-bit)",
+  "deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
+  "deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
+  "deepseek-r1-distill-qwen-32b-6bit": "DeepSeek R1 Distill Qwen 32B (6-bit)",
 }
 
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:

+ 87 - 0
exo/tinychat/index.css

@@ -742,4 +742,91 @@ main {
 .peer-connection i {
   font-size: 0.8em;
   color: #666;
+}
+
+.thinking-block {
+  background-color: rgba(255, 255, 255, 0.05);
+  border-radius: 8px;
+  margin: 8px 0;
+  overflow: hidden;
+}
+
+.thinking-header {
+  background-color: rgba(255, 255, 255, 0.1);
+  padding: 8px 12px;
+  font-size: 0.9em;
+  color: #a0a0a0;
+  display: flex;
+  align-items: center;
+  gap: 8px;
+}
+
+.thinking-content {
+  padding: 12px;
+  white-space: pre-wrap;
+}
+
+@keyframes thinking-spin {
+  to { transform: rotate(360deg); }
+}
+
+.thinking-header.thinking::before {
+  content: '';
+  width: 12px;
+  height: 12px;
+  border: 2px solid #a0a0a0;
+  border-top-color: transparent;
+  border-radius: 50%;
+  animation: thinking-spin 1s linear infinite;
+}
+
+.model-group {
+  margin-bottom: 12px;
+}
+
+.model-group-header,
+.model-subgroup-header {
+  display: flex;
+  justify-content: space-between;
+  align-items: center;
+  padding: 8px 12px;
+  background-color: var(--primary-bg-color);
+  border-radius: 6px;
+  cursor: pointer;
+  transition: all 0.2s ease;
+  margin-bottom: 8px;
+}
+
+.model-group-header:hover,
+.model-subgroup-header:hover {
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-group-content {
+  padding-left: 12px;
+}
+
+.model-subgroup {
+  margin-bottom: 8px;
+}
+
+.model-subgroup-header {
+  font-size: 0.9em;
+  background-color: rgba(255, 255, 255, 0.05);
+}
+
+.model-subgroup-content {
+  padding-left: 12px;
+}
+
+.group-header-content {
+  display: flex;
+  align-items: center;
+  gap: 8px;
+}
+
+.model-count {
+  font-size: 0.8em;
+  color: var(--secondary-color-transparent);
+  font-family: monospace;
 }

+ 184 - 96
exo/tinychat/index.html

@@ -22,6 +22,7 @@
 <link href="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css" rel="stylesheet"/>
 <link href="/index.css" rel="stylesheet"/>
 <link href="/common.css" rel="stylesheet"/>
+<script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
 </head>
 <body>
 <main x-data="state" x-init="console.log(endpoint)">
@@ -49,50 +50,78 @@
         <span>Loading models...</span>
     </div>
 
-    <template x-for="(model, key) in models" :key="key">
-        <div class="model-option"
-             :class="{ 'selected': cstate.selectedModel === key }"
-             @click="cstate.selectedModel = key">
-            <div class="model-header">
-                <div class="model-name" x-text="model.name"></div>
-                <button
-                    @click.stop="deleteModel(key, model)"
-                    class="model-delete-button"
-                    x-show="model.download_percentage > 0">
-                    <i class="fas fa-trash"></i>
-                </button>
-            </div>
-            <div class="model-info">
-                <div class="model-progress">
-                    <template x-if="model.loading">
-                        <span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
-                    </template>
-                    <div class="model-progress-info">
-                        <template x-if="!model.loading && model.download_percentage != null">
-                            <span>
-                                <!-- Check if there's an active download for this model -->
-                                <template x-if="downloadProgress?.some(p =>
-                                    p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
-                                )">
-                                    <i class="fas fa-circle-notch fa-spin"></i>
-                                </template>
-                                <span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
-                            </span>
-                        </template>
-                        <template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
-                            <button
-                                @click.stop="handleDownload(key)"
-                                class="model-download-button">
-                                <i class="fas fa-download"></i>
-                                <span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
-                            </button>
-                        </template>
-                    </div>
+    <!-- Group models by prefix -->
+    <template x-for="[mainPrefix, subGroups] in Object.entries(groupModelsByPrefix(models))" :key="mainPrefix">
+        <div class="model-group">
+            <div class="model-group-header" @click="toggleGroup(mainPrefix)">
+                <div class="group-header-content">
+                    <span x-text="mainPrefix"></span>
+                    <span class="model-count" x-text="getGroupCounts(Object.values(subGroups).flatMap(group => Object.values(group)))"></span>
                 </div>
-                <template x-if="model.total_size">
-                    <div class="model-size" x-text="model.total_downloaded ?
-                        `${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
-                        formatBytes(model.total_size)">
+                <i class="fas" :class="isGroupExpanded(mainPrefix) ? 'fa-chevron-down' : 'fa-chevron-right'"></i>
+            </div>
+            
+            <div class="model-group-content" x-show="isGroupExpanded(mainPrefix)" x-transition>
+                <template x-for="[subPrefix, groupModels] in Object.entries(subGroups)" :key="subPrefix">
+                    <div class="model-subgroup">
+                        <div class="model-subgroup-header" @click.stop="toggleGroup(mainPrefix, subPrefix)">
+                            <div class="group-header-content">
+                                <span x-text="subPrefix"></span>
+                                <span class="model-count" x-text="getGroupCounts(groupModels)"></span>
+                            </div>
+                            <i class="fas" :class="isGroupExpanded(mainPrefix, subPrefix) ? 'fa-chevron-down' : 'fa-chevron-right'"></i>
+                        </div>
+                        
+                        <div class="model-subgroup-content" x-show="isGroupExpanded(mainPrefix, subPrefix)" x-transition>
+                            <template x-for="(model, key) in groupModels" :key="key">
+                                <div class="model-option"
+                                     :class="{ 'selected': cstate.selectedModel === key }"
+                                     @click="cstate.selectedModel = key">
+                                    <div class="model-header">
+                                        <div class="model-name" x-text="model.name"></div>
+                                        <button
+                                            @click.stop="deleteModel(key, model)"
+                                            class="model-delete-button"
+                                            x-show="model.download_percentage > 0">
+                                            <i class="fas fa-trash"></i>
+                                        </button>
+                                    </div>
+                                    <div class="model-info">
+                                        <div class="model-progress">
+                                            <template x-if="model.loading">
+                                                <span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
+                                            </template>
+                                            <div class="model-progress-info">
+                                                <template x-if="!model.loading && model.download_percentage != null">
+                                                    <span>
+                                                        <template x-if="downloadProgress?.some(p =>
+                                                            p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
+                                                        )">
+                                                            <i class="fas fa-circle-notch fa-spin"></i>
+                                                        </template>
+                                                        <span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
+                                                    </span>
+                                                </template>
+                                                <template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
+                                                    <button
+                                                        @click.stop="handleDownload(key)"
+                                                        class="model-download-button">
+                                                        <i class="fas fa-download"></i>
+                                                        <span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
+                                                    </button>
+                                                </template>
+                                            </div>
+                                        </div>
+                                        <template x-if="model.total_size">
+                                            <div class="model-size" x-text="model.total_downloaded ?
+                                                `${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
+                                                formatBytes(model.total_size)">
+                                            </div>
+                                        </template>
+                                    </div>
+                                </div>
+                            </template>
+                        </div>
                     </div>
                 </template>
             </div>
@@ -177,6 +206,7 @@
 </template>
 </div>
 </div>
+</div>
 <button
     @click="
         home = 0;
@@ -190,67 +220,87 @@
     <i class="fas fa-arrow-left"></i>
     Back to Chats
 </button>
-<div class="messages" x-init="
-      $watch('cstate', value =&gt; {
-        $el.innerHTML = '';
-        value.messages.forEach(({ role, content }) =&gt; {
-          const div = document.createElement('div');
-          div.className = `message message-role-${role}`;
-          try {
-              if (content.includes('![Generated Image]')) {
-                const imageUrl = content.match(/\((.*?)\)/)[1];
-                const img = document.createElement('img');
-                img.src = imageUrl;
-                img.alt = 'Generated Image';
-                img.onclick = async () => {
-                  try {
-                    const response = await fetch(img.src);
-                    const blob = await response.blob();
-                    const file = new File([blob], 'image.png', { type: 'image/png' });
-                    handleImageUpload({ target: { files: [file] } });
-                  } catch (error) {
-                    console.error('Error fetching image:', error);
-                  }
-                };
-                div.appendChild(img);
-              } else {
-                div.innerHTML = DOMPurify.sanitize(marked.parse(content));
-              }
-          } catch (e) {
-            console.log(content);
-            console.error(e);
+<div class="messages"
+  x-init="
+    $watch('cstate', (value) => {
+      $el.innerHTML = '';
+
+      value.messages.forEach((msg) => {
+        const div = document.createElement('div');
+        div.className = `message message-role-${msg.role}`;
+
+        try {
+          // If there's an embedded generated image
+          if (msg.content.includes('![Generated Image]')) {
+            const imageUrlMatch = msg.content.match(/\((.*?)\)/);
+            if (imageUrlMatch) {
+              const imageUrl = imageUrlMatch[1];
+              const img = document.createElement('img');
+              img.src = imageUrl;
+              img.alt = 'Generated Image';
+
+              img.onclick = async () => {
+                try {
+                  const response = await fetch(img.src);
+                  const blob = await response.blob();
+                  const file = new File([blob], 'image.png', { type: 'image/png' });
+                  handleImageUpload({ target: { files: [file] } });
+                } catch (error) {
+                  console.error('Error fetching image:', error);
+                }
+              };
+              div.appendChild(img);
+            } else {
+              // fallback if markdown is malformed
+              div.textContent = msg.content;
+            }
+          } else {
+            // Otherwise, transform message text (including streamed think blocks).
+            div.innerHTML = transformMessageContent(msg);
+            // Render math after content is inserted
+            MathJax.typesetPromise([div]);
           }
+        } catch (e) {
+          console.error('Error rendering message:', e);
+          div.textContent = msg.content; // fallback
+        }
+
+        // Add a clipboard button to code blocks
+        const codeBlocks = div.querySelectorAll('.hljs');
+        codeBlocks.forEach((codeBlock) => {
+          const button = document.createElement('button');
+          button.className = 'clipboard-button';
+          button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
 
-          // add a clipboard button to all code blocks
-          const codeBlocks = div.querySelectorAll('.hljs');
-          codeBlocks.forEach(codeBlock =&gt; {
-            const button = document.createElement('button');
-            button.className = 'clipboard-button';
-            button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;';
-            button.onclick = () =&gt; {
-              // navigator.clipboard.writeText(codeBlock.textContent);
-              const range = document.createRange();
-              range.setStartBefore(codeBlock);
-              range.setEndAfter(codeBlock);
-              window.getSelection()?.removeAllRanges();
-              window.getSelection()?.addRange(range);
-              document.execCommand('copy');
-              window.getSelection()?.removeAllRanges();
+          button.onclick = () => {
+            const range = document.createRange();
+            range.setStartBefore(codeBlock);
+            range.setEndAfter(codeBlock);
+            window.getSelection()?.removeAllRanges();
+            window.getSelection()?.addRange(range);
+            document.execCommand('copy');
+            window.getSelection()?.removeAllRanges();
 
-              button.innerHTML = '&lt;i class=\'fas fa-check\'&gt;&lt;/i&gt;';
-              setTimeout(() =&gt; button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;', 1000);
-            };
-            codeBlock.appendChild(button);
-          });
+            button.innerHTML = '<i class=\'fas fa-check\'></i>';
+            setTimeout(() => {
+              button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
+            }, 1000);
+          };
 
-          $el.appendChild(div);
+          codeBlock.appendChild(button);
         });
 
-        $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
+        $el.appendChild(div);
       });
-    " x-intersect="
+
+      // Scroll to bottom after rendering
       $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
-    " x-ref="messages" x-show="home === 2" x-transition="">
+    });
+  "
+  x-ref="messages"
+  x-show="home === 2"
+  x-transition=""
+>
 </div>
 
 <!-- Download Progress Section -->
@@ -353,4 +403,42 @@
 </div>
 </div>
 </main>
+
+<script>
+  /**
+   * Transform a single message's content into HTML, preserving <think> blocks.
+   * Ensure LaTeX expressions are properly delimited for MathJax.
+   */
+  function transformMessageContent(message) {
+    let text = message.content;
+    console.log('Processing message content:', text);
+
+    // First replace think blocks
+    text = text.replace(
+      /<think>([\s\S]*?)(?:<\/think>|$)/g,
+      (match, body) => {
+        console.log('Found think block with content:', body);
+        const isComplete = match.includes('</think>');
+        const spinnerClass = isComplete ? '' : ' thinking';
+        const parsedBody = DOMPurify.sanitize(marked.parse(body));
+        return `
+<div class='thinking-block'>
+  <div class='thinking-header${spinnerClass}'>Thinking...</div>
+  <div class='thinking-content'>${parsedBody}</div>
+</div>`;
+      }
+    );
+
+    // Add backslashes to parentheses and brackets for LaTeX
+    text = text
+      .replace(/\((?=\s*[\d\\])/g, '\\(')  // Add backslash before opening parentheses
+      .replace(/\)(?!\w)/g, '\\)')          // Add backslash before closing parentheses
+      .replace(/\[(?=\s*[\d\\])/g, '\\[')   // Add backslash before opening brackets
+      .replace(/\](?!\w)/g, '\\]')          // Add backslash before closing brackets
+      .replace(/\[[\s\n]*\\boxed/g, '\\[\\boxed') // Ensure boxed expressions are properly delimited
+      .replace(/\\!/g, '\\\\!');  // Preserve LaTeX spacing commands
+
+    return DOMPurify.sanitize(marked.parse(text));
+  }
+</script>
 </body>

+ 57 - 10
exo/tinychat/index.js

@@ -42,6 +42,9 @@ document.addEventListener("alpine:init", () => {
     topology: null,
     topologyInterval: null,
 
+    // Add these new properties
+    expandedGroups: {},
+
     init() {
       // Clean up any pending messages
       localStorage.removeItem("pendingMessage");
@@ -393,8 +396,6 @@ document.addEventListener("alpine:init", () => {
     },
 
     async *openaiChatCompletion(model, messages) {
-      // stream response
-      console.log("model", model)
       const response = await fetch(`${this.endpoint}/chat/completions`, {
         method: "POST",
         headers: {
@@ -417,19 +418,17 @@ document.addEventListener("alpine:init", () => {
 
       const reader = response.body.pipeThrough(new TextDecoderStream())
         .pipeThrough(new EventSourceParserStream()).getReader();
+      
       while (true) {
         const { done, value } = await reader.read();
-        if (done) {
-          break;
-        }
+        if (done) break;
+        
         if (value.type === "event") {
           const json = JSON.parse(value.data);
           if (json.choices) {
             const choice = json.choices[0];
-            if (choice.finish_reason === "stop") {
-              break;
-            }
-            yield choice.delta.content;
+            if (choice.finish_reason === "stop") break;
+            if (choice.delta.content) yield choice.delta.content;
           }
         }
       }
@@ -668,7 +667,55 @@ document.addEventListener("alpine:init", () => {
         `;
         vizElement.appendChild(nodeElement);
       });
-    }
+    },
+
+    // Add these helper methods
+    countDownloadedModels(models) {
+      return Object.values(models).filter(model => model.downloaded).length;
+    },
+
+    getGroupCounts(groupModels) {
+      const total = Object.keys(groupModels).length;
+      const downloaded = this.countDownloadedModels(groupModels);
+      return `[${downloaded}/${total}]`;
+    },
+
+    // Update the existing groupModelsByPrefix method to include counts
+    groupModelsByPrefix(models) {
+      const groups = {};
+      Object.entries(models).forEach(([key, model]) => {
+        const parts = key.split('-');
+        const mainPrefix = parts[0].toUpperCase();
+        
+        let subPrefix;
+        if (parts.length === 2) {
+          subPrefix = parts[1].toUpperCase();
+        } else if (parts.length > 2) {
+          subPrefix = parts[1].toUpperCase();
+        } else {
+          subPrefix = 'OTHER';
+        }
+        
+        if (!groups[mainPrefix]) {
+          groups[mainPrefix] = {};
+        }
+        if (!groups[mainPrefix][subPrefix]) {
+          groups[mainPrefix][subPrefix] = {};
+        }
+        groups[mainPrefix][subPrefix][key] = model;
+      });
+      return groups;
+    },
+
+    toggleGroup(prefix, subPrefix = null) {
+      const key = subPrefix ? `${prefix}-${subPrefix}` : prefix;
+      this.expandedGroups[key] = !this.expandedGroups[key];
+    },
+
+    isGroupExpanded(prefix, subPrefix = null) {
+      const key = subPrefix ? `${prefix}-${subPrefix}` : prefix;
+      return this.expandedGroups[key] || false;
+    },
   }));
 });
 

+ 2 - 2
setup.py

@@ -35,8 +35,8 @@ install_requires = [
 extras_require = {
   "formatting": ["yapf==0.40.2",],
   "apple_silicon": [
-    "mlx==0.21.1",
-    "mlx-lm==0.20.4",
+    "mlx==0.22.0",
+    "mlx-lm==0.21.1",
   ],
   "windows": ["pywin32==308",],
   "nvidia-gpu": ["nvidia-ml-py==12.560.30",],

+ 3 - 2
test/test_tokenizers.py

@@ -37,5 +37,6 @@ verbose = os.environ.get("VERBOSE", "0").lower() == "1"
 for m in models:
     # TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit
     # test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=False), verbose)
-    test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True), verbose)
-    test_tokenizer(m, AutoTokenizer.from_pretrained(m), verbose)
+    if m not in ["mlx-community/DeepSeek-R1-4bit", "mlx-community/DeepSeek-V3-4bit"]:
+      test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True, trust_remote_code=True), verbose)
+    test_tokenizer(m, AutoTokenizer.from_pretrained(m, trust_remote_code=True), verbose)