123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- import numpy as np
- import mlx.core as mx
- import mlx.nn as nn
- from mlx_lm.sample_utils import make_sampler
- import mlx.optimizers as optim
- from ..inference_engine import InferenceEngine
- from .sharded_utils import load_model_shard, resolve_tokenizer
- from .losses import loss_fns
- from ..shard import Shard
- from typing import Dict, Optional, Tuple
- from exo.download.shard_download import ShardDownloader
- import asyncio
- from collections import OrderedDict
- from mlx_lm.models.cache import make_prompt_cache
- from concurrent.futures import ThreadPoolExecutor
- class MLXDynamicShardInferenceEngine(InferenceEngine):
- def __init__(self, shard_downloader: ShardDownloader):
- self.shard = None
- self.shard_downloader = shard_downloader
- self.caches = OrderedDict()
- self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
- self.sampler = make_sampler(*self.sampler_params)
- self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
- self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
- self.session = {}
- async def _eval_mlx(self, *args):
- await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
- async def poll_state(self, request_id: str, max_caches=2):
- if request_id in self.caches:
- self.caches.move_to_end(request_id)
- else:
- newcache = make_prompt_cache(self.model)
- if len(self.caches) > max_caches:
- self.caches.popitem(last=False)
- self.caches[request_id] = newcache
- return {"cache": self.caches[request_id]}
- async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
- if (temp, top_p, 0.0, 1) != self.sampler_params:
- self.sampler_params = (temp, top_p, 0.0, 1)
- self.sampler = make_sampler(*self.sampler_params)
- logits = mx.array(x)
- logits = logits[:, -1, :]
- logprobs = logits - mx.logsumexp(logits, keepdims=True)
- result = self.sampler(logprobs)
- await self._eval_mlx(result)
- return np.asarray(result, dtype=int)
- async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
- await self.ensure_shard(shard)
- return np.asarray(
- await asyncio.get_running_loop().run_in_executor(
- self._tokenizer_thread,
- self.tokenizer.encode,
- prompt
- )
- )
- async def decode(self, shard: Shard, tokens) -> str:
- await self.ensure_shard(shard)
- return await asyncio.get_running_loop().run_in_executor(
- self._tokenizer_thread,
- self.tokenizer.decode,
- tokens
- )
- async def save_checkpoint(self, shard: Shard, path: str):
- await self.ensure_shard(shard)
- await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.save_weights(path))
- async def load_checkpoint(self, shard: Shard, path: str):
- await self.ensure_shard(shard)
- await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.load_weights(path))
- async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
- await self.ensure_shard(shard)
- state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
- x = mx.array(input_data)
- if self.model.model_type != 'StableDiffusionPipeline':
- output_data = await asyncio.get_running_loop().run_in_executor(
- self._mlx_thread,
- lambda: self.model(x, **state, **(inference_state or {}))
- )
- inference_state = None
- else:
- result = await asyncio.get_running_loop().run_in_executor(
- self._mlx_thread,
- lambda: self.model(x, **state, **(inference_state or {}))
- )
- output_data, inference_state = result
- await self._eval_mlx(output_data)
- output_data = await asyncio.get_running_loop().run_in_executor(
- self._mlx_thread,
- lambda: np.array(output_data, copy=False)
- )
- return output_data, inference_state
- async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
- await self.ensure_shard(shard)
- await self.save_session('loss', loss_fns[loss])
- x = mx.array(inputs)
- y = mx.array(targets)
- l = mx.array(lengths)
- score = await asyncio.get_running_loop().run_in_executor(
- self._mlx_thread,
- lambda: self.session['loss'](self.model, x, y, l)
- )
- return score
- async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
- await self.ensure_shard(shard)
- if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
- await self.save_session('train_layers', trainable_layers)
- def freeze_unfreeze():
- self.model.freeze()
- self.model.apply_to_modules(
- lambda k, v: v.unfreeze() if any(k.endswith(layer_name) for layer_name in trainable_layers) else None
- )
- await asyncio.get_running_loop().run_in_executor(self._mlx_thread, freeze_unfreeze)
- if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
- await self.save_session('lossname', loss)
- await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
- if 'opt' not in self.session:
- await self.save_session('opt', opt(lr))
- return True
- async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
- await self.ensure_train(shard, loss, opt, lr)
- def train_step(inp, tar, lng):
- lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
- gradlayers = grad['model']['layers']
- self.session['opt'].update(self.model, grad)
- return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
- x = mx.array(inputs)
- y = mx.array(targets)
- l = mx.array(lengths)
- score, gradients, eval_args = await asyncio.get_running_loop().run_in_executor(
- self._mlx_thread,
- lambda: train_step(x, y, l)
- )
- await self._eval_mlx(*eval_args)
- layers = [{k: v["weight"] for k, v in layer.items() if 'weight' in v} for layer in gradients if layer]
- first_layer = np.array(layers[0]['input_layernorm'], copy=False)
- await self._eval_mlx(first_layer)
- return score, first_layer
- async def ensure_shard(self, shard: Shard):
- if self.shard == shard:
- return
- model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
- if self.shard != shard:
- model_shard = await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: load_model_shard(model_path, shard, lazy=False))
- if hasattr(model_shard, "tokenizer"):
- self.tokenizer = model_shard.tokenizer
- else:
- self.tokenizer = await resolve_tokenizer(model_path)
- self.shard = shard
- self.model = model_shard
- self.caches = OrderedDict()
- self.session = {}
- async def cleanup(self):
- self._mlx_thread.shutdown(wait=True)
|