123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- # Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
- import glob
- import importlib
- import json
- import logging
- import asyncio
- import aiohttp
- from functools import partial
- from pathlib import Path
- from typing import Optional, Tuple, Union, List, Callable
- from PIL import Image
- from io import BytesIO
- import base64
- import mlx.core as mx
- import mlx.nn as nn
- from transformers import AutoProcessor
- from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
- from mlx_lm.tuner.utils import apply_lora_layers
- from exo import DEBUG
- from ..shard import Shard
- class ModelNotFoundError(Exception):
- def __init__(self, message):
- self.message = message
- super().__init__(self.message)
- MODEL_REMAPPING = {
- "mistral": "llama", # mistral is compatible with llama
- "phi-msft": "phixtral",
- }
- def _get_classes(config: dict):
- """
- Retrieve the model and model args classes based on the configuration.
- Args:
- config (dict): The model configuration.
- Returns:
- A tuple containing the Model class and the ModelArgs class.
- """
- model_type = config["model_type"]
- model_type = MODEL_REMAPPING.get(model_type, model_type)
- try:
- arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
- except ImportError:
- msg = f"Model type {model_type} not supported."
- logging.error(msg)
- raise ValueError(msg)
- return arch.Model, arch.ModelArgs
- def load_config(model_path: Path) -> dict:
- try:
- with open(model_path/"config.json", "r") as f:
- config = json.load(f)
- except FileNotFoundError:
- logging.error(f"Config file not found in {model_path}")
- raise
- return config
- def load_model_shard(
- model_path: Path,
- shard: Shard,
- lazy: bool = False,
- model_config: dict = {},
- ) -> nn.Module:
- """
- Load and initialize the model from a given path.
- Args:
- model_path (Path): The path to load the model from.
- lazy (bool): If False eval the model parameters to make sure they are
- loaded in memory before returning, otherwise they will be loaded
- when needed. Default: ``False``
- model_config(dict, optional): Configuration parameters for the model.
- Defaults to an empty dictionary.
- Returns:
- nn.Module: The loaded and initialized model.
- Raises:
- FileNotFoundError: If the weight files (.safetensors) are not found.
- ValueError: If the model class or args class are not found or cannot be instantiated.
- """
- config = load_config(model_path)
- config.update(model_config)
- # TODO hack
- config["shard"] = {
- "model_id": model_path.name,
- "start_layer": shard.start_layer,
- "end_layer": shard.end_layer,
- "n_layers": shard.n_layers,
- }
- weight_files = glob.glob(str(model_path/"model*.safetensors"))
- if not weight_files:
- # Try weight for back-compat
- weight_files = glob.glob(str(model_path/"weight*.safetensors"))
- if not weight_files:
- logging.error(f"No safetensors found in {model_path}")
- raise FileNotFoundError(f"No safetensors found in {model_path}")
- weights = {}
- for wf in sorted(weight_files):
- if DEBUG >= 8:
- layer_nums = set()
- for k in mx.load(wf):
- if k.startswith("model.layers."):
- layer_num = int(k.split(".")[2])
- layer_nums.add(layer_num)
- if k.startswith("language_model.model.layers."):
- layer_num = int(k.split(".")[3])
- layer_nums.add(layer_num)
- print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},")
- weights.update(mx.load(wf))
- model_class, model_args_class = _get_classes(config=config)
- model_args = model_args_class.from_dict(config)
- model = model_class(model_args)
- if hasattr(model, "sanitize"):
- weights = model.sanitize(weights)
- if (quantization := config.get("quantization", None)) is not None:
- # Handle legacy models which may not have everything quantized
- def class_predicate(p, m):
- if not hasattr(m, "to_quantized"):
- return False
- return f"{p}.scales" in weights
- nn.quantize(
- model,
- **quantization,
- class_predicate=class_predicate,
- )
- model.load_weights(list(weights.items()), strict=True)
- if not lazy:
- mx.eval(model.parameters())
- model.eval()
- return model
- async def load_shard(
- model_path: str,
- shard: Shard,
- tokenizer_config={},
- model_config={},
- adapter_path: Optional[str] = None,
- lazy: bool = False,
- ) -> Tuple[nn.Module, TokenizerWrapper]:
- model = load_model_shard(model_path, shard, lazy, model_config)
- if adapter_path is not None:
- model = apply_lora_layers(model, adapter_path)
- model.eval()
- # TODO: figure out a generic solution
- if model.model_type == "llava":
- processor = AutoProcessor.from_pretrained(model_path)
- processor.eos_token_id = processor.tokenizer.eos_token_id
- processor.encode = processor.tokenizer.encode
- return model, processor
- else:
- tokenizer = load_tokenizer(model_path, tokenizer_config)
- return model, tokenizer
- async def get_image_from_str(_image_str: str):
- image_str = _image_str.strip()
- if image_str.startswith("http"):
- async with aiohttp.ClientSession() as session:
- async with session.get(image_str, timeout=10) as response:
- content = await response.read()
- return Image.open(BytesIO(content)).convert("RGB")
- elif image_str.startswith("data:image/"):
- # Extract the image format and base64 data
- format_prefix, base64_data = image_str.split(";base64,")
- image_format = format_prefix.split("/")[1].lower()
- if DEBUG >= 2: print(f"{image_str=} {image_format=}")
- imgdata = base64.b64decode(base64_data)
- img = Image.open(BytesIO(imgdata))
- # Convert to RGB if not already
- if img.mode != "RGB":
- img = img.convert("RGB")
- return img
- else:
- raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
|