|
@@ -4,7 +4,7 @@ import mlx.nn as nn
|
|
|
from mlx_lm.sample_utils import top_p_sampling, make_sampler
|
|
|
import mlx.optimizers as optim
|
|
|
from ..inference_engine import InferenceEngine
|
|
|
-from .sharded_utils import load_shard, get_image_from_str
|
|
|
+from .sharded_utils import load_shard, load_model_shard, resolve_tokenizer
|
|
|
from .losses import loss_fns
|
|
|
from ..shard import Shard
|
|
|
from typing import Dict, Optional, Tuple
|
|
@@ -157,7 +157,11 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
return
|
|
|
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
|
|
if self.shard != shard:
|
|
|
- model_shard, self.tokenizer = await load_shard(model_path, shard)
|
|
|
+ model_shard = await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: load_model_shard(model_path, shard, lazy=False))
|
|
|
+ if hasattr(model_shard, "tokenizer"):
|
|
|
+ self.tokenizer = model_shard.tokenizer
|
|
|
+ else:
|
|
|
+ self.tokenizer = await resolve_tokenizer(model_path)
|
|
|
self.shard = shard
|
|
|
self.model = model_shard
|
|
|
self.caches = OrderedDict()
|