Browse Source

add web url and chatgpt api endpoint to panel (fixes #43), fix a rounding error in the partition to shard mapping implementation

Alex Cheema 1 year ago
parent
commit
a342e1abd8

+ 24 - 26
exo/orchestration/standard_node.py

@@ -3,20 +3,19 @@ import json
 import asyncio
 import asyncio
 import uuid
 import uuid
 import time
 import time
-from typing import List, Dict, Optional, Callable, Tuple, Union
+from typing import List, Dict, Optional, Tuple, Union
 from exo.networking import Discovery, PeerHandle, Server
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
 from exo.inference.inference_engine import InferenceEngine, Shard
 from .node import Node
 from .node import Node
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import device_capabilities
 from exo.topology.device_capabilities import device_capabilities
-from exo.topology.partitioning_strategy import PartitioningStrategy
-from exo.topology.partitioning_strategy import Partition
+from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
 from exo import DEBUG
 from exo import DEBUG
-from exo.helpers import AsyncCallback, AsyncCallbackSystem
+from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
 
 
 class StandardNode(Node):
 class StandardNode(Node):
-    def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256):
+    def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256, chatgpt_api_endpoint: Optional[str] = None, web_chat_url: Optional[str] = None):
         self.id = id
         self.id = id
         self.inference_engine = inference_engine
         self.inference_engine = inference_engine
         self.server = server
         self.server = server
@@ -26,7 +25,7 @@ class StandardNode(Node):
         self.topology: Topology = Topology()
         self.topology: Topology = Topology()
         self.device_capabilities = device_capabilities()
         self.device_capabilities = device_capabilities()
         self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
         self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-        self.topology_viz = TopologyViz()
+        self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url)
         self.max_generate_tokens = max_generate_tokens
         self.max_generate_tokens = max_generate_tokens
         self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
         self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
         self._on_opaque_status = AsyncCallbackSystem[str, str]()
         self._on_opaque_status = AsyncCallbackSystem[str, str]()
@@ -57,28 +56,29 @@ class StandardNode(Node):
         await self.discovery.stop()
         await self.discovery.stop()
         await self.server.stop()
         await self.server.stop()
 
 
-    async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
         asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
         start_time = time.perf_counter_ns()
         start_time = time.perf_counter_ns()
-        resp = await self._process_prompt(shard, prompt, request_id, inference_state)
+        resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
         end_time = time.perf_counter_ns()
         end_time = time.perf_counter_ns()
         elapsed_time_ns = end_time - start_time
         elapsed_time_ns = end_time - start_time
         asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
         asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
         return resp
         return resp
 
 
-    async def _process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         if request_id is None:
         if request_id is None:
             request_id = str(uuid.uuid4())
             request_id = str(uuid.uuid4())
         if request_id not in self.buffered_token_output:
         if request_id not in self.buffered_token_output:
             self.buffered_token_output[request_id] = ([], False)
             self.buffered_token_output[request_id] = ([], False)
+        shard = self.get_current_shard(base_shard)
 
 
-        if DEBUG >= 2: print(f"[{request_id}] process prompt: {shard=} {prompt=}")
-        if self.get_current_shard(shard).start_layer != 0:
-            if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {shard=} {prompt=}")
+        if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+        if shard.start_layer != 0:
+            if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
             await self.forward_to_next_shard(shard, prompt, request_id)
             await self.forward_to_next_shard(shard, prompt, request_id)
             return
             return
 
 
-        result, inference_state, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt, inference_state=inference_state)
+        result, inference_state, is_finished = await self.inference_engine.infer_prompt(shard, prompt, inference_state=inference_state)
         is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
         is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
         if is_finished:
         if is_finished:
             self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
             self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
@@ -104,15 +104,16 @@ class StandardNode(Node):
         asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
         asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
         return resp
         return resp
 
 
-    async def _process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    async def _process_tensor(self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         if request_id is None:
         if request_id is None:
             request_id = str(uuid.uuid4())
             request_id = str(uuid.uuid4())
         if request_id not in self.buffered_token_output:
         if request_id not in self.buffered_token_output:
             self.buffered_token_output[request_id] = ([], False)
             self.buffered_token_output[request_id] = ([], False)
+        shard = self.get_current_shard(base_shard)
 
 
         try:
         try:
             if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
             if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-            result, inference_state, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor, inference_state=inference_state)
+            result, inference_state, is_finished = await self.inference_engine.infer_tensor(shard, tensor, inference_state=inference_state)
             is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
             is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
             if is_finished:
             if is_finished:
                 self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
                 self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
@@ -169,23 +170,19 @@ class StandardNode(Node):
                 else:
                 else:
                     await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
                     await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id)
 
 
-    def get_current_shard(self, shard: Shard) -> Shard:
+    def get_current_shard(self, base_shard: Shard) -> Shard:
         partitions = self.partitioning_strategy.partition(self.topology)
         partitions = self.partitioning_strategy.partition(self.topology)
+        shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
         current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
         current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
         if current_partition_index is None:
         if current_partition_index is None:
             raise ValueError(f"No current partition found for node: {self.id}")
             raise ValueError(f"No current partition found for node: {self.id}")
+        return shards[current_partition_index]
 
 
-        current_partition = partitions[current_partition_index]
-        start_layer = int(current_partition.start * shard.n_layers)
-        end_layer = int(current_partition.end * shard.n_layers) - 1
-        return Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
-
-
-    async def reset_shard(self, shard: Shard) -> None:
+    async def reset_shard(self, base_shard: Shard) -> None:
         # Implement shard reset logic
         # Implement shard reset logic
-        if DEBUG >= 2: print(f"Resetting shard: {shard}")
+        if DEBUG >= 2: print(f"Resetting shard: {base_shard}")
         self.buffered_token_output = {}
         self.buffered_token_output = {}
-        await self.inference_engine.reset_shard(self.get_current_shard(shard))
+        await self.inference_engine.reset_shard(self.get_current_shard(base_shard))
 
 
     async def update_peers(self, wait_for_peers: int = 0) -> None:
     async def update_peers(self, wait_for_peers: int = 0) -> None:
         self.peers = await self.discovery.discover_peers(wait_for_peers)
         self.peers = await self.discovery.discover_peers(wait_for_peers)
@@ -250,7 +247,8 @@ class StandardNode(Node):
 
 
     # TODO: unify this and collect_topology as global actions
     # TODO: unify this and collect_topology as global actions
     async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
     async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
-        await self.reset_shard(self.get_current_shard(base_shard))
+        shard = self.get_current_shard(base_shard)
+        await self.reset_shard(shard)
 
 
         if DEBUG >= 2: print(f"Global reset {base_shard=} {max_depth=} {visited=}")
         if DEBUG >= 2: print(f"Global reset {base_shard=} {max_depth=} {visited=}")
 
 

+ 22 - 1
exo/topology/partitioning_strategy.py

@@ -1,7 +1,8 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import List
+from typing import List, Tuple
 from dataclasses import dataclass
 from dataclasses import dataclass
 from .topology import Topology
 from .topology import Topology
+from exo.inference.shard import Shard
 
 
 # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
 # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
 @dataclass
 @dataclass
@@ -14,3 +15,23 @@ class PartitioningStrategy(ABC):
     @abstractmethod
     @abstractmethod
     def partition(self, topology: Topology) -> List[Partition]:
     def partition(self, topology: Topology) -> List[Partition]:
         pass
         pass
+
+def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
+    shards = []
+    for i, partition in enumerate(partitions):
+        start_layer = int(partition.start * num_layers)
+        end_layer = int(partition.end * num_layers) - 1
+
+        # Ensure the last partition covers up to num_layers - 1
+        if i == len(partitions) - 1:
+            end_layer = num_layers - 1
+
+        # Ensure no empty shards
+        if start_layer <= end_layer:
+            shards.append(Shard(model_id, start_layer, end_layer, num_layers))
+
+    # Ensure full coverage
+    if shards and shards[-1].end_layer < num_layers - 1:
+        shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers)
+
+    return shards

+ 68 - 0
exo/topology/test_map_partitions.py

@@ -0,0 +1,68 @@
+import unittest
+from typing import List
+from exo.topology.partitioning_strategy import Partition, map_partitions_to_shards
+from exo.inference.shard import Shard
+
+class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
+    def test_map_partitions_to_shards(self):
+        partitions = [
+            Partition('node1', 0.0, 0.42857),
+            Partition('node2', 0.42857, 0.71428),
+            Partition('node3', 0.71428, 0.99999),
+        ]
+        shards = map_partitions_to_shards(partitions, 32, 'model')
+        self.assertEqual(shards, [
+            Shard('model', 0, 12, 32),
+            Shard('model', 13, 21, 32),
+            Shard('model', 22, 31, 32),
+        ])
+
+        partitions = [
+            Partition('node1', 0.0, 0.1),
+            Partition('node2', 0.1, 0.2),
+            Partition('node3', 0.2, 1.0),
+        ]
+        shards = map_partitions_to_shards(partitions, 32, 'model')
+        self.assertEqual(shards, [
+            Shard('model', 0, 2, 32),
+            Shard('model', 3, 5, 32),
+            Shard('model', 6, 31, 32),
+        ])
+
+        partitions = [
+            Partition('node1', 0.0, 1.0),
+        ]
+        shards = map_partitions_to_shards(partitions, 32, 'model')
+        self.assertEqual(shards, [
+            Shard('model', 0, 31, 32),
+        ])
+
+        partitions = []
+        shards = map_partitions_to_shards(partitions, 32, 'model')
+        self.assertEqual(shards, [])
+
+    def test_broken_map_partitions_to_shards(self):
+        # this was an old broken implementation that sometimes had rounding errors!
+        def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str):
+            shards = []
+            for i, partition in enumerate(partitions):
+                start_layer = int(partition.start * num_layers)
+                end_layer = int(partition.end * num_layers) - 1
+                shards.append(Shard(model_id, start_layer, end_layer, num_layers))
+            return shards
+
+        partitions = [
+            Partition('node1', 0.0, 0.42857),
+            Partition('node2', 0.42857, 0.71428),
+            Partition('node3', 0.71428, 0.99999),
+        ]
+        shards = _broken_map_partitions_to_shards(partitions, 32, 'model')
+        self.assertEqual(shards, [
+            Shard('model', 0, 12, 32),
+            Shard('model', 13, 21, 32),
+            Shard('model', 22, 30, 32),
+        ])
+
+if __name__ == '__main__':
+    unittest.main()
+

+ 22 - 3
exo/topology/test_ring_memory_weighted_partitioning_strategy.py

@@ -1,7 +1,8 @@
 import unittest
 import unittest
-from .ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
-from .topology import Topology, DeviceCapabilities, DeviceFlops
-from .partitioning_strategy import Partition
+from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
+from exo.topology.topology import Topology
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
+from exo.topology.partitioning_strategy import Partition
 
 
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
     def test_partition(self):
     def test_partition(self):
@@ -26,5 +27,23 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
             Partition('node2', 0.9, 1.0),
             Partition('node2', 0.9, 1.0),
         ])
         ])
 
 
+    def test_partition_rounding(self):
+        # triangle
+        # node1 -> node2 -> node3 -> node1
+        topology = Topology()
+        topology.update_node('node1', DeviceCapabilities(model="MacBook Pro", chip="test1", memory=128*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
+        topology.update_node('node2', DeviceCapabilities(model="Mac Studio", chip="test2", memory=192*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
+        topology.update_node('node3', DeviceCapabilities(model="MacBook Pro", chip="test3", memory=128*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
+
+        strategy = RingMemoryWeightedPartitioningStrategy()
+        partitions = strategy.partition(topology)
+
+        self.assertEqual(len(partitions), 3)
+        self.assertEqual(partitions, [
+            Partition('node3', 0.0, 0.42857),
+            Partition('node1', 0.6, 0.9),
+            Partition('node2', 0.9, 1.0),
+        ])
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     unittest.main()
     unittest.main()

+ 18 - 2
exo/viz/topology_viz.py

@@ -12,10 +12,13 @@ from rich.style import Style
 from exo.topology.device_capabilities import DeviceCapabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import DeviceCapabilities, UNKNOWN_DEVICE_CAPABILITIES
 
 
 class TopologyViz:
 class TopologyViz:
-    def __init__(self):
-        self.console = Console()
+    def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
+        self.chatgpt_api_endpoint = chatgpt_api_endpoint
+        self.web_chat_url = web_chat_url
         self.topology = Topology()
         self.topology = Topology()
         self.partitions: List[Partition] = []
         self.partitions: List[Partition] = []
+
+        self.console = Console()
         self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
         self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
         self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
         self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
         self.live_panel.start()
         self.live_panel.start()
@@ -53,6 +56,19 @@ class TopologyViz:
                 if 0 <= start_x + j < 90 and i < len(visualization):  # Ensure we don't exceed the width and height
                 if 0 <= start_x + j < 90 and i < len(visualization):  # Ensure we don't exceed the width and height
                     visualization[i][start_x + j] = char
                     visualization[i][start_x + j] = char
 
 
+        # Display chatgpt_api_endpoint and web_chat_url if set
+        info_lines = []
+        if self.web_chat_url:
+            info_lines.append(f"Web Chat URL (tinychat): {self.web_chat_url}")
+        if self.chatgpt_api_endpoint:
+            info_lines.append(f"ChatGPT API endpoint: {self.chatgpt_api_endpoint}")
+
+        for i, line in enumerate(info_lines):
+            start_x = 0
+            for j, char in enumerate(line):
+                if j < 90 and i + len(exo_lines) < 45:  # Ensure we don't exceed the width and height
+                    visualization[i + len(exo_lines)][j] = char
+
         for i, partition in enumerate(self.partitions):
         for i, partition in enumerate(self.partitions):
             device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
             device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
 
 

+ 1 - 4
main.py

@@ -5,10 +5,8 @@ import uuid
 import platform
 import platform
 import psutil
 import psutil
 import os
 import os
-import json
 from typing import List
 from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
-from exo.viz.topology_viz import TopologyViz
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
@@ -56,12 +54,11 @@ if args.node_port is None:
     args.node_port = find_available_port(args.node_host)
     args.node_port = find_available_port(args.node_host)
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
-node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy())
+node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}")
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
 api = ChatGPTAPI(node, inference_engine.__class__.__name__)
 api = ChatGPTAPI(node, inference_engine.__class__.__name__)
 
 
-topology_viz = TopologyViz()
 node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
 node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
 
 
 async def shutdown(signal, loop):
 async def shutdown(signal, loop):