Browse Source

Applied patch idea from https://github.com/exo-explore/exo/issues/458

Will Bickford 4 months ago
parent
commit
89815b16b2
1 changed files with 6 additions and 0 deletions
  1. 6 0
      exo/inference/tinygrad/models/llama.py

+ 6 - 0
exo/inference/tinygrad/models/llama.py

@@ -275,6 +275,12 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
 
 
 
 
 def fix_bf16(weights: Dict[Any, Tensor]):
 def fix_bf16(weights: Dict[Any, Tensor]):
+  if Device.DEFAULT == "CLANG":
+    # TODO: without casting to float16, 70B llama OOM on tinybox.
+    return {
+      k: (v.llvm_bf16_cast(dtypes.float32).to(v.device) if v.dtype == dtypes.bfloat16 else v) 
+      for k, v in weights.items()
+    }
   if getenv("SUPPORT_BF16", 1):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
     return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}