sharded_utils.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
  2. import glob
  3. import importlib
  4. import json
  5. import logging
  6. from pathlib import Path
  7. from typing import Optional, Tuple
  8. import mlx.core as mx
  9. import mlx.nn as nn
  10. from huggingface_hub import snapshot_download
  11. from huggingface_hub.utils._errors import RepositoryNotFoundError
  12. from mlx.utils import tree_flatten
  13. from transformers import PreTrainedTokenizer
  14. from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
  15. from mlx_lm.tuner.utils import apply_lora_layers
  16. from ..shard import Shard
  17. class ModelNotFoundError(Exception):
  18. def __init__(self, message):
  19. self.message = message
  20. super().__init__(self.message)
  21. MODEL_REMAPPING = {
  22. "mistral": "llama", # mistral is compatible with llama
  23. "phi-msft": "phixtral",
  24. }
  25. def _get_classes(config: dict):
  26. """
  27. Retrieve the model and model args classes based on the configuration.
  28. Args:
  29. config (dict): The model configuration.
  30. Returns:
  31. A tuple containing the Model class and the ModelArgs class.
  32. """
  33. model_type = config["model_type"]
  34. model_type = MODEL_REMAPPING.get(model_type, model_type)
  35. try:
  36. arch = importlib.import_module(f"inference.mlx.models.{model_type}")
  37. except ImportError:
  38. msg = f"Model type {model_type} not supported."
  39. logging.error(msg)
  40. raise ValueError(msg)
  41. return arch.Model, arch.ModelArgs
  42. def load_config(model_path: Path) -> dict:
  43. try:
  44. with open(model_path / "config.json", "r") as f:
  45. config = json.load(f)
  46. except FileNotFoundError:
  47. logging.error(f"Config file not found in {model_path}")
  48. raise
  49. return config
  50. def load_model_shard(
  51. model_path: Path,
  52. shard: Shard,
  53. lazy: bool = False,
  54. model_config: dict = {},
  55. ) -> nn.Module:
  56. """
  57. Load and initialize the model from a given path.
  58. Args:
  59. model_path (Path): The path to load the model from.
  60. lazy (bool): If False eval the model parameters to make sure they are
  61. loaded in memory before returning, otherwise they will be loaded
  62. when needed. Default: ``False``
  63. model_config(dict, optional): Configuration parameters for the model.
  64. Defaults to an empty dictionary.
  65. Returns:
  66. nn.Module: The loaded and initialized model.
  67. Raises:
  68. FileNotFoundError: If the weight files (.safetensors) are not found.
  69. ValueError: If the model class or args class are not found or cannot be instantiated.
  70. """
  71. config = load_config(model_path)
  72. config.update(model_config)
  73. # TODO hack
  74. config["model_type"] = f"sharded_{config['model_type']}"
  75. config["shard"] = {
  76. "model_id": model_path.name,
  77. "start_layer": shard.start_layer,
  78. "end_layer": shard.end_layer,
  79. "n_layers": shard.n_layers
  80. }
  81. weight_files = glob.glob(str(model_path / "model*.safetensors"))
  82. if not weight_files:
  83. # Try weight for back-compat
  84. weight_files = glob.glob(str(model_path / "weight*.safetensors"))
  85. if not weight_files:
  86. logging.error(f"No safetensors found in {model_path}")
  87. raise FileNotFoundError(f"No safetensors found in {model_path}")
  88. weights = {}
  89. for wf in weight_files:
  90. weights.update(mx.load(wf))
  91. model_class, model_args_class = _get_classes(config=config)
  92. model_args = model_args_class.from_dict(config)
  93. model = model_class(model_args)
  94. if hasattr(model, "sanitize"):
  95. weights = model.sanitize(weights)
  96. if (quantization := config.get("quantization", None)) is not None:
  97. # Handle legacy models which may not have everything quantized
  98. def class_predicate(p, m):
  99. if not hasattr(m, "to_quantized"):
  100. return False
  101. return f"{p}.scales" in weights
  102. nn.quantize(
  103. model,
  104. **quantization,
  105. class_predicate=class_predicate,
  106. )
  107. filtered_weights = {}
  108. for k, v in weights.items():
  109. if k.startswith("model.layers."):
  110. layer_num = int(k.split('.')[2])
  111. if shard.start_layer <= layer_num <= shard.end_layer:
  112. new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
  113. filtered_weights[new_key] = v
  114. else:
  115. filtered_weights[k] = v
  116. weights = filtered_weights
  117. model.load_weights(list(weights.items()), strict=False)
  118. if not lazy:
  119. mx.eval(model.parameters())
  120. model.eval()
  121. return model
  122. def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
  123. """
  124. Ensures the model is available locally. If the path does not exist locally,
  125. it is downloaded from the Hugging Face Hub.
  126. Args:
  127. path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
  128. revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
  129. Returns:
  130. Path: The path to the model.
  131. """
  132. model_path = Path(path_or_hf_repo)
  133. if not model_path.exists():
  134. try:
  135. model_path = Path(
  136. snapshot_download(
  137. repo_id=path_or_hf_repo,
  138. revision=revision,
  139. allow_patterns=[
  140. "*.json",
  141. "*.safetensors",
  142. "*.py",
  143. "tokenizer.model",
  144. "*.tiktoken",
  145. "*.txt",
  146. ],
  147. )
  148. )
  149. except RepositoryNotFoundError:
  150. raise ModelNotFoundError(
  151. f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
  152. "Please make sure you specified the local path or Hugging Face"
  153. " repo id correctly.\nIf you are trying to access a private or"
  154. " gated Hugging Face repo, make sure you are authenticated:\n"
  155. "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
  156. ) from None
  157. return model_path
  158. def load_shard(
  159. path_or_hf_repo: str,
  160. shard: Shard,
  161. tokenizer_config={},
  162. model_config={},
  163. adapter_path: Optional[str] = None,
  164. lazy: bool = False,
  165. ) -> Tuple[nn.Module, TokenizerWrapper]:
  166. """
  167. Load the model and tokenizer from a given path or a huggingface repository.
  168. Args:
  169. path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
  170. tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
  171. Defaults to an empty dictionary.
  172. model_config(dict, optional): Configuration parameters specifically for the model.
  173. Defaults to an empty dictionary.
  174. adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
  175. to the model. Default: ``None``.
  176. lazy (bool): If False eval the model parameters to make sure they are
  177. loaded in memory before returning, otherwise they will be loaded
  178. when needed. Default: ``False``
  179. Returns:
  180. Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
  181. Raises:
  182. FileNotFoundError: If config file or safetensors are not found.
  183. ValueError: If model class or args class are not found.
  184. """
  185. model_path = get_model_path(path_or_hf_repo)
  186. model = load_model_shard(model_path, shard, lazy, model_config)
  187. if adapter_path is not None:
  188. model = apply_lora_layers(model, adapter_path)
  189. model.eval()
  190. tokenizer = load_tokenizer(model_path, tokenizer_config)
  191. return model, tokenizer