|
@@ -24,6 +24,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
|
|
|
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
|
|
|
self.session = {}
|
|
|
+ self._shard_lock = asyncio.Lock()
|
|
|
|
|
|
async def _eval_mlx(self, *args):
|
|
|
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
|
|
@@ -157,19 +158,22 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
return score, first_layer
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
|
- if self.shard == shard:
|
|
|
- return
|
|
|
- model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
|
|
- if self.shard != shard:
|
|
|
- model_shard = await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: load_model_shard(model_path, shard, lazy=False))
|
|
|
- if hasattr(model_shard, "tokenizer"):
|
|
|
- self.tokenizer = model_shard.tokenizer
|
|
|
- else:
|
|
|
- self.tokenizer = await resolve_tokenizer(model_path)
|
|
|
- self.shard = shard
|
|
|
- self.model = model_shard
|
|
|
- self.caches = OrderedDict()
|
|
|
- self.session = {}
|
|
|
+ async with self._shard_lock:
|
|
|
+ if self.shard == shard: return
|
|
|
+ model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
|
|
+ if self.shard != shard:
|
|
|
+ model_shard = await asyncio.get_running_loop().run_in_executor(
|
|
|
+ self._mlx_thread,
|
|
|
+ lambda: load_model_shard(model_path, shard, lazy=False)
|
|
|
+ )
|
|
|
+ if hasattr(model_shard, "tokenizer"):
|
|
|
+ self.tokenizer = model_shard.tokenizer
|
|
|
+ else:
|
|
|
+ self.tokenizer = await resolve_tokenizer(model_path)
|
|
|
+ self.shard = shard
|
|
|
+ self.model = model_shard
|
|
|
+ self.caches = OrderedDict()
|
|
|
+ self.session = {}
|
|
|
|
|
|
async def cleanup(self):
|
|
|
self._mlx_thread.shutdown(wait=True)
|