Browse Source

Merge pull request #542 from wbic16/fix-issue-458

Support CPU-Only CLANG
Alex Cheema 4 months ago
parent
commit
2f74ea112e
1 changed files with 6 additions and 0 deletions
  1. 6 0
      exo/inference/tinygrad/models/llama.py

+ 6 - 0
exo/inference/tinygrad/models/llama.py

@@ -314,6 +314,12 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
 
 
 def fix_bf16(weights: Dict[Any, Tensor]):
+  if Device.DEFAULT == "CLANG":
+    # TODO: without casting to float16, 70B llama OOM on tinybox.
+    return {
+      k: (v.llvm_bf16_cast(dtypes.float32).to(v.device) if v.dtype == dtypes.bfloat16 else v) 
+      for k, v in weights.items()
+    }
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}