Jelajahi Sumber

topology with partitioning strategy

Alex Cheema 11 bulan lalu
induk
melakukan
6c8c9ee7b1

+ 1 - 0
inference/inference_engine.py

@@ -9,6 +9,7 @@ class InferenceEngine(ABC):
     async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
         pass
 
+    @abstractmethod
     async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
         pass
 

+ 0 - 1
inference/mlx/sharded_inference_engine.py

@@ -1,4 +1,3 @@
-import mlx.nn as nn
 import numpy as np
 import mlx.core as mx
 from ..inference_engine import InferenceEngine

+ 4 - 1
networking/grpc/grpc_discovery.py

@@ -6,6 +6,7 @@ from typing import List, Dict
 from ..discovery import Discovery
 from ..peer_handle import PeerHandle
 from .grpc_peer_handle import GRPCPeerHandle
+from topology.device_capabilities import DeviceCapabilities, mac_device_capabilities
 
 class GRPCDiscovery(Discovery):
     def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1):
@@ -61,13 +62,15 @@ class GRPCDiscovery(Discovery):
         return list(self.known_peers.values())
 
     async def _broadcast_presence(self):
+        self.device_capabilities: DeviceCapabilities = mac_device_capabilities()
         sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
         sock.settimeout(0.5)
         message = json.dumps({
             "type": "discovery",
             "node_id": self.node_id,
-            "grpc_port": self.node_port
+            "grpc_port": self.node_port,
+            "device_capabilities": self.device_capabilities.to_dict()
         }).encode('utf-8')
 
         while True:

+ 1 - 6
networking/grpc/grpc_peer_handle.py

@@ -34,7 +34,7 @@ class GRPCPeerHandle(PeerHandle):
 
         return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
-    async def send_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> Optional[np.array]:
+    async def send_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.array]:
         request = node_service_pb2.TensorRequest(
             shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
             tensor = node_service_pb2.Tensor(
@@ -42,13 +42,8 @@ class GRPCPeerHandle(PeerHandle):
                 shape=tensor.shape,
                 dtype=str(tensor.dtype)
             ),
-            target=target
         )
         response = await self.stub.SendTensor(request)
-        if target:
-            print(f"Sent tensor to {self.address} with target {target}: shape {tensor.shape}")
-        else:
-            print(f"Sent tensor to {self.address}: shape {tensor.shape}")
 
         if not response.tensor_data or not response.shape or not response.dtype:
             return None

+ 2 - 4
networking/grpc/grpc_server.py

@@ -31,16 +31,14 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     async def SendPrompt(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         prompt = request.prompt
-        target = request.target if request.HasField('target') else None
-        result = await self.node.process_prompt(shard, prompt, target)
+        result = await self.node.process_prompt(shard, prompt)
         tensor_data = result.tobytes() if result is not None else None
         return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
 
     async def SendTensor(self, request, context):
         shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
-        target = request.target if request.HasField('target') else None
-        result = await self.node.process_tensor(shard, tensor, target)
+        result = await self.node.process_tensor(shard, tensor)
         print("SendTensor tensor result", result)
         tensor_data = result.tobytes() if result is not None else None
         return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))

+ 0 - 2
networking/grpc/node_service.proto

@@ -18,13 +18,11 @@ message Shard {
 message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
-  optional string target = 3;
 }
 
 message TensorRequest {
   Shard shard = 1;
   Tensor tensor = 2;
-  optional string target = 3;
 }
 
 message Tensor {

+ 12 - 12
networking/grpc/node_service_pb2.py

@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
 
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"c\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"\x07\n\x05\x45mpty2\xd9\x01\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"C\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\"Y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"\x07\n\x05\x45mpty2\xd9\x01\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
 
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -24,15 +24,15 @@ if not _descriptor._USE_C_DESCRIPTORS:
   _globals['_SHARD']._serialized_start=36
   _globals['_SHARD']._serialized_end=119
   _globals['_PROMPTREQUEST']._serialized_start=121
-  _globals['_PROMPTREQUEST']._serialized_end=220
-  _globals['_TENSORREQUEST']._serialized_start=222
-  _globals['_TENSORREQUEST']._serialized_end=343
-  _globals['_TENSOR']._serialized_start=345
-  _globals['_TENSOR']._serialized_end=404
-  _globals['_RESETSHARDREQUEST']._serialized_start=406
-  _globals['_RESETSHARDREQUEST']._serialized_end=461
-  _globals['_EMPTY']._serialized_start=463
-  _globals['_EMPTY']._serialized_end=470
-  _globals['_NODESERVICE']._serialized_start=473
-  _globals['_NODESERVICE']._serialized_end=690
+  _globals['_PROMPTREQUEST']._serialized_end=188
+  _globals['_TENSORREQUEST']._serialized_start=190
+  _globals['_TENSORREQUEST']._serialized_end=279
+  _globals['_TENSOR']._serialized_start=281
+  _globals['_TENSOR']._serialized_end=340
+  _globals['_RESETSHARDREQUEST']._serialized_start=342
+  _globals['_RESETSHARDREQUEST']._serialized_end=397
+  _globals['_EMPTY']._serialized_start=399
+  _globals['_EMPTY']._serialized_end=406
+  _globals['_NODESERVICE']._serialized_start=409
+  _globals['_NODESERVICE']._serialized_end=626
 # @@protoc_insertion_point(module_scope)

+ 1 - 1
networking/grpc/node_service_pb2_grpc.py

@@ -3,7 +3,7 @@
 import grpc
 import warnings
 
-from . import node_service_pb2 as node__service__pb2
+import node_service_pb2 as node__service__pb2
 
 GRPC_GENERATED_VERSION = '1.64.1'
 GRPC_VERSION = grpc.__version__

+ 6 - 0
networking/peer_handle.py

@@ -2,11 +2,17 @@ from abc import ABC, abstractmethod
 from typing import Optional
 import numpy as np
 from inference.shard import Shard
+from topology.device_capabilities import DeviceCapabilities
 
 class PeerHandle(ABC):
+    @abstractmethod
     def id(self) -> str:
         pass
 
+    @abstractmethod
+    def device_capabilities(self) -> DeviceCapabilities:
+        pass
+
     @abstractmethod
     async def connect(self) -> None:
         pass

+ 2 - 2
orchestration/node.py

@@ -13,11 +13,11 @@ class Node(ABC):
         pass
 
     @abstractmethod
-    def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
+    def process_tensor(self, shard: Shard, tensor: np.ndarray) -> None:
         pass
 
     @abstractmethod
-    def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> None:
+    def process_prompt(self, shard: Shard, prompt: str) -> None:
         pass
 
     @abstractmethod

+ 7 - 9
orchestration/standard_node.py

@@ -3,6 +3,7 @@ import numpy as np
 from networking import Discovery, PeerHandle, Server
 from inference.inference_engine import InferenceEngine, Shard
 from .node import Node
+from topology.topology import Topology
 
 class StandardNode(Node):
     def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery):
@@ -11,7 +12,8 @@ class StandardNode(Node):
         self.server = server
         self.discovery = discovery
         self.peers: List[PeerHandle] = {}
-        self.ring_order: List[str] = []
+        self.topology: Topology = Topology()
+        self.successor: Optional[PeerHandle] = None
 
     async def start(self, wait_for_peers: int = 0) -> None:
         await self.server.start()
@@ -27,18 +29,14 @@ class StandardNode(Node):
         await self.discovery.stop()
         await self.server.stop()
 
-    async def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> Optional[np.array]:
-        print("Process prompt", shard, prompt, target)
+    async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
+        print("Process prompt", shard, prompt)
         result = await self.inference_engine.infer_prompt(shard, prompt)
         # Implement prompt processing logic
         print(f"Got result from prompt: {prompt}. Result: {result}")
         # You might want to initiate inference here
-        if target:
-            target_peer = next((p for p in self.peers if p.id() == target), None)
-            if not target_peer:
-                raise ValueError(f"Peer {target} not found")
-
-            await target_peer.send_tensor(result)
+        if self.successor:
+            await self.succesor.send_tensor()
 
         return result
 

+ 0 - 0
topology/__init__.py


+ 27 - 0
topology/device_capabilities.py

@@ -0,0 +1,27 @@
+from dataclasses import dataclass
+import subprocess
+
+@dataclass
+class DeviceCapabilities:
+    model: str
+    chip: str
+    memory: int
+
+def mac_device_capabilities() -> DeviceCapabilities:
+    # Fetch the model of the Mac using system_profiler
+    model = subprocess.check_output(['system_profiler', 'SPHardwareDataType']).decode('utf-8')
+    model_line = next((line for line in model.split('\n') if "Model Name" in line), None)
+    model_id = model_line.split(': ')[1] if model_line else "Unknown Model"
+    chip_line = next((line for line in model.split('\n') if "Chip" in line), None)
+    chip_id = chip_line.split(': ')[1] if chip_line else "Unknown Chip"
+    memory_line = next((line for line in model.split('\n') if "Memory" in line), None)
+    memory_str = memory_line.split(': ')[1] if memory_line else "Unknown Memory"
+    memory_units = memory_str.split()
+    memory_value = int(memory_units[0])
+    if memory_units[1] == "GB":
+        memory = memory_value * 1024
+    else:
+        memory = memory_value
+
+    # Assuming static values for other attributes for demonstration
+    return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory)

+ 10 - 0
topology/partitioning_strategy.py

@@ -0,0 +1,10 @@
+from abc import ABC, abstractmethod
+from typing import List
+from inference.shard import Shard
+from networking.peer_handle import PeerHandle
+from .topology import Topology
+
+class PartitioningStrategy(ABC):
+    @abstractmethod
+    def next_shard(self, current_shard: Shard, topology: Topology, node_stats: dict) -> Shard:
+        pass

+ 27 - 0
topology/ring_memory_weighted_partitioning_strategy.py

@@ -0,0 +1,27 @@
+from .partitioning_strategy import PartitioningStrategy
+from inference.shard import Shard
+from .topology import Topology
+
+class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
+    def next_shard(self, current_shard: Shard, topology: Topology, node_stats: dict) -> Shard:
+        # Get all nodes from the topology and include the current node
+        nodes = list(topology.all_nodes())
+        nodes.append((self.id, None, node_stats))
+
+        # Sort nodes by their IDs
+        nodes.sort(key=lambda x: x[0])
+
+        # Calculate the total memory of all nodes
+        total_memory = sum(node[2]['memory'] for node in nodes)
+
+        # Calculate the number of layers to assign to each node proportional to its memory
+        layers_per_node = {node[0]: (node[2]['memory'] / total_memory) * current_shard.n_layers for node in nodes}
+
+        # Find the successor node
+        node_ids = [node[0] for node in nodes]
+        current_index = node_ids.index(self.id)
+        successor_index = (current_index + 1) % len(node_ids)
+        successor_id = node_ids[successor_index]
+
+        # Return the Shard calculated for the successor
+        return Shard(successor_id, layers_per_node[successor_id])

+ 49 - 0
topology/test_device_capabilities.py

@@ -0,0 +1,49 @@
+import unittest
+from unittest.mock import patch
+from topology.device_capabilities import mac_device_capabilities, DeviceCapabilities
+
+class TestMacDeviceCapabilities(unittest.TestCase):
+    @patch('subprocess.check_output')
+    def test_mac_device_capabilities(self, mock_check_output):
+        # Mock the subprocess output
+        mock_check_output.return_value = b"""
+Hardware:
+
+    Hardware Overview:
+
+        Model Name: MacBook Pro
+        Model Identifier: Mac15,9
+        Model Number: Z1CM000EFB/A
+        Chip: Apple M3 Max
+        Total Number of Cores: 16 (12 performance and 4 efficiency)
+        Memory: 128 GB
+        System Firmware Version: 10000.000.0
+        OS Loader Version: 10000.000.0
+        Serial Number (system): XXXXXXXXXX
+        Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
+        Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
+        Activation Lock Status: Enabled
+        """
+
+        # Call the function
+        result = mac_device_capabilities()
+
+        # Check the results
+        self.assertIsInstance(result, DeviceCapabilities)
+        self.assertEqual(result.model, "MacBook Pro")
+        self.assertEqual(result.chip, "Apple M3 Max")
+        self.assertEqual(result.memory, 131072)  # 16 GB in MB
+
+    @unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
+    def test_mac_device_capabilities_real(self):
+        # Call the function without mocking
+        result = mac_device_capabilities()
+
+        # Check the results
+        self.assertIsInstance(result, DeviceCapabilities)
+        self.assertEqual(result.model, "MacBook Pro")
+        self.assertEqual(result.chip, "Apple M3 Max")
+        self.assertEqual(result.memory, 131072)  # 128 GB in MB
+
+if __name__ == '__main__':
+    unittest.main()

+ 12 - 0
topology/topology.py

@@ -0,0 +1,12 @@
+class Topology:
+    def __init__(self):
+        self.nodes = {}  # Maps node IDs to a tuple of (host, port, stats)
+
+    def update_node(self, node_id, stats):
+        self.nodes[node_id] = stats
+
+    def get_node(self, node_id):
+        return self.nodes.get(node_id)
+
+    def all_nodes(self):
+        return self.nodes.items()