Skip to content

Commit

Permalink
resolve TE dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
suiyoubi committed Dec 2, 2024
1 parent 9b11d84 commit 4a7f608
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 112 deletions.
27 changes: 17 additions & 10 deletions nemo/collections/llm/bert/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,23 @@
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.transformer.utils import get_linear_layer as mcore_get_linear_layer
from megatron.core.models.bert import bert_layer_specs
from megatron.core.utils import make_viewless_tensor
from torch import Tensor, nn

from megatron.core.models.bert.bert_model import BertModel as MCoreBert
from nemo.collections.llm import fn
from nemo.collections.llm.bert.loss import BERTLossReduction
from nemo.collections.llm.bert.model.bert_spec import (
bert_layer_local_spec_postln,
bert_layer_with_transformer_engine_spec_postln,
megatron_layer_local_spec_preln,
get_bert_layer_local_spec_postln,
get_bert_layer_with_transformer_engine_spec_postln,
)
from nemo.lightning import get_vocab_size, io
from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule

HAVE_TE = True
try:
import transformer_engine # pylint: disable=W0611

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'transformer_engine' is not used.
from megatron.core.models.bert import bert_layer_specs
from megatron.core.models.bert.bert_model import BertModel as MCoreBert
except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError) as e:
HAVE_TE = False

if TYPE_CHECKING:
Expand Down Expand Up @@ -114,12 +112,12 @@ def default_layer_spec(config: "BertConfig") -> ModuleSpec:
if bert_type == 'megatron':
return bert_layer_specs.bert_layer_with_transformer_engine_spec
else:
return bert_layer_with_transformer_engine_spec_postln
return get_bert_layer_with_transformer_engine_spec_postln()

if bert_type == 'megatron':
return megatron_layer_local_spec_preln
return bert_layer_specs.bert_layer_local_spec
else:
return bert_layer_local_spec_postln
return get_bert_layer_local_spec_postln()


@dataclass
Expand Down Expand Up @@ -542,6 +540,15 @@ def __init__(
tokenizer: Optional["TokenizerSpec"] = None,
model_transform: Optional[Callable[[nn.Module], nn.Module]] = None,
):
# Megatron-LM's BERT implementation has high dependency on TE, and it is not possible
# to instantiate the MCore BERT without TE package.
# Few issues there: 1. bert_layer_specs.py is not TE dependency-free.
# 2. in bert_model.py _sanity_check_attention_and_get_attn_mask_dimension() checks on
# if transformer_layer_spec is identical to bert_layer_local_spec to determine if TE is
# required; since in NeMo we use customized bert layer spec, it will always assume this
# if using TE.
# We need to address the above two issues to enable TE-Free NeMo BERT.
assert HAVE_TE, "NeMo BERT requires Transformer Engine to be installed."
super().__init__()
self.config = config
self.tokenizer = tokenizer
Expand Down
288 changes: 186 additions & 102 deletions nemo/collections/llm/bert/model/bert_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,139 +11,223 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults
from dataclasses import dataclass

try:
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TENorm,
TERowParallelLinear,
)
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.utils import make_viewless_tensor

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError) as e:
TransformerConfig = ApexGuardDefaults
HAVE_MEGATRON_CORE = False

try:
import apex # pylint: disable=unused-import
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TENorm,
TERowParallelLinear,

from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
)

HAVE_TE = True
except (ImportError, ModuleNotFoundError) as e:
HAVE_TE = False

@dataclass
class TransformerLayerSubmodulesWithPostLNSupport(TransformerLayerSubmodules):
"""TransformerLayerSubmodules with post layer norm"""
def __init__(self, post_att_layernorm, post_mlp_layernorm, **kwargs):
super(TransformerLayerSubmodulesWithPostLNSupport, self).__init__(**kwargs)
self.post_att_layernorm = post_att_layernorm
self.post_mlp_layernorm = post_mlp_layernorm


class TransformerLayerWithPostLNSupport(TransformerLayer):
"""TransformerLayer with post layer norm."""
def __init__(self, *args, **kwargs):
super(TransformerLayerWithPostLNSupport, self).__init__(*args, **kwargs)
## [Module add: Post attention LN]
self.post_att_layernorm = build_module(
self.submodules_config.post_att_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
## [Module add: Post MLP LN]
self.post_mlp_layernorm = build_module(
self.submodules_config.post_mlp_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)

def forward(
self,
hidden_states,
attention_mask,
context=None,
context_mask=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
inference_params=None,
packed_seq_params=None,
):
"""Copy from megatron/core/transformer/transformer_layer.py with modification of applying
extra post layer norm if needed."""
# hidden_states: [s, b, h]

# Residual connection.
residual = hidden_states

HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
# Optional Input Layer norm
input_layernorm_output = self.input_layernorm(hidden_states)

from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm
# Self attention.
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
)

warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)

# Residual connection.
residual = hidden_states

# Post-LN after Self Attention
hidden_states = self.post_att_layernorm(hidden_states)

# Optional Layer norm after self-attention
pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states)

# Cross attention.
attention_output_with_bias = self.cross_attention(
pre_cross_attn_layernorm_output,
attention_mask=context_mask,
key_value_states=context,
inference_params=inference_params,
)

if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
context = attention_output_with_bias["context"]

# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)

# Residual connection.
residual = hidden_states

# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)

# MLP.
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)

# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)

# Post-LN after MLP
hidden_states = self.post_mlp_layernorm(hidden_states)

# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True)

return output, context

from nemo.collections.nlp.models.language_modeling.megatron.bert.bert_model import (
TransformerLayerSubmodulesWithPostLNSupport,
TransformerLayerWithPostLNSupport,
)

# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
bert_layer_with_transformer_engine_spec_postln = ModuleSpec(
module=TransformerLayerWithPostLNSupport,
submodules=TransformerLayerSubmodulesWithPostLNSupport(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
def get_bert_layer_with_transformer_engine_spec_postln():
"""Retrieve the Layer Spec when using Transformer Engine"""
return ModuleSpec(
module=TransformerLayerWithPostLNSupport,
submodules=TransformerLayerSubmodulesWithPostLNSupport(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
),
self_attn_bda=get_bias_dropout_add,
post_att_layernorm=TENorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear,
linear_fc2=TERowParallelLinear,
self_attn_bda=get_bias_dropout_add,
post_att_layernorm=TENorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear,
linear_fc2=TERowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
post_mlp_layernorm=TENorm,
),
mlp_bda=get_bias_dropout_add,
post_mlp_layernorm=TENorm,
),
)
)

# Use this spec for an implementation using only modules in megatron core
bert_layer_local_spec_postln = ModuleSpec(
module=TransformerLayerWithPostLNSupport,
submodules=TransformerLayerSubmodulesWithPostLNSupport(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
def get_bert_layer_local_spec_postln():
"""Retrieve the Layer Spec when using MCore Engine"""
return ModuleSpec(
module=TransformerLayerWithPostLNSupport,
submodules=TransformerLayerSubmodulesWithPostLNSupport(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
),
self_attn_bda=get_bias_dropout_add,
post_att_layernorm=FusedLayerNorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
post_mlp_layernorm=FusedLayerNorm,
),
)

# We copy the Mcore's local spec here to avoid TE dependency issue.
# Megatron-LM's core/models/bert/bert_layer_specs.py always requires
# TE dependency to load. Avoid it by copying paste the local specs in NeMo.
megatron_layer_local_spec_preln = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.padding},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
self_attn_bda=get_bias_dropout_add,
post_att_layernorm=FusedLayerNorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
post_mlp_layernorm=FusedLayerNorm,
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
)

0 comments on commit 4a7f608

Please sign in to comment.