sharded_inference_engine.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import numpy as np
  2. import mlx.core as mx
  3. import mlx.nn as nn
  4. from mlx_lm.sample_utils import make_sampler
  5. import mlx.optimizers as optim
  6. from ..inference_engine import InferenceEngine
  7. from .sharded_utils import load_model_shard, resolve_tokenizer
  8. from .losses import loss_fns
  9. from ..shard import Shard
  10. from typing import Dict, Optional, Tuple
  11. from exo.download.shard_download import ShardDownloader
  12. import asyncio
  13. from collections import OrderedDict
  14. from mlx_lm.models.cache import make_prompt_cache
  15. from concurrent.futures import ThreadPoolExecutor
  16. class MLXDynamicShardInferenceEngine(InferenceEngine):
  17. def __init__(self, shard_downloader: ShardDownloader):
  18. self.shard = None
  19. self.shard_downloader = shard_downloader
  20. self.caches = OrderedDict()
  21. self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
  22. self.sampler = make_sampler(*self.sampler_params)
  23. self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
  24. self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
  25. self.session = {}
  26. async def _eval_mlx(self, *args):
  27. await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
  28. async def poll_state(self, request_id: str, max_caches=2):
  29. if request_id in self.caches:
  30. self.caches.move_to_end(request_id)
  31. else:
  32. newcache = make_prompt_cache(self.model)
  33. if len(self.caches) > max_caches:
  34. self.caches.popitem(last=False)
  35. self.caches[request_id] = newcache
  36. return {"cache": self.caches[request_id]}
  37. async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
  38. if (temp, top_p, 0.0, 1) != self.sampler_params:
  39. self.sampler_params = (temp, top_p, 0.0, 1)
  40. self.sampler = make_sampler(*self.sampler_params)
  41. logits = mx.array(x)
  42. logits = logits[:, -1, :]
  43. logprobs = logits - mx.logsumexp(logits, keepdims=True)
  44. result = self.sampler(logprobs)
  45. await self._eval_mlx(result)
  46. return np.asarray(result, dtype=int)
  47. async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
  48. await self.ensure_shard(shard)
  49. return np.asarray(
  50. await asyncio.get_running_loop().run_in_executor(
  51. self._tokenizer_thread,
  52. self.tokenizer.encode,
  53. prompt
  54. )
  55. )
  56. async def decode(self, shard: Shard, tokens) -> str:
  57. await self.ensure_shard(shard)
  58. return await asyncio.get_running_loop().run_in_executor(
  59. self._tokenizer_thread,
  60. self.tokenizer.decode,
  61. tokens
  62. )
  63. async def save_checkpoint(self, shard: Shard, path: str):
  64. await self.ensure_shard(shard)
  65. await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.save_weights(path))
  66. async def load_checkpoint(self, shard: Shard, path: str):
  67. await self.ensure_shard(shard)
  68. await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.load_weights(path))
  69. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
  70. await self.ensure_shard(shard)
  71. state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
  72. x = mx.array(input_data)
  73. if self.model.model_type != 'StableDiffusionPipeline':
  74. output_data = await asyncio.get_running_loop().run_in_executor(
  75. self._mlx_thread,
  76. lambda: self.model(x, **state, **(inference_state or {}))
  77. )
  78. inference_state = None
  79. else:
  80. result = await asyncio.get_running_loop().run_in_executor(
  81. self._mlx_thread,
  82. lambda: self.model(x, **state, **(inference_state or {}))
  83. )
  84. output_data, inference_state = result
  85. await self._eval_mlx(output_data)
  86. output_data = await asyncio.get_running_loop().run_in_executor(
  87. self._mlx_thread,
  88. lambda: np.array(output_data, copy=False)
  89. )
  90. return output_data, inference_state
  91. async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
  92. await self.ensure_shard(shard)
  93. await self.save_session('loss', loss_fns[loss])
  94. x = mx.array(inputs)
  95. y = mx.array(targets)
  96. l = mx.array(lengths)
  97. score = await asyncio.get_running_loop().run_in_executor(
  98. self._mlx_thread,
  99. lambda: self.session['loss'](self.model, x, y, l)
  100. )
  101. return score
  102. async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
  103. await self.ensure_shard(shard)
  104. if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
  105. await self.save_session('train_layers', trainable_layers)
  106. def freeze_unfreeze():
  107. self.model.freeze()
  108. self.model.apply_to_modules(
  109. lambda k, v: v.unfreeze() if any(k.endswith(layer_name) for layer_name in trainable_layers) else None
  110. )
  111. await asyncio.get_running_loop().run_in_executor(self._mlx_thread, freeze_unfreeze)
  112. if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
  113. await self.save_session('lossname', loss)
  114. await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
  115. if 'opt' not in self.session:
  116. await self.save_session('opt', opt(lr))
  117. return True
  118. async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
  119. await self.ensure_train(shard, loss, opt, lr)
  120. def train_step(inp, tar, lng):
  121. lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
  122. gradlayers = grad['model']['layers']
  123. self.session['opt'].update(self.model, grad)
  124. return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
  125. x = mx.array(inputs)
  126. y = mx.array(targets)
  127. l = mx.array(lengths)
  128. score, gradients, eval_args = await asyncio.get_running_loop().run_in_executor(
  129. self._mlx_thread,
  130. lambda: train_step(x, y, l)
  131. )
  132. await self._eval_mlx(*eval_args)
  133. layers = [{k: v["weight"] for k, v in layer.items() if 'weight' in v} for layer in gradients if layer]
  134. first_layer = np.array(layers[0]['input_layernorm'], copy=False)
  135. await self._eval_mlx(first_layer)
  136. return score, first_layer
  137. async def ensure_shard(self, shard: Shard):
  138. if self.shard == shard:
  139. return
  140. model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
  141. if self.shard != shard:
  142. model_shard = await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: load_model_shard(model_path, shard, lazy=False))
  143. if hasattr(model_shard, "tokenizer"):
  144. self.tokenizer = model_shard.tokenizer
  145. else:
  146. self.tokenizer = await resolve_tokenizer(model_path)
  147. self.shard = shard
  148. self.model = model_shard
  149. self.caches = OrderedDict()
  150. self.session = {}
  151. async def cleanup(self):
  152. self._mlx_thread.shutdown(wait=True)