sharded_utils.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
  2. import glob
  3. import importlib
  4. import json
  5. import logging
  6. import asyncio
  7. import aiohttp
  8. from functools import partial
  9. from pathlib import Path
  10. from typing import Optional, Tuple, Union, List, Callable
  11. from PIL import Image
  12. from io import BytesIO
  13. import base64
  14. import traceback
  15. import mlx.core as mx
  16. import mlx.nn as nn
  17. from transformers import AutoProcessor
  18. from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
  19. from exo import DEBUG
  20. from exo.inference.tokenizers import resolve_tokenizer
  21. from ..shard import Shard
  22. class ModelNotFoundError(Exception):
  23. def __init__(self, message):
  24. self.message = message
  25. super().__init__(self.message)
  26. MODEL_REMAPPING = {
  27. "mistral": "llama", # mistral is compatible with llama
  28. "phi-msft": "phixtral",
  29. }
  30. def _get_classes(config: dict):
  31. """
  32. Retrieve the model and model args classes based on the configuration.
  33. Args:
  34. config (dict): The model configuration.
  35. Returns:
  36. A tuple containing the Model class and the ModelArgs class.
  37. """
  38. model_type = config["model_type"]
  39. model_type = MODEL_REMAPPING.get(model_type, model_type)
  40. try:
  41. arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
  42. except ImportError:
  43. msg = f"Model type {model_type} not supported."
  44. logging.error(msg)
  45. traceback.print_exc()
  46. raise ValueError(msg)
  47. return arch.Model, arch.ModelArgs
  48. def load_config(model_path: Path) -> dict:
  49. try:
  50. with open(model_path/"config.json", "r") as f:
  51. config = json.load(f)
  52. except FileNotFoundError:
  53. logging.error(f"Config file not found in {model_path}")
  54. raise
  55. return config
  56. def load_model_shard(
  57. model_path: Path,
  58. shard: Shard,
  59. lazy: bool = False,
  60. model_config: dict = {},
  61. ) -> nn.Module:
  62. """
  63. Load and initialize the model from a given path.
  64. Args:
  65. model_path (Path): The path to load the model from.
  66. lazy (bool): If False eval the model parameters to make sure they are
  67. loaded in memory before returning, otherwise they will be loaded
  68. when needed. Default: ``False``
  69. model_config(dict, optional): Configuration parameters for the model.
  70. Defaults to an empty dictionary.
  71. Returns:
  72. nn.Module: The loaded and initialized model.
  73. Raises:
  74. FileNotFoundError: If the weight files (.safetensors) are not found.
  75. ValueError: If the model class or args class are not found or cannot be instantiated.
  76. """
  77. config = load_config(model_path)
  78. config.update(model_config)
  79. # TODO hack
  80. config["shard"] = {
  81. "model_id": model_path.name,
  82. "start_layer": shard.start_layer,
  83. "end_layer": shard.end_layer,
  84. "n_layers": shard.n_layers,
  85. }
  86. weight_files = glob.glob(str(model_path/"model*.safetensors"))
  87. if not weight_files:
  88. # Try weight for back-compat
  89. weight_files = glob.glob(str(model_path/"weight*.safetensors"))
  90. if not weight_files:
  91. logging.error(f"No safetensors found in {model_path}")
  92. raise FileNotFoundError(f"No safetensors found in {model_path}")
  93. weights = {}
  94. for wf in sorted(weight_files):
  95. if DEBUG >= 8:
  96. layer_nums = set()
  97. for k in mx.load(wf):
  98. if k.startswith("model.layers."):
  99. layer_num = int(k.split(".")[2])
  100. layer_nums.add(layer_num)
  101. if k.startswith("language_model.model.layers."):
  102. layer_num = int(k.split(".")[3])
  103. layer_nums.add(layer_num)
  104. print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},")
  105. weights.update(mx.load(wf))
  106. model_class, model_args_class = _get_classes(config=config)
  107. class ShardedModel(model_class):
  108. def __init__(self, args):
  109. super().__init__(args)
  110. self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
  111. def __call__(self, x, *args, **kwargs):
  112. y = super().__call__(x, *args, **kwargs)
  113. return y
  114. model_args = model_args_class.from_dict(config)
  115. model = ShardedModel(model_args)
  116. if hasattr(model, "sanitize"):
  117. weights = model.sanitize(weights)
  118. if DEBUG >= 8:
  119. print(f"\n|| {config=} ||\n")
  120. if (quantization := config.get("quantization", None)) is not None:
  121. # Handle legacy models which may not have everything quantized
  122. def class_predicate(p, m):
  123. if not hasattr(m, "to_quantized"):
  124. return False
  125. return f"{p}.scales" in weights
  126. nn.quantize(
  127. model,
  128. **quantization,
  129. class_predicate=class_predicate,
  130. )
  131. model.load_weights(list(weights.items()), strict=True)
  132. if not lazy:
  133. mx.eval(model.parameters())
  134. model.eval()
  135. return model
  136. async def load_shard(
  137. model_path: str,
  138. shard: Shard,
  139. tokenizer_config={},
  140. model_config={},
  141. adapter_path: Optional[str] = None,
  142. lazy: bool = False,
  143. ) -> Tuple[nn.Module, TokenizerWrapper]:
  144. model = load_model_shard(model_path, shard, lazy, model_config)
  145. # TODO: figure out a generic solution
  146. if model.model_type == "llava":
  147. processor = AutoProcessor.from_pretrained(model_path)
  148. processor.eos_token_id = processor.tokenizer.eos_token_id
  149. processor.encode = processor.tokenizer.encode
  150. return model, processor
  151. else:
  152. tokenizer = await resolve_tokenizer(model_path)
  153. return model, tokenizer
  154. async def get_image_from_str(_image_str: str):
  155. image_str = _image_str.strip()
  156. if image_str.startswith("http"):
  157. async with aiohttp.ClientSession() as session:
  158. async with session.get(image_str, timeout=10) as response:
  159. content = await response.read()
  160. return Image.open(BytesIO(content)).convert("RGB")
  161. elif image_str.startswith("data:image/"):
  162. # Extract the image format and base64 data
  163. format_prefix, base64_data = image_str.split(";base64,")
  164. image_format = format_prefix.split("/")[1].lower()
  165. if DEBUG >= 2: print(f"{image_str=} {image_format=}")
  166. imgdata = base64.b64decode(base64_data)
  167. img = Image.open(BytesIO(imgdata))
  168. # Convert to RGB if not already
  169. if img.mode != "RGB":
  170. img = img.convert("RGB")
  171. return img
  172. else:
  173. raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")