Ver Fonte

parallelise model loading

Alex Cheema há 3 meses atrás
pai
commit
4887be5103
2 ficheiros alterados com 20 adições e 22 exclusões
  1. 17 13
      exo/inference/mlx/sharded_inference_engine.py
  2. 3 9
      exo/main.py

+ 17 - 13
exo/inference/mlx/sharded_inference_engine.py

@@ -24,6 +24,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
     self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
     self.session = {}
+    self._shard_lock = asyncio.Lock()
 
   async def _eval_mlx(self, *args):
     await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
@@ -157,19 +158,22 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     return score, first_layer
 
   async def ensure_shard(self, shard: Shard):
-    if self.shard == shard:
-      return
-    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
-    if self.shard != shard:
-      model_shard = await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: load_model_shard(model_path, shard, lazy=False))
-      if hasattr(model_shard, "tokenizer"):
-        self.tokenizer = model_shard.tokenizer
-      else:
-        self.tokenizer = await resolve_tokenizer(model_path)
-      self.shard = shard
-      self.model = model_shard
-      self.caches = OrderedDict()
-      self.session = {}
+    async with self._shard_lock:
+      if self.shard == shard: return
+      model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
+      if self.shard != shard:
+        model_shard = await asyncio.get_running_loop().run_in_executor(
+          self._mlx_thread,
+          lambda: load_model_shard(model_path, shard, lazy=False)
+        )
+        if hasattr(model_shard, "tokenizer"):
+          self.tokenizer = model_shard.tokenizer
+        else:
+          self.tokenizer = await resolve_tokenizer(model_path)
+        self.shard = shard
+        self.model = model_shard
+        self.caches = OrderedDict()
+        self.session = {}
 
   async def cleanup(self):
     self._mlx_thread.shutdown(wait=True)

+ 3 - 9
exo/main.py

@@ -193,24 +193,20 @@ def update_prompt_viz(request_id, opaque_status: str):
       traceback.print_exc()
 node.on_opaque_status.register("update_prompt_viz").on_next(update_prompt_viz)
 
-def preemptively_start_download(request_id: str, opaque_status: str):
+def preemptively_load_shard(request_id: str, opaque_status: str):
   try:
     status = json.loads(opaque_status)
     if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
     current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
     if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-    asyncio.create_task(shard_downloader.ensure_shard(current_shard, node.inference_engine.__class__.__name__))
+    asyncio.create_task(node.inference_engine.ensure_shard(current_shard))
   except Exception as e:
     if DEBUG >= 2:
       print(f"Failed to preemptively start download: {e}")
       traceback.print_exc()
-
-
-node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
+node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)
 
 last_broadcast_time = 0
-
-
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
   global last_broadcast_time
   current_time = time.time()
@@ -218,8 +214,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
   if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
     last_broadcast_time = current_time
     asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
-
-
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 async def run_model_cli(node: Node, model_name: str, prompt: str):