sharded_inference_engine.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import numpy as np
  2. import mlx.core as mx
  3. import mlx.nn as nn
  4. from mlx_lm.sample_utils import top_p_sampling, make_sampler
  5. import mlx.optimizers as optim
  6. from ..inference_engine import InferenceEngine
  7. from .sharded_utils import load_shard, get_image_from_str
  8. from .losses import loss_fns
  9. from ..shard import Shard
  10. from typing import Dict, Optional, Tuple
  11. from exo.download.shard_download import ShardDownloader
  12. import asyncio
  13. from collections import OrderedDict
  14. from mlx_lm.models.cache import make_prompt_cache
  15. class MLXDynamicShardInferenceEngine(InferenceEngine):
  16. def __init__(self, shard_downloader: ShardDownloader):
  17. self.shard = None
  18. self.shard_downloader = shard_downloader
  19. self.caches = OrderedDict()
  20. self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
  21. self.sampler = make_sampler(*self.sampler_params)
  22. async def poll_state(self, request_id: str, max_caches=2):
  23. if request_id in self.caches:
  24. self.caches.move_to_end(request_id)
  25. else:
  26. newcache = make_prompt_cache(self.model)
  27. if len(self.caches) > max_caches:
  28. self.caches.popitem(last=False)
  29. self.caches[request_id] = newcache
  30. return {"cache": self.caches[request_id]}
  31. async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
  32. if (temp, top_p, 0.0, 1) != self.sampler_params:
  33. self.sampler_params = (temp, top_p, 0.0, 1)
  34. self.sampler = make_sampler(*self.sampler_params)
  35. logits = mx.array(x)
  36. logits = logits[:, -1, :]
  37. logprobs = logits - mx.logsumexp(logits, keepdims=True)
  38. return np.asarray(self.sampler(logprobs), dtype=int)
  39. async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
  40. await self.ensure_shard(shard)
  41. tokens = self.tokenizer.encode(prompt)
  42. return np.asarray(tokens)
  43. async def decode(self, shard: Shard, tokens) -> str:
  44. await self.ensure_shard(shard)
  45. return self.tokenizer.decode(tokens)
  46. async def save_checkpoint(self, shard: Shard, path: str):
  47. await self.ensure_shard(shard)
  48. self.model.save_weights(path)
  49. async def load_checkpoint(self, shard: Shard, path: str):
  50. await self.ensure_shard(shard)
  51. self.model.load_weights(path)
  52. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
  53. await self.ensure_shard(shard)
  54. state = await self.poll_state(request_id)
  55. x = mx.array(input_data)
  56. output_data = np.array(self.model(x, **state), copy=False)
  57. return output_data
  58. async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
  59. await self.ensure_shard(shard)
  60. await self.save_session('loss', loss_fns[loss])
  61. x = mx.array(inputs)
  62. y = mx.array(targets)
  63. l = mx.array(lengths)
  64. score = self.session['loss'](self.model, x, y, l)
  65. return score
  66. async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
  67. await self.ensure_shard(shard)
  68. if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
  69. await self.save_session('train_layers', trainable_layers)
  70. self.model.freeze()
  71. self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
  72. if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
  73. await self.save_session('lossname', loss)
  74. await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
  75. if 'opt' not in self.session:
  76. await self.save_session('opt', opt(lr))
  77. return True
  78. async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
  79. loop = asyncio.get_running_loop()
  80. nothin = await self.ensure_train(shard, loss, opt, lr)
  81. def train_step(inp, tar, lng):
  82. lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
  83. gradlayers = grad['model']['layers']
  84. self.session['opt'].update(self.model, grad)
  85. mx.eval(self.model.parameters(), self.session['opt'].state, lval)
  86. return lval, gradlayers
  87. x = mx.array(inputs)
  88. y = mx.array(targets)
  89. l = mx.array(lengths)
  90. score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
  91. #print(f"{score=}")
  92. layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
  93. #print(layers[0])
  94. return score, np.array(layers[0]['input_layernorm'], copy=False)
  95. async def ensure_shard(self, shard: Shard):
  96. if self.shard == shard:
  97. return
  98. model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
  99. if self.shard != shard:
  100. model_shard, self.tokenizer = await load_shard(model_path, shard)
  101. self.shard = shard
  102. self.model = model_shard
  103. self.caches = OrderedDict()
  104. self.session = {}