Quellcode durchsuchen

more robust grpc discovery with asyncio and proper error handling, add flops to device capabilities. fixes #23 and progress on #33

Alex Cheema vor 9 Monaten
Ursprung
Commit
54c98607ef

+ 2 - 2
examples/llama3_distributed.py

@@ -6,7 +6,7 @@ from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
 from exo.inference.shard import Shard
 from exo.networking.peer_handle import PeerHandle
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
-from exo.topology.device_capabilities import DeviceCapabilities
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from typing import List
 import asyncio
 import argparse
@@ -32,7 +32,7 @@ tokenizer = load_tokenizer(model_path, tokenizer_config)
 peer2 = GRPCPeerHandle(
     "node2",
     "localhost:8081",
-    DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
+    DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
 )
 shard = models[path_or_hf_repo]
 request_id = str(uuid.uuid4())

+ 50 - 41
exo/networking/grpc/grpc_discovery.py

@@ -2,15 +2,28 @@ import asyncio
 import json
 import socket
 import time
-from typing import List, Dict
+from typing import List, Dict, Callable, Tuple, Coroutine
 from ..discovery import Discovery
 from ..peer_handle import PeerHandle
 from .grpc_peer_handle import GRPCPeerHandle
-from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities
+from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo import DEBUG_DISCOVERY
 
+class ListenProtocol(asyncio.DatagramProtocol):
+    def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
+        super().__init__()
+        self.on_message = on_message
+        self.loop = asyncio.get_event_loop()
+
+    def connection_made(self, transport):
+        self.transport = transport
+
+    def datagram_received(self, data, addr):
+        asyncio.create_task(self.on_message(data, addr))
+
+
 class GRPCDiscovery(Discovery):
-    def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):
+    def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES):
         self.node_id = node_id
         self.node_port = node_port
         self.device_capabilities = device_capabilities
@@ -24,9 +37,10 @@ class GRPCDiscovery(Discovery):
         self.cleanup_task = None
 
     async def start(self):
-        self.broadcast_task = asyncio.create_task(self._broadcast_presence())
-        self.listen_task = asyncio.create_task(self._listen_for_peers())
-        self.cleanup_task = asyncio.create_task(self._cleanup_peers())
+        self.device_capabilities = device_capabilities()
+        self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
+        self.listen_task = asyncio.create_task(self.task_listen_for_peers())
+        self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
 
     async def stop(self):
         if self.broadcast_task:
@@ -62,54 +76,49 @@ class GRPCDiscovery(Discovery):
 
         return list(self.known_peers.values())
 
-    async def _broadcast_presence(self):
-        if not self.device_capabilities:
-            self.device_capabilities = device_capabilities()
-
-        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
+    async def task_broadcast_presence(self):
+        transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
+                    lambda: asyncio.DatagramProtocol(),
+                    local_addr=('0.0.0.0', 0),
+                    family=socket.AF_INET)
+        sock = transport.get_extra_info('socket')
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-        sock.settimeout(0.5)
+
         message = json.dumps({
             "type": "discovery",
             "node_id": self.node_id,
             "grpc_port": self.node_port,
-            "device_capabilities": {
-                "model": self.device_capabilities.model,
-                "chip": self.device_capabilities.chip,
-                "memory": self.device_capabilities.memory
-            }
+            "device_capabilities": self.device_capabilities.to_dict()
         }).encode('utf-8')
 
-        while True:
-            sock.sendto(message, ('<broadcast>', self.broadcast_port))
-            await asyncio.sleep(self.broadcast_interval)
-
-    async def _listen_for_peers(self):
-        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-        sock.bind(('', self.listen_port))
-        sock.setblocking(False)
-
         while True:
             try:
-                data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024)
-                message = json.loads(data.decode('utf-8'))
-                if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
-                if message['type'] == 'discovery' and message['node_id'] != self.node_id:
-                    peer_id = message['node_id']
-                    peer_host = addr[0]
-                    peer_port = message['grpc_port']
-                    device_capabilities = DeviceCapabilities(**message['device_capabilities'])
-                    if peer_id not in self.known_peers:
-                        self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
-                        if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
-                    self.peer_last_seen[peer_id] = time.time()
+                if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
+                transport.sendto(message, ('<broadcast>', self.broadcast_port))
+                await asyncio.sleep(self.broadcast_interval)
             except Exception as e:
-                print(f"Error in peer discovery: {e}")
+                print(f"Error in broadcast presence: {e}")
                 import traceback
                 print(traceback.format_exc())
-                await asyncio.sleep(self.broadcast_interval / 2)
 
-    async def _cleanup_peers(self):
+    async def on_listen_message(self, data, addr):
+        message = json.loads(data.decode('utf-8'))
+        if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
+        if message['type'] == 'discovery' and message['node_id'] != self.node_id:
+            peer_id = message['node_id']
+            peer_host = addr[0]
+            peer_port = message['grpc_port']
+            device_capabilities = DeviceCapabilities(**message['device_capabilities'])
+            if peer_id not in self.known_peers:
+                self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+                if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
+            self.peer_last_seen[peer_id] = time.time()
+
+    async def task_listen_for_peers(self):
+        await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
+        if DEBUG_DISCOVERY >= 2: print("Started listen task")
+
+    async def task_cleanup_peers(self):
         while True:
             current_time = time.time()
             timeout = 15 * self.broadcast_interval

+ 1 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -82,7 +82,7 @@ class GRPCPeerHandle(PeerHandle):
         response = await self.stub.CollectTopology(request)
         topology = Topology()
         for node_id, capabilities in response.nodes.items():
-            device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory)
+            device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
             topology.update_node(node_id, device_capabilities)
         for node_id, peers in response.peer_graph.items():
             for peer_id in peers.peer_ids:

+ 1 - 1
exo/networking/grpc/grpc_server.py

@@ -68,7 +68,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         max_depth = request.max_depth
         visited = set(request.visited)
         topology = await self.node.collect_topology(visited, max_depth)
-        nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory) for node_id, cap in topology.nodes.items()}
+        nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory, flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8)) for node_id, cap in topology.nodes.items()}
         peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
         if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
         return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)

+ 7 - 0
exo/networking/grpc/node_service.proto

@@ -72,10 +72,17 @@ message Peers {
     repeated string peer_ids = 1;
 }
 
+message DeviceFlops {
+  float fp32 = 1;
+  float fp16 = 2;
+  float int8 = 3;
+}
+
 message DeviceCapabilities {
   string model = 1;
   string chip = 2;
   int32 memory = 3;
+  DeviceFlops flops = 4;
 }
 
 message SendResultRequest {

Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 1 - 0
exo/orchestration/standard_node.py

@@ -250,4 +250,5 @@ class StandardNode(Node):
                 import traceback
                 traceback.print_exc()
 
+        print(f"Broadcast result: {request_id=} {result=} {is_finished=}")
         await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)

+ 66 - 6
exo/topology/device_capabilities.py

@@ -1,13 +1,73 @@
 from exo import DEBUG
-from dataclasses import dataclass
+from dataclasses import dataclass, asdict
 import subprocess
 import psutil
 
+TFLOPS = 1.00
+
+@dataclass
+class DeviceFlops:
+    # units of TFLOPS
+    fp32: float
+    fp16: float
+    int8: float
+
+    def __str__(self):
+        return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
+
+    def to_dict(self):
+        return asdict(self)
+
 @dataclass
 class DeviceCapabilities:
     model: str
     chip: str
     memory: int
+    flops: DeviceFlops
+
+    def __str__(self):
+        return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
+
+    def __post_init__(self):
+        if isinstance(self.flops, dict):
+            self.flops = DeviceFlops(**self.flops)
+
+    def to_dict(self):
+        return {
+            'model': self.model,
+            'chip': self.chip,
+            'memory': self.memory,
+            'flops': self.flops.to_dict()
+        }
+
+UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
+
+CHIP_FLOPS = {
+    # Source: https://www.cpu-monkey.com
+    # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
+    ### M chips
+    "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS),
+    "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS),
+    "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS),
+    "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS),
+    "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+    "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
+    "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS),
+    "Apple M2 Ultra": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.30*TFLOPS, int8=42.60*TFLOPS),
+    "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+    "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
+    "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
+    "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+    ### A chips
+    "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
+    "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
+    "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS),
+    "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS),
+    "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS),
+    ### NVIDIA GPUs: TODO
+    ### AMD GPUs: TODO
+    ### Qualcomm embedded chips: TODO
+}
 
 def device_capabilities() -> DeviceCapabilities:
     if psutil.MACOS:
@@ -15,7 +75,7 @@ def device_capabilities() -> DeviceCapabilities:
     elif psutil.LINUX:
         return linux_device_capabilities()
     else:
-        return DeviceCapabilities(model=f"Unknown Device", chip=f"Unknown Chip", memory=psutil.virtual_memory().total // 2**20)
+        return DeviceCapabilities(model=f"Unknown Device", chip=f"Unknown Chip", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
 
 def mac_device_capabilities() -> DeviceCapabilities:
     # Fetch the model of the Mac using system_profiler
@@ -34,7 +94,7 @@ def mac_device_capabilities() -> DeviceCapabilities:
         memory = memory_value
 
     # Assuming static values for other attributes for demonstration
-    return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory)
+    return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
 
 def linux_device_capabilities() -> DeviceCapabilities:
     import psutil
@@ -50,9 +110,9 @@ def linux_device_capabilities() -> DeviceCapabilities:
 
         print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
 
-        return DeviceCapabilities(model=f"Linux Box ({gpu_name})", chip=gpu_name, memory=gpu_memory_info.total // 2**20)
+        return DeviceCapabilities(model=f"Linux Box ({gpu_name})", chip=gpu_name, memory=gpu_memory_info.total // 2**20, flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)))
     elif Device.DEFAULT == "AMD":
         # TODO AMD support
-        return DeviceCapabilities(model="Linux Box (AMD)", chip="Unknown AMD", memory=psutil.virtual_memory().total // 2**20)
+        return DeviceCapabilities(model="Linux Box (AMD)", chip="Unknown AMD", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
     else:
-        return DeviceCapabilities(model=f"Linux Box (Device: {Device.DEFAULT})", chip=f"Unknown Chip (Device: {Device.DEFAULT})", memory=psutil.virtual_memory().total // 2**20)
+        return DeviceCapabilities(model=f"Linux Box (Device: {Device.DEFAULT})", chip=f"Unknown Chip (Device: {Device.DEFAULT})", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0))

+ 4 - 1
exo/topology/test_device_capabilities.py

@@ -1,6 +1,6 @@
 import unittest
 from unittest.mock import patch
-from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities
+from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS
 
 class TestMacDeviceCapabilities(unittest.TestCase):
     @patch('subprocess.check_output')
@@ -33,6 +33,7 @@ Hardware:
         self.assertEqual(result.model, "MacBook Pro")
         self.assertEqual(result.chip, "Apple M3 Max")
         self.assertEqual(result.memory, 131072)  # 16 GB in MB
+        self.assertEqual(str(result), "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS")
 
     @patch('subprocess.check_output')
     def test_mac_device_capabilities(self, mock_check_output):
@@ -75,6 +76,8 @@ Hardware:
         self.assertEqual(result.model, "MacBook Pro")
         self.assertEqual(result.chip, "Apple M3 Max")
         self.assertEqual(result.memory, 131072)  # 128 GB in MB
+        self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS))
+        self.assertEqual(str(result), "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS")
 
 if __name__ == '__main__':
     unittest.main()

+ 4 - 5
exo/topology/test_ring_memory_weighted_partitioning_strategy.py

@@ -1,7 +1,6 @@
 import unittest
-from unittest.mock import MagicMock
 from .ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
-from .topology import Topology, DeviceCapabilities
+from .topology import Topology, DeviceCapabilities, DeviceFlops
 from .partitioning_strategy import Partition
 
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
@@ -9,9 +8,9 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
         # triangle
         # node1 -> node2 -> node3 -> node1
         topology = Topology()
-        topology.update_node('node1', DeviceCapabilities(model="test1", chip="test1", memory=3000))
-        topology.update_node('node2', DeviceCapabilities(model="test2", chip="test2", memory=1000))
-        topology.update_node('node3', DeviceCapabilities(model="test3", chip="test3", memory=6000))
+        topology.update_node('node1', DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
+        topology.update_node('node2', DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
+        topology.update_node('node3', DeviceCapabilities(model="test3", chip="test3", memory=6000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
         topology.add_edge('node1', 'node2')
         topology.add_edge('node2', 'node3')
         topology.add_edge('node3', 'node1')

Einige Dateien werden nicht angezeigt, da zu viele Dateien in diesem Diff geändert wurden.