diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 26b276d70f..789dff131d 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -40,6 +40,9 @@ def __init__( n_attn_head: int = 4, pre_ln: bool = False, only_e_ln: bool = False, + pre_bn: bool = False, + only_e_bn: bool = False, + bn_moment: float = 0.1, n_update_has_a: bool = False, n_update_has_a_first_sum: bool = False, ) -> None: @@ -124,6 +127,9 @@ def __init__( self.n_attn_head = n_attn_head self.pre_ln = pre_ln self.only_e_ln = only_e_ln + self.pre_bn = pre_bn + self.only_e_bn = only_e_bn + self.bn_moment = bn_moment self.a_norm_use_max_v = a_norm_use_max_v self.e_norm_use_max_v = e_norm_use_max_v self.e_a_reduce_use_sqrt = e_a_reduce_use_sqrt diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 0cfd0f12b6..78f7ea2f7f 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -179,6 +179,9 @@ def init_subclass_params(sub_data, sub_class): h1_dim=self.repflow_args.h1_dim, pre_ln=self.repflow_args.pre_ln, only_e_ln=self.repflow_args.only_e_ln, + pre_bn=self.repflow_args.pre_bn, + only_e_bn=self.repflow_args.only_e_bn, + bn_moment=self.repflow_args.bn_moment, skip_stat=self.repflow_args.skip_stat, exclude_types=exclude_types, env_protection=env_protection, diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 9c83aeb371..e9613e2c39 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -73,6 +73,9 @@ def __init__( n_update_has_a_first_sum: bool = False, pre_ln: bool = False, only_e_ln: bool = False, + pre_bn: bool = False, + only_e_bn: bool = False, + bn_moment: float = 0.1, activation_function: str = "silu", update_style: str = "res_residual", update_residual: float = 0.1, @@ -130,6 +133,9 @@ def __init__( self.prec = PRECISION_DICT[precision] self.pre_ln = pre_ln self.only_e_ln = only_e_ln + self.pre_bn = pre_bn + self.only_e_bn = only_e_bn + self.bn_moment = bn_moment self.a_norm_use_max_v = a_norm_use_max_v self.e_norm_use_max_v = e_norm_use_max_v self.e_a_reduce_use_sqrt = e_a_reduce_use_sqrt @@ -175,6 +181,41 @@ def __init__( self.angle_layernorm = None self.h1_layernorm = None + if self.pre_bn: + self.node_batchnorm = nn.BatchNorm1d( + self.n_dim, + affine=False, + device=env.DEVICE, + dtype=self.prec, + momentum=self.bn_moment, + ) + self.edge_batchnorm = nn.BatchNorm1d( + self.e_dim, + affine=False, + device=env.DEVICE, + dtype=self.prec, + momentum=self.bn_moment, + ) + self.angle_batchnorm = nn.BatchNorm1d( + self.a_dim, + affine=False, + device=env.DEVICE, + dtype=self.prec, + momentum=self.bn_moment, + ) + self.h1_batchnorm = nn.BatchNorm1d( + self.h1_dim, + affine=False, + device=env.DEVICE, + dtype=self.prec, + momentum=self.bn_moment, + ) + else: + self.node_batchnorm = None + self.edge_batchnorm = None + self.angle_batchnorm = None + self.h1_batchnorm = None + self.update_residual = update_residual self.update_residual_init = update_residual_init self.n_residual = [] @@ -642,6 +683,22 @@ def forward( angle_ebd = self.angle_layernorm(angle_ebd) edge_ebd = self.edge_layernorm(edge_ebd) + if self.pre_bn: + assert self.node_batchnorm is not None + assert self.edge_batchnorm is not None + assert self.angle_batchnorm is not None + if not self.only_e_bn: + node_ebd_ext = self.node_batchnorm( + node_ebd_ext.view(nb * nall, self.n_dim) + ).view(nb, nall, self.n_dim) + node_ebd, _ = torch.split(node_ebd_ext, [nloc, nall - nloc], dim=1) + angle_ebd = self.angle_batchnorm( + angle_ebd.view(nb * nloc * self.a_sel * self.a_sel, self.a_dim) + ).view(nb, nloc, self.a_sel, self.a_sel, self.a_dim) + edge_ebd = self.edge_batchnorm( + edge_ebd.view(nb * nloc * self.nnei, self.e_dim) + ).view(nb, nloc, self.nnei, self.e_dim) + # only norm angle with max absolute value if self.a_norm_use_max_v: angle_ebd = angle_ebd / (angle_ebd.abs().max(-1)[0] + 1e-5).unsqueeze(-1) diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index bf6e8b124f..0a8d3e42e5 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -113,6 +113,9 @@ def __init__( skip_stat: bool = True, pre_ln: bool = False, only_e_ln: bool = False, + pre_bn: bool = False, + only_e_bn: bool = False, + bn_moment: float = 0.1, a_norm_use_max_v: bool = False, e_norm_use_max_v: bool = False, e_a_reduce_use_sqrt: bool = True, @@ -262,14 +265,17 @@ def __init__( self.seed = seed self.pre_ln = pre_ln self.only_e_ln = only_e_ln - self.out_ln = None - if self.pre_ln: - self.out_ln = torch.nn.LayerNorm( - self.n_dim, - device=env.DEVICE, - dtype=self.prec, - elementwise_affine=False, - ) + self.pre_bn = pre_bn + self.only_e_bn = only_e_bn + self.bn_moment = bn_moment + # self.out_ln = None + # if self.pre_ln: + # self.out_ln = torch.nn.LayerNorm( + # self.n_dim, + # device=env.DEVICE, + # dtype=self.prec, + # elementwise_affine=False, + # ) self.edge_embd = MLPLayer( 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) @@ -334,6 +340,9 @@ def __init__( precision=precision, pre_ln=self.pre_ln, only_e_ln=self.only_e_ln, + pre_bn=self.pre_bn, + only_e_bn=self.only_e_bn, + bn_moment=self.bn_moment, seed=child_seed(child_seed(seed, 1), ii), ) ) @@ -634,9 +643,9 @@ def forward( h2g2 = RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) # (nb x nloc) x e_dim x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) - if self.pre_ln: - assert self.out_ln is not None - node_ebd = self.out_ln(node_ebd) + # if self.pre_ln: + # assert self.out_ln is not None + # node_ebd = self.out_ln(node_ebd) return node_ebd, edge_ebd, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index a163eb7f2d..20e96e5359 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1625,6 +1625,24 @@ def dpa3_repflow_args(): optional=True, default=False, ), + Argument( + "pre_bn", + bool, + optional=True, + default=False, + ), + Argument( + "only_e_bn", + bool, + optional=True, + default=False, + ), + Argument( + "bn_moment", + float, + optional=True, + default=0.1, + ), Argument( "a_norm_use_max_v", bool,