Browse Source

remove redundant sample_logits, put back opaque status for process_prompt so we have a way of preemptively starting downloads

Alex Cheema 6 months ago
parent
commit
063964aab3
2 changed files with 36 additions and 25 deletions
  1. 0 21
      exo/inference/mlx/sharded_inference_engine.py
  2. 36 4
      exo/orchestration/node.py

+ 0 - 21
exo/inference/mlx/sharded_inference_engine.py

@@ -13,27 +13,6 @@ import asyncio
 from collections import OrderedDict
 from mlx_lm.models.cache import make_prompt_cache
 
-def sample_logits(
-  logits: mx.array,
-  temp: float = 0.0,
-  top_p: float = 1.0,
-  logit_bias: Optional[Dict[int, float]] = None
-) -> Tuple[mx.array, float]:
-  if logit_bias:
-    indices = mx.array(list(logit_bias.keys()))
-    values = mx.array(list(logit_bias.values()))
-    logits[:, indices] += values
-
-  if temp == 0:
-    token = mx.argmax(logits, axis=-1)
-  else:
-    if top_p > 0 and top_p < 1.0:
-      token = top_p_sampling(logits, top_p, temp)
-    else:
-      token = mx.random.categorical(logits*(1/temp))
-
-  return token
-
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None

+ 36 - 4
exo/orchestration/node.py

@@ -70,25 +70,28 @@ class Node:
   def on_node_status(self, request_id, opaque_status):
     try:
       status_data = json.loads(opaque_status)
-      if status_data.get("type", "") == "supported_inference_engines":
+      status_type = status_data.get("type", "")
+      if status_type == "supported_inference_engines":
         node_id = status_data.get("node_id")
         engines = status_data.get("engines", [])
         self.topology_inference_engines_pool.append(engines)
-      if status_data.get("type", "") == "node_status":
+      elif status_type == "node_status":
         if status_data.get("status", "").startswith("start_"):
           self.current_topology.active_node_id = status_data.get("node_id")
         elif status_data.get("status", "").startswith("end_"):
           if status_data.get("node_id") == self.current_topology.active_node_id:
             self.current_topology.active_node_id = None
+
       download_progress = None
-      if status_data.get("type", "") == "download_progress":
+      if status_type == "download_progress":
         if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
         download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
         self.node_download_progress[status_data.get('node_id')] = download_progress
+
       if self.topology_viz:
         self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
     except Exception as e:
-      if DEBUG >= 1: print(f"Error updating visualization: {e}")
+      if DEBUG >= 1: print(f"Error on_node_status: {e}")
       if DEBUG >= 1: traceback.print_exc()
 
   def get_supported_inference_engines(self):
@@ -153,10 +156,39 @@ class Node:
     request_id: Optional[str] = None,
   ) -> None:
     shard = self.get_current_shard(base_shard)
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "start_process_prompt",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "prompt": prompt,
+          "request_id": request_id,
+        }),
+      )
+    )
     start_time = time.perf_counter_ns()
     await self._process_prompt(base_shard, prompt, request_id)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "end_process_prompt",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "prompt": prompt,
+          "request_id": request_id,
+          "elapsed_time_ns": elapsed_time_ns,
+        }),
+      )
+    )
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
 
   async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]: