1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- from typing import Any, Dict, Generator, Optional, Tuple
- import mlx.core as mx
- import mlx.nn as nn
- from mlx_lm.models.base import KVCache
- from mlx_lm.sample_utils import top_p_sampling
- from ..shard import Shard
- class StatefulShardedModel:
- def __init__(self, shard: Shard, model: nn.Module):
- self.shard = shard
- self.model = model
- self.reset()
- def step(
- self,
- x,
- temp: float = 0.0,
- top_p: float = 1.0,
- logit_bias: Optional[Dict[int, float]] = None,
- ) -> Generator[Tuple[mx.array, mx.array], None, None]:
- def sample(logits: mx.array) -> Tuple[mx.array, float]:
- if logit_bias:
- indices = mx.array(list(logit_bias.keys()))
- values = mx.array(list(logit_bias.values()))
- logits[:, indices] += values
- if temp == 0:
- token = mx.argmax(logits, axis=-1)
- else:
- if top_p > 0 and top_p < 1.0:
- token = top_p_sampling(logits, top_p, temp)
- else:
- token = mx.random.categorical(logits * (1 / temp))
- return token
- y = x
- output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
- if self.shard.is_last_layer():
- logits = output[:, -1, :]
- y = sample(logits)
- return y
- else:
- return output
- def __call__(
- self,
- x,
- temp: float = 0.0,
- top_p: float = 1.0,
- logit_bias: Optional[Dict[int, float]] = None,
- ) -> Generator[Tuple[mx.array, mx.array], None, None]:
- return self.step(x, temp, top_p, logit_bias)
- def reset(self):
- kv_heads = (
- [self.model.n_kv_heads] * len(self.model.layers)
- if isinstance(self.model.n_kv_heads, int)
- else self.model.n_kv_heads
- )
- self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|