123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230 |
- # Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
- import glob
- import importlib
- import json
- import logging
- from pathlib import Path
- from typing import Optional, Tuple
- import mlx.core as mx
- import mlx.nn as nn
- from huggingface_hub import snapshot_download
- from huggingface_hub.utils._errors import RepositoryNotFoundError
- from mlx.utils import tree_flatten
- from transformers import PreTrainedTokenizer
- from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
- from mlx_lm.tuner.utils import apply_lora_layers
- from ..shard import Shard
- class ModelNotFoundError(Exception):
- def __init__(self, message):
- self.message = message
- super().__init__(self.message)
- MODEL_REMAPPING = {
- "mistral": "llama", # mistral is compatible with llama
- "phi-msft": "phixtral",
- }
- def _get_classes(config: dict):
- """
- Retrieve the model and model args classes based on the configuration.
- Args:
- config (dict): The model configuration.
- Returns:
- A tuple containing the Model class and the ModelArgs class.
- """
- model_type = config["model_type"]
- model_type = MODEL_REMAPPING.get(model_type, model_type)
- try:
- arch = importlib.import_module(f"inference.mlx.models.{model_type}")
- except ImportError:
- msg = f"Model type {model_type} not supported."
- logging.error(msg)
- raise ValueError(msg)
- return arch.Model, arch.ModelArgs
- def load_config(model_path: Path) -> dict:
- try:
- with open(model_path / "config.json", "r") as f:
- config = json.load(f)
- except FileNotFoundError:
- logging.error(f"Config file not found in {model_path}")
- raise
- return config
- def load_model_shard(
- model_path: Path,
- shard: Shard,
- lazy: bool = False,
- model_config: dict = {},
- ) -> nn.Module:
- """
- Load and initialize the model from a given path.
- Args:
- model_path (Path): The path to load the model from.
- lazy (bool): If False eval the model parameters to make sure they are
- loaded in memory before returning, otherwise they will be loaded
- when needed. Default: ``False``
- model_config(dict, optional): Configuration parameters for the model.
- Defaults to an empty dictionary.
- Returns:
- nn.Module: The loaded and initialized model.
- Raises:
- FileNotFoundError: If the weight files (.safetensors) are not found.
- ValueError: If the model class or args class are not found or cannot be instantiated.
- """
- config = load_config(model_path)
- config.update(model_config)
- # TODO hack
- config["model_type"] = f"sharded_{config['model_type']}"
- config["shard"] = {
- "model_id": model_path.name,
- "start_layer": shard.start_layer,
- "end_layer": shard.end_layer,
- "n_layers": shard.n_layers
- }
- weight_files = glob.glob(str(model_path / "model*.safetensors"))
- if not weight_files:
- # Try weight for back-compat
- weight_files = glob.glob(str(model_path / "weight*.safetensors"))
- if not weight_files:
- logging.error(f"No safetensors found in {model_path}")
- raise FileNotFoundError(f"No safetensors found in {model_path}")
- weights = {}
- for wf in weight_files:
- weights.update(mx.load(wf))
- model_class, model_args_class = _get_classes(config=config)
- model_args = model_args_class.from_dict(config)
- model = model_class(model_args)
- if hasattr(model, "sanitize"):
- weights = model.sanitize(weights)
- if (quantization := config.get("quantization", None)) is not None:
- # Handle legacy models which may not have everything quantized
- def class_predicate(p, m):
- if not hasattr(m, "to_quantized"):
- return False
- return f"{p}.scales" in weights
- nn.quantize(
- model,
- **quantization,
- class_predicate=class_predicate,
- )
- filtered_weights = {}
- for k, v in weights.items():
- if k.startswith("model.layers."):
- layer_num = int(k.split('.')[2])
- if shard.start_layer <= layer_num <= shard.end_layer:
- new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
- filtered_weights[new_key] = v
- else:
- filtered_weights[k] = v
- weights = filtered_weights
- model.load_weights(list(weights.items()), strict=False)
- if not lazy:
- mx.eval(model.parameters())
- model.eval()
- return model
- def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
- """
- Ensures the model is available locally. If the path does not exist locally,
- it is downloaded from the Hugging Face Hub.
- Args:
- path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
- revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
- Returns:
- Path: The path to the model.
- """
- model_path = Path(path_or_hf_repo)
- if not model_path.exists():
- try:
- model_path = Path(
- snapshot_download(
- repo_id=path_or_hf_repo,
- revision=revision,
- allow_patterns=[
- "*.json",
- "*.safetensors",
- "*.py",
- "tokenizer.model",
- "*.tiktoken",
- "*.txt",
- ],
- )
- )
- except RepositoryNotFoundError:
- raise ModelNotFoundError(
- f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
- "Please make sure you specified the local path or Hugging Face"
- " repo id correctly.\nIf you are trying to access a private or"
- " gated Hugging Face repo, make sure you are authenticated:\n"
- "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
- ) from None
- return model_path
- def load_shard(
- path_or_hf_repo: str,
- shard: Shard,
- tokenizer_config={},
- model_config={},
- adapter_path: Optional[str] = None,
- lazy: bool = False,
- ) -> Tuple[nn.Module, TokenizerWrapper]:
- """
- Load the model and tokenizer from a given path or a huggingface repository.
- Args:
- path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
- tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
- Defaults to an empty dictionary.
- model_config(dict, optional): Configuration parameters specifically for the model.
- Defaults to an empty dictionary.
- adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
- to the model. Default: ``None``.
- lazy (bool): If False eval the model parameters to make sure they are
- loaded in memory before returning, otherwise they will be loaded
- when needed. Default: ``False``
- Returns:
- Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
- Raises:
- FileNotFoundError: If config file or safetensors are not found.
- ValueError: If model class or args class are not found.
- """
- model_path = get_model_path(path_or_hf_repo)
- model = load_model_shard(model_path, shard, lazy, model_config)
- if adapter_path is not None:
- model = apply_lora_layers(model, adapter_path)
- model.eval()
- tokenizer = load_tokenizer(model_path, tokenizer_config)
- return model, tokenizer
|