Alex Cheema пре 10 месеци
родитељ
комит
cb575f5dc3
1 измењених фајлова са 1 додато и 1 уклоњено
  1. 1 1
      exo/inference/mlx/models/llama.py

+ 1 - 1
exo/inference/mlx/models/llama.py

@@ -54,7 +54,7 @@ class LlamaModel(nn.Module):
       h = inputs
 
     mask = None
-    if h.shape[1] > 1:
+    if h.ndim > 1 and h.shape[1] > 1:
       mask = create_attention_mask(h, cache)
 
     if cache is None: