|
@@ -314,6 +314,12 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
|
|
|
|
|
|
|
|
|
def fix_bf16(weights: Dict[Any, Tensor]):
|
|
|
+ if Device.DEFAULT == "CLANG":
|
|
|
+ # TODO: without casting to float16, 70B llama OOM on tinybox.
|
|
|
+ return {
|
|
|
+ k: (v.llvm_bf16_cast(dtypes.float32).to(v.device) if v.dtype == dtypes.bfloat16 else v)
|
|
|
+ for k, v in weights.items()
|
|
|
+ }
|
|
|
if getenv("SUPPORT_BF16", 1):
|
|
|
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
|
|
return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|