فهرست منبع

ndim check in llama

Alex Cheema 7 ماه پیش
والد
کامیت
cb575f5dc3
1فایلهای تغییر یافته به همراه1 افزوده شده و 1 حذف شده
  1. 1 1
      exo/inference/mlx/models/llama.py

+ 1 - 1
exo/inference/mlx/models/llama.py

@@ -54,7 +54,7 @@ class LlamaModel(nn.Module):
       h = inputs
 
     mask = None
-    if h.shape[1] > 1:
+    if h.ndim > 1 and h.shape[1] > 1:
       mask = create_attention_mask(h, cache)
 
     if cache is None: