|
@@ -23,10 +23,10 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
self.sampler = make_sampler(*self.sampler_params)
|
|
self.sampler = make_sampler(*self.sampler_params)
|
|
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
|
|
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
|
|
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
|
|
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
|
|
|
|
+ self.session = {}
|
|
|
|
|
|
async def _eval_mlx(self, *args):
|
|
async def _eval_mlx(self, *args):
|
|
- loop = asyncio.get_running_loop()
|
|
|
|
- await loop.run_in_executor(self._mlx_thread, mx.eval, *args)
|
|
|
|
|
|
+ await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
|
|
|
|
|
|
async def poll_state(self, request_id: str, max_caches=2):
|
|
async def poll_state(self, request_id: str, max_caches=2):
|
|
if request_id in self.caches:
|
|
if request_id in self.caches:
|
|
@@ -51,31 +51,48 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
|
|
|
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
|
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
- loop = asyncio.get_running_loop()
|
|
|
|
- return np.asarray(await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.encode, prompt))
|
|
|
|
|
|
+ return np.asarray(
|
|
|
|
+ await asyncio.get_running_loop().run_in_executor(
|
|
|
|
+ self._tokenizer_thread,
|
|
|
|
+ self.tokenizer.encode,
|
|
|
|
+ prompt
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
|
|
async def decode(self, shard: Shard, tokens) -> str:
|
|
async def decode(self, shard: Shard, tokens) -> str:
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
- loop = asyncio.get_running_loop()
|
|
|
|
- return await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.decode, tokens)
|
|
|
|
|
|
+ return await asyncio.get_running_loop().run_in_executor(
|
|
|
|
+ self._tokenizer_thread,
|
|
|
|
+ self.tokenizer.decode,
|
|
|
|
+ tokens
|
|
|
|
+ )
|
|
|
|
|
|
async def save_checkpoint(self, shard: Shard, path: str):
|
|
async def save_checkpoint(self, shard: Shard, path: str):
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
- self.model.save_weights(path)
|
|
|
|
|
|
+ await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.save_weights(path))
|
|
|
|
|
|
async def load_checkpoint(self, shard: Shard, path: str):
|
|
async def load_checkpoint(self, shard: Shard, path: str):
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
- self.model.load_weights(path)
|
|
|
|
-
|
|
|
|
|
|
+ await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.load_weights(path))
|
|
|
|
+
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
- loop = asyncio.get_running_loop()
|
|
|
|
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
|
|
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
|
|
x = mx.array(input_data)
|
|
x = mx.array(input_data)
|
|
|
|
+
|
|
if self.model.model_type != 'StableDiffusionPipeline':
|
|
if self.model.model_type != 'StableDiffusionPipeline':
|
|
- output_data = self.model(x, **state, **(inference_state or {}))
|
|
|
|
|
|
+ output_data = await asyncio.get_running_loop().run_in_executor(
|
|
|
|
+ self._mlx_thread,
|
|
|
|
+ lambda: self.model(x, **state, **(inference_state or {}))
|
|
|
|
+ )
|
|
|
|
+ inference_state = None
|
|
else:
|
|
else:
|
|
- output_data, inference_state = self.model(x, **state, **(inference_state or {}))
|
|
|
|
|
|
+ result = await asyncio.get_running_loop().run_in_executor(
|
|
|
|
+ self._mlx_thread,
|
|
|
|
+ lambda: self.model(x, **state, **(inference_state or {}))
|
|
|
|
+ )
|
|
|
|
+ output_data, inference_state = result
|
|
|
|
+
|
|
output_data = np.array(output_data, copy=False)
|
|
output_data = np.array(output_data, copy=False)
|
|
return output_data, inference_state
|
|
return output_data, inference_state
|
|
|
|
|
|
@@ -85,18 +102,29 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
x = mx.array(inputs)
|
|
x = mx.array(inputs)
|
|
y = mx.array(targets)
|
|
y = mx.array(targets)
|
|
l = mx.array(lengths)
|
|
l = mx.array(lengths)
|
|
- score = self.session['loss'](self.model, x, y, l)
|
|
|
|
|
|
+
|
|
|
|
+ score = await asyncio.get_running_loop().run_in_executor(
|
|
|
|
+ self._mlx_thread,
|
|
|
|
+ lambda: self.session['loss'](self.model, x, y, l)
|
|
|
|
+ )
|
|
return score
|
|
return score
|
|
|
|
|
|
async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
|
|
async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
|
|
+
|
|
if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
|
|
if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
|
|
await self.save_session('train_layers', trainable_layers)
|
|
await self.save_session('train_layers', trainable_layers)
|
|
- self.model.freeze()
|
|
|
|
- self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
|
|
|
|
|
|
+ def freeze_unfreeze():
|
|
|
|
+ self.model.freeze()
|
|
|
|
+ self.model.apply_to_modules(
|
|
|
|
+ lambda k, v: v.unfreeze() if any(k.endswith(layer_name) for layer_name in trainable_layers) else None
|
|
|
|
+ )
|
|
|
|
+ await asyncio.get_running_loop().run_in_executor(self._mlx_thread, freeze_unfreeze)
|
|
|
|
+
|
|
if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
|
|
if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
|
|
await self.save_session('lossname', loss)
|
|
await self.save_session('lossname', loss)
|
|
await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
|
|
await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
|
|
|
|
+
|
|
if 'opt' not in self.session:
|
|
if 'opt' not in self.session:
|
|
await self.save_session('opt', opt(lr))
|
|
await self.save_session('opt', opt(lr))
|
|
return True
|
|
return True
|
|
@@ -113,11 +141,13 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
x = mx.array(inputs)
|
|
x = mx.array(inputs)
|
|
y = mx.array(targets)
|
|
y = mx.array(targets)
|
|
l = mx.array(lengths)
|
|
l = mx.array(lengths)
|
|
-
|
|
|
|
- score, gradients, eval_args = train_step(x, y, l)
|
|
|
|
|
|
+ score, gradients, eval_args = await asyncio.get_running_loop().run_in_executor(
|
|
|
|
+ self._mlx_thread,
|
|
|
|
+ lambda: train_step(x, y, l)
|
|
|
|
+ )
|
|
await self._eval_mlx(*eval_args)
|
|
await self._eval_mlx(*eval_args)
|
|
|
|
|
|
- layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
|
|
|
|
|
|
+ layers = [{k: v["weight"] for k, v in layer.items() if 'weight' in v} for layer in gradients if layer]
|
|
first_layer = np.array(layers[0]['input_layernorm'], copy=False)
|
|
first_layer = np.array(layers[0]['input_layernorm'], copy=False)
|
|
await self._eval_mlx(first_layer)
|
|
await self._eval_mlx(first_layer)
|
|
return score, first_layer
|
|
return score, first_layer
|
|
@@ -125,9 +155,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
async def ensure_shard(self, shard: Shard):
|
|
async def ensure_shard(self, shard: Shard):
|
|
if self.shard == shard:
|
|
if self.shard == shard:
|
|
return
|
|
return
|
|
-
|
|
|
|
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
|
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
|
-
|
|
|
|
if self.shard != shard:
|
|
if self.shard != shard:
|
|
model_shard, self.tokenizer = await load_shard(model_path, shard)
|
|
model_shard, self.tokenizer = await load_shard(model_path, shard)
|
|
self.shard = shard
|
|
self.shard = shard
|