sharded_utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
  2. import glob
  3. import importlib
  4. import json
  5. import logging
  6. import asyncio
  7. import aiohttp
  8. from functools import partial
  9. from pathlib import Path
  10. from typing import Optional, Tuple, Union, List, Callable
  11. from PIL import Image
  12. from io import BytesIO
  13. import base64
  14. import mlx.core as mx
  15. import mlx.nn as nn
  16. from transformers import AutoProcessor
  17. from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
  18. from mlx_lm.tuner.utils import apply_lora_layers
  19. from exo import DEBUG
  20. from ..shard import Shard
  21. class ModelNotFoundError(Exception):
  22. def __init__(self, message):
  23. self.message = message
  24. super().__init__(self.message)
  25. MODEL_REMAPPING = {
  26. "mistral": "llama", # mistral is compatible with llama
  27. "phi-msft": "phixtral",
  28. }
  29. def _get_classes(config: dict):
  30. """
  31. Retrieve the model and model args classes based on the configuration.
  32. Args:
  33. config (dict): The model configuration.
  34. Returns:
  35. A tuple containing the Model class and the ModelArgs class.
  36. """
  37. model_type = config["model_type"]
  38. model_type = MODEL_REMAPPING.get(model_type, model_type)
  39. try:
  40. arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
  41. except ImportError:
  42. msg = f"Model type {model_type} not supported."
  43. logging.error(msg)
  44. raise ValueError(msg)
  45. return arch.Model, arch.ModelArgs
  46. def load_config(model_path: Path) -> dict:
  47. try:
  48. with open(model_path/"config.json", "r") as f:
  49. config = json.load(f)
  50. except FileNotFoundError:
  51. logging.error(f"Config file not found in {model_path}")
  52. raise
  53. return config
  54. def load_model_shard(
  55. model_path: Path,
  56. shard: Shard,
  57. lazy: bool = False,
  58. model_config: dict = {},
  59. ) -> nn.Module:
  60. """
  61. Load and initialize the model from a given path.
  62. Args:
  63. model_path (Path): The path to load the model from.
  64. lazy (bool): If False eval the model parameters to make sure they are
  65. loaded in memory before returning, otherwise they will be loaded
  66. when needed. Default: ``False``
  67. model_config(dict, optional): Configuration parameters for the model.
  68. Defaults to an empty dictionary.
  69. Returns:
  70. nn.Module: The loaded and initialized model.
  71. Raises:
  72. FileNotFoundError: If the weight files (.safetensors) are not found.
  73. ValueError: If the model class or args class are not found or cannot be instantiated.
  74. """
  75. config = load_config(model_path)
  76. config.update(model_config)
  77. # TODO hack
  78. config["shard"] = {
  79. "model_id": model_path.name,
  80. "start_layer": shard.start_layer,
  81. "end_layer": shard.end_layer,
  82. "n_layers": shard.n_layers,
  83. }
  84. weight_files = glob.glob(str(model_path/"model*.safetensors"))
  85. if not weight_files:
  86. # Try weight for back-compat
  87. weight_files = glob.glob(str(model_path/"weight*.safetensors"))
  88. if not weight_files:
  89. logging.error(f"No safetensors found in {model_path}")
  90. raise FileNotFoundError(f"No safetensors found in {model_path}")
  91. weights = {}
  92. for wf in sorted(weight_files):
  93. if DEBUG >= 8:
  94. layer_nums = set()
  95. for k in mx.load(wf):
  96. if k.startswith("model.layers."):
  97. layer_num = int(k.split(".")[2])
  98. layer_nums.add(layer_num)
  99. if k.startswith("language_model.model.layers."):
  100. layer_num = int(k.split(".")[3])
  101. layer_nums.add(layer_num)
  102. print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},")
  103. weights.update(mx.load(wf))
  104. model_class, model_args_class = _get_classes(config=config)
  105. model_args = model_args_class.from_dict(config)
  106. model = model_class(model_args)
  107. if hasattr(model, "sanitize"):
  108. weights = model.sanitize(weights)
  109. if (quantization := config.get("quantization", None)) is not None:
  110. # Handle legacy models which may not have everything quantized
  111. def class_predicate(p, m):
  112. if not hasattr(m, "to_quantized"):
  113. return False
  114. return f"{p}.scales" in weights
  115. nn.quantize(
  116. model,
  117. **quantization,
  118. class_predicate=class_predicate,
  119. )
  120. model.load_weights(list(weights.items()), strict=True)
  121. if not lazy:
  122. mx.eval(model.parameters())
  123. model.eval()
  124. return model
  125. async def load_shard(
  126. model_path: str,
  127. shard: Shard,
  128. tokenizer_config={},
  129. model_config={},
  130. adapter_path: Optional[str] = None,
  131. lazy: bool = False,
  132. ) -> Tuple[nn.Module, TokenizerWrapper]:
  133. model = load_model_shard(model_path, shard, lazy, model_config)
  134. if adapter_path is not None:
  135. model = apply_lora_layers(model, adapter_path)
  136. model.eval()
  137. # TODO: figure out a generic solution
  138. if model.model_type == "llava":
  139. processor = AutoProcessor.from_pretrained(model_path)
  140. processor.eos_token_id = processor.tokenizer.eos_token_id
  141. processor.encode = processor.tokenizer.encode
  142. return model, processor
  143. else:
  144. tokenizer = load_tokenizer(model_path, tokenizer_config)
  145. return model, tokenizer
  146. async def get_image_from_str(_image_str: str):
  147. image_str = _image_str.strip()
  148. if image_str.startswith("http"):
  149. async with aiohttp.ClientSession() as session:
  150. async with session.get(image_str, timeout=10) as response:
  151. content = await response.read()
  152. return Image.open(BytesIO(content)).convert("RGB")
  153. elif image_str.startswith("data:image/"):
  154. # Extract the image format and base64 data
  155. format_prefix, base64_data = image_str.split(";base64,")
  156. image_format = format_prefix.split("/")[1].lower()
  157. if DEBUG >= 2: print(f"{image_str=} {image_format=}")
  158. imgdata = base64.b64decode(base64_data)
  159. img = Image.open(BytesIO(imgdata))
  160. # Convert to RGB if not already
  161. if img.mode != "RGB":
  162. img = img.convert("RGB")
  163. return img
  164. else:
  165. raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")