Browse Source

Refactors to simplify messaging and properly batch inputs

Nel Nibcord 7 months ago
parent
commit
8f78c7819e

+ 2 - 1
exo/inference/inference_engine.py

@@ -26,7 +26,8 @@ class InferenceEngine(ABC):
   
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
     tokens = await self.encode(shard, prompt)
-    output_data = await self.infer_tensor(request_id, shard, tokens)
+    x = tokens.reshape(1, -1)
+    output_data = await self.infer_tensor(request_id, shard, x)
     return output_data 
 
 inference_engine_classes = {

+ 1 - 1
exo/inference/mlx/sharded_utils.py

@@ -137,7 +137,7 @@ def load_model_shard(
       self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
 
     def __call__(self, x, *args, **kwargs):
-      y = super().__call__(x[None] if self.shard.is_first_layer() else x, *args, **kwargs)
+      y = super().__call__(x, *args, **kwargs)
       return y
 
   model_args = model_args_class.from_dict(config)

+ 55 - 56
exo/orchestration/standard_node.py

@@ -102,11 +102,7 @@ class StandardNode(Node):
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
   
-  async def encode_prompt(self, shard: Shard, prompt):
-    toks = await self.inference_engine.encode(shard, prompt)
-    return toks
-  
-  async def process_result(
+  async def process_inference_result(
     self,
     shard,
     result: np.ndarray,
@@ -114,32 +110,24 @@ class StandardNode(Node):
   ):
     if request_id not in self.buffered_token_output:
       self.buffered_token_output[request_id] = ([], False)
-    
-    if request_id not in self.buffered_logits:
-      self.buffered_logits[request_id] = []
-
-    self.buffered_logits[request_id] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]
-
-    if shard.is_last_layer():
-      result = await self.inference_engine.sample(result)
-    
-    await self.inference_engine.ensure_shard(shard)
-    is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-
-    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-    if result.size == 1:  # we got a new token out
-      self.buffered_token_output[request_id][0].append(result.item())
+    is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+    if shard.is_last_layer() and not is_finished:
+      token = await self.inference_engine.sample(result)
+      await self.inference_engine.ensure_shard(shard)
+      self.buffered_token_output[request_id][0].append(token.item())
       self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-    
-    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
+      forward = token.reshape(1, -1)
+    else:
+      forward = result
 
     if is_finished:
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
     else:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
+      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
 
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+    return np.array(self.buffered_token_output[request_id][0])
 
   async def process_prompt(
     self,
@@ -190,13 +178,13 @@ class StandardNode(Node):
     shard = self.get_current_shard(base_shard)
 
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
-    if shard.start_layer != 0:
+    if not shard.is_first_layer():
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-      await self.forward_to_next_shard(shard, prompt, request_id)
+      resp = await self.forward_prompt(shard, prompt, request_id, 0)
       return None
     else:
       result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
-      ret = await self.process_result(shard, result, request_id) 
+      ret = await self.process_inference_result(shard, result, request_id) 
       return result
 
   async def process_tensor(
@@ -255,46 +243,57 @@ class StandardNode(Node):
     if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
       result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
-      ret = await self.process_result(shard, result, request_id) 
+      ret = await self.process_inference_result(shard, result, request_id) 
       return ret
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
       return None
 
-  async def forward_to_next_shard(
+  async def forward_prompt(
     self,
     base_shard: Shard,
-    tensor_or_prompt: Union[np.ndarray, str],
+    prompt: str,
     request_id: str,
+    target_index: int,
   ) -> None:
-    if not self.partitioning_strategy:
-      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
-      return
-
-    next_partition_index = self.get_partition_index(offset = 1)
-    if DEBUG >= 1: print(f"Next partition index: {next_partition_index}")
-    if next_partition_index is not None:
-      target_id = self.partitioning_strategy.partition(self.topology)[next_partition_index].node_id
-      next_shard = self.get_current_shard(base_shard, next_partition_index)
-      if DEBUG >= 2: print(f"Computed next from: {base_shard} {next_partition_index}, {self.topology}. Next shard: {next_shard}")
-      is_tensor = isinstance(tensor_or_prompt, np.ndarray)
-      if target_id == self.id:
-        if is_tensor:
-          await self.process_tensor(next_shard, tensor_or_prompt, request_id)
-        else:
-          await self.process_prompt(next_shard, tensor_or_prompt, request_id)
-      else:
-        target_peer = next((p for p in self.peers if p.id() == target_id), None)
-        if not target_peer:
-          raise ValueError(f"Peer for {next_partition_index} not found")
-        if is_tensor:
-          if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
-          await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id)
-        else:
-          await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id)
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {target_shard}")
+    if target_id == self.id:
+      await self.process_prompt(next_shard, prompt, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+  
+  async def forward_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: str,
+    target_index: int,
+  ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {target_shard}")
+    if target_id == self.id:
+      await self.process_tensor(next_shard, tensor, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
+      await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
 
   def get_partition_index(self, offset: int = 0):
+    if not self.partitioning_strategy:
+      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
+      return None
     partitions = self.partitioning_strategy.partition(self.topology)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     if current_partition_index is None: