sharded_model.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from typing import Any, Dict, Generator, Optional, Tuple
  2. import mlx.core as mx
  3. import mlx.nn as nn
  4. from mlx_lm.models.base import KVCache
  5. from mlx_lm.sample_utils import top_p_sampling
  6. from ..shard import Shard
  7. class StatefulShardedModel:
  8. def __init__(self, shard: Shard, model: nn.Module):
  9. self.shard = shard
  10. self.model = model
  11. self.reset()
  12. def step(
  13. self,
  14. x,
  15. temp: float = 0.0,
  16. top_p: float = 1.0,
  17. logit_bias: Optional[Dict[int, float]] = None,
  18. ) -> Generator[Tuple[mx.array, mx.array], None, None]:
  19. def sample(logits: mx.array) -> Tuple[mx.array, float]:
  20. if logit_bias:
  21. indices = mx.array(list(logit_bias.keys()))
  22. values = mx.array(list(logit_bias.values()))
  23. logits[:, indices] += values
  24. if temp == 0:
  25. token = mx.argmax(logits, axis=-1)
  26. else:
  27. if top_p > 0 and top_p < 1.0:
  28. token = top_p_sampling(logits, top_p, temp)
  29. else:
  30. token = mx.random.categorical(logits * (1 / temp))
  31. return token
  32. y = x
  33. output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
  34. if self.shard.is_last_layer():
  35. logits = output[:, -1, :]
  36. y = sample(logits)
  37. return y
  38. else:
  39. return output
  40. def __call__(
  41. self,
  42. x,
  43. temp: float = 0.0,
  44. top_p: float = 1.0,
  45. logit_bias: Optional[Dict[int, float]] = None,
  46. ) -> Generator[Tuple[mx.array, mx.array], None, None]:
  47. return self.step(x, temp, top_p, logit_bias)
  48. def reset(self):
  49. kv_heads = (
  50. [self.model.n_kv_heads] * len(self.model.layers)
  51. if isinstance(self.model.n_kv_heads, int)
  52. else self.model.n_kv_heads
  53. )
  54. self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]