Browse Source

new approach to mlx async operations and make tokenizer operations async too

Alex Cheema 6 months ago
parent
commit
b02c0a5be0

+ 28 - 15
exo/inference/mlx/sharded_inference_engine.py

@@ -12,6 +12,7 @@ from exo.download.shard_download import ShardDownloader
 import asyncio
 from collections import OrderedDict
 from mlx_lm.models.cache import make_prompt_cache
+from concurrent.futures import ThreadPoolExecutor
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -20,6 +21,12 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.caches = OrderedDict()
     self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
     self.sampler = make_sampler(*self.sampler_params)
+    self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
+    self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
+
+  async def _eval_mlx(self, *args):
+    loop = asyncio.get_running_loop()
+    await loop.run_in_executor(self._mlx_thread, mx.eval, *args)
 
   async def poll_state(self, request_id: str, max_caches=2):
     if request_id in self.caches:
@@ -38,16 +45,19 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     logits = mx.array(x)
     logits = logits[:, -1, :]
     logprobs = logits - mx.logsumexp(logits, keepdims=True)
-    return np.asarray(self.sampler(logprobs), dtype=int)
+    result = self.sampler(logprobs)
+    await self._eval_mlx(result)
+    return np.asarray(result, dtype=int)
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
-    tokens = self.tokenizer.encode(prompt)
-    return np.asarray(tokens)
+    loop = asyncio.get_running_loop()
+    return np.asarray(await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.encode, prompt))
 
   async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
-    return self.tokenizer.decode(tokens)
+    loop = asyncio.get_running_loop()
+    return await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.decode, tokens)
 
   async def save_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
@@ -61,8 +71,9 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     await self.ensure_shard(shard)
     state = await self.poll_state(request_id)
     x = mx.array(input_data)
-    output_data = np.array(self.model(x, **state), copy=False)
-    return output_data
+    output = self.model(x, **state)
+    await self._eval_mlx(output)
+    return np.array(output, copy=False)
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
     await self.ensure_shard(shard)
@@ -87,26 +98,25 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     return True
 
   async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
-    loop = asyncio.get_running_loop()
-    nothin = await self.ensure_train(shard, loss, opt, lr)
+    await self.ensure_train(shard, loss, opt, lr)
+
     def train_step(inp, tar, lng):
       lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
       gradlayers = grad['model']['layers']
       self.session['opt'].update(self.model, grad)
-      mx.eval(self.model.parameters(), self.session['opt'].state, lval)
-      return lval, gradlayers
+      return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
 
     x = mx.array(inputs)
     y = mx.array(targets)
     l = mx.array(lengths)
 
-    score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
-    #print(f"{score=}")
+    score, gradients, eval_args = train_step(x, y, l)
+    await self._eval_mlx(*eval_args)
 
     layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
-    #print(layers[0])
-
-    return score, np.array(layers[0]['input_layernorm'], copy=False)
+    first_layer = np.array(layers[0]['input_layernorm'], copy=False)
+    await self._eval_mlx(first_layer)
+    return score, first_layer
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
@@ -121,3 +131,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
       self.caches = OrderedDict()
       self.session = {}
 
+  async def cleanup(self):
+    self._mlx_thread.shutdown(wait=True)
+

+ 2 - 2
exo/inference/mlx/sharded_utils.py

@@ -164,8 +164,8 @@ def load_model_shard(
 
   model.load_weights(list(weights.items()), strict=True)
 
-  if not lazy:
-    mx.eval(model.parameters())
+  # if not lazy:
+  #   mx.eval(model.parameters())
 
   model.eval()
   return model

+ 81 - 0
exo/inference/mlx/test_non_blocking.py

@@ -0,0 +1,81 @@
+import asyncio
+import time
+import numpy as np
+from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.inference.shard import Shard
+from exo.models import build_base_shard
+from collections import deque
+from statistics import mean, median
+
+async def test_non_blocking():
+    # Setup
+    shard_downloader = HFShardDownloader()
+    engine = MLXDynamicShardInferenceEngine(shard_downloader)
+    _shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
+    shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
+    await engine.ensure_shard(shard)
+    
+    queue = asyncio.Queue()
+    measurements = deque(maxlen=1000000)
+    running = True
+
+    async def mlx_worker():
+        try:
+            start_time = time.time()
+            count = 0
+            while running and (time.time() - start_time) < 5:  # Hard time limit
+                start = time.perf_counter_ns()
+                await engine.infer_prompt("req1", shard, "test prompt")
+                duration = (time.perf_counter_ns() - start) / 1_000_000  # Convert to ms
+                count += 1
+                print(f"MLX operation {count} took: {duration:.3f}ms")
+        except asyncio.CancelledError:
+            pass
+        finally:
+            print(f"\nTotal MLX operations completed: {count}")
+            print(f"Average rate: {count/5:.1f} ops/second")
+
+    async def latency_producer():
+        try:
+            start_time = time.perf_counter_ns()
+            count = 0
+            while running:
+                await queue.put(time.perf_counter_ns())
+                count += 1
+                await asyncio.sleep(0)  # Yield to event loop without delay
+            duration = (time.perf_counter_ns() - start_time) / 1e9  # Convert to seconds
+            print(f"\nProducer iterations: {count}")
+            print(f"Producer rate: {count/duration:.1f} iterations/second")
+        except asyncio.CancelledError:
+            pass
+
+    async def latency_consumer():
+        try:
+            while running:
+                timestamp = await queue.get()
+                latency = (time.perf_counter_ns() - timestamp) / 1_000_000  # Convert to ms
+                measurements.append(latency)
+                queue.task_done()
+        except asyncio.CancelledError:
+            pass
+
+    tasks = [
+        asyncio.create_task(mlx_worker()),
+        asyncio.create_task(latency_producer()),
+        asyncio.create_task(latency_consumer())
+    ]
+    
+    try:
+        await asyncio.wait_for(asyncio.gather(*tasks), timeout=6)
+    except asyncio.TimeoutError:
+        print("\nTest timed out")
+    finally:
+        running = False
+        for task in tasks:
+            task.cancel()
+        await asyncio.gather(*tasks, return_exceptions=True)
+        print(f"\nFinal measurement count: {len(measurements)}")
+
+if __name__ == "__main__":
+    asyncio.run(test_non_blocking())

+ 12 - 0
exo/main.py

@@ -235,12 +235,24 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
     print(f"Processing prompt: {prompt}")
     await node.process_prompt(shard, prompt, request_id=request_id)
 
+    first_token_time = time.time()
     tokens = []
+    i = 0
     def on_token(_request_id, _token, _is_finished):
+      nonlocal i
+      i += 1
+      if i % 20 == 0:
+        print(f"TPS: {i / (time.time() - first_token_time)}")
+
       tokens.append(_token)
       return _request_id == request_id and _is_finished
     await callback.wait(on_token, timeout=300)
 
+    print("=== Stats ===")
+    print(f"Total time: {time.time() - first_token_time}")
+    print(f"Total tokens: {len(tokens)}")
+    print(f"Total tokens per second: {len(tokens) / (time.time() - first_token_time)}")
+
     print("\nGenerated response:")
     print(tokenizer.decode(tokens))
   except Exception as e:

+ 5 - 2
exo/orchestration/node.py

@@ -123,6 +123,7 @@ class Node:
       context = TraceContext(request_id=request_id or str(uuid.uuid4()), sequence_number=0)
       tracer.set_context(request_id, context)
 
+    is_finished = False
     try:
       with tracer.start_span(
         f"process_inference_result.{self.get_partition_index()}",
@@ -136,9 +137,10 @@ class Node:
       ):
         if request_id not in self.buffered_token_output:
           self.buffered_token_output[request_id] = ([], False)
-        is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
         
-        if shard.is_last_layer() and not is_finished:
+        if shard.is_last_layer():
+          is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+
           # Add span for sampling
           with tracer.start_span(
             "sample_token",
@@ -203,6 +205,7 @@ class Node:
           self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
           self.outstanding_requests.pop(request_id)
 
+
         return np.array(self.buffered_token_output[request_id][0])
     except Exception as e:
       if request_id in self.outstanding_requests: