|
@@ -37,10 +37,10 @@ def _get_classes(config: dict):
|
|
Retrieve the model and model args classes based on the configuration.
|
|
Retrieve the model and model args classes based on the configuration.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- config (dict): The model configuration.
|
|
|
|
|
|
+ config (dict): The model configuration.
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
- A tuple containing the Model class and the ModelArgs class.
|
|
|
|
|
|
+ A tuple containing the Model class and the ModelArgs class.
|
|
"""
|
|
"""
|
|
model_type = config["model_type"]
|
|
model_type = config["model_type"]
|
|
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
@@ -74,19 +74,19 @@ def load_model_shard(
|
|
Load and initialize the model from a given path.
|
|
Load and initialize the model from a given path.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- model_path (Path): The path to load the model from.
|
|
|
|
- lazy (bool): If False eval the model parameters to make sure they are
|
|
|
|
- loaded in memory before returning, otherwise they will be loaded
|
|
|
|
- when needed. Default: ``False``
|
|
|
|
- model_config(dict, optional): Configuration parameters for the model.
|
|
|
|
- Defaults to an empty dictionary.
|
|
|
|
|
|
+ model_path (Path): The path to load the model from.
|
|
|
|
+ lazy (bool): If False eval the model parameters to make sure they are
|
|
|
|
+ loaded in memory before returning, otherwise they will be loaded
|
|
|
|
+ when needed. Default: ``False``
|
|
|
|
+ model_config(dict, optional): Configuration parameters for the model.
|
|
|
|
+ Defaults to an empty dictionary.
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
- nn.Module: The loaded and initialized model.
|
|
|
|
|
|
+ nn.Module: The loaded and initialized model.
|
|
|
|
|
|
Raises:
|
|
Raises:
|
|
- FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
|
|
- ValueError: If the model class or args class are not found or cannot be instantiated.
|
|
|
|
|
|
+ FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
|
|
+ ValueError: If the model class or args class are not found or cannot be instantiated.
|
|
"""
|
|
"""
|
|
config = load_config(model_path)
|
|
config = load_config(model_path)
|
|
config.update(model_config)
|
|
config.update(model_config)
|
|
@@ -148,11 +148,11 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -
|
|
it is downloaded from the Hugging Face Hub.
|
|
it is downloaded from the Hugging Face Hub.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
|
|
|
- revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
|
|
|
|
|
+ path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
|
|
|
+ revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
- Path: The path to the model.
|
|
|
|
|
|
+ Path: The path to the model.
|
|
"""
|
|
"""
|
|
model_path = Path(path_or_hf_repo)
|
|
model_path = Path(path_or_hf_repo)
|
|
if not model_path.exists():
|
|
if not model_path.exists():
|
|
@@ -194,22 +194,22 @@ async def load_shard(
|
|
Load the model and tokenizer from a given path or a huggingface repository.
|
|
Load the model and tokenizer from a given path or a huggingface repository.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
|
|
|
- tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
|
|
|
- Defaults to an empty dictionary.
|
|
|
|
- model_config(dict, optional): Configuration parameters specifically for the model.
|
|
|
|
- Defaults to an empty dictionary.
|
|
|
|
- adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
|
|
|
- to the model. Default: ``None``.
|
|
|
|
- lazy (bool): If False eval the model parameters to make sure they are
|
|
|
|
- loaded in memory before returning, otherwise they will be loaded
|
|
|
|
- when needed. Default: ``False``
|
|
|
|
|
|
+ path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
|
|
|
+ tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
|
|
|
+ Defaults to an empty dictionary.
|
|
|
|
+ model_config(dict, optional): Configuration parameters specifically for the model.
|
|
|
|
+ Defaults to an empty dictionary.
|
|
|
|
+ adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
|
|
|
+ to the model. Default: ``None``.
|
|
|
|
+ lazy (bool): If False eval the model parameters to make sure they are
|
|
|
|
+ loaded in memory before returning, otherwise they will be loaded
|
|
|
|
+ when needed. Default: ``False``
|
|
Returns:
|
|
Returns:
|
|
- Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
|
|
|
|
|
+ Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
|
|
|
|
|
Raises:
|
|
Raises:
|
|
- FileNotFoundError: If config file or safetensors are not found.
|
|
|
|
- ValueError: If model class or args class are not found.
|
|
|
|
|
|
+ FileNotFoundError: If config file or safetensors are not found.
|
|
|
|
+ ValueError: If model class or args class are not found.
|
|
"""
|
|
"""
|
|
model_path = await get_model_path(path_or_hf_repo)
|
|
model_path = await get_model_path(path_or_hf_repo)
|
|
|
|
|