Browse Source

Merge remote-tracking branch 'origin/main' into runners

Alex Cheema 7 months ago
parent
commit
2ff4638122
2 changed files with 30 additions and 0 deletions
  1. 6 0
      exo/inference/tinygrad/models/llama.py
  2. 24 0
      exo/models.py

+ 6 - 0
exo/inference/tinygrad/models/llama.py

@@ -314,6 +314,12 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
 
 
 
 
 def fix_bf16(weights: Dict[Any, Tensor]):
 def fix_bf16(weights: Dict[Any, Tensor]):
+  if Device.DEFAULT == "CLANG":
+    # TODO: without casting to float16, 70B llama OOM on tinybox.
+    return {
+      k: (v.llvm_bf16_cast(dtypes.float32).to(v.device) if v.dtype == dtypes.bfloat16 else v) 
+      for k, v in weights.items()
+    }
   if getenv("SUPPORT_BF16", 1):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
     return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

+ 24 - 0
exo/models.py

@@ -17,7 +17,28 @@ model_cards = {
       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
     },
     },
   },
   },
+  "llama-3.2-1b-8bit": {
+    "layers": 16,
+    "repo": {
+      "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-8bit",
+      "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
+    },
+  },
   "llama-3.2-3b": {
   "llama-3.2-3b": {
+    "layers": 28,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+    },
+  },
+  "llama-3.2-3b-8bit": {
+    "layers": 28,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-8bit",
+       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+    },
+  },
+  "llama-3.2-3b-bf16": {
     "layers": 28,
     "layers": 28,
     "repo": {
     "repo": {
        "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct",
        "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct",
@@ -94,7 +115,10 @@ model_cards = {
 pretty_name = {
 pretty_name = {
   "llama-3.3-70b": "Llama 3.3 70B",
   "llama-3.3-70b": "Llama 3.3 70B",
   "llama-3.2-1b": "Llama 3.2 1B",
   "llama-3.2-1b": "Llama 3.2 1B",
+  "llama-3.2-1b-8bit": "Llama 3.2 1B (8-bit)",
   "llama-3.2-3b": "Llama 3.2 3B",
   "llama-3.2-3b": "Llama 3.2 3B",
+  "llama-3.2-3b-8bit": "Llama 3.2 3B (8-bit)",
+  "llama-3.2-3b-bf16": "Llama 3.2 3B (BF16)",
   "llama-3.1-8b": "Llama 3.1 8B",
   "llama-3.1-8b": "Llama 3.1 8B",
   "llama-3.1-70b": "Llama 3.1 70B",
   "llama-3.1-70b": "Llama 3.1 70B",
   "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
   "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",