Browse Source

formatting

Alex Cheema 1 year ago
parent
commit
63e51a8270

+ 1 - 3
exo/api/chatgpt_api.py

@@ -40,9 +40,7 @@ shard_mappings = {
   },
   },
   ### deepseek v2
   ### deepseek v2
   "deepseek-coder-v2-lite": {
   "deepseek-coder-v2-lite": {
-    "MLXDynamicShardInferenceEngine": Shard(
-      model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27
-    ),
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
   },
   },
 }
 }
 
 

+ 2 - 5
exo/inference/mlx/models/deepseek_v2.py

@@ -106,15 +106,12 @@ class Model(nn.Module):
       for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
       for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
         for k in ["weight", "scales", "biases"]:
         for k in ["weight", "scales", "biases"]:
           if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
           if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
-            to_join = [
-              shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)
-            ]
+            to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
             shard_state_dict[
             shard_state_dict[
               f"{prefix}.mlp.switch_mlp.{
               f"{prefix}.mlp.switch_mlp.{
-              m}.{k}"
+       m}.{k}"
             ] = mx.stack(to_join)
             ] = mx.stack(to_join)
 
 
-
     return shard_state_dict
     return shard_state_dict
 
 
   @property
   @property

+ 27 - 27
exo/inference/mlx/sharded_utils.py

@@ -37,10 +37,10 @@ def _get_classes(config: dict):
   Retrieve the model and model args classes based on the configuration.
   Retrieve the model and model args classes based on the configuration.
 
 
   Args:
   Args:
-    config (dict): The model configuration.
+   config (dict): The model configuration.
 
 
   Returns:
   Returns:
-    A tuple containing the Model class and the ModelArgs class.
+   A tuple containing the Model class and the ModelArgs class.
   """
   """
   model_type = config["model_type"]
   model_type = config["model_type"]
   model_type = MODEL_REMAPPING.get(model_type, model_type)
   model_type = MODEL_REMAPPING.get(model_type, model_type)
@@ -74,19 +74,19 @@ def load_model_shard(
   Load and initialize the model from a given path.
   Load and initialize the model from a given path.
 
 
   Args:
   Args:
-    model_path (Path): The path to load the model from.
-    lazy (bool): If False eval the model parameters to make sure they are
-      loaded in memory before returning, otherwise they will be loaded
-      when needed. Default: ``False``
-    model_config(dict, optional): Configuration parameters for the model.
-      Defaults to an empty dictionary.
+   model_path (Path): The path to load the model from.
+   lazy (bool): If False eval the model parameters to make sure they are
+    loaded in memory before returning, otherwise they will be loaded
+    when needed. Default: ``False``
+   model_config(dict, optional): Configuration parameters for the model.
+    Defaults to an empty dictionary.
 
 
   Returns:
   Returns:
-    nn.Module: The loaded and initialized model.
+   nn.Module: The loaded and initialized model.
 
 
   Raises:
   Raises:
-    FileNotFoundError: If the weight files (.safetensors) are not found.
-    ValueError: If the model class or args class are not found or cannot be instantiated.
+   FileNotFoundError: If the weight files (.safetensors) are not found.
+   ValueError: If the model class or args class are not found or cannot be instantiated.
   """
   """
   config = load_config(model_path)
   config = load_config(model_path)
   config.update(model_config)
   config.update(model_config)
@@ -148,11 +148,11 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
   it is downloaded from the Hugging Face Hub.
   it is downloaded from the Hugging Face Hub.
 
 
   Args:
   Args:
-    path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
-    revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
+   path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
+   revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
 
 
   Returns:
   Returns:
-    Path: The path to the model.
+   Path: The path to the model.
   """
   """
   model_path = Path(path_or_hf_repo)
   model_path = Path(path_or_hf_repo)
   if not model_path.exists():
   if not model_path.exists():
@@ -194,22 +194,22 @@ async def load_shard(
   Load the model and tokenizer from a given path or a huggingface repository.
   Load the model and tokenizer from a given path or a huggingface repository.
 
 
   Args:
   Args:
-    path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
-    tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
-      Defaults to an empty dictionary.
-    model_config(dict, optional): Configuration parameters specifically for the model.
-      Defaults to an empty dictionary.
-    adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
-      to the model. Default: ``None``.
-    lazy (bool): If False eval the model parameters to make sure they are
-      loaded in memory before returning, otherwise they will be loaded
-      when needed. Default: ``False``
+   path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
+   tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
+    Defaults to an empty dictionary.
+   model_config(dict, optional): Configuration parameters specifically for the model.
+    Defaults to an empty dictionary.
+   adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
+    to the model. Default: ``None``.
+   lazy (bool): If False eval the model parameters to make sure they are
+    loaded in memory before returning, otherwise they will be loaded
+    when needed. Default: ``False``
   Returns:
   Returns:
-    Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
+   Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
 
 
   Raises:
   Raises:
-    FileNotFoundError: If config file or safetensors are not found.
-    ValueError: If model class or args class are not found.
+   FileNotFoundError: If config file or safetensors are not found.
+   ValueError: If model class or args class are not found.
   """
   """
   model_path = await get_model_path(path_or_hf_repo)
   model_path = await get_model_path(path_or_hf_repo)