@@ -54,7 +54,7 @@ class LlamaModel(nn.Module):
h = inputs
mask = None
- if h.shape[1] > 1:
+ if h.ndim > 1 and h.shape[1] > 1:
mask = create_attention_mask(h, cache)
if cache is None: