소스 검색

ndim check in llama

Alex Cheema 10 달 전
부모
커밋
cb575f5dc3
1개의 변경된 파일1개의 추가작업 그리고 1개의 파일을 삭제
  1. 1 1
      exo/inference/mlx/models/llama.py

+ 1 - 1
exo/inference/mlx/models/llama.py

@@ -54,7 +54,7 @@ class LlamaModel(nn.Module):
       h = inputs
 
     mask = None
-    if h.shape[1] > 1:
+    if h.ndim > 1 and h.shape[1] > 1:
       mask = create_attention_mask(h, cache)
 
     if cache is None: