sharded_inference_engine.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import numpy as np
  2. import mlx.core as mx
  3. import mlx.nn as nn
  4. from mlx_lm.sample_utils import top_p_sampling
  5. import mlx.optimizers as optim
  6. from ..inference_engine import InferenceEngine
  7. from .sharded_utils import load_shard, get_image_from_str
  8. from .losses import loss_fns
  9. from ..shard import Shard
  10. from typing import Dict, Optional, Tuple
  11. from exo.download.shard_download import ShardDownloader
  12. import asyncio
  13. from concurrent.futures import ThreadPoolExecutor
  14. from functools import partial
  15. from collections import OrderedDict
  16. from mlx_lm.models.cache import make_prompt_cache
  17. def sample_logits(
  18. logits: mx.array,
  19. temp: float = 0.0,
  20. top_p: float = 1.0,
  21. logit_bias: Optional[Dict[int, float]] = None
  22. ) -> Tuple[mx.array, float]:
  23. if logit_bias:
  24. indices = mx.array(list(logit_bias.keys()))
  25. values = mx.array(list(logit_bias.values()))
  26. logits[:, indices] += values
  27. if temp == 0:
  28. token = mx.argmax(logits, axis=-1)
  29. else:
  30. if top_p > 0 and top_p < 1.0:
  31. token = top_p_sampling(logits, top_p, temp)
  32. else:
  33. token = mx.random.categorical(logits*(1/temp))
  34. return token
  35. class MLXDynamicShardInferenceEngine(InferenceEngine):
  36. def __init__(self, shard_downloader: ShardDownloader):
  37. self.shard = None
  38. self.shard_downloader = shard_downloader
  39. self.executor = ThreadPoolExecutor(max_workers=1)
  40. self.caches = OrderedDict()
  41. async def poll_state(self, request_id: str, max_caches=2):
  42. if request_id in self.caches:
  43. self.caches.move_to_end(request_id)
  44. else:
  45. newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
  46. if len(self.caches) > max_caches:
  47. self.caches.popitem(last=False)
  48. self.caches[request_id] = newcache
  49. return {"cache": self.caches[request_id]}
  50. async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
  51. y = mx.array(x)
  52. logits = y[:, -1, :]
  53. out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
  54. return out
  55. async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
  56. await self.ensure_shard(shard)
  57. tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
  58. return np.array(tokens)
  59. async def decode(self, shard: Shard, tokens) -> str:
  60. await self.ensure_shard(shard)
  61. tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
  62. return tokens
  63. async def save_checkpoint(self, shard: Shard, path: str):
  64. await self.ensure_shard(shard)
  65. await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
  66. async def load_checkpoint(self, shard: Shard, path: str):
  67. await self.ensure_shard(shard)
  68. await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
  69. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
  70. await self.ensure_shard(shard)
  71. loop = asyncio.get_running_loop()
  72. state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
  73. x = mx.array(input_data)
  74. if self.model.model_type != 'StableDiffusionPipeline':
  75. output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
  76. else:
  77. output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
  78. output_data = np.array(output_data)
  79. return output_data, inference_state
  80. async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
  81. await self.ensure_shard(shard)
  82. await self.save_session('loss', loss_fns[loss])
  83. loop = asyncio.get_running_loop()
  84. #print(f"evaluate in <- {inputs}")
  85. x = mx.array(inputs)
  86. y = mx.array(targets)
  87. l = mx.array(lengths)
  88. score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
  89. #print(f"evaluate out -> {score}")
  90. return score
  91. async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
  92. await self.ensure_shard(shard)
  93. if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
  94. await self.save_session('train_layers', trainable_layers)
  95. self.model.freeze()
  96. self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
  97. if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
  98. await self.save_session('lossname', loss)
  99. await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
  100. if 'opt' not in self.session:
  101. await self.save_session('opt', opt(lr))
  102. return True
  103. async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
  104. loop = asyncio.get_running_loop()
  105. nothin = await self.ensure_train(shard, loss, opt, lr)
  106. def train_step(inp, tar, lng):
  107. lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
  108. gradlayers = grad['model']['layers']
  109. self.session['opt'].update(self.model, grad)
  110. mx.eval(self.model.parameters(), self.session['opt'].state, lval)
  111. return lval, gradlayers
  112. x = mx.array(inputs)
  113. y = mx.array(targets)
  114. l = mx.array(lengths)
  115. score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
  116. #print(f"{score=}")
  117. layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
  118. #print(layers[0])
  119. return score, np.array(layers[0]['input_layernorm'])
  120. async def ensure_shard(self, shard: Shard):
  121. if self.shard == shard:
  122. return
  123. model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
  124. if self.shard != shard:
  125. def load_shard_wrapper():
  126. return asyncio.run(load_shard(model_path, shard))
  127. model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
  128. self.shard = shard
  129. self.model = model_shard
  130. self.caches = OrderedDict()
  131. self.session = {}