sharded_inference_engine.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import numpy as np
  2. import mlx.core as mx
  3. import mlx.nn as nn
  4. from mlx_lm.sample_utils import top_p_sampling
  5. import mlx.optimizers as optim
  6. from ..inference_engine import InferenceEngine
  7. from .sharded_utils import load_shard, get_image_from_str
  8. from .losses import loss_fns
  9. from ..shard import Shard
  10. from typing import Dict, Optional, Tuple
  11. from exo.download.shard_download import ShardDownloader
  12. import asyncio
  13. from concurrent.futures import ThreadPoolExecutor
  14. from functools import partial
  15. from collections import OrderedDict
  16. from mlx_lm.models.cache import make_prompt_cache
  17. def sample_logits(
  18. logits: mx.array,
  19. temp: float = 0.0,
  20. top_p: float = 1.0,
  21. logit_bias: Optional[Dict[int, float]] = None
  22. ) -> Tuple[mx.array, float]:
  23. if logit_bias:
  24. indices = mx.array(list(logit_bias.keys()))
  25. values = mx.array(list(logit_bias.values()))
  26. logits[:, indices] += values
  27. if temp == 0:
  28. token = mx.argmax(logits, axis=-1)
  29. else:
  30. if top_p > 0 and top_p < 1.0:
  31. token = top_p_sampling(logits, top_p, temp)
  32. else:
  33. token = mx.random.categorical(logits*(1/temp))
  34. return token
  35. class MLXDynamicShardInferenceEngine(InferenceEngine):
  36. def __init__(self, shard_downloader: ShardDownloader):
  37. self.shard = None
  38. self.shard_downloader = shard_downloader
  39. self.executor = ThreadPoolExecutor(max_workers=1)
  40. self.caches = OrderedDict()
  41. async def poll_state(self, request_id: str, max_caches=2):
  42. if request_id in self.caches:
  43. self.caches.move_to_end(request_id)
  44. else:
  45. newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
  46. if len(self.caches) > max_caches:
  47. self.caches.popitem(last=False)
  48. self.caches[request_id] = newcache
  49. return {"cache": self.caches[request_id]}
  50. async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
  51. y = mx.array(x)
  52. logits = y[:, -1, :]
  53. out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
  54. return out
  55. async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
  56. await self.ensure_shard(shard)
  57. tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
  58. return np.array(tokens)
  59. async def decode(self, shard: Shard, tokens) -> str:
  60. await self.ensure_shard(shard)
  61. tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
  62. return tokens
  63. async def save_checkpoint(self, shard: Shard, path: str):
  64. await self.ensure_shard(shard)
  65. await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
  66. async def load_checkpoint(self, shard: Shard, path: str):
  67. await self.ensure_shard(shard)
  68. await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
  69. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
  70. await self.ensure_shard(shard)
  71. loop = asyncio.get_running_loop()
  72. state = await self.poll_state(request_id)
  73. x = mx.array(input_data)
  74. output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
  75. return output_data
  76. async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
  77. await self.ensure_shard(shard)
  78. await self.save_session('loss', loss_fns[loss])
  79. loop = asyncio.get_running_loop()
  80. #print(f"evaluate in <- {inputs}")
  81. x = mx.array(inputs)
  82. y = mx.array(targets)
  83. l = mx.array(lengths)
  84. score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
  85. #print(f"evaluate out -> {score}")
  86. return score
  87. async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
  88. await self.ensure_shard(shard)
  89. if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
  90. await self.save_session('train_layers', trainable_layers)
  91. self.model.freeze()
  92. self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
  93. if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
  94. await self.save_session('lossname', loss)
  95. await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
  96. if 'opt' not in self.session:
  97. await self.save_session('opt', opt(lr))
  98. return True
  99. async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
  100. loop = asyncio.get_running_loop()
  101. nothin = await self.ensure_train(shard, loss, opt, lr)
  102. def train_step(inp, tar, lng):
  103. lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
  104. gradlayers = grad['model']['layers']
  105. self.session['opt'].update(self.model, grad)
  106. mx.eval(self.model.parameters(), self.session['opt'].state, lval)
  107. return lval, gradlayers
  108. x = mx.array(inputs)
  109. y = mx.array(targets)
  110. l = mx.array(lengths)
  111. score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
  112. #print(f"{score=}")
  113. layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
  114. #print(layers[0])
  115. return score, np.array(layers[0]['input_layernorm'])
  116. async def ensure_shard(self, shard: Shard):
  117. if self.shard == shard:
  118. return
  119. model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
  120. if self.shard != shard:
  121. def load_shard_wrapper():
  122. return asyncio.run(load_shard(model_path, shard))
  123. model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
  124. self.shard = shard
  125. self.model = model_shard
  126. self.caches = OrderedDict()
  127. self.session = {}