Browse Source

clean debug logs

Alex Cheema 1 year ago
parent
commit
bcd58938de

+ 4 - 1
exo/inference/mlx/models/sharded_llama.py

@@ -90,6 +90,8 @@ class Attention(nn.Module):
     ) -> mx.array:
     ) -> mx.array:
         B, L, D = x.shape
         B, L, D = x.shape
 
 
+        print("q_proj: ", self.q_proj)
+        print("x: ", x.shape)
         queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
         queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
 
 
         # Prepare the queries, keys and values for the attention computation
         # Prepare the queries, keys and values for the attention computation
@@ -190,7 +192,8 @@ class LlamaModel(nn.Module):
         if cache is None:
         if cache is None:
             cache = [None] * len(self.layers)
             cache = [None] * len(self.layers)
 
 
-        for layer, c in zip(self.layers, cache):
+        for i, (layer, c) in enumerate(zip(self.layers, cache)):
+            print(f"layer: {i}")
             h = layer(h, mask, cache=c)
             h = layer(h, mask, cache=c)
 
 
         if self.args.shard.is_last_layer():
         if self.args.shard.is_last_layer():

+ 3 - 2
exo/inference/mlx/test_sharded_llama.py

@@ -3,6 +3,7 @@ from exo.inference.mlx.sharded_model import StatefulShardedModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
+# 79, 80 for Llama-3-70B
 shard_full = Shard("llama", 0, 31, 32)
 shard_full = Shard("llama", 0, 31, 32)
 shard1 = Shard("llama", 0, 12, 32)
 shard1 = Shard("llama", 0, 12, 32)
 shard2 = Shard("llama", 13, 31, 32)
 shard2 = Shard("llama", 13, 31, 32)
@@ -16,7 +17,7 @@ m1 = StatefulShardedModel(shard1, model_shard1)
 m2 = StatefulShardedModel(shard2, model_shard2)
 m2 = StatefulShardedModel(shard2, model_shard2)
 
 
 prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
 prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
-prompt_tokens = mx.array(tokenizer1.encode(prompt))
+prompt_tokens = mx.array(full_tokenizer.encode(prompt))
 max_tokens = 50
 max_tokens = 50
 
 
 resp = prompt_tokens
 resp = prompt_tokens
@@ -25,7 +26,7 @@ for _ in range(max_tokens):
     resp = full.step(resp)
     resp = full.step(resp)
     full_generated_tokens.append(resp.item())
     full_generated_tokens.append(resp.item())
 
 
-print("full response: ", tokenizer1.decode(full_generated_tokens))
+print("full response: ", full_tokenizer.decode(full_generated_tokens))
 
 
 
 
 sharded_generated_tokens = []
 sharded_generated_tokens = []

+ 18 - 17
exo/orchestration/standard_node.py

@@ -7,6 +7,7 @@ from exo.topology.topology import Topology
 from exo.topology.device_capabilities import device_capabilities
 from exo.topology.device_capabilities import device_capabilities
 from exo.topology.partitioning_strategy import PartitioningStrategy
 from exo.topology.partitioning_strategy import PartitioningStrategy
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
+from exo import DEBUG
 import asyncio
 import asyncio
 import uuid
 import uuid
 
 
@@ -29,7 +30,7 @@ class StandardNode(Node):
         await self.discovery.start()
         await self.discovery.start()
         await self.update_peers(wait_for_peers)
         await self.update_peers(wait_for_peers)
         await self.collect_topology()
         await self.collect_topology()
-        print(f"Collected topology: {self.topology}")
+        if DEBUG >= 2: print(f"Collected topology: {self.topology}")
         asyncio.create_task(self.periodic_topology_collection(5))
         asyncio.create_task(self.periodic_topology_collection(5))
 
 
     async def stop(self) -> None:
     async def stop(self) -> None:
@@ -42,7 +43,7 @@ class StandardNode(Node):
         if request_id not in self.buffered_token_output:
         if request_id not in self.buffered_token_output:
             self.buffered_token_output[request_id] = ([], False)
             self.buffered_token_output[request_id] = ([], False)
 
 
-        print(f"[{request_id}] process prompt: {shard}, {prompt}")
+        if DEBUG >= 2: print(f"[{request_id}] process prompt: {shard}, {prompt}")
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
         is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
         is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
         if is_finished:
         if is_finished:
@@ -52,7 +53,7 @@ class StandardNode(Node):
             self.buffered_token_output[request_id][0].append(result.item())
             self.buffered_token_output[request_id][0].append(result.item())
             self.on_token(self.buffered_token_output[request_id][0])
             self.on_token(self.buffered_token_output[request_id][0])
 
 
-        print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
+        if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
 
         if not is_finished:
         if not is_finished:
             asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
             asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
@@ -66,7 +67,7 @@ class StandardNode(Node):
             self.buffered_token_output[request_id] = ([], False)
             self.buffered_token_output[request_id] = ([], False)
 
 
         try:
         try:
-            print(f"[{request_id}] process_tensor: {shard}, {tensor}")
+            if DEBUG >= 2: print(f"[{request_id}] process_tensor: {shard}, {tensor}")
             result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
             result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
             is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
             is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
             if is_finished:
             if is_finished:
@@ -75,7 +76,7 @@ class StandardNode(Node):
             if result.size == 1:  # we got a new token out
             if result.size == 1:  # we got a new token out
                 self.buffered_token_output[request_id][0].append(result.item())
                 self.buffered_token_output[request_id][0].append(result.item())
                 self.on_token(self.buffered_token_output[request_id][0])
                 self.on_token(self.buffered_token_output[request_id][0])
-            print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
+            if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
 
             if not is_finished:
             if not is_finished:
                 asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
                 asyncio.create_task(self.forward_tensor_to_next_shard(shard, result, request_id))
@@ -89,16 +90,16 @@ class StandardNode(Node):
 
 
     async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray, request_id: str) -> None:
     async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray, request_id: str) -> None:
         if not self.partitioning_strategy:
         if not self.partitioning_strategy:
-            print("No partitioning strategy found. Skipping forward.")
+            if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
             return
             return
 
 
         partitions = self.partitioning_strategy.partition(self.topology)
         partitions = self.partitioning_strategy.partition(self.topology)
         current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
         current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-        print(f"Current partition index: {current_partition_index}")
+        if DEBUG >= 2: print(f"Current partition index: {current_partition_index}")
         if current_partition_index is not None:
         if current_partition_index is not None:
             next_partition_index = (current_partition_index + 1) % len(partitions)
             next_partition_index = (current_partition_index + 1) % len(partitions)
             next_partition: Partition = partitions[next_partition_index]
             next_partition: Partition = partitions[next_partition_index]
-            print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
+            if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
 
 
             if next_partition:
             if next_partition:
                 if next_partition.node_id == self.id:
                 if next_partition.node_id == self.id:
@@ -113,7 +114,7 @@ class StandardNode(Node):
                 end_layer = int(next_partition.end * shard.n_layers) - 1
                 end_layer = int(next_partition.end * shard.n_layers) - 1
                 next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
                 next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
 
 
-                print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}: {tensor}")
+                if DEBUG >= 2: print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}: {tensor}")
 
 
                 await target_peer.send_tensor(next_shard, tensor, request_id)
                 await target_peer.send_tensor(next_shard, tensor, request_id)
 
 
@@ -131,20 +132,20 @@ class StandardNode(Node):
 
 
     async def reset_shard(self, shard: Shard) -> None:
     async def reset_shard(self, shard: Shard) -> None:
         # Implement shard reset logic
         # Implement shard reset logic
-        print(f"Resetting shard: {shard}")
+        if DEBUG >= 2: print(f"Resetting shard: {shard}")
         self.buffered_token_output = {}
         self.buffered_token_output = {}
         await self.inference_engine.reset_shard(self.get_current_shard(shard))
         await self.inference_engine.reset_shard(self.get_current_shard(shard))
 
 
     async def update_peers(self, wait_for_peers: int = 0) -> None:
     async def update_peers(self, wait_for_peers: int = 0) -> None:
         self.peers = await self.discovery.discover_peers(wait_for_peers)
         self.peers = await self.discovery.discover_peers(wait_for_peers)
-        print(f"Starting with the following peers: {self.peers}")
-        print("Connecting to new peers...")
+        if DEBUG >= 2: print(f"Starting with the following peers: {self.peers}")
+        if DEBUG >= 2: print("Connecting to new peers...")
         for peer in self.peers:
         for peer in self.peers:
             is_connected = await peer.is_connected()
             is_connected = await peer.is_connected()
-            print(f"Connected to {peer.id()}: {is_connected}")
+            if DEBUG >= 2: print(f"Connected to {peer.id()}: {is_connected}")
             if not is_connected:
             if not is_connected:
                 await peer.connect()
                 await peer.connect()
-                print(f"Connected to peer {peer.id()}")
+                if DEBUG >= 2: print(f"Connected to peer {peer.id()}")
 
 
     async def collect_topology(self, max_depth: int = 4) -> Topology:
     async def collect_topology(self, max_depth: int = 4) -> Topology:
         self.topology.update_node(self.id, self.device_capabilities)
         self.topology.update_node(self.id, self.device_capabilities)
@@ -156,7 +157,7 @@ class StandardNode(Node):
             if max_depth > 0:
             if max_depth > 0:
                 try:
                 try:
                     other_topology = await peer.collect_topology(max_depth = max_depth - 1)
                     other_topology = await peer.collect_topology(max_depth = max_depth - 1)
-                    print(f"Collected topology from: {peer.id()}: {other_topology}")
+                    if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
                     self.topology.merge(other_topology)
                     self.topology.merge(other_topology)
                 except Exception as e:
                 except Exception as e:
                     print(f"Error collecting topology from {peer.id()}: {e}")
                     print(f"Error collecting topology from {peer.id()}: {e}")
@@ -172,8 +173,8 @@ class StandardNode(Node):
             except Exception as e:
             except Exception as e:
                 print(f"Error collecting topology: {e}")
                 print(f"Error collecting topology: {e}")
 
 
-            print("Topology collection task executed.")
-            print(f"Current topology: {self.topology}")
+            if DEBUG >= 2: print("Topology collection task executed.")
+            if DEBUG >= 2: print(f"Current topology: {self.topology}")
 
 
     async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
     async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
         if request_id not in self.buffered_token_output:
         if request_id not in self.buffered_token_output: