Sfoglia il codice sorgente

ndim check in llama

Alex Cheema 7 mesi fa
parent
commit
cb575f5dc3
1 ha cambiato i file con 1 aggiunte e 1 eliminazioni
  1. 1 1
      exo/inference/mlx/models/llama.py

+ 1 - 1
exo/inference/mlx/models/llama.py

@@ -54,7 +54,7 @@ class LlamaModel(nn.Module):
       h = inputs
 
     mask = None
-    if h.shape[1] > 1:
+    if h.ndim > 1 and h.shape[1] > 1:
       mask = create_attention_mask(h, cache)
 
     if cache is None: