Pārlūkot izejas kodu

load mlx model shard on mlx thread so it doesnt block

Alex Cheema 3 mēneši atpakaļ
vecāks
revīzija
6662d5668c
1 mainītis faili ar 6 papildinājumiem un 2 dzēšanām
  1. 6 2
      exo/inference/mlx/sharded_inference_engine.py

+ 6 - 2
exo/inference/mlx/sharded_inference_engine.py

@@ -4,7 +4,7 @@ import mlx.nn as nn
 from mlx_lm.sample_utils import top_p_sampling, make_sampler
 import mlx.optimizers as optim
 from ..inference_engine import InferenceEngine
-from .sharded_utils import load_shard, get_image_from_str
+from .sharded_utils import load_shard, load_model_shard, resolve_tokenizer
 from .losses import loss_fns
 from ..shard import Shard
 from typing import Dict, Optional, Tuple
@@ -157,7 +157,11 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
       return
     model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
     if self.shard != shard:
-      model_shard, self.tokenizer = await load_shard(model_path, shard)
+      model_shard = await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: load_model_shard(model_path, shard, lazy=False))
+      if hasattr(model_shard, "tokenizer"):
+        self.tokenizer = model_shard.tokenizer
+      else:
+        self.tokenizer = await resolve_tokenizer(model_path)
       self.shard = shard
       self.model = model_shard
       self.caches = OrderedDict()