|
@@ -4,7 +4,6 @@ import mlx.nn as nn
|
|
|
from mlx_lm.sample_utils import top_p_sampling
|
|
|
import mlx.optimizers as optim
|
|
|
from ..inference_engine import InferenceEngine
|
|
|
-from .stateful_model import StatefulModel
|
|
|
from .sharded_utils import load_shard, get_image_from_str
|
|
|
from .losses import loss_fns
|
|
|
from ..shard import Shard
|
|
@@ -12,6 +11,9 @@ from typing import Dict, Optional, Tuple
|
|
|
from exo.download.shard_download import ShardDownloader
|
|
|
import asyncio
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
+from functools import partial
|
|
|
+from collections import OrderedDict
|
|
|
+from mlx_lm.models.cache import make_prompt_cache
|
|
|
|
|
|
def sample_logits(
|
|
|
logits: mx.array,
|
|
@@ -39,8 +41,19 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self.shard = None
|
|
|
self.shard_downloader = shard_downloader
|
|
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
+ self.caches = OrderedDict()
|
|
|
self.session = {}
|
|
|
|
|
|
+ async def poll_cache(self, request_id: str, max_caches=2):
|
|
|
+ if request_id in self.caches:
|
|
|
+ self.caches.move_to_end(request_id)
|
|
|
+ else:
|
|
|
+ newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
|
|
|
+ if len(self.caches) > max_caches:
|
|
|
+ self.caches.popitem(last=False)
|
|
|
+ self.caches[request_id] = newcache
|
|
|
+ return self.caches[request_id]
|
|
|
+
|
|
|
async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
|
|
|
y = mx.array(x)
|
|
|
logits = y[:, -1, :]
|
|
@@ -57,54 +70,72 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
|
|
return tokens
|
|
|
|
|
|
- async def save_checkpoint(self, path: Path):
|
|
|
+ async def save_checkpoint(self, path: str):
|
|
|
await self.ensure_shard(shard)
|
|
|
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
|
|
|
|
|
|
- async def load_checkpoint(self, path: Path):
|
|
|
+ async def load_checkpoint(self, path: str):
|
|
|
await self.ensure_shard(shard)
|
|
|
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
#print(f"infer_tensor in <- {input_data}")
|
|
|
- output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ cache = await self.poll_cache(request_id)
|
|
|
+ x = mx.array(input_data).astype(mx.int64) if self.shard.is_first_layer() else mx.array(input_data)
|
|
|
+ #print(f"Infer Tensor: {x=}")
|
|
|
+ output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, cache=cache)))
|
|
|
#print(f"infer_tensor out -> {output_data}")
|
|
|
return output_data
|
|
|
-
|
|
|
+
|
|
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
|
|
|
await self.ensure_shard(shard)
|
|
|
- await self.ensure_session('loss', lambda: loss_fns[loss])
|
|
|
- await self.ensure_session('task', lambda: ('eval', self.model.eval()))
|
|
|
+ await self.save_session('loss', loss_fns[loss])
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
#print(f"evaluate in <- {inputs}")
|
|
|
x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
|
|
|
y = mx.array(targets)
|
|
|
l = mx.array(lengths)
|
|
|
- score = await asyncio.get_running_loop().run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
|
|
|
+ score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
|
|
|
#print(f"evaluate out -> {score}")
|
|
|
return np.array(score)
|
|
|
|
|
|
- async def update_model(self, grad, lval):
|
|
|
+ async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
|
|
|
await self.ensure_shard(shard)
|
|
|
- self.session['opt'].update(self.model, grad)
|
|
|
- mx.eval(self.model.parameters(), self.session['opt'].state, lval)
|
|
|
-
|
|
|
- async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.Adam, lr=1e-5):
|
|
|
- await self.ensure_shard(shard)
|
|
|
- await self.ensure_session('loss', lambda: loss_fns[loss])
|
|
|
- await self.ensure_session('LVaG', lambda: nn.value_and_grad(self.model, self.session['loss']))
|
|
|
- await self.ensure_session('opt', lambda: opt(lr))
|
|
|
- await self.ensure_session('task', lambda: ('train', self.model.train()))
|
|
|
+ if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
|
|
|
+ await self.save_session('train_layers', trainable_layers)
|
|
|
+ self.model.freeze()
|
|
|
+ self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
|
|
|
+ if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
|
|
|
+ await self.save_session('lossname', loss)
|
|
|
+ await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
|
|
|
+ if 'opt' not in self.session:
|
|
|
+ await self.save_session('opt', opt(lr))
|
|
|
+ return True
|
|
|
+
|
|
|
+ async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ nothin = await self.ensure_train(shard, loss, opt, lr)
|
|
|
+ def train_step(inp, tar, lng):
|
|
|
+ lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
|
|
|
+ gradlayers = grad['model']['layers']
|
|
|
+ self.session['opt'].update(self.model, grad)
|
|
|
+ mx.eval(self.model.parameters(), self.session['opt'].state, lval)
|
|
|
+ return lval, gradlayers
|
|
|
|
|
|
x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
|
|
|
y = mx.array(targets)
|
|
|
l = mx.array(lengths)
|
|
|
- loop = asyncio.get_running_loop()
|
|
|
- score, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
|
|
|
- layers = [{k: v["weight"].shape for k,v in l.items() if 'weight' in v} for l in grad['model']['model']['layers'] if l]
|
|
|
- await loop.run_in_executor(self.executor, self.update_model, grad, score)
|
|
|
+
|
|
|
+ score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
|
|
|
+ #print(f"{score=}")
|
|
|
+
|
|
|
+ layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
|
|
|
+ #print(layers[0])
|
|
|
|
|
|
return np.array(score).reshape(inputs.shape[0], -1), np.array(layers[0]['input_layernorm']).reshape(inputs.shape[0], -1)
|
|
|
+ return 0, 0
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
|
if self.shard == shard:
|
|
@@ -113,12 +144,13 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
|
|
|
|
|
if self.shard != shard:
|
|
|
- loop = asyncio.get_running_loop()
|
|
|
|
|
|
def load_shard_wrapper():
|
|
|
return asyncio.run(load_shard(model_path, shard))
|
|
|
|
|
|
- model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
|
|
|
+ model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
|
|
|
self.shard = shard
|
|
|
- self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard)
|
|
|
+ self.model = model_shard
|
|
|
+ self.caches = OrderedDict()
|
|
|
+ self.session = {}
|
|
|
|