Sfoglia il codice sorgente

Some initial inference engine refactors for enabling training

Only on MLX for now, breaks Tinygrad (doesn't fulfill interface yet)
Nel Nibcord 6 mesi fa
parent
commit
82cce4408e

+ 4 - 12
exo/api/chatgpt_api.py

@@ -117,19 +117,11 @@ def remap_messages(messages: List[Message]) -> List[Message]:
 def build_prompt(tokenizer, _messages: List[Message]):
   messages = remap_messages(_messages)
   prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
-  image_str = None
   for message in messages:
     if not isinstance(message.content, list):
       continue
 
-    for content in message.content:
-      # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
-      # follows the convention in https://platform.openai.com/docs/guides/vision
-      if isinstance(content, dict) and content.get("type", None) == "image":
-        image_str = content.get("image", None)
-        break
-
-  return prompt, image_str
+  return prompt
 
 
 def parse_message(data: dict):
@@ -246,7 +238,7 @@ class ChatGPTAPI:
     tokenizer = await resolve_tokenizer(shard.model_id)
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
-    prompt, image_str = build_prompt(tokenizer, chat_request.messages)
+    prompt = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:
       try:
@@ -269,10 +261,10 @@ class ChatGPTAPI:
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback = self.node.on_token.register(callback_id)
 
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
 
     try:
-      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))), timeout=self.response_timeout)
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
 
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
 

+ 1 - 1
exo/inference/dummy_inference_engine.py

@@ -14,7 +14,7 @@ class DummyInferenceEngine(InferenceEngine):
     self.latency_mean = 0.1
     self.latency_stddev = 0.02
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     try:
       await self.ensure_shard(shard)
 

+ 13 - 3
exo/inference/inference_engine.py

@@ -9,13 +9,23 @@ from .shard import Shard
 
 class InferenceEngine(ABC):
   @abstractmethod
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    pass
+  
+  async def sample(self, x: np.ndarray) -> np.ndarray:
+    pass
+
+  @abstractmethod
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     pass
 
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> np.ndarray:
     pass
 
+  @abstractmethod
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
+    pass
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
   if DEBUG >= 2:
@@ -33,4 +43,4 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
   elif inference_engine_name == "dummy":
     from exo.inference.dummy_inference_engine import DummyInferenceEngine
     return DummyInferenceEngine()
-  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")
+  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 44 - 19
exo/inference/mlx/sharded_inference_engine.py

@@ -1,15 +1,35 @@
 import numpy as np
 import mlx.core as mx
+import mlx.nn as nn
 from ..inference_engine import InferenceEngine
-from .sharded_model import StatefulShardedModel
+from .sharded_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
-from typing import Optional
+from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial
+def sample_logits(
+  logits: mx.array,
+  temp: float = 0.0,
+  top_p: float = 1.0,
+  logit_bias: Optional[Dict[int, float]] = None
+) -> Tuple[mx.array, float]:
+  if logit_bias:
+    indices = mx.array(list(logit_bias.keys()))
+    values = mx.array(list(logit_bias.values()))
+    logits[:, indices] += values
 
+  if temp == 0:
+    token = mx.argmax(logits, axis=-1)
+  else:
+    if top_p > 0 and top_p < 1.0:
+      token = top_p_sampling(logits, top_p, temp)
+    else:
+      token = mx.random.categorical(logits*(1/temp))
+
+  return token
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -17,25 +37,30 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def sample(self, x):
+    y = mx.array(x)
+    logits = y[:, -1, :]
+    out = np.array(sample_logits(logits))
+    return out
+
+  async def encode(self, shard: Shard, prompt: str):
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
-    if image_str:
-      image = await get_image_from_str(image_str)
-      tokenize = partial(self.tokenizer, prompt, image, return_tensors="np")
-      inputs = await loop.run_in_executor(self.executor, tokenize)
-      pixel_values = mx.array(inputs["pixel_values"])
-      input_ids = mx.array(inputs["input_ids"])
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
-    else:
-      input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return tokens
+
+  async def decode(self, shard: Shard, tokens):
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
+    return tokens
+    
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
+    output_data = await self.infer_tensor(request_id, shard, await self.encode(shard, prompt), inference_state)
+    return output_data 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, bool):
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
+    return output_data
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
@@ -50,5 +75,5 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
         return asyncio.run(load_shard(model_path, shard))
 
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
-      self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 16 - 62
exo/inference/mlx/sharded_model.py

@@ -8,72 +8,14 @@ from mlx_lm.sample_utils import top_p_sampling
 
 from ..shard import Shard
 
-
-# TODO: support a speculative model so we can parallelise compute across devices
-class StatefulShardedModel:
-  def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
-    self.shard = shard
+class StatefulModel(nn.Module):
+  def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
+    super().__init__()
     self.model = model
     self.max_kv_size = max_kv_size
     self.max_caches = max_caches
     self.caches = OrderedDict()
-
-  def step(
-    self,
-    request_id: str,
-    x,
-    pixel_values=None,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    def sample(logits: mx.array) -> Tuple[mx.array, float]:
-      if logit_bias:
-        indices = mx.array(list(logit_bias.keys()))
-        values = mx.array(list(logit_bias.values()))
-        logits[:, indices] += values
-
-      if temp == 0:
-        token = mx.argmax(logits, axis=-1)
-      else:
-        if top_p > 0 and top_p < 1.0:
-          token = top_p_sampling(logits, top_p, temp)
-        else:
-          token = mx.random.categorical(logits*(1/temp))
-
-      return token
-
-    y = x
-
-    if request_id not in self.caches:
-      self.init_cache(request_id)
-    else:
-      self.caches.move_to_end(request_id)
-
-    cache = self.caches[request_id]
-
-    if pixel_values is None:
-      output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache)
-    else:
-      output = self.model(y, pixel_values=pixel_values, cache=cache)
-
-    if self.shard.is_last_layer():
-      logits = output[:, -1, :]
-      y = sample(logits)
-      return y
-    else:
-      return output
-
-  def __call__(
-    self,
-    request_id: str,
-    x,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
-
+  
   def init_cache(self, request_id: str):
     kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
     # if self.max_kv_size is not None:
@@ -87,3 +29,15 @@ class StatefulShardedModel:
       self.caches.popitem(last=False)
 
     self.caches[request_id] = cache
+
+  def __call__(self, x, request_id: str):
+    if request_id not in self.caches:
+      self.init_cache(request_id)
+    else:
+      self.caches.move_to_end(request_id)
+
+    cache = self.caches[request_id]
+
+    y = self.model(x, cache=cache)
+    return y
+    

+ 10 - 3
exo/inference/mlx/sharded_utils.py

@@ -68,7 +68,6 @@ def load_config(model_path: Path) -> dict:
     raise
   return config
 
-
 def load_model_shard(
   model_path: Path,
   shard: Shard,
@@ -131,8 +130,17 @@ def load_model_shard(
 
   model_class, model_args_class = _get_classes(config=config)
 
+  class ShardedModel(model_class):
+    def __init__(self, args):
+      super().__init__(args)
+      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
+
+    def __call__(self, x, *args, **kwargs):
+      y = super().__call__(x[None] if self.shard.is_first_layer() else x, *args, **kwargs)
+      return y
+
   model_args = model_args_class.from_dict(config)
-  model = model_class(model_args)
+  model = ShardedModel(model_args)
 
   if hasattr(model, "sanitize"):
     weights = model.sanitize(weights)
@@ -158,7 +166,6 @@ def load_model_shard(
   model.eval()
   return model
 
-
 async def load_shard(
   model_path: str,
   shard: Shard,

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -65,7 +65,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)

+ 2 - 2
exo/main.py

@@ -189,7 +189,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
 
   try:
     print(f"Processing prompt: {prompt}")
-    await node.process_prompt(shard, prompt, None, request_id=request_id)
+    await node.process_prompt(shard, prompt, request_id=request_id)
 
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
 
@@ -238,4 +238,4 @@ def run():
 
 
 if __name__ == "__main__":
-  run()
+  run()

+ 1 - 2
exo/networking/grpc/grpc_peer_handle.py

@@ -63,10 +63,9 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
       return False
 
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
-      image_str=image_str,
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         start_layer=shard.start_layer,

+ 2 - 3
exo/networking/grpc/grpc_server.py

@@ -49,10 +49,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       n_layers=request.shard.n_layers,
     )
     prompt = request.prompt
-    image_str = request.image_str
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, image_str, request_id)
-    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
+    result = await self.node.process_prompt(shard, prompt, request_id)
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 

+ 3 - 4
exo/networking/grpc/node_service.proto

@@ -22,9 +22,8 @@ message Shard {
 message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
-  optional string image_str = 3;
-  optional string request_id = 4;
-  optional string inference_state = 5;
+  optional string request_id = 3;
+  optional string inference_state = 4;
 }
 
 message TensorRequest {
@@ -93,4 +92,4 @@ message HealthCheckResponse {
   bool is_healthy = 1;
 }
 
-message Empty {}
+message Empty {}

File diff suppressed because it is too large
+ 0 - 1
exo/networking/grpc/node_service_pb2.py


+ 314 - 263
exo/networking/grpc/node_service_pb2_grpc.py

@@ -12,298 +12,349 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 
 try:
-  from grpc._utilities import first_version_is_lower
-  _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+    from grpc._utilities import first_version_is_lower
+    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
 except ImportError:
-  _version_not_supported = True
+    _version_not_supported = True
 
 if _version_not_supported:
-  warnings.warn(
-    f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
-    f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
-    f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
-  )
+    warnings.warn(
+        f'The grpc package installed is at version {GRPC_VERSION},'
+        + f' but the generated code in node_service_pb2_grpc.py depends on'
+        + f' grpcio>={GRPC_GENERATED_VERSION}.'
+        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
+        RuntimeWarning
+    )
 
 
 class NodeServiceStub(object):
-  """Missing associated documentation comment in .proto file."""
-  def __init__(self, channel):
-    """Constructor.
+    """Missing associated documentation comment in .proto file."""
+
+    def __init__(self, channel):
+        """Constructor.
 
         Args:
             channel: A grpc.Channel.
         """
-    self.SendPrompt = channel.unary_unary(
-      '/node_service.NodeService/SendPrompt',
-      request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Tensor.FromString,
-      _registered_method=True
-    )
-    self.SendTensor = channel.unary_unary(
-      '/node_service.NodeService/SendTensor',
-      request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Tensor.FromString,
-      _registered_method=True
-    )
-    self.GetInferenceResult = channel.unary_unary(
-      '/node_service.NodeService/GetInferenceResult',
-      request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-      response_deserializer=node__service__pb2.InferenceResult.FromString,
-      _registered_method=True
-    )
-    self.CollectTopology = channel.unary_unary(
-      '/node_service.NodeService/CollectTopology',
-      request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Topology.FromString,
-      _registered_method=True
-    )
-    self.SendResult = channel.unary_unary(
-      '/node_service.NodeService/SendResult',
-      request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Empty.FromString,
-      _registered_method=True
-    )
-    self.SendOpaqueStatus = channel.unary_unary(
-      '/node_service.NodeService/SendOpaqueStatus',
-      request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Empty.FromString,
-      _registered_method=True
-    )
-    self.HealthCheck = channel.unary_unary(
-      '/node_service.NodeService/HealthCheck',
-      request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
-      response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
-      _registered_method=True
-    )
+        self.SendPrompt = channel.unary_unary(
+                '/node_service.NodeService/SendPrompt',
+                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.SendTensor = channel.unary_unary(
+                '/node_service.NodeService/SendTensor',
+                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.GetInferenceResult = channel.unary_unary(
+                '/node_service.NodeService/GetInferenceResult',
+                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.InferenceResult.FromString,
+                _registered_method=True)
+        self.CollectTopology = channel.unary_unary(
+                '/node_service.NodeService/CollectTopology',
+                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Topology.FromString,
+                _registered_method=True)
+        self.SendResult = channel.unary_unary(
+                '/node_service.NodeService/SendResult',
+                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.SendOpaqueStatus = channel.unary_unary(
+                '/node_service.NodeService/SendOpaqueStatus',
+                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.HealthCheck = channel.unary_unary(
+                '/node_service.NodeService/HealthCheck',
+                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
+                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
+                _registered_method=True)
 
 
 class NodeServiceServicer(object):
-  """Missing associated documentation comment in .proto file."""
-  def SendPrompt(self, request, context):
     """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
 
-  def SendTensor(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendPrompt(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def GetInferenceResult(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendTensor(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def CollectTopology(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def GetInferenceResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def SendResult(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def CollectTopology(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def SendOpaqueStatus(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def HealthCheck(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendOpaqueStatus(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def HealthCheck(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
 
 def add_NodeServiceServicer_to_server(servicer, server):
-  rpc_method_handlers = {
-    'SendPrompt':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendPrompt,
-        request_deserializer=node__service__pb2.PromptRequest.FromString,
-        response_serializer=node__service__pb2.Tensor.SerializeToString,
-      ),
-    'SendTensor':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendTensor,
-        request_deserializer=node__service__pb2.TensorRequest.FromString,
-        response_serializer=node__service__pb2.Tensor.SerializeToString,
-      ),
-    'GetInferenceResult':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.GetInferenceResult,
-        request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-        response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-      ),
-    'CollectTopology':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.CollectTopology,
-        request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-        response_serializer=node__service__pb2.Topology.SerializeToString,
-      ),
-    'SendResult':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendResult,
-        request_deserializer=node__service__pb2.SendResultRequest.FromString,
-        response_serializer=node__service__pb2.Empty.SerializeToString,
-      ),
-    'SendOpaqueStatus':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendOpaqueStatus,
-        request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-        response_serializer=node__service__pb2.Empty.SerializeToString,
-      ),
-    'HealthCheck':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.HealthCheck,
-        request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
-        response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
-      ),
-  }
-  generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
-  server.add_generic_rpc_handlers((generic_handler,))
-  server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+    rpc_method_handlers = {
+            'SendPrompt': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendPrompt,
+                    request_deserializer=node__service__pb2.PromptRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'SendTensor': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendTensor,
+                    request_deserializer=node__service__pb2.TensorRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.GetInferenceResult,
+                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+            ),
+            'CollectTopology': grpc.unary_unary_rpc_method_handler(
+                    servicer.CollectTopology,
+                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+                    response_serializer=node__service__pb2.Topology.SerializeToString,
+            ),
+            'SendResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendResult,
+                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendOpaqueStatus,
+                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'HealthCheck': grpc.unary_unary_rpc_method_handler(
+                    servicer.HealthCheck,
+                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
+                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
+            ),
+    }
+    generic_handler = grpc.method_handlers_generic_handler(
+            'node_service.NodeService', rpc_method_handlers)
+    server.add_generic_rpc_handlers((generic_handler,))
+    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
-# This class is part of an EXPERIMENTAL API.
+ # This class is part of an EXPERIMENTAL API.
 class NodeService(object):
-  """Missing associated documentation comment in .proto file."""
-  @staticmethod
-  def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendPrompt',
-      node__service__pb2.PromptRequest.SerializeToString,
-      node__service__pb2.Tensor.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    """Missing associated documentation comment in .proto file."""
 
-  @staticmethod
-  def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendTensor',
-      node__service__pb2.TensorRequest.SerializeToString,
-      node__service__pb2.Tensor.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendPrompt(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendPrompt',
+            node__service__pb2.PromptRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/GetInferenceResult',
-      node__service__pb2.GetInferenceResultRequest.SerializeToString,
-      node__service__pb2.InferenceResult.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendTensor(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendTensor',
+            node__service__pb2.TensorRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/CollectTopology',
-      node__service__pb2.CollectTopologyRequest.SerializeToString,
-      node__service__pb2.Topology.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def GetInferenceResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/GetInferenceResult',
+            node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            node__service__pb2.InferenceResult.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendResult',
-      node__service__pb2.SendResultRequest.SerializeToString,
-      node__service__pb2.Empty.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def CollectTopology(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/CollectTopology',
+            node__service__pb2.CollectTopologyRequest.SerializeToString,
+            node__service__pb2.Topology.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendOpaqueStatus',
-      node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-      node__service__pb2.Empty.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendResult',
+            node__service__pb2.SendResultRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def HealthCheck(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/HealthCheck',
-      node__service__pb2.HealthCheckRequest.SerializeToString,
-      node__service__pb2.HealthCheckResponse.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendOpaqueStatus(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendOpaqueStatus',
+            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def HealthCheck(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/HealthCheck',
+            node__service__pb2.HealthCheckRequest.SerializeToString,
+            node__service__pb2.HealthCheckResponse.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)

+ 1 - 1
exo/networking/peer_handle.py

@@ -36,7 +36,7 @@ class PeerHandle(ABC):
     pass
 
   @abstractmethod
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     pass
 
   @abstractmethod

+ 1 - 1
exo/orchestration/node.py

@@ -16,7 +16,7 @@ class Node(ABC):
     pass
 
   @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod

+ 92 - 76
exo/orchestration/standard_node.py

@@ -18,7 +18,6 @@ from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.download.hf.hf_shard_download import HFShardDownloader
 
-
 class StandardNode(Node):
   def __init__(
     self,
@@ -40,6 +39,7 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+    self.buffered_logits: Dict[str, Tuple[List[np.ndarray], bool]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
@@ -100,8 +100,56 @@ class StandardNode(Node):
 
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
+  
+  async def encode_prompt(self, shard: Shard, prompt):
+    toks = await self.inference_engine.encode(shard, prompt)
+    return toks
+  
+  async def process_result(
+    self,
+    shard,
+    result,
+    request_id: Optional[str] = None,
+    inference_state: Optional[str] = None,
+  ):
+    if request_id not in self.buffered_token_output:
+      self.buffered_token_output[request_id] = ([], False)
+    
+    if request_id not in self.buffered_logits:
+      self.buffered_logits[request_id] = ([], False)
+
+    for i in np.reshape(result, (-1, 1, result.shape[-1])):
+      self.buffered_logits[request_id][0].append(i)
+
+    if shard.is_last_layer():
+      result = await self.inference_engine.sample(result)
+    
+    await self.inference_engine.ensure_shard(shard)
+    is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+
+    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
+
+    if result.size == 1:  # we got a new token out
+      self.buffered_token_output[request_id][0].append(result.item())
+      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+    
+    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
 
-  async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    if is_finished:
+      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+      self.buffered_logits[request_id] = (self.buffered_logits[request_id][0], True)
+    else:
+      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
+
+    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+
+  async def process_prompt(
+    self,
+    base_shard: Shard,
+    prompt: str,
+    request_id: Optional[str] = None,
+    inference_state: Optional[str] = None
+  ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
       self.broadcast_opaque_status(
@@ -113,14 +161,13 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
-          "image_str": image_str,
           "inference_state": inference_state,
           "request_id": request_id,
         }),
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
+    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -133,7 +180,6 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
-          "image_str": image_str,
           "inference_state": inference_state,
           "request_id": request_id,
           "elapsed_time_ns": elapsed_time_ns,
@@ -143,35 +189,20 @@ class StandardNode(Node):
     )
     return resp
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
 
-    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
     if shard.start_layer != 0:
-      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
-      await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state)
-      return
-
-    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
-    is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-    if is_finished:
-      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-    if result.size == 1:
-      self.buffered_token_output[request_id][0].append(result.item())
-      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-
-    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-    if not is_finished:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state))
-
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+      await self.forward_to_next_shard(shard, prompt, request_id, inference_state=inference_state)
+      return None
+    else:
+      result = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
+      ret = await self.process_result(shard, result, request_id, inference_state=inference_state) 
+      return result
 
   async def process_tensor(
     self,
@@ -227,27 +258,13 @@ class StandardNode(Node):
   ) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
 
+    if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
-      if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-      result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
-      is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      if is_finished:
-        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-      if result.size == 1:  # we got a new token out
-        self.buffered_token_output[request_id][0].append(result.item())
-        self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-      if not is_finished:
-        asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
-
-      return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+      result = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
+      ret = await self.process_result(shard, result, request_id, inference_state=inference_state) 
+      return ret
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
@@ -258,49 +275,48 @@ class StandardNode(Node):
     base_shard: Shard,
     tensor_or_prompt: Union[np.ndarray, str],
     request_id: str,
-    image_str: Optional[str] = None,
     inference_state: Optional[str] = None,
   ) -> None:
     if not self.partitioning_strategy:
       if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
       return
-    shard = self.get_current_shard(base_shard)
 
-    partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
-    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
+    next_partition_index = self.get_partition_index(offset = 1)
     if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
-    if current_partition_index is not None:
-      next_partition_index = (current_partition_index+1) % len(partitions)
-      next_partition: Partition = partitions[next_partition_index]
-      next_shard = shards[next_partition_index]
+    if next_partition_index is not None:
+      target_id = self.partitioning_strategy.partition(self.topology)[next_partition_index].node_id
+      next_shard = self.get_current_shard(base_shard, next_partition_index)
       if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
-
-      if next_partition.node_id == self.id:
-        if isinstance(tensor_or_prompt, np.ndarray):
-          await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+      is_tensor = isinstance(tensor_or_prompt, np.ndarray)
+      if target_id == self.id:
+        if is_tensor:
+          await self.process_tensor(next_shard, tensor_or_prompt, request_id, inference_state=inference_state)
         else:
-          await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
-        return
-
-      target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
-      if not target_peer:
-        raise ValueError(f"Peer for {next_partition} not found")
-
-      if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
-
-      if isinstance(tensor_or_prompt, np.ndarray):
-        await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+          await self.process_prompt(next_shard, tensor_or_prompt, request_id, inference_state=inference_state)
       else:
-        await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
+        target_peer = next((p for p in self.peers if p.id() == target_id), None)
+        if not target_peer:
+          raise ValueError(f"Peer for {next_partition} not found")
+        
+        if is_tensor:
+          if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
+          await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+        else:
+          await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
 
-  def get_current_shard(self, base_shard: Shard) -> Shard:
+  def get_partition_index(self, offset: int = 0):
     partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     if current_partition_index is None:
       raise ValueError(f"No current partition found for node: {self.id}")
-    return shards[current_partition_index]
+    return (current_partition_index + offset) % len(partitions)
+
+  def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
+    if index is None:
+      index = self.get_partition_index()
+    partitions = self.partitioning_strategy.partition(self.topology)
+    shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
+    return shards[index]
 
   async def update_peers(self, wait_for_peers: int = 0) -> bool:
     next_peers = await self.discovery.discover_peers(wait_for_peers)
@@ -428,7 +444,7 @@ class StandardNode(Node):
   def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
     if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
     self.on_token.trigger_all(request_id, tokens, is_finished)
-
+  
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
     async def send_result_to_peer(peer):
       try:

Some files were not shown because too many files changed in this diff