diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index ed0eb5d919..f39172d979 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -129,10 +129,10 @@ def __init__( "const", ], "'update_residual_init' only support 'norm' or 'const'!" - if self.pre_ln: - assert self.update_style == "res_layer" + # if self.pre_ln: + # assert self.update_style == "res_layer" - if self.update_style == "res_layer": + if self.update_style == "res_layer" or self.pre_ln: self.node_layernorm = nn.LayerNorm( self.n_dim, device=env.DEVICE, diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index bb5e5be96c..0e88f4b779 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -236,14 +236,14 @@ def __init__( self.epsilon = 1e-4 self.seed = seed self.pre_ln = pre_ln - self.out_ln = None - if self.pre_ln: - self.out_ln = torch.nn.LayerNorm( - self.n_dim, - device=env.DEVICE, - dtype=self.prec, - elementwise_affine=False, - ) + # self.out_ln = None + # if self.pre_ln: + # self.out_ln = torch.nn.LayerNorm( + # self.n_dim, + # device=env.DEVICE, + # dtype=self.prec, + # elementwise_affine=False, + # ) self.edge_embd = MLPLayer( 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) @@ -600,9 +600,9 @@ def forward( h2g2 = RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) # (nb x nloc) x e_dim x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) - if self.pre_ln: - assert self.out_ln is not None - node_ebd = self.out_ln(node_ebd) + # if self.pre_ln: + # assert self.out_ln is not None + # node_ebd = self.out_ln(node_ebd) return node_ebd, edge_ebd, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw