stateful_model.py 949 B

1234567891011121314151617181920212223242526272829303132
  1. from typing import Dict, Tuple, Optional
  2. from collections import OrderedDict
  3. import mlx.core as mx
  4. import mlx.nn as nn
  5. from mlx_lm.models.cache import make_prompt_cache
  6. import numpy as np
  7. from ..shard import Shard
  8. class StatefulModel(nn.Module):
  9. def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
  10. super().__init__()
  11. self.model = model
  12. self.max_kv_size = max_kv_size
  13. self.max_caches = max_caches
  14. self.caches = OrderedDict()
  15. def __call__(self, x, request_id: Optional[str] = None, use_cache: bool = True):
  16. #print(f"StatefulModel in <- {x}")
  17. if use_cache and request_id is not None:
  18. if request_id not in self.caches:
  19. self.init_cache(request_id)
  20. else:
  21. self.caches.move_to_end(request_id)
  22. cache = mx.array(self.caches[request_id])
  23. y = self.model(x, cache=cache)
  24. else:
  25. y = self.model(x)
  26. #print(f"StatefulModel out -> {y}")
  27. return y