Browse Source

Merge pull request #449 from blindcrone/remove_inference_state

Remove inference state
Alex Cheema 5 months ago
parent
commit
11c5e9e1af

+ 11 - 12
exo/inference/debug_inference_engine.py

@@ -13,32 +13,31 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
 
   prompt = "In a single word only, what is the last name of the president of the United States? "
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
+
+  next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
-    input_data=resp_full,
-    inference_state=inference_state_full,
+    input_data=token_full,
   )
 
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp1,
-    inference_state=inference_state_1,
   )
-  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+  token2 = await inference_engine_2.sample(resp2)
+  resp3 = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
-    input_data=resp2,
-    inference_state=inference_state_2,
+    input_data=token2,
   )
-  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,
-    inference_state=inference_state_3,
   )
 
   print(f"{resp2=}")

+ 1 - 1
exo/inference/dummy_inference_engine.py

@@ -28,7 +28,7 @@ class DummyInferenceEngine(InferenceEngine):
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     sequence_length = input_data.shape[0 if self.shard.is_first_layer() else 1]
     output = np.random.random(size=(1, sequence_length, self.vocab_size if self.shard.is_last_layer() else self.hidden_size))

+ 3 - 3
exo/inference/inference_engine.py

@@ -21,12 +21,12 @@ class InferenceEngine(ABC):
     pass
 
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     pass
   
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> np.ndarray:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
     tokens = await self.encode(shard, prompt)
-    output_data = await self.infer_tensor(request_id, shard, tokens, inference_state)
+    output_data = await self.infer_tensor(request_id, shard, tokens)
     return output_data 
 
 

+ 1 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -53,7 +53,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
     
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
     return output_data

+ 2 - 3
exo/inference/tinygrad/inference.py

@@ -82,10 +82,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, request_id).realize())
+    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
     return output_data.numpy()
 
   async def ensure_shard(self, shard: Shard):

+ 7 - 4
exo/inference/tinygrad/models/llama.py

@@ -196,10 +196,10 @@ class Transformer:
       self.output.weight = self.tok_embeddings.weight
     self.max_context = max_context
     self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
-    self.forward_jit = TinyJit(self.forward) if jit else None
+    self.forward_jit = TinyJit(self.forward_base) if jit else None
     self.shard = shard
 
-  def forward(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
+  def forward_base(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
     seqlen = x.shape[1]
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
@@ -225,11 +225,14 @@ class Transformer:
       h = inputs
     return h
 
+  def forward(self, x: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
+    if x.shape[0:2] == (1, 1) and self.forward_jit is not None:
+      return self.forward_jit(x, Variable("start_pos", 0, self.max_context).bind(start_pos), cache=cache)
+    return self.forward_base(x, start_pos, cache=cache)
+
   def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
     # TODO: better way to handle the first call v.s. the rest?
     h = self.embed(x)
-    if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(h, Variable("start_pos", 0, self.max_context).bind(start_pos), cache=cache)
     return self.forward(h, start_pos, cache=cache)
 
 

+ 20 - 12
exo/inference/tinygrad/stateful_model.py

@@ -1,5 +1,6 @@
 from tinygrad import Tensor, Variable 
 from collections import OrderedDict
+from typing import List
 
 def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
   cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
@@ -8,27 +9,34 @@ def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int)
     cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
   return cache_kv.realize()
 
+class ModelState:
+  cache: List[Tensor]
+  start: int 
+  def __init__(self, cache: List[Tensor], start: int = 0):
+    self.cache = cache
+    self.start = start
+
 class StatefulModel:
-  def __init__(self, model, max_caches: int = 2):
+  def __init__(self, model, max_states: int = 2):
     super().__init__()
     self.model = model
-    self.max_caches = max_caches
-    self.caches = OrderedDict()
+    self.max_states = max_states
+    self.states = OrderedDict()
  
   def init_cache(self, x: Tensor, request_id: str):
     cache = [create_kv_cache(x, self.model.layers[i].attention.max_context, self.model.layers[i].attention.n_kv_heads, self.model.layers[i].attention.head_dim) for i in range(self.model.shard.start_layer, self.model.shard.end_layer + 1)]
-    if len(self.caches) >= self.max_caches:
-      self.caches.popitem(last=False)
+    if len(self.states) >= self.max_states:
+      self.states.popitem(last=False)
 
-    self.caches[request_id] = cache
+    self.states[request_id] = ModelState(cache)
 
-  def __call__(self, x: Tensor, start_pos: Variable, request_id: str): 
+  def __call__(self, x: Tensor, request_id: str): 
     h = self.model.embed(x)
-    if request_id not in self.caches:
+    if request_id not in self.states:
       self.init_cache(h, request_id)
     else:
-      self.caches.move_to_end(request_id)
-    if h.shape[0:2] == (1, 1) and self.model.forward_jit is not None:
-      return self.model.forward_jit(h, Variable("start_pos", 0, self.model.max_context).bind(start_pos), cache=self.caches[request_id])
-    return self.model.forward(h, start_pos, cache=self.caches[request_id])
+      self.states.move_to_end(request_id)
+    out = self.model.forward(h, self.states[request_id].start, cache=self.states[request_id].cache)
+    self.states[request_id].start += h.shape[1]
+    return out
 

+ 2 - 4
exo/networking/grpc/grpc_peer_handle.py

@@ -67,7 +67,7 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
       return False
 
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
       shard=node_service_pb2.Shard(
@@ -77,7 +77,6 @@ class GRPCPeerHandle(PeerHandle):
         n_layers=shard.n_layers,
       ),
       request_id=request_id,
-      inference_state=inference_state,
     )
     response = await self.stub.SendPrompt(request)
 
@@ -86,7 +85,7 @@ class GRPCPeerHandle(PeerHandle):
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
-  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
@@ -96,7 +95,6 @@ class GRPCPeerHandle(PeerHandle):
       ),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       request_id=request_id,
-      inference_state=inference_state,
     )
     response = await self.stub.SendTensor(request)
 

+ 1 - 2
exo/networking/grpc/grpc_server.py

@@ -64,9 +64,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
-    inference_state = request.inference_state
 
-    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
+    result = await self.node.process_tensor(shard, tensor, request_id)
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()

+ 0 - 2
exo/networking/grpc/node_service.proto

@@ -23,14 +23,12 @@ message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
   optional string request_id = 3;
-  optional string inference_state = 4;
 }
 
 message TensorRequest {
   Shard shard = 1;
   Tensor tensor = 2;
   optional string request_id = 3;
-  optional string inference_state = 4;
 }
 
 message GetInferenceResultRequest {

File diff suppressed because it is too large
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 2 - 2
exo/networking/peer_handle.py

@@ -36,11 +36,11 @@ class PeerHandle(ABC):
     pass
 
   @abstractmethod
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
     pass
 
   @abstractmethod
-  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
     pass
 
   @abstractmethod

+ 2 - 2
exo/orchestration/node.py

@@ -16,11 +16,11 @@ class Node(ABC):
     pass
 
   @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod
-  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod

+ 13 - 22
exo/orchestration/standard_node.py

@@ -111,7 +111,6 @@ class StandardNode(Node):
     shard,
     result: np.ndarray,
     request_id: Optional[str] = None,
-    inference_state: Optional[str] = None,
   ):
     if request_id not in self.buffered_token_output:
       self.buffered_token_output[request_id] = ([], False)
@@ -123,7 +122,6 @@ class StandardNode(Node):
 
     if shard.is_last_layer():
       result = await self.inference_engine.sample(result)
-      inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id]) + 1})
     
     await self.inference_engine.ensure_shard(shard)
     is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
@@ -139,7 +137,7 @@ class StandardNode(Node):
     if is_finished:
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
     else:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
+      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
 
     return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
 
@@ -148,7 +146,6 @@ class StandardNode(Node):
     base_shard: Shard,
     prompt: str,
     request_id: Optional[str] = None,
-    inference_state: Optional[str] = None
   ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
@@ -161,13 +158,12 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
-          "inference_state": inference_state,
           "request_id": request_id,
         }),
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
+    resp = await self._process_prompt(base_shard, prompt, request_id)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -180,7 +176,6 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
-          "inference_state": inference_state,
           "request_id": request_id,
           "elapsed_time_ns": elapsed_time_ns,
           "result_size": resp.size if resp is not None else 0,
@@ -189,7 +184,7 @@ class StandardNode(Node):
     )
     return resp
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
@@ -197,11 +192,11 @@ class StandardNode(Node):
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
     if shard.start_layer != 0:
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-      await self.forward_to_next_shard(shard, prompt, request_id, inference_state=inference_state)
+      await self.forward_to_next_shard(shard, prompt, request_id)
       return None
     else:
-      result = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
-      ret = await self.process_result(shard, result, request_id, inference_state=inference_state) 
+      result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
+      ret = await self.process_result(shard, result, request_id) 
       return result
 
   async def process_tensor(
@@ -209,7 +204,6 @@ class StandardNode(Node):
     base_shard: Shard,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
-    inference_state: Optional[str] = None,
   ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
@@ -224,12 +218,11 @@ class StandardNode(Node):
           "tensor_size": tensor.size,
           "tensor_shape": tensor.shape,
           "request_id": request_id,
-          "inference_state": inference_state,
         }),
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_tensor(shard, tensor, request_id, inference_state)
+    resp = await self._process_tensor(shard, tensor, request_id)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -254,7 +247,6 @@ class StandardNode(Node):
     base_shard: Shard,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
-    inference_state: Optional[str] = None,
   ) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
@@ -262,8 +254,8 @@ class StandardNode(Node):
 
     if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
-      result = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
-      ret = await self.process_result(shard, result, request_id, inference_state=inference_state) 
+      result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
+      ret = await self.process_result(shard, result, request_id) 
       return ret
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
@@ -275,7 +267,6 @@ class StandardNode(Node):
     base_shard: Shard,
     tensor_or_prompt: Union[np.ndarray, str],
     request_id: str,
-    inference_state: Optional[str] = None,
   ) -> None:
     if not self.partitioning_strategy:
       if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
@@ -290,18 +281,18 @@ class StandardNode(Node):
       is_tensor = isinstance(tensor_or_prompt, np.ndarray)
       if target_id == self.id:
         if is_tensor:
-          await self.process_tensor(next_shard, tensor_or_prompt, request_id, inference_state=inference_state)
+          await self.process_tensor(next_shard, tensor_or_prompt, request_id)
         else:
-          await self.process_prompt(next_shard, tensor_or_prompt, request_id, inference_state=inference_state)
+          await self.process_prompt(next_shard, tensor_or_prompt, request_id)
       else:
         target_peer = next((p for p in self.peers if p.id() == target_id), None)
         if not target_peer:
           raise ValueError(f"Peer for {next_partition_index} not found")
         if is_tensor:
           if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
-          await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+          await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id)
         else:
-          await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+          await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id)
 
   def get_partition_index(self, offset: int = 0):
     partitions = self.partitioning_strategy.partition(self.topology)

Some files were not shown because too many files changed in this diff