Browse Source

remove uneeded prints

Alex Cheema 1 year ago
parent
commit
54e8cad2d6
1 changed files with 1 additions and 4 deletions
  1. 1 4
      exo/inference/mlx/models/sharded_llama.py

+ 1 - 4
exo/inference/mlx/models/sharded_llama.py

@@ -90,8 +90,6 @@ class Attention(nn.Module):
     ) -> mx.array:
         B, L, D = x.shape
 
-        print("q_proj: ", self.q_proj)
-        print("x: ", x.shape)
         queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
 
         # Prepare the queries, keys and values for the attention computation
@@ -192,8 +190,7 @@ class LlamaModel(nn.Module):
         if cache is None:
             cache = [None] * len(self.layers)
 
-        for i, (layer, c) in enumerate(zip(self.layers, cache)):
-            print(f"layer: {i}")
+        for layer, c in zip(self.layers, cache):
             h = layer(h, mask, cache=c)
 
         if self.args.shard.is_last_layer():