Browse Source

Merge branch 'main' into bench_inference_engines

Alex Cheema 5 months ago
parent
commit
8bf70379da

+ 30 - 31
.circleci/config.yml

@@ -20,6 +20,12 @@ commands:
           command: |
             source env/bin/activate
 
+            # Set CLANG=1 for tinygrad only
+            if [ "<<parameters.inference_engine>>" = "tinygrad" ]; then
+              pip install llvmlite
+              export TOKENIZERS_PARALLELISM=true SUPPORT_BF16=0 CLANG=1
+            fi
+
             # Start first instance
             HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 2>&1 | tee output1.log &
             PID1=$!
@@ -48,13 +54,6 @@ commands:
             # Check processes before proceeding
             check_processes
 
-            # Special handling for dummy engine
-            if [ "<<parameters.inference_engine>>" = "dummy" ]; then
-              expected_content="This is a dummy response"
-            else
-              expected_content="Michael Jackson"
-            fi
-
             echo "Sending request to first instance..."
             response_1=$(curl -s http://localhost:8000/v1/chat/completions \
               -H "Content-Type: application/json" \
@@ -223,29 +222,29 @@ jobs:
       - checkout
       - run: system_profiler SPHardwareDataType
 
-  # chatgpt_api_integration_test_tinygrad:
-  #   macos:
-  #     xcode: "16.0.0"
-  #   resource_class: m2pro.large
-  #   steps:
-  #     - checkout
-  #     - run:
-  #         name: Set up Python
-  #         command: |
-  #           brew install python@3.12
-  #           python3.12 -m venv env
-  #           source env/bin/activate
-  #     - run:
-  #         name: Install dependencies
-  #         command: |
-  #           source env/bin/activate
-  #           pip install --upgrade pip
-  #           pip install .
-  #     - run_chatgpt_api_test:
-  #         inference_engine: tinygrad
-  #         model_id: llama-3-8b
-  #         prompt: "Keep responses concise. Who was the king of pop?"
-  #         expected_output: "Michael Jackson"
+  chatgpt_api_integration_test_tinygrad:
+    macos:
+      xcode: "16.0.0"
+    resource_class: m2pro.large
+    steps:
+      - checkout
+      - run:
+          name: Set up Python
+          command: |
+            brew install python@3.12
+            python3.12 -m venv env
+            source env/bin/activate
+      - run:
+          name: Install dependencies
+          command: |
+            source env/bin/activate
+            pip install --upgrade pip
+            pip install .
+      - run_chatgpt_api_test:
+          inference_engine: tinygrad
+          model_id: llama-3.2-1b
+          prompt: "Keep responses concise. Who was the king of pop?"
+          expected_output: "Michael Jackson"
 
 workflows:
   version: 2
@@ -254,6 +253,6 @@ workflows:
       - unit_test
       - discovery_integration_test
       - chatgpt_api_integration_test_mlx
+      - chatgpt_api_integration_test_tinygrad
       - chatgpt_api_integration_test_dummy
       - test_macos_m1
-      # - chatgpt_api_integration_test_tinygrad

+ 15 - 4
exo/inference/tinygrad/inference.py

@@ -22,7 +22,17 @@ TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_P = 0.0
 MODEL_PARAMS = {
-  "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
+  "1B": {
+    "args": {
+      "dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
+      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
+    }, "files": 1
+  }, "3B": {
+    "args": {
+      "dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
+      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
+    }, "files": 1
+  }, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
   "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
 }
 
@@ -55,7 +65,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
@@ -72,7 +82,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       n_captured_toks = len(toks)
       return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
@@ -94,7 +104,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     model_path = await self.shard_downloader.ensure_shard(shard)
 
     if self.shard != shard:
-      self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
+      parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
+      self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
 
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       self.tokenizer = await resolve_tokenizer(tokenizer_path)

+ 18 - 3
exo/inference/tinygrad/models/llama.py

@@ -4,8 +4,19 @@ from tinygrad.helpers import getenv
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half, rope_scaling: Optional[Dict[str, float]] = None) -> Tensor:
   freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
+
+  if rope_scaling:
+    factor = rope_scaling.get('factor', 1.0)
+    low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
+    high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
+    original_max_pos_emb = rope_scaling.get('original_max_position_embeddings', end)
+
+    freqs[:dim // 4] *= low_freq_factor
+    freqs[dim // 4:] = freqs[dim // 4:].contiguous()*high_freq_factor
+    freqs *= (original_max_pos_emb/end)**(1.0/factor)
+
   freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
@@ -176,14 +187,18 @@ class Transformer:
     rope_theta=10000,
     max_context=1024,
     jit=True,
-    feed_forward=FeedForward
+    feed_forward=FeedForward,
+    rope_scaling: Optional[Dict[str, float]] = None,
+    tie_word_embeddings=False
   ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
+    if tie_word_embeddings:
+      self.output.weight = self.tok_embeddings.weight
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
 

+ 9 - 3
exo/models.py

@@ -2,8 +2,14 @@ from exo.inference.shard import Shard
 
 model_base_shards = {
   ### llama
-  "llama-3.2-1b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),},
-  "llama-3.2-3b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
+  "llama-3.2-1b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16),
+  },
+  "llama-3.2-3b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-3B-Instruct", start_layer=0, end_layer=0, n_layers=28),
+  },
   "llama-3.1-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
@@ -17,7 +23,7 @@ model_base_shards = {
     "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
   },
   "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
-  "llama-3.1-405b-8bit": {"MLXDynamicShardInferenceEngine": Shard(model_id="IntuitIntel/Meta-Llama-3.1-405B-Instruct-8bit", start_layer=0, end_layer=0, n_layers=126),},
+  "llama-3.1-405b-8bit": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", start_layer=0, end_layer=0, n_layers=126),},
   "llama-3-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),

+ 23 - 0
exo/tinychat/index.css

@@ -414,4 +414,27 @@ p {
   max-width: 100%;
   max-height: 100%;
   object-fit: contain;
+}
+
+.clear-history-button {
+  background-color: var(--red-color);
+  color: white;
+  padding: 10px 20px;
+  border-radius: 5px;
+  display: flex;
+  align-items: center;
+  gap: 8px;
+  transition: all 0.3s ease;
+  margin: 1rem auto;
+  border: none;
+  cursor: pointer;
+}
+
+.clear-history-button:hover {
+  opacity: 0.8;
+  transform: scale(1.05);
+}
+
+.clear-history-button i {
+  font-size: 14px;
 }

+ 7 - 0
exo/tinychat/index.html

@@ -71,6 +71,13 @@
       if (home === -1) setTimeout(() =&gt; home = 0, 100);
     " x-show="home === 0" x-transition="">
 <h1 class="title megrim-regular">tinychat</h1>
+<template x-if="histories.length">
+  <button 
+    @click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();" 
+    class="clear-history-button">
+    <i class="fas fa-trash"></i> Clear All History
+  </button>
+</template>
 <div class="histories-container-container">
 <template x-if="histories.length">
 <div class="histories-start"></div>

+ 6 - 0
exo/tinychat/index.js

@@ -47,6 +47,12 @@ document.addEventListener("alpine:init", () => {
         localStorage.setItem("histories", JSON.stringify(this.histories));
       }
     },
+
+    clearAllHistory() {
+      this.histories = [];
+      localStorage.setItem("histories", JSON.stringify([]));
+    },
+
     // Utility functions
     formatBytes(bytes) {
       if (bytes === 0) return '0 B';

+ 4 - 2
exo/topology/device_capabilities.py

@@ -52,9 +52,11 @@ CHIP_FLOPS = {
   "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
   "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
   "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
   "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
-  "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
+  "Apple M4": DeviceFlops(fp32=4.26*TFLOPS, fp16=8.52*TFLOPS, int8=17.04*TFLOPS),
+  "Apple M4 Pro": DeviceFlops(fp32=5.72*TFLOPS, fp16=11.44*TFLOPS, int8=22.88*TFLOPS),
+  "Apple M4 Max": DeviceFlops(fp32=18.03*TFLOPS, fp16=36.07*TFLOPS, int8=72.14*TFLOPS),
   ### A chips
   "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
   "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),

+ 1 - 1
test/test_tokenizers.py

@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
 
-ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "IntuitIntel/*"]
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
 models = [shard.model_id for shards in model_base_shards.values() for shard in shards.values() if not ignore_pattern.match(shard.model_id)]