sharded_inference_engine.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import numpy as np
  2. import mlx.core as mx
  3. import mlx.nn as nn
  4. from ..inference_engine import InferenceEngine
  5. from .stateful_model import StatefulModel
  6. from .sharded_utils import load_shard, get_image_from_str
  7. from ..shard import Shard
  8. from typing import Dict, Optional, Tuple
  9. from exo.download.shard_download import ShardDownloader
  10. import asyncio
  11. from concurrent.futures import ThreadPoolExecutor
  12. from functools import partial
  13. def sample_logits(
  14. logits: mx.array,
  15. temp: float = 0.0,
  16. top_p: float = 1.0,
  17. logit_bias: Optional[Dict[int, float]] = None
  18. ) -> Tuple[mx.array, float]:
  19. if logit_bias:
  20. indices = mx.array(list(logit_bias.keys()))
  21. values = mx.array(list(logit_bias.values()))
  22. logits[:, indices] += values
  23. if temp == 0:
  24. token = mx.argmax(logits, axis=-1)
  25. else:
  26. if top_p > 0 and top_p < 1.0:
  27. token = top_p_sampling(logits, top_p, temp)
  28. else:
  29. token = mx.random.categorical(logits*(1/temp))
  30. return token
  31. class MLXDynamicShardInferenceEngine(InferenceEngine):
  32. def __init__(self, shard_downloader: ShardDownloader):
  33. self.shard = None
  34. self.shard_downloader = shard_downloader
  35. self.executor = ThreadPoolExecutor(max_workers=1)
  36. async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
  37. y = mx.array(x)
  38. logits = y[:, -1, :]
  39. out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
  40. return out
  41. async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
  42. await self.ensure_shard(shard)
  43. tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
  44. return np.array(tokens)
  45. async def decode(self, shard: Shard, tokens) -> str:
  46. await self.ensure_shard(shard)
  47. tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
  48. return tokens
  49. async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
  50. await self.ensure_shard(shard)
  51. output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
  52. return output_data
  53. async def ensure_shard(self, shard: Shard):
  54. if self.shard == shard:
  55. return
  56. model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
  57. if self.shard != shard:
  58. loop = asyncio.get_running_loop()
  59. def load_shard_wrapper():
  60. return asyncio.run(load_shard(model_path, shard))
  61. model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
  62. self.shard = shard
  63. self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard)