Browse Source

mlx sharded implementation with example of distributed inference

Alex Cheema 1 year ago
parent
commit
563dcb56b0

+ 75 - 0
example_user.py

@@ -0,0 +1,75 @@
+# In this example, a user is running a home cluster with 3 shards.
+# They are prompting the cluster to generate a response to a question.
+# The cluster is given the question, and the user is given the response.
+
+from inference.mlx.sharded_utils import get_model_path, load_tokenizer
+from inference.shard import Shard
+from networking.peer_handle import PeerHandle
+from networking.grpc.grpc_peer_handle import GRPCPeerHandle
+from typing import List
+import asyncio
+import argparse
+
+path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
+model_path = get_model_path(path_or_hf_repo)
+tokenizer_config = {}
+tokenizer = load_tokenizer(model_path, tokenizer_config)
+
+peers: List[PeerHandle] = [
+    GRPCPeerHandle(
+        "node1",
+        "localhost:8080",
+    ),
+    GRPCPeerHandle(
+        "node2",
+        "localhost:8081",
+    )
+]
+shards: List[Shard] = [
+    # Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=15, n_layers=32),
+    # Shard(model_id=path_or_hf_repo, start_layer=16, end_layer=31, n_layers=32),
+    Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=30, n_layers=32),
+    Shard(model_id=path_or_hf_repo, start_layer=31, end_layer=31, n_layers=32),
+]
+
+async def run_prompt(prompt: str):
+    if tokenizer.chat_template is None:
+        tokenizer.chat_template = tokenizer.default_chat_template
+    if (
+        hasattr(tokenizer, "apply_chat_template")
+        and tokenizer.chat_template is not None
+    ):
+        messages = [{"role": "user", "content": prompt}]
+        prompt = tokenizer.apply_chat_template(
+            messages, tokenize=False, add_generation_prompt=True
+        )
+
+    for peer, shard in zip(peers, shards):
+        await peer.connect()
+        await peer.reset_shard(shard)
+
+    tokens = []
+    last_output = prompt
+
+    for _ in range(20):
+        for peer, shard in zip(peers, shards):
+            if isinstance(last_output, str):
+                last_output = await peer.send_prompt(shard, last_output)
+                print("prompt output:", last_output)
+            else:
+                last_output = await peer.send_tensor(shard, last_output)
+                print("tensor output:", last_output)
+
+        if not last_output:
+            break
+
+        tokens.append(last_output.item())
+
+    print(tokenizer.decode(tokens))
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Run prompt")
+    parser.add_argument("--prompt", type=str, help="The prompt to run")
+    args = parser.parse_args()
+
+    asyncio.run(run_prompt(args.prompt))

+ 3 - 17
inference/inference_engine.py

@@ -9,23 +9,9 @@ class InferenceEngine(ABC):
     async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
         pass
         pass
 
 
-    @abstractmethod
-    async def reset_shard(self, shard: Shard):
+    async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
         pass
         pass
 
 
-class MLXFixedShardInferenceEngine(InferenceEngine):
-    def __init__(self, model: nn.Module, shard: Shard):
-        self.model = model
-        self.shard = shard
-
-    async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
-        if shard != self.shard:
-            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
-
-        output_data = self.model.process(input_data)
-        print("Processed data through model shard")
-        return output_data
-
+    @abstractmethod
     async def reset_shard(self, shard: Shard):
     async def reset_shard(self, shard: Shard):
-        # TODO
-        print(f"Resetting shard: {shard}")
+        pass

+ 244 - 0
inference/mlx/models/sharded_llama.py

@@ -0,0 +1,244 @@
+from dataclasses import dataclass, field
+from typing import Dict, Optional, Tuple, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import BaseModelArgs, create_additive_causal_mask
+from ...shard import Shard
+
+
+@dataclass
+class NormalModelArgs(BaseModelArgs):
+    model_type: str
+    hidden_size: int
+    num_hidden_layers: int
+    intermediate_size: int
+    num_attention_heads: int
+    rms_norm_eps: float
+    vocab_size: int
+    num_key_value_heads: int = None
+    attention_bias: bool = False
+    mlp_bias: bool = False
+    rope_theta: float = 10000
+    rope_traditional: bool = False
+    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+    tie_word_embeddings: bool = True
+
+    def __post_init__(self):
+        if self.num_key_value_heads is None:
+            self.num_key_value_heads = self.num_attention_heads
+
+        if self.rope_scaling:
+            required_keys = {"factor", "type"}
+            if not all(key in self.rope_scaling for key in required_keys):
+                raise ValueError(f"rope_scaling must contain keys {required_keys}")
+
+            if self.rope_scaling["type"] != "linear":
+                raise ValueError("rope_scaling 'type' currently only supports 'linear'")
+@dataclass
+class ModelArgs(NormalModelArgs):
+    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+    def __post_init__(self):
+        super().__post_init__()  # Ensure parent initializations are respected
+
+        if isinstance(self.shard, Shard):
+            return
+        if not isinstance(self.shard, dict):
+            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+        self.shard = Shard(**self.shard)
+
+class Attention(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+
+        dim = args.hidden_size
+        self.n_heads = n_heads = args.num_attention_heads
+        self.n_kv_heads = n_kv_heads = args.num_key_value_heads
+
+        head_dim = args.hidden_size // n_heads
+        self.scale = head_dim**-0.5
+        if hasattr(args, "attention_bias"):
+            attention_bias = args.attention_bias
+        else:
+            attention_bias = False
+
+        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
+        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
+        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
+        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
+
+        rope_scale = (
+            1 / args.rope_scaling["factor"]
+            if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
+            else 1
+        )
+        self.rope = nn.RoPE(
+            head_dim,
+            traditional=args.rope_traditional,
+            base=args.rope_theta,
+            scale=rope_scale,
+        )
+
+    def __call__(
+        self,
+        x: mx.array,
+        mask: Optional[mx.array] = None,
+        cache: Optional[Tuple[mx.array, mx.array]] = None,
+    ) -> mx.array:
+        B, L, D = x.shape
+
+        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+        # Prepare the queries, keys and values for the attention computation
+        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
+        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+
+        if cache is not None:
+            queries = self.rope(queries, offset=cache.offset)
+            keys = self.rope(keys, offset=cache.offset)
+            keys, values = cache.update_and_fetch(keys, values)
+        else:
+            queries = self.rope(queries)
+            keys = self.rope(keys)
+
+        output = mx.fast.scaled_dot_product_attention(
+            queries, keys, values, scale=self.scale, mask=mask
+        )
+        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+        return self.o_proj(output)
+
+
+class MLP(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+
+        dim = args.hidden_size
+        hidden_dim = args.intermediate_size
+        if hasattr(args, "mlp_bias"):
+            mlp_bias = args.mlp_bias
+        else:
+            mlp_bias = False
+
+        self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
+        self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
+        self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
+
+    def __call__(self, x) -> mx.array:
+        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        self.num_attention_heads = args.num_attention_heads
+        self.hidden_size = args.hidden_size
+        self.self_attn = Attention(args)
+        self.mlp = MLP(args)
+        self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+        self.post_attention_layernorm = nn.RMSNorm(
+            args.hidden_size, eps=args.rms_norm_eps
+        )
+        self.args = args
+
+    def __call__(
+        self,
+        x: mx.array,
+        mask: Optional[mx.array] = None,
+        cache: Optional[Tuple[mx.array, mx.array]] = None,
+    ) -> mx.array:
+        r = self.self_attn(self.input_layernorm(x), mask, cache)
+        h = x + r
+        r = self.mlp(self.post_attention_layernorm(h))
+        out = h + r
+        return out
+
+
+class LlamaModel(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        self.args = args
+        self.vocab_size = args.vocab_size
+        self.num_hidden_layers = args.num_hidden_layers
+        assert self.vocab_size > 0
+        self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+        self.layers = [
+            TransformerBlock(args=args) for _ in range(args.shard.n_layers)
+        ]
+        self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+    def __call__(
+        self,
+        inputs: mx.array,
+        cache=None,
+    ):
+        if self.args.shard.is_first_layer():
+            h = self.embed_tokens(inputs)
+        else:
+            h = inputs
+
+        mask = None
+        if h.shape[1] > 1:
+            mask = create_additive_causal_mask(
+                h.shape[1], cache[0].offset if cache is not None else 0
+            )
+            mask = mask.astype(h.dtype)
+
+        if cache is None:
+            cache = [None] * len(self.layers)
+
+        for layer, c in zip(self.layers, cache):
+            h = layer(h, mask, cache=c)
+
+        if self.args.shard.is_last_layer():
+            return self.norm(h)
+        else:
+            return h
+
+
+class Model(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        self.args = args
+        self.model_type = args.model_type
+        self.model = LlamaModel(args)
+        if not args.tie_word_embeddings:
+            self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+    def __call__(
+        self,
+        inputs: mx.array,
+        cache=None,
+    ):
+        out = self.model(inputs, cache)
+
+        if self.args.shard.is_last_layer():
+            if self.args.tie_word_embeddings:
+                out = self.model.embed_tokens.as_linear(out)
+            else:
+                out = self.lm_head(out)
+
+        return out
+
+
+    def sanitize(self, weights):
+        # Remove unused precomputed rotary freqs
+        return {
+            k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
+        }
+
+    @property
+    def layers(self):
+        return self.model.layers
+
+    @property
+    def head_dim(self):
+        return self.args.hidden_size // self.args.num_attention_heads
+
+    @property
+    def n_kv_heads(self):
+        return self.args.num_key_value_heads
+

+ 37 - 0
inference/mlx/sharded_inference_engine.py

@@ -0,0 +1,37 @@
+import mlx.nn as nn
+import numpy as np
+import mlx.core as mx
+from ..inference_engine import InferenceEngine
+from .sharded_model import StatefulShardedModel
+from .sharded_utils import load_shard
+from ..shard import Shard
+
+class MLXFixedShardInferenceEngine(InferenceEngine):
+    def __init__(self, model_path: str, shard: Shard):
+        print("initializing fixed shard inference", shard)
+        self.shard = shard
+        model_shard, self.tokenizer = load_shard(model_path, shard)
+        self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
+
+    async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
+        if shard != self.shard:
+            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
+
+        output_data = self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt)))
+        return np.array(output_data)
+
+    async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+        if shard != self.shard:
+            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
+
+        print("infer_shard", shard, input_data)
+
+        output_data = self.stateful_sharded_model.step(mx.array(input_data))
+        return np.array(output_data)
+
+    async def reset_shard(self, shard: Shard):
+        if shard != self.shard:
+            raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
+
+        print(f"Resetting shard: {shard}")
+        self.stateful_sharded_model.reset()

+ 56 - 0
inference/mlx/sharded_model.py

@@ -0,0 +1,56 @@
+from typing import Dict, Generator, Optional, Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.base import KVCache
+from mlx_lm.sample_utils import top_p_sampling
+
+from ..shard import Shard
+
+class StatefulShardedModel:
+    def __init__(self, shard: Shard, model: nn.Module):
+        self.shard = shard
+        self.model = model
+        self.reset()
+
+    def step(
+        self,
+        x,
+        temp: float = 0.0,
+        top_p: float = 1.0,
+        logit_bias: Optional[Dict[int, float]] = None,
+    ) -> Generator[Tuple[mx.array, mx.array], None, None]:
+        def sample(logits: mx.array) -> Tuple[mx.array, float]:
+            if logit_bias:
+                indices = mx.array(list(logit_bias.keys()))
+                values = mx.array(list(logit_bias.values()))
+                logits[:, indices] += values
+
+            if temp == 0:
+                token = mx.argmax(logits, axis=-1)
+            else:
+                if top_p > 0 and top_p < 1.0:
+                    token = top_p_sampling(logits, top_p, temp)
+                else:
+                    token = mx.random.categorical(logits * (1 / temp))
+
+            return token
+
+        y = x
+
+        output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
+
+        if self.shard.is_last_layer():
+            logits = output[:, -1, :]
+            y = sample(logits)
+            return y
+        else:
+            return output
+
+    def reset(self):
+        kv_heads = (
+            [self.model.n_kv_heads] * len(self.model.layers)
+            if isinstance(self.model.n_kv_heads, int)
+            else self.model.n_kv_heads
+        )
+        self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]

+ 230 - 0
inference/mlx/sharded_utils.py

@@ -0,0 +1,230 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
+
+import glob
+import importlib
+import json
+import logging
+from pathlib import Path
+from typing import Optional, Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+from huggingface_hub import snapshot_download
+from huggingface_hub.utils._errors import RepositoryNotFoundError
+from mlx.utils import tree_flatten
+from transformers import PreTrainedTokenizer
+
+from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
+from mlx_lm.tuner.utils import apply_lora_layers
+
+from ..shard import Shard
+
+class ModelNotFoundError(Exception):
+    def __init__(self, message):
+        self.message = message
+        super().__init__(self.message)
+
+MODEL_REMAPPING = {
+    "mistral": "llama",  # mistral is compatible with llama
+    "phi-msft": "phixtral",
+}
+
+def _get_classes(config: dict):
+    """
+    Retrieve the model and model args classes based on the configuration.
+
+    Args:
+        config (dict): The model configuration.
+
+    Returns:
+        A tuple containing the Model class and the ModelArgs class.
+    """
+    model_type = config["model_type"]
+    model_type = MODEL_REMAPPING.get(model_type, model_type)
+    try:
+        arch = importlib.import_module(f"inference.mlx.models.{model_type}")
+    except ImportError:
+        msg = f"Model type {model_type} not supported."
+        logging.error(msg)
+        raise ValueError(msg)
+
+    return arch.Model, arch.ModelArgs
+
+def load_config(model_path: Path) -> dict:
+    try:
+        with open(model_path / "config.json", "r") as f:
+            config = json.load(f)
+    except FileNotFoundError:
+        logging.error(f"Config file not found in {model_path}")
+        raise
+    return config
+
+def load_model_shard(
+    model_path: Path,
+    shard: Shard,
+    lazy: bool = False,
+    model_config: dict = {},
+) -> nn.Module:
+    """
+    Load and initialize the model from a given path.
+
+    Args:
+        model_path (Path): The path to load the model from.
+        lazy (bool): If False eval the model parameters to make sure they are
+            loaded in memory before returning, otherwise they will be loaded
+            when needed. Default: ``False``
+        model_config(dict, optional): Configuration parameters for the model.
+            Defaults to an empty dictionary.
+
+    Returns:
+        nn.Module: The loaded and initialized model.
+
+    Raises:
+        FileNotFoundError: If the weight files (.safetensors) are not found.
+        ValueError: If the model class or args class are not found or cannot be instantiated.
+    """
+
+    config = load_config(model_path)
+    config.update(model_config)
+
+    # TODO hack
+    config["model_type"] = f"sharded_{config['model_type']}"
+    config["shard"] = {
+        "model_id": model_path.name,
+        "start_layer": shard.start_layer,
+        "end_layer": shard.end_layer,
+        "n_layers": shard.n_layers
+    }
+
+    weight_files = glob.glob(str(model_path / "model*.safetensors"))
+
+    if not weight_files:
+        # Try weight for back-compat
+        weight_files = glob.glob(str(model_path / "weight*.safetensors"))
+
+    if not weight_files:
+        logging.error(f"No safetensors found in {model_path}")
+        raise FileNotFoundError(f"No safetensors found in {model_path}")
+
+    weights = {}
+    for wf in weight_files:
+        weights.update(mx.load(wf))
+
+    model_class, model_args_class = _get_classes(config=config)
+
+    model_args = model_args_class.from_dict(config)
+    model = model_class(model_args)
+
+    if hasattr(model, "sanitize"):
+        weights = model.sanitize(weights)
+
+    if (quantization := config.get("quantization", None)) is not None:
+        # Handle legacy models which may not have everything quantized
+        def class_predicate(p, m):
+            if not hasattr(m, "to_quantized"):
+                return False
+            return f"{p}.scales" in weights
+
+        nn.quantize(
+            model,
+            **quantization,
+            class_predicate=class_predicate,
+        )
+
+    filtered_weights = {}
+    for k, v in weights.items():
+        if k.startswith("model.layers."):
+            layer_num = int(k.split('.')[2])
+            if shard.start_layer <= layer_num <= shard.end_layer:
+                new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
+                filtered_weights[new_key] = v
+        else:
+            filtered_weights[k] = v
+    weights = filtered_weights
+
+    model.load_weights(list(weights.items()), strict=False)
+
+    if not lazy:
+        mx.eval(model.parameters())
+
+    model.eval()
+    return model
+
+def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
+    """
+    Ensures the model is available locally. If the path does not exist locally,
+    it is downloaded from the Hugging Face Hub.
+
+    Args:
+        path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
+        revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
+
+    Returns:
+        Path: The path to the model.
+    """
+    model_path = Path(path_or_hf_repo)
+    if not model_path.exists():
+        try:
+            model_path = Path(
+                snapshot_download(
+                    repo_id=path_or_hf_repo,
+                    revision=revision,
+                    allow_patterns=[
+                        "*.json",
+                        "*.safetensors",
+                        "*.py",
+                        "tokenizer.model",
+                        "*.tiktoken",
+                        "*.txt",
+                    ],
+                )
+            )
+        except RepositoryNotFoundError:
+            raise ModelNotFoundError(
+                f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
+                "Please make sure you specified the local path or Hugging Face"
+                " repo id correctly.\nIf you are trying to access a private or"
+                " gated Hugging Face repo, make sure you are authenticated:\n"
+                "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
+            ) from None
+    return model_path
+
+
+def load_shard(
+    path_or_hf_repo: str,
+    shard: Shard,
+    tokenizer_config={},
+    model_config={},
+    adapter_path: Optional[str] = None,
+    lazy: bool = False,
+) -> Tuple[nn.Module, TokenizerWrapper]:
+    """
+    Load the model and tokenizer from a given path or a huggingface repository.
+
+    Args:
+        path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
+        tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
+            Defaults to an empty dictionary.
+        model_config(dict, optional): Configuration parameters specifically for the model.
+            Defaults to an empty dictionary.
+        adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
+            to the model. Default: ``None``.
+        lazy (bool): If False eval the model parameters to make sure they are
+            loaded in memory before returning, otherwise they will be loaded
+            when needed. Default: ``False``
+    Returns:
+        Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
+
+    Raises:
+        FileNotFoundError: If config file or safetensors are not found.
+        ValueError: If model class or args class are not found.
+    """
+    model_path = get_model_path(path_or_hf_repo)
+
+    model = load_model_shard(model_path, shard, lazy, model_config)
+    if adapter_path is not None:
+        model = apply_lora_layers(model, adapter_path)
+        model.eval()
+    tokenizer = load_tokenizer(model_path, tokenizer_config)
+
+    return model, tokenizer

+ 7 - 1
inference/shard.py

@@ -3,6 +3,12 @@ from dataclasses import dataclass
 @dataclass
 @dataclass
 class Shard:
 class Shard:
     model_id: str
     model_id: str
-    n_layers: int
     start_layer: int
     start_layer: int
     end_layer: int
     end_layer: int
+    n_layers: int
+
+    def is_first_layer(self) -> bool:
+        return self.start_layer == 0
+
+    def is_last_layer(self) -> bool:
+        return self.end_layer == self.n_layers - 1

+ 7 - 16
main.py

@@ -5,19 +5,10 @@ import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from orchestration.standard_node import StandardNode
 from orchestration.standard_node import StandardNode
 from networking.grpc.grpc_server import GRPCServer
 from networking.grpc.grpc_server import GRPCServer
-from inference.inference_engine import MLXFixedShardInferenceEngine
+from inference.mlx.sharded_inference_engine import MLXFixedShardInferenceEngine
 from inference.shard import Shard
 from inference.shard import Shard
 from networking.grpc.grpc_discovery import GRPCDiscovery
 from networking.grpc.grpc_discovery import GRPCDiscovery
 
 
-class SimpleMLXModel(nn.Module):
-    def __init__(self):
-        super(SimpleMLXModel, self).__init__()
-        self.linear = nn.Linear(10, 5)  # Example dimensions
-
-    def forward(self, x):
-        return self.linear(x)
-
-
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("--node-id", type=str, default="node1", help="Node ID")
 parser.add_argument("--node-id", type=str, default="node1", help="Node ID")
@@ -25,15 +16,19 @@ parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host"
 parser.add_argument("--node-port", type=int, default=8080, help="Node port")
 parser.add_argument("--node-port", type=int, default=8080, help="Node port")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
+parser.add_argument("--model-id", type=str, default="mlx-community/Meta-Llama-3-8B-Instruct-4bit", help="Path to the model")
+parser.add_argument("--n-layers", type=int, default=32, help="Number of layers in the model")
+parser.add_argument("--start-layer", type=int, default=0, help="Start layer index")
+parser.add_argument("--end-layer", type=int, default=31, help="End layer index")
 args = parser.parse_args()
 args = parser.parse_args()
 
 
-mlx_model = SimpleMLXModel()
-inference_engine = MLXFixedShardInferenceEngine(mlx_model, shard=Shard(model_id="test", n_layers=32, start_layer=0, end_layer=31))
+inference_engine = MLXFixedShardInferenceEngine(args.model_id, shard=Shard(model_id=args.model_id, n_layers=args.n_layers, start_layer=args.start_layer, end_layer=args.end_layer))
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
 node = StandardNode(args.node_id, None, inference_engine, discovery)
 node = StandardNode(args.node_id, None, inference_engine, discovery)
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
 
 
+
 async def shutdown(signal, loop):
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""
     """Gracefully shutdown the server and close the asyncio loop."""
     print(f"Received exit signal {signal.name}...")
     print(f"Received exit signal {signal.name}...")
@@ -56,10 +51,6 @@ async def main():
 
 
     await node.start()
     await node.start()
 
 
-    await asyncio.sleep(5)
-    print("Sending reset shard request")
-    await node.peers[0].reset_shard(f"regards from {node.id}")
-
     await asyncio.Event().wait()
     await asyncio.Event().wait()
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 25 - 11
networking/grpc/grpc_peer_handle.py

@@ -7,6 +7,7 @@ from . import node_service_pb2
 from . import node_service_pb2_grpc
 from . import node_service_pb2_grpc
 
 
 from ..peer_handle import PeerHandle
 from ..peer_handle import PeerHandle
+from inference.shard import Shard
 
 
 class GRPCPeerHandle(PeerHandle):
 class GRPCPeerHandle(PeerHandle):
     def __init__(self, id: str, address: str):
     def __init__(self, id: str, address: str):
@@ -23,25 +24,38 @@ class GRPCPeerHandle(PeerHandle):
     async def disconnect(self):
     async def disconnect(self):
         await self.channel.close()
         await self.channel.close()
 
 
-    async def send_prompt(self, prompt: str) -> None:
-        request = node_service_pb2.PromptRequest(prompt=prompt)
-        await self.stub.SendPrompt(request)
+    async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
+        request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
+        response = await self.stub.SendPrompt(request)
         print(f"Sent prompt to {self.address}: {prompt}")
         print(f"Sent prompt to {self.address}: {prompt}")
 
 
-    async def send_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
+        if not response.tensor_data or not response.shape or not response.dtype:
+            return None
+
+        return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
+
+    async def send_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> Optional[np.array]:
         request = node_service_pb2.TensorRequest(
         request = node_service_pb2.TensorRequest(
-            tensor_data=tensor.tobytes(),
-            shape=tensor.shape,
-            dtype=str(tensor.dtype),
+            shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
+            tensor = node_service_pb2.Tensor(
+                tensor_data=tensor.tobytes(),
+                shape=tensor.shape,
+                dtype=str(tensor.dtype)
+            ),
             target=target
             target=target
         )
         )
-        await self.stub.SendTensor(request)
+        response = await self.stub.SendTensor(request)
         if target:
         if target:
             print(f"Sent tensor to {self.address} with target {target}: shape {tensor.shape}")
             print(f"Sent tensor to {self.address} with target {target}: shape {tensor.shape}")
         else:
         else:
             print(f"Sent tensor to {self.address}: shape {tensor.shape}")
             print(f"Sent tensor to {self.address}: shape {tensor.shape}")
 
 
-    async def reset_shard(self, shard_id: str) -> None:
-        request = node_service_pb2.ResetShardRequest(shard_id=shard_id)
+        if not response.tensor_data or not response.shape or not response.dtype:
+            return None
+
+        return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
+
+    async def reset_shard(self, shard: Shard) -> None:
+        request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
         await self.stub.ResetShard(request)
         await self.stub.ResetShard(request)
-        print(f"Reset shard {shard_id} on {self.address}")
+        print(f"Reset shard {shard} on {self.address}")

+ 14 - 19
networking/grpc/grpc_server.py

@@ -4,6 +4,7 @@ import numpy as np
 
 
 from . import node_service_pb2
 from . import node_service_pb2
 from . import node_service_pb2_grpc
 from . import node_service_pb2_grpc
+from inference.shard import Shard
 
 
 from orchestration import Node
 from orchestration import Node
 
 
@@ -28,30 +29,24 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
             print("Server stopped")
             print("Server stopped")
 
 
     async def SendPrompt(self, request, context):
     async def SendPrompt(self, request, context):
+        shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
         prompt = request.prompt
         prompt = request.prompt
         target = request.target if request.HasField('target') else None
         target = request.target if request.HasField('target') else None
-        if target and target != self.node.node_id:
-            await self.node.process_prompt(prompt, target)
-        else:
-            # Process the prompt locally
-            # You'd need to implement this method in the Node class
-            await self.node.process_prompt(prompt)
-        return node_service_pb2.Empty()
+        result = await self.node.process_prompt(shard, prompt, target)
+        tensor_data = result.tobytes() if result is not None else None
+        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
 
 
     async def SendTensor(self, request, context):
     async def SendTensor(self, request, context):
-        tensor = np.frombuffer(request.tensor_data, dtype=np.dtype(request.dtype)).reshape(request.shape)
+        shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
+        tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
         target = request.target if request.HasField('target') else None
         target = request.target if request.HasField('target') else None
-        if target and target != self.node.node_id:
-            await self.node.process_tensor(tensor, target)
-        else:
-            # Process the tensor locally
-            await self.node.inference_strategy.process_inference(tensor)
-        return node_service_pb2.Empty()
+        result = await self.node.process_tensor(shard, tensor, target)
+        print("SendTensor tensor result", result)
+        tensor_data = result.tobytes() if result is not None else None
+        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
 
 
     async def ResetShard(self, request, context):
     async def ResetShard(self, request, context):
-        print(f"Received ResetShard request: {request}")
-        # TODO
-        # shard_id = request.shard_id
-        # You'd need to implement this method in the Node class
-        # await self.node.reset_shard(shard_id)
+        shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
+        print(f"Received ResetShard request: {shard}")
+        await self.node.reset_shard(shard)
         return node_service_pb2.Empty()
         return node_service_pb2.Empty()

+ 19 - 6
networking/grpc/node_service.proto

@@ -3,25 +3,38 @@ syntax = "proto3";
 package node_service;
 package node_service;
 
 
 service NodeService {
 service NodeService {
-  rpc SendPrompt (PromptRequest) returns (Empty) {}
-  rpc SendTensor (TensorRequest) returns (Empty) {}
+  rpc SendPrompt (PromptRequest) returns (Tensor) {}
+  rpc SendTensor (TensorRequest) returns (Tensor) {}
   rpc ResetShard (ResetShardRequest) returns (Empty) {}
   rpc ResetShard (ResetShardRequest) returns (Empty) {}
 }
 }
 
 
+message Shard {
+  string model_id = 1;
+  int32 start_layer = 2;
+  int32 end_layer = 3;
+  int32 n_layers = 4;
+}
+
 message PromptRequest {
 message PromptRequest {
-  string prompt = 1;
-  optional string target = 2;
+  Shard shard = 1;
+  string prompt = 2;
+  optional string target = 3;
 }
 }
 
 
 message TensorRequest {
 message TensorRequest {
+  Shard shard = 1;
+  Tensor tensor = 2;
+  optional string target = 3;
+}
+
+message Tensor {
   bytes tensor_data = 1;
   bytes tensor_data = 1;
   repeated int32 shape = 2;
   repeated int32 shape = 2;
   string dtype = 3;
   string dtype = 3;
-  optional string target = 4;
 }
 }
 
 
 message ResetShardRequest {
 message ResetShardRequest {
-  string shard_id = 1;
+  Shard shard = 1;
 }
 }
 
 
 message Empty {}
 message Empty {}

+ 15 - 11
networking/grpc/node_service_pb2.py

@@ -14,21 +14,25 @@ _sym_db = _symbol_database.Default()
 
 
 
 
 
 
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"?\n\rPromptRequest\x12\x0e\n\x06prompt\x18\x01 \x01(\t\x12\x13\n\x06target\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"b\n\rTensorRequest\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\x12\x13\n\x06target\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"%\n\x11ResetShardRequest\x12\x10\n\x08shard_id\x18\x01 \x01(\t\"\x07\n\x05\x45mpty2\xd7\x01\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"c\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"\x07\n\x05\x45mpty2\xd9\x01\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
 
 
 _globals = globals()
 _globals = globals()
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
 _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
 _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
 if not _descriptor._USE_C_DESCRIPTORS:
 if not _descriptor._USE_C_DESCRIPTORS:
   DESCRIPTOR._loaded_options = None
   DESCRIPTOR._loaded_options = None
-  _globals['_PROMPTREQUEST']._serialized_start=36
-  _globals['_PROMPTREQUEST']._serialized_end=99
-  _globals['_TENSORREQUEST']._serialized_start=101
-  _globals['_TENSORREQUEST']._serialized_end=199
-  _globals['_RESETSHARDREQUEST']._serialized_start=201
-  _globals['_RESETSHARDREQUEST']._serialized_end=238
-  _globals['_EMPTY']._serialized_start=240
-  _globals['_EMPTY']._serialized_end=247
-  _globals['_NODESERVICE']._serialized_start=250
-  _globals['_NODESERVICE']._serialized_end=465
+  _globals['_SHARD']._serialized_start=36
+  _globals['_SHARD']._serialized_end=119
+  _globals['_PROMPTREQUEST']._serialized_start=121
+  _globals['_PROMPTREQUEST']._serialized_end=220
+  _globals['_TENSORREQUEST']._serialized_start=222
+  _globals['_TENSORREQUEST']._serialized_end=343
+  _globals['_TENSOR']._serialized_start=345
+  _globals['_TENSOR']._serialized_end=404
+  _globals['_RESETSHARDREQUEST']._serialized_start=406
+  _globals['_RESETSHARDREQUEST']._serialized_end=461
+  _globals['_EMPTY']._serialized_start=463
+  _globals['_EMPTY']._serialized_end=470
+  _globals['_NODESERVICE']._serialized_start=473
+  _globals['_NODESERVICE']._serialized_end=690
 # @@protoc_insertion_point(module_scope)
 # @@protoc_insertion_point(module_scope)

+ 6 - 6
networking/grpc/node_service_pb2_grpc.py

@@ -42,12 +42,12 @@ class NodeServiceStub(object):
         self.SendPrompt = channel.unary_unary(
         self.SendPrompt = channel.unary_unary(
                 '/node_service.NodeService/SendPrompt',
                 '/node_service.NodeService/SendPrompt',
                 request_serializer=node__service__pb2.PromptRequest.SerializeToString,
                 request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
                 _registered_method=True)
         self.SendTensor = channel.unary_unary(
         self.SendTensor = channel.unary_unary(
                 '/node_service.NodeService/SendTensor',
                 '/node_service.NodeService/SendTensor',
                 request_serializer=node__service__pb2.TensorRequest.SerializeToString,
                 request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
                 _registered_method=True)
         self.ResetShard = channel.unary_unary(
         self.ResetShard = channel.unary_unary(
                 '/node_service.NodeService/ResetShard',
                 '/node_service.NodeService/ResetShard',
@@ -83,12 +83,12 @@ def add_NodeServiceServicer_to_server(servicer, server):
             'SendPrompt': grpc.unary_unary_rpc_method_handler(
             'SendPrompt': grpc.unary_unary_rpc_method_handler(
                     servicer.SendPrompt,
                     servicer.SendPrompt,
                     request_deserializer=node__service__pb2.PromptRequest.FromString,
                     request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
             ),
             'SendTensor': grpc.unary_unary_rpc_method_handler(
             'SendTensor': grpc.unary_unary_rpc_method_handler(
                     servicer.SendTensor,
                     servicer.SendTensor,
                     request_deserializer=node__service__pb2.TensorRequest.FromString,
                     request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
             ),
             'ResetShard': grpc.unary_unary_rpc_method_handler(
             'ResetShard': grpc.unary_unary_rpc_method_handler(
                     servicer.ResetShard,
                     servicer.ResetShard,
@@ -122,7 +122,7 @@ class NodeService(object):
             target,
             target,
             '/node_service.NodeService/SendPrompt',
             '/node_service.NodeService/SendPrompt',
             node__service__pb2.PromptRequest.SerializeToString,
             node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
+            node__service__pb2.Tensor.FromString,
             options,
             options,
             channel_credentials,
             channel_credentials,
             insecure,
             insecure,
@@ -149,7 +149,7 @@ class NodeService(object):
             target,
             target,
             '/node_service.NodeService/SendTensor',
             '/node_service.NodeService/SendTensor',
             node__service__pb2.TensorRequest.SerializeToString,
             node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
+            node__service__pb2.Tensor.FromString,
             options,
             options,
             channel_credentials,
             channel_credentials,
             insecure,
             insecure,

+ 6 - 4
networking/peer_handle.py

@@ -1,5 +1,7 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Any
+from typing import Optional
+import numpy as np
+from inference.shard import Shard
 
 
 class PeerHandle(ABC):
 class PeerHandle(ABC):
     def id(self) -> str:
     def id(self) -> str:
@@ -14,13 +16,13 @@ class PeerHandle(ABC):
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    async def send_prompt(self, prompt: str) -> None:
+    async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    async def send_tensor(self, tensor: Any) -> None:
+    async def send_tensor(self, shard: Shard, tensor: np.array) -> Optional[np.array]:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    async def reset_shard(self, shard_id: str) -> None:
+    async def reset_shard(self, shard: Shard) -> None:
         pass
         pass

+ 5 - 4
orchestration/node.py

@@ -1,10 +1,11 @@
 from typing import Optional
 from typing import Optional
 import numpy as np
 import numpy as np
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
+from inference.shard import Shard
 
 
 class Node(ABC):
 class Node(ABC):
     @abstractmethod
     @abstractmethod
-    def start(self) -> None:
+    def start(self, wait_for_peers: int = 0) -> None:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
@@ -12,13 +13,13 @@ class Node(ABC):
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
+    def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    def process_prompt(self, prompt: str, target: Optional[str] = None) -> None:
+    def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> None:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    def reset_shard(self, shard_id: str) -> None:
+    def reset_shard(self, shard: Shard) -> None:
         pass
         pass

+ 26 - 10
orchestration/standard_node.py

@@ -13,10 +13,10 @@ class StandardNode(Node):
         self.peers: List[PeerHandle] = {}
         self.peers: List[PeerHandle] = {}
         self.ring_order: List[str] = []
         self.ring_order: List[str] = []
 
 
-    async def start(self) -> None:
+    async def start(self, wait_for_peers: int = 0) -> None:
         await self.server.start()
         await self.server.start()
         await self.discovery.start()
         await self.discovery.start()
-        self.peers = await self.discovery.discover_peers()
+        self.peers = await self.discovery.discover_peers(wait_for_peers)
         print(f"Starting with the following peers: {self.peers}")
         print(f"Starting with the following peers: {self.peers}")
         print("Connecting to peers...")
         print("Connecting to peers...")
         for peer in self.peers:
         for peer in self.peers:
@@ -27,19 +27,35 @@ class StandardNode(Node):
         await self.discovery.stop()
         await self.discovery.stop()
         await self.server.stop()
         await self.server.stop()
 
 
-    async def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
-        result = await self.inference_engine.process_shard(tensor)
-
+    async def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> Optional[np.array]:
+        print("Process prompt", shard, prompt, target)
+        result = await self.inference_engine.infer_prompt(shard, prompt)
+        # Implement prompt processing logic
+        print(f"Got result from prompt: {prompt}. Result: {result}")
+        # You might want to initiate inference here
         if target:
         if target:
-            if not filter(lambda p: p.id() == target, self.peers):
+            target_peer = next((p for p in self.peers if p.id() == target), None)
+            if not target_peer:
                 raise ValueError(f"Peer {target} not found")
                 raise ValueError(f"Peer {target} not found")
 
 
-            await self.peers[target].send_tensor(result)
+            await target_peer.send_tensor(result)
 
 
-    async def process_prompt(self, prompt: str) -> None:
+        return result
+
+    async def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
+        print("Process tensor", shard, tensor)
+        result = await self.inference_engine.infer_shard(shard, tensor)
         # Implement prompt processing logic
         # Implement prompt processing logic
-        print(f"Processing prompt: {prompt}")
-        # You might want to initiate inference here
+        print(f"Got result from prompt: {len(tensor)}. Result: {result}")
+
+        if target:
+            target_peer = next((p for p in self.peers if p.id() == target), None)
+            if not target_peer:
+                raise ValueError(f"Peer {target} not found")
+
+            await target_peer.send_tensor(result)
+
+        return result
 
 
     async def reset_shard(self, shard: Shard) -> None:
     async def reset_shard(self, shard: Shard) -> None:
         # Implement shard reset logic
         # Implement shard reset logic