浏览代码

code-breaking typo

oops
divinity76 3 月之前
父节点
当前提交
5fe241ec61
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      exo/inference/tinygrad/models/llama.py

+ 1 - 1
exo/inference/tinygrad/models/llama.py

@@ -322,6 +322,6 @@ def fix_bf16(weights: Dict[Any, Tensor]):
     }
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()
+    return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
   # TODO: check if device supports bf16
   return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}