|
@@ -1,7 +1,7 @@
|
|
|
import numpy as np
|
|
|
import mlx.core as mx
|
|
|
import mlx.nn as nn
|
|
|
-from mlx_lm.sample_utils import top_p_sampling
|
|
|
+from mlx_lm.sample_utils import top_p_sampling, make_sampler
|
|
|
import mlx.optimizers as optim
|
|
|
from ..inference_engine import InferenceEngine
|
|
|
from .sharded_utils import load_shard, get_image_from_str
|
|
@@ -10,8 +10,6 @@ from ..shard import Shard
|
|
|
from typing import Dict, Optional, Tuple
|
|
|
from exo.download.shard_download import ShardDownloader
|
|
|
import asyncio
|
|
|
-from concurrent.futures import ThreadPoolExecutor
|
|
|
-from functools import partial
|
|
|
from collections import OrderedDict
|
|
|
from mlx_lm.models.cache import make_prompt_cache
|
|
|
|
|
@@ -40,61 +38,60 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
def __init__(self, shard_downloader: ShardDownloader):
|
|
|
self.shard = None
|
|
|
self.shard_downloader = shard_downloader
|
|
|
- self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
self.caches = OrderedDict()
|
|
|
+ self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
|
|
|
+ self.sampler = make_sampler(*self.sampler_params)
|
|
|
|
|
|
async def poll_state(self, request_id: str, max_caches=2):
|
|
|
if request_id in self.caches:
|
|
|
self.caches.move_to_end(request_id)
|
|
|
else:
|
|
|
- newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
|
|
|
+ newcache = make_prompt_cache(self.model)
|
|
|
if len(self.caches) > max_caches:
|
|
|
self.caches.popitem(last=False)
|
|
|
self.caches[request_id] = newcache
|
|
|
return {"cache": self.caches[request_id]}
|
|
|
|
|
|
- async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
|
|
|
- y = mx.array(x)
|
|
|
- logits = y[:, -1, :]
|
|
|
- out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
|
|
|
- return out
|
|
|
+ async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
|
|
|
+ if (temp, top_p, 0.0, 1) != self.sampler_params:
|
|
|
+ self.sampler_params = (temp, top_p, 0.0, 1)
|
|
|
+ self.sampler = make_sampler(*self.sampler_params)
|
|
|
+ logits = mx.array(x)
|
|
|
+ logits = logits[:, -1, :]
|
|
|
+ logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
|
|
+ return np.asarray(self.sampler(logprobs), dtype=int)
|
|
|
|
|
|
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
- tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
|
|
- return np.array(tokens)
|
|
|
+ tokens = self.tokenizer.encode(prompt)
|
|
|
+ return np.asarray(tokens)
|
|
|
|
|
|
async def decode(self, shard: Shard, tokens) -> str:
|
|
|
await self.ensure_shard(shard)
|
|
|
- tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
|
|
- return tokens
|
|
|
+ return self.tokenizer.decode(tokens)
|
|
|
|
|
|
async def save_checkpoint(self, shard: Shard, path: str):
|
|
|
await self.ensure_shard(shard)
|
|
|
- await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
|
|
|
+ self.model.save_weights(path)
|
|
|
|
|
|
async def load_checkpoint(self, shard: Shard, path: str):
|
|
|
await self.ensure_shard(shard)
|
|
|
- await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
|
|
|
+ self.model.load_weights(path)
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
- loop = asyncio.get_running_loop()
|
|
|
state = await self.poll_state(request_id)
|
|
|
x = mx.array(input_data)
|
|
|
- output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
|
|
|
+ output_data = np.array(self.model(x, **state), copy=False)
|
|
|
return output_data
|
|
|
|
|
|
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
|
|
|
await self.ensure_shard(shard)
|
|
|
await self.save_session('loss', loss_fns[loss])
|
|
|
- loop = asyncio.get_running_loop()
|
|
|
- #print(f"evaluate in <- {inputs}")
|
|
|
x = mx.array(inputs)
|
|
|
y = mx.array(targets)
|
|
|
l = mx.array(lengths)
|
|
|
- score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
|
|
|
- #print(f"evaluate out -> {score}")
|
|
|
+ score = self.session['loss'](self.model, x, y, l)
|
|
|
return score
|
|
|
|
|
|
async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
|
|
@@ -130,7 +127,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
|
|
|
#print(layers[0])
|
|
|
|
|
|
- return score, np.array(layers[0]['input_layernorm'])
|
|
|
+ return score, np.array(layers[0]['input_layernorm'], copy=False)
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
|
if self.shard == shard:
|
|
@@ -139,11 +136,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
|
|
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
|
|
|
|
|
if self.shard != shard:
|
|
|
-
|
|
|
- def load_shard_wrapper():
|
|
|
- return asyncio.run(load_shard(model_path, shard))
|
|
|
-
|
|
|
- model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
|
|
|
+ model_shard, self.tokenizer = await load_shard(model_path, shard)
|
|
|
self.shard = shard
|
|
|
self.model = model_shard
|
|
|
self.caches = OrderedDict()
|