|
@@ -90,8 +90,6 @@ class Attention(nn.Module):
|
|
|
) -> mx.array:
|
|
|
B, L, D = x.shape
|
|
|
|
|
|
- print("q_proj: ", self.q_proj)
|
|
|
- print("x: ", x.shape)
|
|
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
|
|
|
|
# Prepare the queries, keys and values for the attention computation
|
|
@@ -192,8 +190,7 @@ class LlamaModel(nn.Module):
|
|
|
if cache is None:
|
|
|
cache = [None] * len(self.layers)
|
|
|
|
|
|
- for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
|
|
- print(f"layer: {i}")
|
|
|
+ for layer, c in zip(self.layers, cache):
|
|
|
h = layer(h, mask, cache=c)
|
|
|
|
|
|
if self.args.shard.is_last_layer():
|