|
@@ -6,7 +6,7 @@ import mlx.optimizers as optim
|
|
|
from ..inference_engine import InferenceEngine
|
|
|
from .stateful_model import StatefulModel
|
|
|
from .sharded_utils import load_shard, get_image_from_str
|
|
|
-from .losses import length_masked_ce_loss
|
|
|
+from .losses import loss_fns
|
|
|
from ..shard import Shard
|
|
|
from typing import Dict, Optional, Tuple
|
|
|
from exo.download.shard_download import ShardDownloader
|
|
@@ -64,33 +64,39 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
#print(f"infer_tensor out -> {output_data}")
|
|
|
return output_data
|
|
|
|
|
|
- async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|
|
|
+ async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
|
|
|
await self.ensure_shard(shard)
|
|
|
- await self.ensure_session('loss', lambda: loss)
|
|
|
+ await self.ensure_session('loss', lambda: loss_fns[loss])
|
|
|
await self.ensure_session('task', lambda: ('eval', self.model.eval()))
|
|
|
#print(f"evaluate in <- {inputs}")
|
|
|
x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
|
|
|
- y = mx.array(targets).astype(mx.int64)
|
|
|
+ y = mx.array(targets)
|
|
|
l = mx.array(lengths)
|
|
|
score = await asyncio.get_running_loop().run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
|
|
|
#print(f"evaluate out -> {score}")
|
|
|
return np.array(score)
|
|
|
+
|
|
|
+ async def update_model(self, grad, lval):
|
|
|
+ await self.ensure_shard(shard)
|
|
|
+ self.session['opt'].update(self.model, grad)
|
|
|
+ mx.eval(self.model.parameters(), self.session['opt'].state, lval)
|
|
|
|
|
|
- async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss, opt=optim.Adam, lr=1e-5):
|
|
|
+ async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.Adam, lr=1e-5):
|
|
|
await self.ensure_shard(shard)
|
|
|
- await self.ensure_session('loss', lambda: loss)
|
|
|
+ await self.ensure_session('loss', lambda: loss_fns[loss])
|
|
|
await self.ensure_session('LVaG', lambda: nn.value_and_grad(self.model, self.session['loss']))
|
|
|
await self.ensure_session('opt', lambda: opt(lr))
|
|
|
await self.ensure_session('task', lambda: ('train', self.model.train()))
|
|
|
|
|
|
x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
|
|
|
- y = mx.array(targets).astype(mx.int64)
|
|
|
+ y = mx.array(targets)
|
|
|
l = mx.array(lengths)
|
|
|
loop = asyncio.get_running_loop()
|
|
|
- loss, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
|
|
|
- await loop.run_in_executor(self.executor, lambda: self.session['opt'].update(self.model, grad))
|
|
|
+ score, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
|
|
|
+ loop.run_in_executor(self.executor, self.update_model, grad, score)
|
|
|
+ layers = [{k: v["weight"].shape for k,v in l.items() if 'weight' in v} for l in grad['model']['model']['layers'] if l]
|
|
|
|
|
|
- return np.array(loss), np.array(grad)
|
|
|
+ return np.array(score).reshape(inputs.shape[0], -1), np.array(layers[0]['input_layernorm']).reshape(inputs.shape[0], -1)
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
|
if self.shard == shard:
|