瀏覽代碼

scaffolding for networking, inference and orchestration

Alex Cheema 1 年之前
當前提交
a21f59ff45

+ 2 - 0
.gitignore

@@ -0,0 +1,2 @@
+__pycache__/
+.venv

+ 31 - 0
inference/inference_engine.py

@@ -0,0 +1,31 @@
+import numpy as np
+import mlx.nn as nn
+
+from abc import ABC, abstractmethod
+from .shard import Shard
+
+class InferenceEngine(ABC):
+    @abstractmethod
+    async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+        pass
+
+    @abstractmethod
+    async def reset_shard(self, shard: Shard):
+        pass
+
+class MLXFixedShardInferenceEngine(InferenceEngine):
+    def __init__(self, model: nn.Module, shard: Shard):
+        self.model = model
+        self.shard = shard
+
+    async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+        if shard != self.shard:
+            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
+
+        output_data = self.model.process(input_data)
+        print("Processed data through model shard")
+        return output_data
+
+    async def reset_shard(self, shard: Shard):
+        # TODO
+        print(f"Resetting shard: {shard}")

+ 8 - 0
inference/shard.py

@@ -0,0 +1,8 @@
+from dataclasses import dataclass
+
+@dataclass
+class Shard:
+    model_id: str
+    n_layers: int
+    start_layer: int
+    end_layer: int

+ 71 - 0
main.py

@@ -0,0 +1,71 @@
+import argparse
+import asyncio
+import signal
+import mlx.core as mx
+import mlx.nn as nn
+from orchestration.standard_node import StandardNode
+from networking.grpc.grpc_server import GRPCServer
+from inference.inference_engine import MLXFixedShardInferenceEngine
+from inference.shard import Shard
+from networking.grpc.grpc_discovery import GRPCDiscovery
+
+class SimpleMLXModel(nn.Module):
+    def __init__(self):
+        super(SimpleMLXModel, self).__init__()
+        self.linear = nn.Linear(10, 5)  # Example dimensions
+
+    def forward(self, x):
+        return self.linear(x)
+
+
+# parse args
+parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
+parser.add_argument("--node-id", type=str, default="node1", help="Node ID")
+parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
+parser.add_argument("--node-port", type=int, default=8080, help="Node port")
+parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
+parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
+args = parser.parse_args()
+
+mlx_model = SimpleMLXModel()
+inference_engine = MLXFixedShardInferenceEngine(mlx_model, shard=Shard(model_id="test", n_layers=32, start_layer=0, end_layer=31))
+discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
+node = StandardNode(args.node_id, None, inference_engine, discovery)
+server = GRPCServer(node, args.node_host, args.node_port)
+node.server = server
+
+async def shutdown(signal, loop):
+    """Gracefully shutdown the server and close the asyncio loop."""
+    print(f"Received exit signal {signal.name}...")
+    server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+    [task.cancel() for task in server_tasks]
+    print(f"Cancelling {len(server_tasks)} outstanding tasks")
+    await asyncio.gather(*server_tasks, return_exceptions=True)
+    await server.shutdown()
+    loop.stop()
+
+async def main():
+    loop = asyncio.get_running_loop()
+
+    # Use a more direct approach to handle signals
+    def handle_exit():
+        asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
+
+    for s in [signal.SIGINT, signal.SIGTERM]:
+        loop.add_signal_handler(s, handle_exit)
+
+    await node.start()
+
+    await asyncio.sleep(5)
+    print("Sending reset shard request")
+    await node.peers[0].reset_shard(f"regards from {node.id}")
+
+    await asyncio.Event().wait()
+
+if __name__ == "__main__":
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(loop)
+    try:
+        loop.run_until_complete(main())
+    finally:
+        loop.close()

+ 5 - 0
networking/__init__.py

@@ -0,0 +1,5 @@
+from .discovery import Discovery
+from .peer_handle import PeerHandle
+from .server import Server
+
+__all__ = ['Discovery', 'PeerHandle', 'Server']

+ 16 - 0
networking/discovery.py

@@ -0,0 +1,16 @@
+from abc import ABC, abstractmethod
+from typing import List
+from .peer_handle import PeerHandle
+
+class Discovery(ABC):
+    @abstractmethod
+    async def start(self) -> None:
+        pass
+
+    @abstractmethod
+    async def stop(self) -> None:
+        pass
+
+    @abstractmethod
+    async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+        pass

+ 0 - 0
networking/grpc/__init__.py


+ 105 - 0
networking/grpc/grpc_discovery.py

@@ -0,0 +1,105 @@
+import asyncio
+import json
+import socket
+import time
+from typing import List, Dict
+from ..discovery import Discovery
+from ..peer_handle import PeerHandle
+from .grpc_peer_handle import GRPCPeerHandle
+
+class GRPCDiscovery(Discovery):
+    def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1):
+        self.node_id = node_id
+        self.node_port = node_port
+        self.listen_port = listen_port
+        self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
+        self.broadcast_interval = broadcast_interval
+        self.known_peers: Dict[str, GRPCPeerHandle] = {}
+        self.peer_last_seen: Dict[str, float] = {}
+        self.broadcast_task = None
+        self.listen_task = None
+        self.cleanup_task = None
+
+    async def start(self):
+        self.broadcast_task = asyncio.create_task(self._broadcast_presence())
+        self.listen_task = asyncio.create_task(self._listen_for_peers())
+        self.cleanup_task = asyncio.create_task(self._cleanup_peers())
+
+    async def stop(self):
+        if self.broadcast_task:
+            self.broadcast_task.cancel()
+        if self.listen_task:
+            self.listen_task.cancel()
+        if self.cleanup_task:
+            self.cleanup_task.cancel()
+        if self.broadcast_task or self.listen_task or self.cleanup_task:
+            await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
+
+    async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+        print("Starting peer discovery process...")
+
+        if wait_for_peers > 0:
+            while not self.known_peers:
+                print("No peers discovered yet, retrying in 1 second...")
+                await asyncio.sleep(1)  # Keep trying to find peers
+            print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
+
+        grace_period = 5  # seconds
+        while True:
+            initial_peer_count = len(self.known_peers)
+            print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
+            await asyncio.sleep(grace_period)
+            if len(self.known_peers) == initial_peer_count:
+                if wait_for_peers > 0:
+                    print(f"Waiting additional {wait_for_peers} seconds for more peers.")
+                    await asyncio.sleep(wait_for_peers)
+                    wait_for_peers = 0
+                else:
+                    print("No new peers discovered in the last grace period. Ending discovery process.")
+                    break  # No new peers found in the grace period, we are done
+
+        return list(self.known_peers.values())
+
+    async def _broadcast_presence(self):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+        sock.settimeout(0.5)
+        message = json.dumps({
+            "type": "discovery",
+            "node_id": self.node_id,
+            "grpc_port": self.node_port
+        }).encode('utf-8')
+
+        while True:
+            sock.sendto(message, ('<broadcast>', self.broadcast_port))
+            await asyncio.sleep(self.broadcast_interval)
+
+    async def _listen_for_peers(self):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        sock.bind(('', self.listen_port))
+        sock.setblocking(False)
+
+        while True:
+            try:
+                data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024)
+                message = json.loads(data.decode('utf-8'))
+                if message['type'] == 'discovery' and message['node_id'] != self.node_id:
+                    peer_id = message['node_id']
+                    peer_host = addr[0]
+                    peer_port = message['grpc_port']
+                    self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}")
+                    self.peer_last_seen[peer_id] = time.time()
+            except Exception as e:
+                print(f"Error in peer discovery: {e}")
+                await asyncio.sleep(self.broadcast_interval / 2)
+
+    async def _cleanup_peers(self):
+        while True:
+            current_time = time.time()
+            timeout = 5 * self.broadcast_interval
+            peers_to_remove = [peer_id for peer_id, last_seen in self.peer_last_seen.items() if current_time - last_seen > timeout]
+            for peer_id in peers_to_remove:
+                del self.known_peers[peer_id]
+                del self.peer_last_seen[peer_id]
+                print(f"Removed peer {peer_id} due to inactivity.")
+            await asyncio.sleep(self.broadcast_interval)

+ 47 - 0
networking/grpc/grpc_peer_handle.py

@@ -0,0 +1,47 @@
+import grpc
+import numpy as np
+from typing import Optional
+
+# These would be generated from the .proto file
+from . import node_service_pb2
+from . import node_service_pb2_grpc
+
+from ..peer_handle import PeerHandle
+
+class GRPCPeerHandle(PeerHandle):
+    def __init__(self, id: str, address: str):
+        self._id = id
+        self.address = address
+
+    def id(self) -> str:
+        return self._id
+
+    async def connect(self):
+        self.channel = grpc.aio.insecure_channel(self.address)
+        self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
+
+    async def disconnect(self):
+        await self.channel.close()
+
+    async def send_prompt(self, prompt: str) -> None:
+        request = node_service_pb2.PromptRequest(prompt=prompt)
+        await self.stub.SendPrompt(request)
+        print(f"Sent prompt to {self.address}: {prompt}")
+
+    async def send_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
+        request = node_service_pb2.TensorRequest(
+            tensor_data=tensor.tobytes(),
+            shape=tensor.shape,
+            dtype=str(tensor.dtype),
+            target=target
+        )
+        await self.stub.SendTensor(request)
+        if target:
+            print(f"Sent tensor to {self.address} with target {target}: shape {tensor.shape}")
+        else:
+            print(f"Sent tensor to {self.address}: shape {tensor.shape}")
+
+    async def reset_shard(self, shard_id: str) -> None:
+        request = node_service_pb2.ResetShardRequest(shard_id=shard_id)
+        await self.stub.ResetShard(request)
+        print(f"Reset shard {shard_id} on {self.address}")

+ 57 - 0
networking/grpc/grpc_server.py

@@ -0,0 +1,57 @@
+import grpc
+from concurrent import futures
+import numpy as np
+
+from . import node_service_pb2
+from . import node_service_pb2_grpc
+
+from orchestration import Node
+
+class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
+    def __init__(self, node: Node, host: str, port: int):
+        self.node = node
+        self.host = host
+        self.port = port
+        self.server = None
+
+    async def start(self) -> None:
+        self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10))
+        node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
+        listen_addr = f'{self.host}:{self.port}'
+        self.server.add_insecure_port(listen_addr)
+        await self.server.start()
+        print(f"Server started, listening on {listen_addr}")
+
+    async def stop(self) -> None:
+        if self.server:
+            await self.server.stop(5)  # 5 seconds grace period
+            print("Server stopped")
+
+    async def SendPrompt(self, request, context):
+        prompt = request.prompt
+        target = request.target if request.HasField('target') else None
+        if target and target != self.node.node_id:
+            await self.node.process_prompt(prompt, target)
+        else:
+            # Process the prompt locally
+            # You'd need to implement this method in the Node class
+            await self.node.process_prompt(prompt)
+        return node_service_pb2.Empty()
+
+    async def SendTensor(self, request, context):
+        tensor = np.frombuffer(request.tensor_data, dtype=np.dtype(request.dtype)).reshape(request.shape)
+        target = request.target if request.HasField('target') else None
+        if target and target != self.node.node_id:
+            await self.node.process_tensor(tensor, target)
+        else:
+            # Process the tensor locally
+            await self.node.inference_strategy.process_inference(tensor)
+        return node_service_pb2.Empty()
+
+    async def ResetShard(self, request, context):
+        print(f"Received ResetShard request: {request}")
+        # TODO
+        # shard_id = request.shard_id
+        # You'd need to implement this method in the Node class
+        # await self.node.reset_shard(shard_id)
+        return node_service_pb2.Empty()

+ 27 - 0
networking/grpc/node_service.proto

@@ -0,0 +1,27 @@
+syntax = "proto3";
+
+package node_service;
+
+service NodeService {
+  rpc SendPrompt (PromptRequest) returns (Empty) {}
+  rpc SendTensor (TensorRequest) returns (Empty) {}
+  rpc ResetShard (ResetShardRequest) returns (Empty) {}
+}
+
+message PromptRequest {
+  string prompt = 1;
+  optional string target = 2;
+}
+
+message TensorRequest {
+  bytes tensor_data = 1;
+  repeated int32 shape = 2;
+  string dtype = 3;
+  optional string target = 4;
+}
+
+message ResetShardRequest {
+  string shard_id = 1;
+}
+
+message Empty {}

+ 34 - 0
networking/grpc/node_service_pb2.py

@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler.  DO NOT EDIT!
+# source: node_service.proto
+# Protobuf Python Version: 5.26.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"?\n\rPromptRequest\x12\x0e\n\x06prompt\x18\x01 \x01(\t\x12\x13\n\x06target\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"b\n\rTensorRequest\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\x12\x13\n\x06target\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"%\n\x11ResetShardRequest\x12\x10\n\x08shard_id\x18\x01 \x01(\t\"\x07\n\x05\x45mpty2\xd7\x01\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
+if not _descriptor._USE_C_DESCRIPTORS:
+  DESCRIPTOR._loaded_options = None
+  _globals['_PROMPTREQUEST']._serialized_start=36
+  _globals['_PROMPTREQUEST']._serialized_end=99
+  _globals['_TENSORREQUEST']._serialized_start=101
+  _globals['_TENSORREQUEST']._serialized_end=199
+  _globals['_RESETSHARDREQUEST']._serialized_start=201
+  _globals['_RESETSHARDREQUEST']._serialized_end=238
+  _globals['_EMPTY']._serialized_start=240
+  _globals['_EMPTY']._serialized_end=247
+  _globals['_NODESERVICE']._serialized_start=250
+  _globals['_NODESERVICE']._serialized_end=465
+# @@protoc_insertion_point(module_scope)

+ 188 - 0
networking/grpc/node_service_pb2_grpc.py

@@ -0,0 +1,188 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+import warnings
+
+from . import node_service_pb2 as node__service__pb2
+
+GRPC_GENERATED_VERSION = '1.64.1'
+GRPC_VERSION = grpc.__version__
+EXPECTED_ERROR_RELEASE = '1.65.0'
+SCHEDULED_RELEASE_DATE = 'June 25, 2024'
+_version_not_supported = False
+
+try:
+    from grpc._utilities import first_version_is_lower
+    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+except ImportError:
+    _version_not_supported = True
+
+if _version_not_supported:
+    warnings.warn(
+        f'The grpc package installed is at version {GRPC_VERSION},'
+        + f' but the generated code in node_service_pb2_grpc.py depends on'
+        + f' grpcio>={GRPC_GENERATED_VERSION}.'
+        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
+        RuntimeWarning
+    )
+
+
+class NodeServiceStub(object):
+    """Missing associated documentation comment in .proto file."""
+
+    def __init__(self, channel):
+        """Constructor.
+
+        Args:
+            channel: A grpc.Channel.
+        """
+        self.SendPrompt = channel.unary_unary(
+                '/node_service.NodeService/SendPrompt',
+                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.SendTensor = channel.unary_unary(
+                '/node_service.NodeService/SendTensor',
+                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.ResetShard = channel.unary_unary(
+                '/node_service.NodeService/ResetShard',
+                request_serializer=node__service__pb2.ResetShardRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+
+
+class NodeServiceServicer(object):
+    """Missing associated documentation comment in .proto file."""
+
+    def SendPrompt(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def SendTensor(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def ResetShard(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+
+def add_NodeServiceServicer_to_server(servicer, server):
+    rpc_method_handlers = {
+            'SendPrompt': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendPrompt,
+                    request_deserializer=node__service__pb2.PromptRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'SendTensor': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendTensor,
+                    request_deserializer=node__service__pb2.TensorRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'ResetShard': grpc.unary_unary_rpc_method_handler(
+                    servicer.ResetShard,
+                    request_deserializer=node__service__pb2.ResetShardRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+    }
+    generic_handler = grpc.method_handlers_generic_handler(
+            'node_service.NodeService', rpc_method_handlers)
+    server.add_generic_rpc_handlers((generic_handler,))
+    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+
+
+ # This class is part of an EXPERIMENTAL API.
+class NodeService(object):
+    """Missing associated documentation comment in .proto file."""
+
+    @staticmethod
+    def SendPrompt(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendPrompt',
+            node__service__pb2.PromptRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def SendTensor(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendTensor',
+            node__service__pb2.TensorRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def ResetShard(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/ResetShard',
+            node__service__pb2.ResetShardRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)

+ 21 - 0
networking/grpc/test_grpc_discovery.py

@@ -0,0 +1,21 @@
+import asyncio
+import unittest
+from .grpc_discovery import GRPCDiscovery
+
+class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
+    async def asyncSetUp(self):
+        self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
+        self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
+        await self.node1.start()
+        await self.node2.start()
+
+    async def asyncTearDown(self):
+        await self.node1.stop()
+        await self.node2.stop()
+
+    async def test_discovery(self):
+        await asyncio.sleep(4)
+
+        # Check discovered peers
+        print("Node1 Peers:", ', '.join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()]))
+        print("Node2 Peers:", ', '.join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()]))

+ 26 - 0
networking/peer_handle.py

@@ -0,0 +1,26 @@
+from abc import ABC, abstractmethod
+from typing import Any
+
+class PeerHandle(ABC):
+    def id(self) -> str:
+        pass
+
+    @abstractmethod
+    async def connect(self) -> None:
+        pass
+
+    @abstractmethod
+    async def disconnect(self) -> None:
+        pass
+
+    @abstractmethod
+    async def send_prompt(self, prompt: str) -> None:
+        pass
+
+    @abstractmethod
+    async def send_tensor(self, tensor: Any) -> None:
+        pass
+
+    @abstractmethod
+    async def reset_shard(self, shard_id: str) -> None:
+        pass

+ 10 - 0
networking/server.py

@@ -0,0 +1,10 @@
+from abc import ABC, abstractmethod
+
+class Server(ABC):
+    @abstractmethod
+    async def start(self) -> None:
+        pass
+
+    @abstractmethod
+    async def stop(self) -> None:
+        pass

+ 4 - 0
orchestration/__init__.py

@@ -0,0 +1,4 @@
+from .node import Node
+from .standard_node import StandardNode
+
+__all__ = ["Node", "StandardNode"]

+ 24 - 0
orchestration/node.py

@@ -0,0 +1,24 @@
+from typing import Optional
+import numpy as np
+from abc import ABC, abstractmethod
+
+class Node(ABC):
+    @abstractmethod
+    def start(self) -> None:
+        pass
+
+    @abstractmethod
+    def stop(self) -> None:
+        pass
+
+    @abstractmethod
+    def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
+        pass
+
+    @abstractmethod
+    def process_prompt(self, prompt: str, target: Optional[str] = None) -> None:
+        pass
+
+    @abstractmethod
+    def reset_shard(self, shard_id: str) -> None:
+        pass

+ 47 - 0
orchestration/standard_node.py

@@ -0,0 +1,47 @@
+from typing import List, Optional
+import numpy as np
+from networking import Discovery, PeerHandle, Server
+from inference.inference_engine import InferenceEngine, Shard
+from .node import Node
+
+class StandardNode(Node):
+    def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery):
+        self.id = id
+        self.inference_engine = inference_engine
+        self.server = server
+        self.discovery = discovery
+        self.peers: List[PeerHandle] = {}
+        self.ring_order: List[str] = []
+
+    async def start(self) -> None:
+        await self.server.start()
+        await self.discovery.start()
+        self.peers = await self.discovery.discover_peers()
+        print(f"Starting with the following peers: {self.peers}")
+        print("Connecting to peers...")
+        for peer in self.peers:
+            await peer.connect()
+            print(f"Connected to {peer.id()}")
+
+    async def stop(self) -> None:
+        await self.discovery.stop()
+        await self.server.stop()
+
+    async def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
+        result = await self.inference_engine.process_shard(tensor)
+
+        if target:
+            if not filter(lambda p: p.id() == target, self.peers):
+                raise ValueError(f"Peer {target} not found")
+
+            await self.peers[target].send_tensor(result)
+
+    async def process_prompt(self, prompt: str) -> None:
+        # Implement prompt processing logic
+        print(f"Processing prompt: {prompt}")
+        # You might want to initiate inference here
+
+    async def reset_shard(self, shard: Shard) -> None:
+        # Implement shard reset logic
+        print(f"Resetting shard: {shard}")
+        await self.inference_engine.reset_shard(shard)

+ 56 - 0
orchestration/test_node.py

@@ -0,0 +1,56 @@
+import unittest
+from unittest.mock import Mock, AsyncMock
+import numpy as np
+
+from .standard_node import StandardNode
+from networking.peer_handle import PeerHandle
+
+class TestNode(unittest.IsolatedAsyncioTestCase):
+    def setUp(self):
+        self.mock_inference_engine = AsyncMock()
+        self.mock_server = AsyncMock()
+        self.mock_server.start = AsyncMock()
+        self.mock_server.stop = AsyncMock()
+        self.mock_discovery = AsyncMock()
+        self.mock_discovery.start = AsyncMock()
+        self.mock_discovery.stop = AsyncMock()
+        mock_peer1 = Mock(spec=PeerHandle)
+        mock_peer1.id.return_value = "peer1"
+        mock_peer2 = Mock(spec=PeerHandle)
+        mock_peer2.id.return_value = "peer2"
+        self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
+
+        self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
+
+    async def asyncSetUp(self):
+        await self.node.start()
+
+    async def asyncTearDown(self):
+        await self.node.stop()
+
+    async def test_node_initialization(self):
+        self.assertEqual(self.node.node_id, "test_node")
+        self.assertEqual(self.node.host, "localhost")
+        self.assertEqual(self.node.port, 50051)
+
+    async def test_node_start(self):
+        self.mock_server.start.assert_called_once_with("localhost", 50051)
+
+    async def test_node_stop(self):
+        await self.node.stop()
+        self.mock_server.stop.assert_called_once()
+
+    async def test_discover_and_connect_to_peers(self):
+        await self.node.discover_and_connect_to_peers()
+        self.assertEqual(len(self.node.peers), 2)
+        self.assertIn("peer1", map(lambda p: p.id(), self.node.peers))
+        self.assertIn("peer2", map(lambda p: p.id(), self.node.peers))
+
+    async def test_process_tensor_calls_inference_engine(self):
+        mock_peer = Mock()
+        self.node.peers = [mock_peer]
+
+        input_tensor = np.array([69, 1, 2])
+        await self.node.process_tensor(input_tensor, None)
+
+        self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)