|
@@ -12,6 +12,7 @@ from exo.download.shard_download import ShardDownloader
|
|
|
import asyncio
|
|
|
from collections import OrderedDict
|
|
|
from mlx_lm.models.cache import make_prompt_cache
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
def __init__(self, shard_downloader: ShardDownloader):
|
|
@@ -20,6 +21,12 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self.caches = OrderedDict()
|
|
|
self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
|
|
|
self.sampler = make_sampler(*self.sampler_params)
|
|
|
+ self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
|
|
|
+ self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
|
|
|
+
|
|
|
+ async def _eval_mlx(self, *args):
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ await loop.run_in_executor(self._mlx_thread, mx.eval, *args)
|
|
|
|
|
|
async def poll_state(self, request_id: str, max_caches=2):
|
|
|
if request_id in self.caches:
|
|
@@ -38,16 +45,19 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
logits = mx.array(x)
|
|
|
logits = logits[:, -1, :]
|
|
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
|
|
- return np.asarray(self.sampler(logprobs), dtype=int)
|
|
|
+ result = self.sampler(logprobs)
|
|
|
+ await self._eval_mlx(result)
|
|
|
+ return np.asarray(result, dtype=int)
|
|
|
|
|
|
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
- tokens = self.tokenizer.encode(prompt)
|
|
|
- return np.asarray(tokens)
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ return np.asarray(await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.encode, prompt))
|
|
|
|
|
|
async def decode(self, shard: Shard, tokens) -> str:
|
|
|
await self.ensure_shard(shard)
|
|
|
- return self.tokenizer.decode(tokens)
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ return await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.decode, tokens)
|
|
|
|
|
|
async def save_checkpoint(self, shard: Shard, path: str):
|
|
|
await self.ensure_shard(shard)
|
|
@@ -61,8 +71,9 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
await self.ensure_shard(shard)
|
|
|
state = await self.poll_state(request_id)
|
|
|
x = mx.array(input_data)
|
|
|
- output_data = np.array(self.model(x, **state), copy=False)
|
|
|
- return output_data
|
|
|
+ output = self.model(x, **state)
|
|
|
+ await self._eval_mlx(output)
|
|
|
+ return np.array(output, copy=False)
|
|
|
|
|
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
|
|
|
await self.ensure_shard(shard)
|
|
@@ -87,26 +98,25 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
return True
|
|
|
|
|
|
async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
|
|
|
- loop = asyncio.get_running_loop()
|
|
|
- nothin = await self.ensure_train(shard, loss, opt, lr)
|
|
|
+ await self.ensure_train(shard, loss, opt, lr)
|
|
|
+
|
|
|
def train_step(inp, tar, lng):
|
|
|
lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
|
|
|
gradlayers = grad['model']['layers']
|
|
|
self.session['opt'].update(self.model, grad)
|
|
|
- mx.eval(self.model.parameters(), self.session['opt'].state, lval)
|
|
|
- return lval, gradlayers
|
|
|
+ return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
|
|
|
|
|
|
x = mx.array(inputs)
|
|
|
y = mx.array(targets)
|
|
|
l = mx.array(lengths)
|
|
|
|
|
|
- score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
|
|
|
- #print(f"{score=}")
|
|
|
+ score, gradients, eval_args = train_step(x, y, l)
|
|
|
+ await self._eval_mlx(*eval_args)
|
|
|
|
|
|
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
|
|
|
- #print(layers[0])
|
|
|
-
|
|
|
- return score, np.array(layers[0]['input_layernorm'], copy=False)
|
|
|
+ first_layer = np.array(layers[0]['input_layernorm'], copy=False)
|
|
|
+ await self._eval_mlx(first_layer)
|
|
|
+ return score, first_layer
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
|
if self.shard == shard:
|
|
@@ -121,3 +131,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self.caches = OrderedDict()
|
|
|
self.session = {}
|
|
|
|
|
|
+ async def cleanup(self):
|
|
|
+ self._mlx_thread.shutdown(wait=True)
|
|
|
+
|