Skip to content

Commit

Permalink
Fixes NxDPPModel for Neuron SDK 2.19 (#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Jul 23, 2024
1 parent 278d76c commit 18c6ab4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 24 deletions.
7 changes: 4 additions & 3 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,10 @@ def _prepare_model_for_mp(
setattr(model, "main_input_name", model_main_input_name)

if isinstance(model, NxDPPModel):
model.local_module = self.patch_model_for_neuron(
model.local_module, patching_specs=NxDPPMODEL_PATCHING_SPECS
)
for idx, module in enumerate(model.local_stage_modules):
model.local_stage_modules[idx] = self.patch_model_for_neuron(
module, patching_specs=NxDPPMODEL_PATCHING_SPECS
)

# Update CPU ids
original_parameter_names_to_gqa_qkv_names = model._gqa_qkv_metadata["original_names_to_gqa_qkv_names"]
Expand Down
3 changes: 2 additions & 1 deletion optimum/neuron/accelerate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import functools
import gc
import inspect
import itertools
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union

import torch
Expand Down Expand Up @@ -197,7 +198,7 @@ def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel",
model._prepare_model_for_gradient_checkpointing(model.get_base_model())

if isinstance(model, NxDPPModel):
modules = model.local_module.modules()
modules = itertools.chain(module.modules() for module in model.local_stage_modules)
else:
modules = model.modules()

Expand Down
41 changes: 21 additions & 20 deletions optimum/neuron/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
def get_tied_parameters_dict(model: Union["torch.nn.Module", "NxDPPModel"]) -> Dict[str, str]:
from neuronx_distributed.pipeline import NxDPPModel

if isinstance(model, NxDPPModel):
tied_parameters = {}
for module in model.local_stage_modules:
tied_parameters.update(get_tied_parameters_dict(module))
return tied_parameters

unique_parameters = {}
tied_parameters = {}
if isinstance(model, NxDPPModel):
module = model.local_module
else:
module = model
for name, param in module.named_parameters(remove_duplicate=False):
for name, param in model.named_parameters(remove_duplicate=False):
if param in unique_parameters:
tied_parameter_name = unique_parameters[param]
tied_parameters[name] = tied_parameter_name
Expand All @@ -61,19 +63,18 @@ def tie_parameters(model: Union["torch.nn.Module", "NxDPPModel"], tied_parameter
from neuronx_distributed.pipeline import NxDPPModel

if isinstance(model, NxDPPModel):
module = model.local_module
for module in model.local_stage_modules:
tie_parameters(module, tied_parameters_dict)
else:
module = model

for param_to_tie_name, param_name in tied_parameters_dict.items():
param_to_tie_parent_module, param_to_tie_name = get_parent_module_and_param_name_from_fully_qualified_name(
module, param_to_tie_name
)
param_to_tie = getattr(param_to_tie_parent_module, param_to_tie_name)

parent_module, param_name = get_parent_module_and_param_name_from_fully_qualified_name(module, param_name)
param = getattr(parent_module, param_name)

if param_to_tie is not param:
del param_to_tie
setattr(param_to_tie_parent_module, param_to_tie_name, param)
for param_to_tie_name, param_name in tied_parameters_dict.items():
param_to_tie_parent_module, param_to_tie_name = get_parent_module_and_param_name_from_fully_qualified_name(
model, param_to_tie_name
)
param_to_tie = getattr(param_to_tie_parent_module, param_to_tie_name)

parent_module, param_name = get_parent_module_and_param_name_from_fully_qualified_name(model, param_name)
param = getattr(parent_module, param_name)

if param_to_tie is not param:
del param_to_tie
setattr(param_to_tie_parent_module, param_to_tie_name, param)

0 comments on commit 18c6ab4

Please sign in to comment.