diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 05d5a8a3e..dc4915cc3 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -404,9 +404,10 @@ def _prepare_model_for_mp( setattr(model, "main_input_name", model_main_input_name) if isinstance(model, NxDPPModel): - model.local_module = self.patch_model_for_neuron( - model.local_module, patching_specs=NxDPPMODEL_PATCHING_SPECS - ) + for idx, module in enumerate(model.local_stage_modules): + model.local_stage_modules[idx] = self.patch_model_for_neuron( + module, patching_specs=NxDPPMODEL_PATCHING_SPECS + ) # Update CPU ids original_parameter_names_to_gqa_qkv_names = model._gqa_qkv_metadata["original_names_to_gqa_qkv_names"] diff --git a/optimum/neuron/accelerate/utils/misc.py b/optimum/neuron/accelerate/utils/misc.py index 15d094691..95adc3699 100644 --- a/optimum/neuron/accelerate/utils/misc.py +++ b/optimum/neuron/accelerate/utils/misc.py @@ -17,6 +17,7 @@ import functools import gc import inspect +import itertools from typing import TYPE_CHECKING, Callable, Dict, Optional, Union import torch @@ -197,7 +198,7 @@ def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel", model._prepare_model_for_gradient_checkpointing(model.get_base_model()) if isinstance(model, NxDPPModel): - modules = model.local_module.modules() + modules = itertools.chain(module.modules() for module in model.local_stage_modules) else: modules = model.modules() diff --git a/optimum/neuron/utils/model_utils.py b/optimum/neuron/utils/model_utils.py index 76d42f40a..a74932674 100644 --- a/optimum/neuron/utils/model_utils.py +++ b/optimum/neuron/utils/model_utils.py @@ -32,13 +32,15 @@ def get_tied_parameters_dict(model: Union["torch.nn.Module", "NxDPPModel"]) -> Dict[str, str]: from neuronx_distributed.pipeline import NxDPPModel + if isinstance(model, NxDPPModel): + tied_parameters = {} + for module in model.local_stage_modules: + tied_parameters.update(get_tied_parameters_dict(module)) + return tied_parameters + unique_parameters = {} tied_parameters = {} - if isinstance(model, NxDPPModel): - module = model.local_module - else: - module = model - for name, param in module.named_parameters(remove_duplicate=False): + for name, param in model.named_parameters(remove_duplicate=False): if param in unique_parameters: tied_parameter_name = unique_parameters[param] tied_parameters[name] = tied_parameter_name @@ -61,19 +63,18 @@ def tie_parameters(model: Union["torch.nn.Module", "NxDPPModel"], tied_parameter from neuronx_distributed.pipeline import NxDPPModel if isinstance(model, NxDPPModel): - module = model.local_module + for module in model.local_stage_modules: + tie_parameters(module, tied_parameters_dict) else: - module = model - - for param_to_tie_name, param_name in tied_parameters_dict.items(): - param_to_tie_parent_module, param_to_tie_name = get_parent_module_and_param_name_from_fully_qualified_name( - module, param_to_tie_name - ) - param_to_tie = getattr(param_to_tie_parent_module, param_to_tie_name) - - parent_module, param_name = get_parent_module_and_param_name_from_fully_qualified_name(module, param_name) - param = getattr(parent_module, param_name) - - if param_to_tie is not param: - del param_to_tie - setattr(param_to_tie_parent_module, param_to_tie_name, param) + for param_to_tie_name, param_name in tied_parameters_dict.items(): + param_to_tie_parent_module, param_to_tie_name = get_parent_module_and_param_name_from_fully_qualified_name( + model, param_to_tie_name + ) + param_to_tie = getattr(param_to_tie_parent_module, param_to_tie_name) + + parent_module, param_name = get_parent_module_and_param_name_from_fully_qualified_name(model, param_name) + param = getattr(parent_module, param_name) + + if param_to_tie is not param: + del param_to_tie + setattr(param_to_tie_parent_module, param_to_tie_name, param)