Sfoglia il codice sorgente

Some initial inference engine refactors for enabling training

Only on MLX for now, breaks Tinygrad (doesn't fulfill interface yet)
Nel Nibcord 10 mesi fa
parent
commit
82cce4408e

+ 4 - 12
exo/api/chatgpt_api.py

@@ -117,19 +117,11 @@ def remap_messages(messages: List[Message]) -> List[Message]:
 def build_prompt(tokenizer, _messages: List[Message]):
   messages = remap_messages(_messages)
   prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
-  image_str = None
   for message in messages:
     if not isinstance(message.content, list):
       continue
 
-    for content in message.content:
-      # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
-      # follows the convention in https://platform.openai.com/docs/guides/vision
-      if isinstance(content, dict) and content.get("type", None) == "image":
-        image_str = content.get("image", None)
-        break
-
-  return prompt, image_str
+  return prompt
 
 
 def parse_message(data: dict):
@@ -246,7 +238,7 @@ class ChatGPTAPI:
     tokenizer = await resolve_tokenizer(shard.model_id)
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
-    prompt, image_str = build_prompt(tokenizer, chat_request.messages)
+    prompt = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:
       try:
@@ -269,10 +261,10 @@ class ChatGPTAPI:
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback = self.node.on_token.register(callback_id)
 
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
 
     try:
-      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))), timeout=self.response_timeout)
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
 
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
 

+ 1 - 1
exo/inference/dummy_inference_engine.py

@@ -14,7 +14,7 @@ class DummyInferenceEngine(InferenceEngine):
     self.latency_mean = 0.1
     self.latency_stddev = 0.02
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     try:
       await self.ensure_shard(shard)
 

+ 13 - 3
exo/inference/inference_engine.py

@@ -9,13 +9,23 @@ from .shard import Shard
 
 class InferenceEngine(ABC):
   @abstractmethod
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    pass
+  
+  async def sample(self, x: np.ndarray) -> np.ndarray:
+    pass
+
+  @abstractmethod
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     pass
 
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> np.ndarray:
     pass
 
+  @abstractmethod
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
+    pass
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
   if DEBUG >= 2:
@@ -33,4 +43,4 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
   elif inference_engine_name == "dummy":
     from exo.inference.dummy_inference_engine import DummyInferenceEngine
     return DummyInferenceEngine()
-  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")
+  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 44 - 19
exo/inference/mlx/sharded_inference_engine.py

@@ -1,15 +1,35 @@
 import numpy as np
 import mlx.core as mx
+import mlx.nn as nn
 from ..inference_engine import InferenceEngine
-from .sharded_model import StatefulShardedModel
+from .sharded_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
-from typing import Optional
+from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial
+def sample_logits(
+  logits: mx.array,
+  temp: float = 0.0,
+  top_p: float = 1.0,
+  logit_bias: Optional[Dict[int, float]] = None
+) -> Tuple[mx.array, float]:
+  if logit_bias:
+    indices = mx.array(list(logit_bias.keys()))
+    values = mx.array(list(logit_bias.values()))
+    logits[:, indices] += values
 
+  if temp == 0:
+    token = mx.argmax(logits, axis=-1)
+  else:
+    if top_p > 0 and top_p < 1.0:
+      token = top_p_sampling(logits, top_p, temp)
+    else:
+      token = mx.random.categorical(logits*(1/temp))
+
+  return token
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -17,25 +37,30 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def sample(self, x):
+    y = mx.array(x)
+    logits = y[:, -1, :]
+    out = np.array(sample_logits(logits))
+    return out
+
+  async def encode(self, shard: Shard, prompt: str):
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
-    if image_str:
-      image = await get_image_from_str(image_str)
-      tokenize = partial(self.tokenizer, prompt, image, return_tensors="np")
-      inputs = await loop.run_in_executor(self.executor, tokenize)
-      pixel_values = mx.array(inputs["pixel_values"])
-      input_ids = mx.array(inputs["input_ids"])
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
-    else:
-      input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return tokens
+
+  async def decode(self, shard: Shard, tokens):
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
+    return tokens
+    
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, bool):
+    output_data = await self.infer_tensor(request_id, shard, await self.encode(shard, prompt), inference_state)
+    return output_data 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, bool):
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
+    return output_data
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
@@ -50,5 +75,5 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
         return asyncio.run(load_shard(model_path, shard))
 
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
-      self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 16 - 62
exo/inference/mlx/sharded_model.py

@@ -8,72 +8,14 @@ from mlx_lm.sample_utils import top_p_sampling
 
 from ..shard import Shard
 
-
-# TODO: support a speculative model so we can parallelise compute across devices
-class StatefulShardedModel:
-  def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
-    self.shard = shard
+class StatefulModel(nn.Module):
+  def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
+    super().__init__()
     self.model = model
     self.max_kv_size = max_kv_size
     self.max_caches = max_caches
     self.caches = OrderedDict()
-
-  def step(
-    self,
-    request_id: str,
-    x,
-    pixel_values=None,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    def sample(logits: mx.array) -> Tuple[mx.array, float]:
-      if logit_bias:
-        indices = mx.array(list(logit_bias.keys()))
-        values = mx.array(list(logit_bias.values()))
-        logits[:, indices] += values
-
-      if temp == 0:
-        token = mx.argmax(logits, axis=-1)
-      else:
-        if top_p > 0 and top_p < 1.0:
-          token = top_p_sampling(logits, top_p, temp)
-        else:
-          token = mx.random.categorical(logits*(1/temp))
-
-      return token
-
-    y = x
-
-    if request_id not in self.caches:
-      self.init_cache(request_id)
-    else:
-      self.caches.move_to_end(request_id)
-
-    cache = self.caches[request_id]
-
-    if pixel_values is None:
-      output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache)
-    else:
-      output = self.model(y, pixel_values=pixel_values, cache=cache)
-
-    if self.shard.is_last_layer():
-      logits = output[:, -1, :]
-      y = sample(logits)
-      return y
-    else:
-      return output
-
-  def __call__(
-    self,
-    request_id: str,
-    x,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
-
+  
   def init_cache(self, request_id: str):
     kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
     # if self.max_kv_size is not None:
@@ -87,3 +29,15 @@ class StatefulShardedModel:
       self.caches.popitem(last=False)
 
     self.caches[request_id] = cache
+
+  def __call__(self, x, request_id: str):
+    if request_id not in self.caches:
+      self.init_cache(request_id)
+    else:
+      self.caches.move_to_end(request_id)
+
+    cache = self.caches[request_id]
+
+    y = self.model(x, cache=cache)
+    return y
+    

+ 10 - 3
exo/inference/mlx/sharded_utils.py

@@ -68,7 +68,6 @@ def load_config(model_path: Path) -> dict:
     raise
   return config
 
-
 def load_model_shard(
   model_path: Path,
   shard: Shard,
@@ -131,8 +130,17 @@ def load_model_shard(
 
   model_class, model_args_class = _get_classes(config=config)
 
+  class ShardedModel(model_class):
+    def __init__(self, args):
+      super().__init__(args)
+      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
+
+    def __call__(self, x, *args, **kwargs):
+      y = super().__call__(x[None] if self.shard.is_first_layer() else x, *args, **kwargs)
+      return y
+
   model_args = model_args_class.from_dict(config)
-  model = model_class(model_args)
+  model = ShardedModel(model_args)
 
   if hasattr(model, "sanitize"):
     weights = model.sanitize(weights)
@@ -158,7 +166,6 @@ def load_model_shard(
   model.eval()
   return model
 
-
 async def load_shard(
   model_path: str,
   shard: Shard,

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -65,7 +65,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)

+ 2 - 2
exo/main.py

@@ -189,7 +189,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
 
   try:
     print(f"Processing prompt: {prompt}")
-    await node.process_prompt(shard, prompt, None, request_id=request_id)
+    await node.process_prompt(shard, prompt, request_id=request_id)
 
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
 
@@ -238,4 +238,4 @@ def run():
 
 
 if __name__ == "__main__":
-  run()
+  run()

+ 1 - 2
exo/networking/grpc/grpc_peer_handle.py

@@ -63,10 +63,9 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
       return False
 
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
-      image_str=image_str,
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         start_layer=shard.start_layer,

+ 2 - 3
exo/networking/grpc/grpc_server.py

@@ -49,10 +49,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       n_layers=request.shard.n_layers,
     )
     prompt = request.prompt
-    image_str = request.image_str
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, image_str, request_id)
-    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
+    result = await self.node.process_prompt(shard, prompt, request_id)
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 

+ 3 - 4
exo/networking/grpc/node_service.proto

@@ -22,9 +22,8 @@ message Shard {
 message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
-  optional string image_str = 3;
-  optional string request_id = 4;
-  optional string inference_state = 5;
+  optional string request_id = 3;
+  optional string inference_state = 4;
 }
 
 message TensorRequest {
@@ -93,4 +92,4 @@ message HealthCheckResponse {
   bool is_healthy = 1;
 }
 
-message Empty {}
+message Empty {}

File diff suppressed because it is too large
+ 0 - 1
exo/networking/grpc/node_service_pb2.py


+ 314 - 263
exo/networking/grpc/node_service_pb2_grpc.py

@@ -12,298 +12,349 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 
 try:
-  from grpc._utilities import first_version_is_lower
-  _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+    from grpc._utilities import first_version_is_lower
+    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
 except ImportError:
-  _version_not_supported = True
+    _version_not_supported = True
 
 if _version_not_supported:
-  warnings.warn(
-    f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
-    f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
-    f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
-  )
+    warnings.warn(
+        f'The grpc package installed is at version {GRPC_VERSION},'
+        + f' but the generated code in node_service_pb2_grpc.py depends on'
+        + f' grpcio>={GRPC_GENERATED_VERSION}.'
+        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
+        RuntimeWarning
+    )
 
 
 class NodeServiceStub(object):
-  """Missing associated documentation comment in .proto file."""
-  def __init__(self, channel):
-    """Constructor.
+    """Missing associated documentation comment in .proto file."""
+
+    def __init__(self, channel):
+        """Constructor.
 
         Args:
             channel: A grpc.Channel.
         """
-    self.SendPrompt = channel.unary_unary(
-      '/node_service.NodeService/SendPrompt',
-      request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Tensor.FromString,
-      _registered_method=True
-    )
-    self.SendTensor = channel.unary_unary(
-      '/node_service.NodeService/SendTensor',
-      request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Tensor.FromString,
-      _registered_method=True
-    )
-    self.GetInferenceResult = channel.unary_unary(
-      '/node_service.NodeService/GetInferenceResult',
-      request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-      response_deserializer=node__service__pb2.InferenceResult.FromString,
-      _registered_method=True
-    )
-    self.CollectTopology = channel.unary_unary(
-      '/node_service.NodeService/CollectTopology',
-      request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Topology.FromString,
-      _registered_method=True
-    )
-    self.SendResult = channel.unary_unary(
-      '/node_service.NodeService/SendResult',
-      request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Empty.FromString,
-      _registered_method=True
-    )
-    self.SendOpaqueStatus = channel.unary_unary(
-      '/node_service.NodeService/SendOpaqueStatus',
-      request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Empty.FromString,
-      _registered_method=True
-    )
-    self.HealthCheck = channel.unary_unary(
-      '/node_service.NodeService/HealthCheck',
-      request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
-      response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
-      _registered_method=True
-    )
+        self.SendPrompt = channel.unary_unary(
+                '/node_service.NodeService/SendPrompt',
+                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.SendTensor = channel.unary_unary(
+                '/node_service.NodeService/SendTensor',
+                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.GetInferenceResult = channel.unary_unary(
+                '/node_service.NodeService/GetInferenceResult',
+                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.InferenceResult.FromString,
+                _registered_method=True)
+        self.CollectTopology = channel.unary_unary(
+                '/node_service.NodeService/CollectTopology',
+                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Topology.FromString,
+                _registered_method=True)
+        self.SendResult = channel.unary_unary(
+                '/node_service.NodeService/SendResult',
+                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.SendOpaqueStatus = channel.unary_unary(
+                '/node_service.NodeService/SendOpaqueStatus',
+                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.HealthCheck = channel.unary_unary(
+                '/node_service.NodeService/HealthCheck',
+                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
+                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
+                _registered_method=True)
 
 
 class NodeServiceServicer(object):
-  """Missing associated documentation comment in .proto file."""
-  def SendPrompt(self, request, context):
     """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
 
-  def SendTensor(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendPrompt(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def GetInferenceResult(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendTensor(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def CollectTopology(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def GetInferenceResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def SendResult(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def CollectTopology(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def SendOpaqueStatus(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
-  def HealthCheck(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendOpaqueStatus(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def HealthCheck(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
 
 def add_NodeServiceServicer_to_server(servicer, server):
-  rpc_method_handlers = {
-    'SendPrompt':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendPrompt,
-        request_deserializer=node__service__pb2.PromptRequest.FromString,
-        response_serializer=node__service__pb2.Tensor.SerializeToString,
-      ),
-    'SendTensor':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendTensor,
-        request_deserializer=node__service__pb2.TensorRequest.FromString,
-        response_serializer=node__service__pb2.Tensor.SerializeToString,
-      ),
-    'GetInferenceResult':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.GetInferenceResult,
-        request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-        response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-      ),
-    'CollectTopology':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.CollectTopology,
-        request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-        response_serializer=node__service__pb2.Topology.SerializeToString,
-      ),
-    'SendResult':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendResult,
-        request_deserializer=node__service__pb2.SendResultRequest.FromString,
-        response_serializer=node__service__pb2.Empty.SerializeToString,
-      ),
-    'SendOpaqueStatus':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendOpaqueStatus,
-        request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-        response_serializer=node__service__pb2.Empty.SerializeToString,
-      ),
-    'HealthCheck':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.HealthCheck,
-        request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
-        response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
-      ),
-  }
-  generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
-  server.add_generic_rpc_handlers((generic_handler,))
-  server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+    rpc_method_handlers = {
+            'SendPrompt': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendPrompt,
+                    request_deserializer=node__service__pb2.PromptRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'SendTensor': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendTensor,
+                    request_deserializer=node__service__pb2.TensorRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.GetInferenceResult,
+                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+            ),
+            'CollectTopology': grpc.unary_unary_rpc_method_handler(
+                    servicer.CollectTopology,
+                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+                    response_serializer=node__service__pb2.Topology.SerializeToString,
+            ),
+            'SendResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendResult,
+                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendOpaqueStatus,
+                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'HealthCheck': grpc.unary_unary_rpc_method_handler(
+                    servicer.HealthCheck,
+                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
+                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
+            ),
+    }
+    generic_handler = grpc.method_handlers_generic_handler(
+            'node_service.NodeService', rpc_method_handlers)
+    server.add_generic_rpc_handlers((generic_handler,))
+    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
-# This class is part of an EXPERIMENTAL API.
+ # This class is part of an EXPERIMENTAL API.
 class NodeService(object):
-  """Missing associated documentation comment in .proto file."""
-  @staticmethod
-  def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendPrompt',
-      node__service__pb2.PromptRequest.SerializeToString,
-      node__service__pb2.Tensor.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    """Missing associated documentation comment in .proto file."""
 
-  @staticmethod
-  def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendTensor',
-      node__service__pb2.TensorRequest.SerializeToString,
-      node__service__pb2.Tensor.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendPrompt(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendPrompt',
+            node__service__pb2.PromptRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/GetInferenceResult',
-      node__service__pb2.GetInferenceResultRequest.SerializeToString,
-      node__service__pb2.InferenceResult.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendTensor(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendTensor',
+            node__service__pb2.TensorRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/CollectTopology',
-      node__service__pb2.CollectTopologyRequest.SerializeToString,
-      node__service__pb2.Topology.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def GetInferenceResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/GetInferenceResult',
+            node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            node__service__pb2.InferenceResult.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendResult',
-      node__service__pb2.SendResultRequest.SerializeToString,
-      node__service__pb2.Empty.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def CollectTopology(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/CollectTopology',
+            node__service__pb2.CollectTopologyRequest.SerializeToString,
+            node__service__pb2.Topology.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendOpaqueStatus',
-      node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-      node__service__pb2.Empty.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendResult',
+            node__service__pb2.SendResultRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
-  @staticmethod
-  def HealthCheck(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/HealthCheck',
-      node__service__pb2.HealthCheckRequest.SerializeToString,
-      node__service__pb2.HealthCheckResponse.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendOpaqueStatus(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendOpaqueStatus',
+            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def HealthCheck(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/HealthCheck',
+            node__service__pb2.HealthCheckRequest.SerializeToString,
+            node__service__pb2.HealthCheckResponse.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)

+ 1 - 1
exo/networking/peer_handle.py

@@ -36,7 +36,7 @@ class PeerHandle(ABC):
     pass
 
   @abstractmethod
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     pass
 
   @abstractmethod

+ 1 - 1
exo/orchestration/node.py

@@ -16,7 +16,7 @@ class Node(ABC):
     pass
 
   @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod

+ 92 - 76
exo/orchestration/standard_node.py

@@ -18,7 +18,6 @@ from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.download.hf.hf_shard_download import HFShardDownloader
 
-
 class StandardNode(Node):
   def __init__(
     self,
@@ -40,6 +39,7 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+    self.buffered_logits: Dict[str, Tuple[List[np.ndarray], bool]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
@@ -100,8 +100,56 @@ class StandardNode(Node):
 
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
+  
+  async def encode_prompt(self, shard: Shard, prompt):
+    toks = await self.inference_engine.encode(shard, prompt)
+    return toks
+  
+  async def process_result(
+    self,
+    shard,
+    result,
+    request_id: Optional[str] = None,
+    inference_state: Optional[str] = None,
+  ):
+    if request_id not in self.buffered_token_output:
+      self.buffered_token_output[request_id] = ([], False)
+    
+    if request_id not in self.buffered_logits:
+      self.buffered_logits[request_id] = ([], False)
+
+    for i in np.reshape(result, (-1, 1, result.shape[-1])):
+      self.buffered_logits[request_id][0].append(i)
+
+    if shard.is_last_layer():
+      result = await self.inference_engine.sample(result)
+    
+    await self.inference_engine.ensure_shard(shard)
+    is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+
+    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
+
+    if result.size == 1:  # we got a new token out
+      self.buffered_token_output[request_id][0].append(result.item())
+      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+    
+    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
 
-  async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    if is_finished:
+      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+      self.buffered_logits[request_id] = (self.buffered_logits[request_id][0], True)
+    else:
+      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
+
+    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+
+  async def process_prompt(
+    self,
+    base_shard: Shard,
+    prompt: str,
+    request_id: Optional[str] = None,
+    inference_state: Optional[str] = None
+  ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
       self.broadcast_opaque_status(
@@ -113,14 +161,13 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
-          "image_str": image_str,
           "inference_state": inference_state,
           "request_id": request_id,
         }),
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
+    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -133,7 +180,6 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
-          "image_str": image_str,
           "inference_state": inference_state,
           "request_id": request_id,
           "elapsed_time_ns": elapsed_time_ns,
@@ -143,35 +189,20 @@ class StandardNode(Node):
     )
     return resp
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
 
-    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
     if shard.start_layer != 0:
-      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
-      await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state)
-      return
-
-    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
-    is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-    if is_finished:
-      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-    if result.size == 1:
-      self.buffered_token_output[request_id][0].append(result.item())
-      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-
-    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-    if not is_finished:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state))
-
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+      await self.forward_to_next_shard(shard, prompt, request_id, inference_state=inference_state)
+      return None
+    else:
+      result = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
+      ret = await self.process_result(shard, result, request_id, inference_state=inference_state) 
+      return result
 
   async def process_tensor(
     self,
@@ -227,27 +258,13 @@ class StandardNode(Node):
   ) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
 
+    if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
-      if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-      result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
-      is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      if is_finished:
-        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-      if result.size == 1:  # we got a new token out
-        self.buffered_token_output[request_id][0].append(result.item())
-        self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-      if not is_finished:
-        asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
-
-      return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+      result = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
+      ret = await self.process_result(shard, result, request_id, inference_state=inference_state) 
+      return ret
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
@@ -258,49 +275,48 @@ class StandardNode(Node):
     base_shard: Shard,
     tensor_or_prompt: Union[np.ndarray, str],
     request_id: str,
-    image_str: Optional[str] = None,
     inference_state: Optional[str] = None,
   ) -> None:
     if not self.partitioning_strategy:
       if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
       return
-    shard = self.get_current_shard(base_shard)
 
-    partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
-    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
+    next_partition_index = self.get_partition_index(offset = 1)
     if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
-    if current_partition_index is not None:
-      next_partition_index = (current_partition_index+1) % len(partitions)
-      next_partition: Partition = partitions[next_partition_index]
-      next_shard = shards[next_partition_index]
+    if next_partition_index is not None:
+      target_id = self.partitioning_strategy.partition(self.topology)[next_partition_index].node_id
+      next_shard = self.get_current_shard(base_shard, next_partition_index)
       if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
-
-      if next_partition.node_id == self.id:
-        if isinstance(tensor_or_prompt, np.ndarray):
-          await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+      is_tensor = isinstance(tensor_or_prompt, np.ndarray)
+      if target_id == self.id:
+        if is_tensor:
+          await self.process_tensor(next_shard, tensor_or_prompt, request_id, inference_state=inference_state)
         else:
-          await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
-        return
-
-      target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
-      if not target_peer:
-        raise ValueError(f"Peer for {next_partition} not found")
-
-      if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
-
-      if isinstance(tensor_or_prompt, np.ndarray):
-        await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+          await self.process_prompt(next_shard, tensor_or_prompt, request_id, inference_state=inference_state)
       else:
-        await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
+        target_peer = next((p for p in self.peers if p.id() == target_id), None)
+        if not target_peer:
+          raise ValueError(f"Peer for {next_partition} not found")
+        
+        if is_tensor:
+          if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
+          await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+        else:
+          await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
 
-  def get_current_shard(self, base_shard: Shard) -> Shard:
+  def get_partition_index(self, offset: int = 0):
     partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     if current_partition_index is None:
       raise ValueError(f"No current partition found for node: {self.id}")
-    return shards[current_partition_index]
+    return (current_partition_index + offset) % len(partitions)
+
+  def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
+    if index is None:
+      index = self.get_partition_index()
+    partitions = self.partitioning_strategy.partition(self.topology)
+    shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
+    return shards[index]
 
   async def update_peers(self, wait_for_peers: int = 0) -> bool:
     next_peers = await self.discovery.discover_peers(wait_for_peers)
@@ -428,7 +444,7 @@ class StandardNode(Node):
   def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
     if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
     self.on_token.trigger_all(request_id, tokens, is_finished)
-
+  
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
     async def send_result_to_peer(peer):
       try:

Some files were not shown because too many files changed in this diff