Browse Source

Merge pull request #420 from blindcrone/refactor-inference

Some initial inference engine refactors for enabling training
Alex Cheema 8 months ago
parent
commit
b400a442ee

+ 4 - 12
exo/api/chatgpt_api.py

@@ -117,19 +117,11 @@ def remap_messages(messages: List[Message]) -> List[Message]:
 def build_prompt(tokenizer, _messages: List[Message]):
 def build_prompt(tokenizer, _messages: List[Message]):
   messages = remap_messages(_messages)
   messages = remap_messages(_messages)
   prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
   prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
-  image_str = None
   for message in messages:
   for message in messages:
     if not isinstance(message.content, list):
     if not isinstance(message.content, list):
       continue
       continue
 
 
-    for content in message.content:
-      # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
-      # follows the convention in https://platform.openai.com/docs/guides/vision
-      if isinstance(content, dict) and content.get("type", None) == "image":
-        image_str = content.get("image", None)
-        break
-
-  return prompt, image_str
+  return prompt
 
 
 
 
 def parse_message(data: dict):
 def parse_message(data: dict):
@@ -246,7 +238,7 @@ class ChatGPTAPI:
     tokenizer = await resolve_tokenizer(shard.model_id)
     tokenizer = await resolve_tokenizer(shard.model_id)
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
 
-    prompt, image_str = build_prompt(tokenizer, chat_request.messages)
+    prompt = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:
     if self.on_chat_completion_request:
       try:
       try:
@@ -269,10 +261,10 @@ class ChatGPTAPI:
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback = self.node.on_token.register(callback_id)
     callback = self.node.on_token.register(callback_id)
 
 
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
 
 
     try:
     try:
-      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))), timeout=self.response_timeout)
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
 
 
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
 
 

+ 17 - 39
exo/inference/dummy_inference_engine.py

@@ -1,60 +1,38 @@
 from typing import Optional, Tuple, TYPE_CHECKING
 from typing import Optional, Tuple, TYPE_CHECKING
 import numpy as np
 import numpy as np
+import random
+import string
 import asyncio
 import asyncio
 import json
 import json
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
-
+def random_string(length: int):
+  return ''.join([random.choice(string.ascii_lowercase) for i in range(length)])
+  
 
 
 class DummyInferenceEngine(InferenceEngine):
 class DummyInferenceEngine(InferenceEngine):
   def __init__(self):
   def __init__(self):
     self.shard = None
     self.shard = None
     self.vocab_size = 1000
     self.vocab_size = 1000
+    self.hidden_size = 256
     self.eos_token_id = 0
     self.eos_token_id = 0
     self.latency_mean = 0.1
     self.latency_mean = 0.1
     self.latency_stddev = 0.02
     self.latency_stddev = 0.02
 
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-    try:
-      await self.ensure_shard(shard)
-
-      # Generate random tokens
-      output_length = np.random.randint(1, 10)
-      output = np.random.randint(1, self.vocab_size, size=(1, output_length))
-
-      # Simulate latency
-      await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
-
-      # Randomly decide if finished
-      is_finished = np.random.random() < 0.2
-      if is_finished:
-        output = np.array([[self.eos_token_id]])
-
-      new_state = json.dumps({"dummy_state": "some_value"})
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    return np.random.randint(1, self.vocab_size, size=(1, len(prompt.split())))
+  
+  async def sample(self, x: np.ndarray) -> np.ndarray:
+    return np.random.randint(1, self.vocab_size)
 
 
-      return output, new_state, is_finished
-    except Exception as e:
-      print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
-      return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
+    return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])
 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    state = json.loads(inference_state or "{}")
-    start_pos = state.get("start_pos", 0)
-
-    output_length = np.random.randint(1, 10)
-    output = np.random.randint(1, self.vocab_size, size=(1, output_length))
-
-    await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
-
-    is_finished = np.random.random() < 0.2
-    if is_finished:
-      output = np.array([[self.eos_token_id]])
-
-    start_pos += input_data.shape[1] + output_length
-    new_state = json.dumps({"start_pos": start_pos})
-
-    return output, new_state, is_finished
+    sequence_length = input_data.shape[0 if self.shard.is_first_layer() else 1]
+    output = np.random.random(size=(1, sequence_length, self.vocab_size if self.shard.is_last_layer() else self.hidden_size))
+    return output
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:

+ 16 - 3
exo/inference/inference_engine.py

@@ -9,12 +9,25 @@ from .shard import Shard
 
 
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    pass
+  
+  @abstractmethod
+  async def sample(self, x: np.ndarray) -> np.ndarray:
+    pass
+
+  @abstractmethod
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
     pass
     pass
+  
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> np.ndarray:
+    tokens = await self.encode(shard, prompt)
+    output_data = await self.infer_tensor(request_id, shard, tokens, inference_state)
+    return output_data 
 
 
 
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
@@ -33,4 +46,4 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
   elif inference_engine_name == "dummy":
   elif inference_engine_name == "dummy":
     from exo.inference.dummy_inference_engine import DummyInferenceEngine
     from exo.inference.dummy_inference_engine import DummyInferenceEngine
     return DummyInferenceEngine()
     return DummyInferenceEngine()
-  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")
+  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 40 - 19
exo/inference/mlx/sharded_inference_engine.py

@@ -1,15 +1,35 @@
 import numpy as np
 import numpy as np
 import mlx.core as mx
 import mlx.core as mx
+import mlx.nn as nn
 from ..inference_engine import InferenceEngine
 from ..inference_engine import InferenceEngine
-from .sharded_model import StatefulShardedModel
+from .stateful_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from ..shard import Shard
-from typing import Optional
+from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 import asyncio
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial
 from functools import partial
+def sample_logits(
+  logits: mx.array,
+  temp: float = 0.0,
+  top_p: float = 1.0,
+  logit_bias: Optional[Dict[int, float]] = None
+) -> Tuple[mx.array, float]:
+  if logit_bias:
+    indices = mx.array(list(logit_bias.keys()))
+    values = mx.array(list(logit_bias.values()))
+    logits[:, indices] += values
 
 
+  if temp == 0:
+    token = mx.argmax(logits, axis=-1)
+  else:
+    if top_p > 0 and top_p < 1.0:
+      token = top_p_sampling(logits, top_p, temp)
+    else:
+      token = mx.random.categorical(logits*(1/temp))
+
+  return token
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -17,25 +37,26 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.executor = ThreadPoolExecutor(max_workers=1)
 
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
+    y = mx.array(x)
+    logits = y[:, -1, :]
+    out = np.array(sample_logits(logits, temp=temp, top_p=top_p))
+    return out
+
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
-    if image_str:
-      image = await get_image_from_str(image_str)
-      tokenize = partial(self.tokenizer, prompt, image, return_tensors="np")
-      inputs = await loop.run_in_executor(self.executor, tokenize)
-      pixel_values = mx.array(inputs["pixel_values"])
-      input_ids = mx.array(inputs["input_ids"])
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
-    else:
-      input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return np.array(tokens)
 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def decode(self, shard: Shard, tokens) -> str:
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
+    return tokens
+    
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
+    return output_data
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:
@@ -50,5 +71,5 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
         return asyncio.run(load_shard(model_path, shard))
         return asyncio.run(load_shard(model_path, shard))
 
 
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
-      self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
       self.shard = shard
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 0 - 89
exo/inference/mlx/sharded_model.py

@@ -1,89 +0,0 @@
-from typing import Dict, Generator, Optional, Tuple
-from collections import OrderedDict
-
-import mlx.core as mx
-import mlx.nn as nn
-from mlx_lm.models.cache import make_prompt_cache
-from mlx_lm.sample_utils import top_p_sampling
-
-from ..shard import Shard
-
-
-# TODO: support a speculative model so we can parallelise compute across devices
-class StatefulShardedModel:
-  def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
-    self.shard = shard
-    self.model = model
-    self.max_kv_size = max_kv_size
-    self.max_caches = max_caches
-    self.caches = OrderedDict()
-
-  def step(
-    self,
-    request_id: str,
-    x,
-    pixel_values=None,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    def sample(logits: mx.array) -> Tuple[mx.array, float]:
-      if logit_bias:
-        indices = mx.array(list(logit_bias.keys()))
-        values = mx.array(list(logit_bias.values()))
-        logits[:, indices] += values
-
-      if temp == 0:
-        token = mx.argmax(logits, axis=-1)
-      else:
-        if top_p > 0 and top_p < 1.0:
-          token = top_p_sampling(logits, top_p, temp)
-        else:
-          token = mx.random.categorical(logits*(1/temp))
-
-      return token
-
-    y = x
-
-    if request_id not in self.caches:
-      self.init_cache(request_id)
-    else:
-      self.caches.move_to_end(request_id)
-
-    cache = self.caches[request_id]
-
-    if pixel_values is None:
-      output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache)
-    else:
-      output = self.model(y, pixel_values=pixel_values, cache=cache)
-
-    if self.shard.is_last_layer():
-      logits = output[:, -1, :]
-      y = sample(logits)
-      return y
-    else:
-      return output
-
-  def __call__(
-    self,
-    request_id: str,
-    x,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
-
-  def init_cache(self, request_id: str):
-    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
-    # if self.max_kv_size is not None:
-      # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
-      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
-    # else:
-      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
-    cache = make_prompt_cache(self.model)
-
-    if len(self.caches) >= self.max_caches:
-      self.caches.popitem(last=False)
-
-    self.caches[request_id] = cache

+ 10 - 3
exo/inference/mlx/sharded_utils.py

@@ -68,7 +68,6 @@ def load_config(model_path: Path) -> dict:
     raise
     raise
   return config
   return config
 
 
-
 def load_model_shard(
 def load_model_shard(
   model_path: Path,
   model_path: Path,
   shard: Shard,
   shard: Shard,
@@ -131,8 +130,17 @@ def load_model_shard(
 
 
   model_class, model_args_class = _get_classes(config=config)
   model_class, model_args_class = _get_classes(config=config)
 
 
+  class ShardedModel(model_class):
+    def __init__(self, args):
+      super().__init__(args)
+      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
+
+    def __call__(self, x, *args, **kwargs):
+      y = super().__call__(x[None] if self.shard.is_first_layer() else x, *args, **kwargs)
+      return y
+
   model_args = model_args_class.from_dict(config)
   model_args = model_args_class.from_dict(config)
-  model = model_class(model_args)
+  model = ShardedModel(model_args)
 
 
   if hasattr(model, "sanitize"):
   if hasattr(model, "sanitize"):
     weights = model.sanitize(weights)
     weights = model.sanitize(weights)
@@ -158,7 +166,6 @@ def load_model_shard(
   model.eval()
   model.eval()
   return model
   return model
 
 
-
 async def load_shard(
 async def load_shard(
   model_path: str,
   model_path: str,
   shard: Shard,
   shard: Shard,

+ 42 - 0
exo/inference/mlx/stateful_model.py

@@ -0,0 +1,42 @@
+from typing import Dict, Tuple
+from collections import OrderedDict
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.cache import make_prompt_cache
+
+from ..shard import Shard
+
+class StatefulModel(nn.Module):
+  def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_kv_size = max_kv_size
+    self.max_caches = max_caches
+    self.caches = OrderedDict()
+  
+  def init_cache(self, request_id: str):
+    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
+    # if self.max_kv_size is not None:
+      # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    # else:
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    cache = make_prompt_cache(self.model)
+
+    if len(self.caches) >= self.max_caches:
+      self.caches.popitem(last=False)
+
+    self.caches[request_id] = cache
+
+  def __call__(self, x, request_id: str):
+    if request_id not in self.caches:
+      self.init_cache(request_id)
+    else:
+      self.caches.move_to_end(request_id)
+
+    cache = self.caches[request_id]
+
+    y = self.model(x, cache=cache)
+    return y
+    

+ 4 - 4
exo/inference/mlx/test_sharded_llama.py

@@ -1,5 +1,5 @@
 import mlx.core as mx
 import mlx.core as mx
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
@@ -12,9 +12,9 @@ full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Ins
 model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
 model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
 model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
 model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
 
 
-full = StatefulShardedModel(shard_full, full_model_shard)
-m1 = StatefulShardedModel(shard1, model_shard1)
-m2 = StatefulShardedModel(shard2, model_shard2)
+full = StatefulModel(shard_full, full_model_shard)
+m1 = StatefulModel(shard1, model_shard1)
+m2 = StatefulModel(shard2, model_shard2)
 
 
 prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
 prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
 prompt_tokens = mx.array(full_tokenizer.encode(prompt))
 prompt_tokens = mx.array(full_tokenizer.encode(prompt))

+ 1 - 1
exo/inference/mlx/test_sharded_llava.py

@@ -7,7 +7,7 @@ from io import BytesIO
 import mlx.core as mx
 import mlx.core as mx
 from mlx_lm.models.cache import KVCache
 from mlx_lm.models.cache import KVCache
 
 
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 

+ 23 - 30
exo/inference/tinygrad/inference.py

@@ -1,7 +1,7 @@
 from pathlib import Path
 from pathlib import Path
 import json
 import json
 import os
 import os
-from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
+from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
 from tinygrad.nn.state import load_state_dict
@@ -12,6 +12,7 @@ import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
+from .stateful_model import StatefulModel
 import asyncio
 import asyncio
 
 
 Tensor.no_grad = True
 Tensor.no_grad = True
@@ -58,44 +59,34 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True
   return model
   return model
 
 
-
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.executor = ThreadPoolExecutor(max_workers=1)
 
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
-    await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
+  async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
+    logits = x[:, -1, :]
+    def sample_wrapper():
+      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
+    out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
+    return out.numpy()
 
 
-    if h.shape == (1,):
-      start_pos += len(toks)
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      n_captured_toks = len(toks)
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return np.array(tokens)
+  
+  async def decode(self, shard: Shard, tokens) -> str:
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
+    return tokens
 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
-
-    if h.shape == (1,):
-      start_pos += n_captured_toks
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
+    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, request_id).realize())
+    return output_data.numpy()
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:
@@ -104,9 +95,11 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     model_path = await self.shard_downloader.ensure_shard(shard)
     model_path = await self.shard_downloader.ensure_shard(shard)
 
 
     if self.shard != shard:
     if self.shard != shard:
+      loop = asyncio.get_running_loop()
       parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
       parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
-      self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
+      model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
 
 
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.shard = shard
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 39 - 32
exo/inference/tinygrad/models/llama.py

@@ -1,6 +1,7 @@
-from typing import Tuple, Union, Optional, Dict, Any
+from typing import Tuple, Union, Optional, Dict, Any, List
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
+from collections import OrderedDict
 
 
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
@@ -47,7 +48,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
 
 
-
 class Attention:
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
     self.n_heads = n_heads
@@ -61,7 +61,7 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
     self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
     self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
     self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None) -> Tensor:
     if getenv("WQKV"):
     if getenv("WQKV"):
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       xqkv = x @ self.wqkv.T
       xqkv = x @ self.wqkv.T
@@ -76,19 +76,16 @@ class Attention:
     xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
     xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
     bsz, seqlen, _, _ = xq.shape
     bsz, seqlen, _, _ = xq.shape
 
 
-    # create kv cache
-    if not hasattr(self, "cache_kv"):
-      self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
-      if isinstance(x.device, tuple):
-        # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-        self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
-
-    # update the cache
-    assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+    if cache is not None:
+      # update the cache
+      assert xk.dtype == xv.dtype == cache.dtype, f"{xk.dtype=}, {xv.dtype=}, {cache.dtype=}"
+      cache.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
 
-    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+      keys = cache[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+      values = cache[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+    else:
+      keys = xk
+      values = xv
 
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -114,13 +111,13 @@ class TransformerBlock:
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
-    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None):
+    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask, cache=cache)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
 
 
 
 # standard openai sampling
 # standard openai sampling
-def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
+def sample_logits(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
   assert logits.ndim == 1, "only works on 1d tensors"
   assert 0 <= p <= 1, "p must be between 0 and 1"
   assert 0 <= p <= 1, "p must be between 0 and 1"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
@@ -189,7 +186,7 @@ class Transformer:
     jit=True,
     jit=True,
     feed_forward=FeedForward,
     feed_forward=FeedForward,
     rope_scaling: Optional[Dict[str, float]] = None,
     rope_scaling: Optional[Dict[str, float]] = None,
-    tie_word_embeddings=False
+    tie_word_embeddings=False,
   ):
   ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
@@ -202,31 +199,38 @@ class Transformer:
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
     self.shard = shard
 
 
-  def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
+  def forward(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
     seqlen = x.shape[1]
     seqlen = x.shape[1]
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
 
-    if self.shard.is_first_layer():
-      h = self.tok_embeddings(x)
-    else:
-      h = x
+    h = x
 
 
-    for i in range(self.shard.start_layer, self.shard.end_layer + 1):
+    if cache is None:
+      cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]  
+    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
       layer = self.layers[i]
       layer = self.layers[i]
-      h = layer(h, start_pos, freqs_cis, mask)
+      h = layer(h, start_pos, freqs_cis, mask, cache=c)
 
 
     if self.shard.is_last_layer():
     if self.shard.is_last_layer():
-      logits = self.output(self.norm(h)).float()[:, -1, :]
-      return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+      logits = self.output(self.norm(h)).float().realize()
+      return logits
     else:
     else:
       return h
       return h
 
 
-  def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
+  def embed(self, inputs: Tensor):
+    if self.shard.is_first_layer():
+      h = self.tok_embeddings(inputs)
+    else:
+      h = inputs
+    return h
+
+  def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
+    h = self.embed(x)
     if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
     if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
-    return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
+      return self.forward_jit(h, Variable("start_pos", 0, self.max_context).bind(start_pos), cache=cache)
+    return self.forward(h, start_pos, cache=cache)
 
 
 
 
 # *** helpers ***
 # *** helpers ***
@@ -260,7 +264,10 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
         v = permute(v, n_heads)
         v = permute(v, n_heads)
       elif "k_proj" in k:
       elif "k_proj" in k:
         v = permute(v, n_kv_heads)
         v = permute(v, n_kv_heads)
-    sd[keymap[k]] = v
+    if k in keymap:
+      sd[keymap[k]] = v
+    else:
+      sd[k] = v
   return sd
   return sd
 
 
 
 

+ 34 - 0
exo/inference/tinygrad/stateful_model.py

@@ -0,0 +1,34 @@
+from tinygrad import Tensor, Variable 
+from collections import OrderedDict
+
+def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
+  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
+  if isinstance(x.device, tuple):
+    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
+    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+  return cache_kv.realize()
+
+class StatefulModel:
+  def __init__(self, model, max_caches: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_caches = max_caches
+    self.caches = OrderedDict()
+ 
+  def init_cache(self, x: Tensor, request_id: str):
+    cache = [create_kv_cache(x, self.model.layers[i].attention.max_context, self.model.layers[i].attention.n_kv_heads, self.model.layers[i].attention.head_dim) for i in range(self.model.shard.start_layer, self.model.shard.end_layer + 1)]
+    if len(self.caches) >= self.max_caches:
+      self.caches.popitem(last=False)
+
+    self.caches[request_id] = cache
+
+  def __call__(self, x: Tensor, start_pos: Variable, request_id: str): 
+    h = self.model.embed(x)
+    if request_id not in self.caches:
+      self.init_cache(h, request_id)
+    else:
+      self.caches.move_to_end(request_id)
+    if h.shape[0:2] == (1, 1) and self.model.forward_jit is not None:
+      return self.model.forward_jit(h, Variable("start_pos", 0, self.model.max_context).bind(start_pos), cache=self.caches[request_id])
+    return self.model.forward(h, start_pos, cache=self.caches[request_id])
+

+ 2 - 2
exo/main.py

@@ -189,7 +189,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
 
 
   try:
   try:
     print(f"Processing prompt: {prompt}")
     print(f"Processing prompt: {prompt}")
-    await node.process_prompt(shard, prompt, None, request_id=request_id)
+    await node.process_prompt(shard, prompt, request_id=request_id)
 
 
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
 
 
@@ -238,4 +238,4 @@ def run():
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-  run()
+  run()

+ 1 - 2
exo/networking/grpc/grpc_peer_handle.py

@@ -63,10 +63,9 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
         traceback.print_exc()
       return False
       return False
 
 
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
       prompt=prompt,
-      image_str=image_str,
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         model_id=shard.model_id,
         start_layer=shard.start_layer,
         start_layer=shard.start_layer,

+ 2 - 3
exo/networking/grpc/grpc_server.py

@@ -49,10 +49,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       n_layers=request.shard.n_layers,
       n_layers=request.shard.n_layers,
     )
     )
     prompt = request.prompt
     prompt = request.prompt
-    image_str = request.image_str
     request_id = request.request_id
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, image_str, request_id)
-    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
+    result = await self.node.process_prompt(shard, prompt, request_id)
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
 

+ 3 - 4
exo/networking/grpc/node_service.proto

@@ -22,9 +22,8 @@ message Shard {
 message PromptRequest {
 message PromptRequest {
   Shard shard = 1;
   Shard shard = 1;
   string prompt = 2;
   string prompt = 2;
-  optional string image_str = 3;
-  optional string request_id = 4;
-  optional string inference_state = 5;
+  optional string request_id = 3;
+  optional string inference_state = 4;
 }
 }
 
 
 message TensorRequest {
 message TensorRequest {
@@ -93,4 +92,4 @@ message HealthCheckResponse {
   bool is_healthy = 1;
   bool is_healthy = 1;
 }
 }
 
 
-message Empty {}
+message Empty {}

File diff suppressed because it is too large
+ 0 - 1
exo/networking/grpc/node_service_pb2.py


+ 314 - 263
exo/networking/grpc/node_service_pb2_grpc.py

@@ -12,298 +12,349 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 _version_not_supported = False
 
 
 try:
 try:
-  from grpc._utilities import first_version_is_lower
-  _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+    from grpc._utilities import first_version_is_lower
+    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
 except ImportError:
 except ImportError:
-  _version_not_supported = True
+    _version_not_supported = True
 
 
 if _version_not_supported:
 if _version_not_supported:
-  warnings.warn(
-    f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
-    f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
-    f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
-  )
+    warnings.warn(
+        f'The grpc package installed is at version {GRPC_VERSION},'
+        + f' but the generated code in node_service_pb2_grpc.py depends on'
+        + f' grpcio>={GRPC_GENERATED_VERSION}.'
+        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
+        RuntimeWarning
+    )
 
 
 
 
 class NodeServiceStub(object):
 class NodeServiceStub(object):
-  """Missing associated documentation comment in .proto file."""
-  def __init__(self, channel):
-    """Constructor.
+    """Missing associated documentation comment in .proto file."""
+
+    def __init__(self, channel):
+        """Constructor.
 
 
         Args:
         Args:
             channel: A grpc.Channel.
             channel: A grpc.Channel.
         """
         """
-    self.SendPrompt = channel.unary_unary(
-      '/node_service.NodeService/SendPrompt',
-      request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Tensor.FromString,
-      _registered_method=True
-    )
-    self.SendTensor = channel.unary_unary(
-      '/node_service.NodeService/SendTensor',
-      request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Tensor.FromString,
-      _registered_method=True
-    )
-    self.GetInferenceResult = channel.unary_unary(
-      '/node_service.NodeService/GetInferenceResult',
-      request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-      response_deserializer=node__service__pb2.InferenceResult.FromString,
-      _registered_method=True
-    )
-    self.CollectTopology = channel.unary_unary(
-      '/node_service.NodeService/CollectTopology',
-      request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Topology.FromString,
-      _registered_method=True
-    )
-    self.SendResult = channel.unary_unary(
-      '/node_service.NodeService/SendResult',
-      request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Empty.FromString,
-      _registered_method=True
-    )
-    self.SendOpaqueStatus = channel.unary_unary(
-      '/node_service.NodeService/SendOpaqueStatus',
-      request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-      response_deserializer=node__service__pb2.Empty.FromString,
-      _registered_method=True
-    )
-    self.HealthCheck = channel.unary_unary(
-      '/node_service.NodeService/HealthCheck',
-      request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
-      response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
-      _registered_method=True
-    )
+        self.SendPrompt = channel.unary_unary(
+                '/node_service.NodeService/SendPrompt',
+                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.SendTensor = channel.unary_unary(
+                '/node_service.NodeService/SendTensor',
+                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.GetInferenceResult = channel.unary_unary(
+                '/node_service.NodeService/GetInferenceResult',
+                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.InferenceResult.FromString,
+                _registered_method=True)
+        self.CollectTopology = channel.unary_unary(
+                '/node_service.NodeService/CollectTopology',
+                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Topology.FromString,
+                _registered_method=True)
+        self.SendResult = channel.unary_unary(
+                '/node_service.NodeService/SendResult',
+                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.SendOpaqueStatus = channel.unary_unary(
+                '/node_service.NodeService/SendOpaqueStatus',
+                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.HealthCheck = channel.unary_unary(
+                '/node_service.NodeService/HealthCheck',
+                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
+                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
+                _registered_method=True)
 
 
 
 
 class NodeServiceServicer(object):
 class NodeServiceServicer(object):
-  """Missing associated documentation comment in .proto file."""
-  def SendPrompt(self, request, context):
     """Missing associated documentation comment in .proto file."""
     """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
 
 
-  def SendTensor(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendPrompt(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
 
-  def GetInferenceResult(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendTensor(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
 
-  def CollectTopology(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def GetInferenceResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
 
-  def SendResult(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def CollectTopology(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
 
-  def SendOpaqueStatus(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
 
-  def HealthCheck(self, request, context):
-    """Missing associated documentation comment in .proto file."""
-    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-    context.set_details('Method not implemented!')
-    raise NotImplementedError('Method not implemented!')
+    def SendOpaqueStatus(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def HealthCheck(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
 
 
 
 
 def add_NodeServiceServicer_to_server(servicer, server):
 def add_NodeServiceServicer_to_server(servicer, server):
-  rpc_method_handlers = {
-    'SendPrompt':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendPrompt,
-        request_deserializer=node__service__pb2.PromptRequest.FromString,
-        response_serializer=node__service__pb2.Tensor.SerializeToString,
-      ),
-    'SendTensor':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendTensor,
-        request_deserializer=node__service__pb2.TensorRequest.FromString,
-        response_serializer=node__service__pb2.Tensor.SerializeToString,
-      ),
-    'GetInferenceResult':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.GetInferenceResult,
-        request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-        response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-      ),
-    'CollectTopology':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.CollectTopology,
-        request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-        response_serializer=node__service__pb2.Topology.SerializeToString,
-      ),
-    'SendResult':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendResult,
-        request_deserializer=node__service__pb2.SendResultRequest.FromString,
-        response_serializer=node__service__pb2.Empty.SerializeToString,
-      ),
-    'SendOpaqueStatus':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.SendOpaqueStatus,
-        request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-        response_serializer=node__service__pb2.Empty.SerializeToString,
-      ),
-    'HealthCheck':
-      grpc.unary_unary_rpc_method_handler(
-        servicer.HealthCheck,
-        request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
-        response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
-      ),
-  }
-  generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
-  server.add_generic_rpc_handlers((generic_handler,))
-  server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+    rpc_method_handlers = {
+            'SendPrompt': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendPrompt,
+                    request_deserializer=node__service__pb2.PromptRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'SendTensor': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendTensor,
+                    request_deserializer=node__service__pb2.TensorRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.GetInferenceResult,
+                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+            ),
+            'CollectTopology': grpc.unary_unary_rpc_method_handler(
+                    servicer.CollectTopology,
+                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+                    response_serializer=node__service__pb2.Topology.SerializeToString,
+            ),
+            'SendResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendResult,
+                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendOpaqueStatus,
+                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'HealthCheck': grpc.unary_unary_rpc_method_handler(
+                    servicer.HealthCheck,
+                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
+                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
+            ),
+    }
+    generic_handler = grpc.method_handlers_generic_handler(
+            'node_service.NodeService', rpc_method_handlers)
+    server.add_generic_rpc_handlers((generic_handler,))
+    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
 
 
-# This class is part of an EXPERIMENTAL API.
+ # This class is part of an EXPERIMENTAL API.
 class NodeService(object):
 class NodeService(object):
-  """Missing associated documentation comment in .proto file."""
-  @staticmethod
-  def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendPrompt',
-      node__service__pb2.PromptRequest.SerializeToString,
-      node__service__pb2.Tensor.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    """Missing associated documentation comment in .proto file."""
 
 
-  @staticmethod
-  def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendTensor',
-      node__service__pb2.TensorRequest.SerializeToString,
-      node__service__pb2.Tensor.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendPrompt(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendPrompt',
+            node__service__pb2.PromptRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
 
-  @staticmethod
-  def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/GetInferenceResult',
-      node__service__pb2.GetInferenceResultRequest.SerializeToString,
-      node__service__pb2.InferenceResult.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendTensor(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendTensor',
+            node__service__pb2.TensorRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
 
-  @staticmethod
-  def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/CollectTopology',
-      node__service__pb2.CollectTopologyRequest.SerializeToString,
-      node__service__pb2.Topology.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def GetInferenceResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/GetInferenceResult',
+            node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            node__service__pb2.InferenceResult.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
 
-  @staticmethod
-  def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendResult',
-      node__service__pb2.SendResultRequest.SerializeToString,
-      node__service__pb2.Empty.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def CollectTopology(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/CollectTopology',
+            node__service__pb2.CollectTopologyRequest.SerializeToString,
+            node__service__pb2.Topology.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
 
-  @staticmethod
-  def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/SendOpaqueStatus',
-      node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-      node__service__pb2.Empty.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendResult',
+            node__service__pb2.SendResultRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
 
 
-  @staticmethod
-  def HealthCheck(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
-    return grpc.experimental.unary_unary(
-      request,
-      target,
-      '/node_service.NodeService/HealthCheck',
-      node__service__pb2.HealthCheckRequest.SerializeToString,
-      node__service__pb2.HealthCheckResponse.FromString,
-      options,
-      channel_credentials,
-      insecure,
-      call_credentials,
-      compression,
-      wait_for_ready,
-      timeout,
-      metadata,
-      _registered_method=True
-    )
+    @staticmethod
+    def SendOpaqueStatus(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendOpaqueStatus',
+            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def HealthCheck(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/HealthCheck',
+            node__service__pb2.HealthCheckRequest.SerializeToString,
+            node__service__pb2.HealthCheckResponse.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)

+ 1 - 1
exo/networking/peer_handle.py

@@ -36,7 +36,7 @@ class PeerHandle(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod

+ 1 - 1
exo/orchestration/node.py

@@ -16,7 +16,7 @@ class Node(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod

+ 92 - 76
exo/orchestration/standard_node.py

@@ -18,7 +18,6 @@ from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 
 
-
 class StandardNode(Node):
 class StandardNode(Node):
   def __init__(
   def __init__(
     self,
     self,
@@ -40,6 +39,8 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+    self.buffered_logits: Dict[str, List[np.ndarray]] = {}
+    self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
     self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
@@ -100,8 +101,55 @@ class StandardNode(Node):
 
 
   def get_topology_inference_engines(self) -> List[List[str]]:
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
     return self.topology_inference_engines_pool
+  
+  async def encode_prompt(self, shard: Shard, prompt):
+    toks = await self.inference_engine.encode(shard, prompt)
+    return toks
+  
+  async def process_result(
+    self,
+    shard,
+    result: np.ndarray,
+    request_id: Optional[str] = None,
+    inference_state: Optional[str] = None,
+  ):
+    if request_id not in self.buffered_token_output:
+      self.buffered_token_output[request_id] = ([], False)
+    
+    if request_id not in self.buffered_logits:
+      self.buffered_logits[request_id] = []
+
+    self.buffered_logits[request_id] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]
+
+    if shard.is_last_layer():
+      result = await self.inference_engine.sample(result)
+      inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id]) + 1})
+    
+    await self.inference_engine.ensure_shard(shard)
+    is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+
+    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
+
+    if result.size == 1:  # we got a new token out
+      self.buffered_token_output[request_id][0].append(result.item())
+      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+    
+    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
 
 
-  async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    if is_finished:
+      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+    else:
+      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
+
+    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+
+  async def process_prompt(
+    self,
+    base_shard: Shard,
+    prompt: str,
+    request_id: Optional[str] = None,
+    inference_state: Optional[str] = None
+  ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
@@ -113,14 +161,13 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
           "prompt": prompt,
-          "image_str": image_str,
           "inference_state": inference_state,
           "inference_state": inference_state,
           "request_id": request_id,
           "request_id": request_id,
         }),
         }),
       )
       )
     )
     )
     start_time = time.perf_counter_ns()
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
+    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
     end_time = time.perf_counter_ns()
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
     asyncio.create_task(
@@ -133,7 +180,6 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
           "prompt": prompt,
-          "image_str": image_str,
           "inference_state": inference_state,
           "inference_state": inference_state,
           "request_id": request_id,
           "request_id": request_id,
           "elapsed_time_ns": elapsed_time_ns,
           "elapsed_time_ns": elapsed_time_ns,
@@ -143,35 +189,20 @@ class StandardNode(Node):
     )
     )
     return resp
     return resp
 
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
     if request_id is None:
       request_id = str(uuid.uuid4())
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
 
 
-    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
     if shard.start_layer != 0:
     if shard.start_layer != 0:
-      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
-      await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state)
-      return
-
-    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
-    is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-    if is_finished:
-      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-    if result.size == 1:
-      self.buffered_token_output[request_id][0].append(result.item())
-      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-
-    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-    if not is_finished:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state))
-
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+      await self.forward_to_next_shard(shard, prompt, request_id, inference_state=inference_state)
+      return None
+    else:
+      result = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
+      ret = await self.process_result(shard, result, request_id, inference_state=inference_state) 
+      return result
 
 
   async def process_tensor(
   async def process_tensor(
     self,
     self,
@@ -227,27 +258,13 @@ class StandardNode(Node):
   ) -> Optional[np.ndarray]:
   ) -> Optional[np.ndarray]:
     if request_id is None:
     if request_id is None:
       request_id = str(uuid.uuid4())
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
 
 
+    if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
     try:
-      if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-      result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
-      is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      if is_finished:
-        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-      if result.size == 1:  # we got a new token out
-        self.buffered_token_output[request_id][0].append(result.item())
-        self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-      if not is_finished:
-        asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
-
-      return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+      result = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
+      ret = await self.process_result(shard, result, request_id, inference_state=inference_state) 
+      return ret
     except Exception as e:
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
       traceback.print_exc()
@@ -258,49 +275,48 @@ class StandardNode(Node):
     base_shard: Shard,
     base_shard: Shard,
     tensor_or_prompt: Union[np.ndarray, str],
     tensor_or_prompt: Union[np.ndarray, str],
     request_id: str,
     request_id: str,
-    image_str: Optional[str] = None,
     inference_state: Optional[str] = None,
     inference_state: Optional[str] = None,
   ) -> None:
   ) -> None:
     if not self.partitioning_strategy:
     if not self.partitioning_strategy:
       if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
       if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
       return
       return
-    shard = self.get_current_shard(base_shard)
 
 
-    partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
-    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
+    next_partition_index = self.get_partition_index(offset = 1)
     if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
     if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
-    if current_partition_index is not None:
-      next_partition_index = (current_partition_index+1) % len(partitions)
-      next_partition: Partition = partitions[next_partition_index]
-      next_shard = shards[next_partition_index]
+    if next_partition_index is not None:
+      target_id = self.partitioning_strategy.partition(self.topology)[next_partition_index].node_id
+      next_shard = self.get_current_shard(base_shard, next_partition_index)
       if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
       if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
-
-      if next_partition.node_id == self.id:
-        if isinstance(tensor_or_prompt, np.ndarray):
-          await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+      is_tensor = isinstance(tensor_or_prompt, np.ndarray)
+      if target_id == self.id:
+        if is_tensor:
+          await self.process_tensor(next_shard, tensor_or_prompt, request_id, inference_state=inference_state)
         else:
         else:
-          await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
-        return
-
-      target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
-      if not target_peer:
-        raise ValueError(f"Peer for {next_partition} not found")
-
-      if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
-
-      if isinstance(tensor_or_prompt, np.ndarray):
-        await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+          await self.process_prompt(next_shard, tensor_or_prompt, request_id, inference_state=inference_state)
       else:
       else:
-        await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
+        target_peer = next((p for p in self.peers if p.id() == target_id), None)
+        if not target_peer:
+          raise ValueError(f"Peer for {next_partition} not found")
+        
+        if is_tensor:
+          if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
+          await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+        else:
+          await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
 
 
-  def get_current_shard(self, base_shard: Shard) -> Shard:
+  def get_partition_index(self, offset: int = 0):
     partitions = self.partitioning_strategy.partition(self.topology)
     partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     if current_partition_index is None:
     if current_partition_index is None:
       raise ValueError(f"No current partition found for node: {self.id}")
       raise ValueError(f"No current partition found for node: {self.id}")
-    return shards[current_partition_index]
+    return (current_partition_index + offset) % len(partitions)
+
+  def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
+    if index is None:
+      index = self.get_partition_index()
+    partitions = self.partitioning_strategy.partition(self.topology)
+    shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
+    return shards[index]
 
 
   async def update_peers(self, wait_for_peers: int = 0) -> bool:
   async def update_peers(self, wait_for_peers: int = 0) -> bool:
     next_peers = await self.discovery.discover_peers(wait_for_peers)
     next_peers = await self.discovery.discover_peers(wait_for_peers)
@@ -428,7 +444,7 @@ class StandardNode(Node):
   def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
   def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
     if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
     if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
     self.on_token.trigger_all(request_id, tokens, is_finished)
     self.on_token.trigger_all(request_id, tokens, is_finished)
-
+  
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
     async def send_result_to_peer(peer):
     async def send_result_to_peer(peer):
       try:
       try:

Some files were not shown because too many files changed in this diff