Skip to content

Commit

Permalink
add bn
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 25, 2025
1 parent 1fa5a07 commit 0d17776
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 11 deletions.
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def __init__(
n_attn_head: int = 4,
pre_ln: bool = False,
only_e_ln: bool = False,
pre_bn: bool = False,
only_e_bn: bool = False,
bn_moment: float = 0.1,
n_update_has_a: bool = False,
n_update_has_a_first_sum: bool = False,
) -> None:
Expand Down Expand Up @@ -124,6 +127,9 @@ def __init__(
self.n_attn_head = n_attn_head
self.pre_ln = pre_ln
self.only_e_ln = only_e_ln
self.pre_bn = pre_bn
self.only_e_bn = only_e_bn
self.bn_moment = bn_moment
self.a_norm_use_max_v = a_norm_use_max_v
self.e_norm_use_max_v = e_norm_use_max_v
self.e_a_reduce_use_sqrt = e_a_reduce_use_sqrt
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def init_subclass_params(sub_data, sub_class):
h1_dim=self.repflow_args.h1_dim,
pre_ln=self.repflow_args.pre_ln,
only_e_ln=self.repflow_args.only_e_ln,
pre_bn=self.repflow_args.pre_bn,
only_e_bn=self.repflow_args.only_e_bn,
bn_moment=self.repflow_args.bn_moment,
skip_stat=self.repflow_args.skip_stat,
exclude_types=exclude_types,
env_protection=env_protection,
Expand Down
57 changes: 57 additions & 0 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def __init__(
n_update_has_a_first_sum: bool = False,
pre_ln: bool = False,
only_e_ln: bool = False,
pre_bn: bool = False,
only_e_bn: bool = False,
bn_moment: float = 0.1,
activation_function: str = "silu",
update_style: str = "res_residual",
update_residual: float = 0.1,
Expand Down Expand Up @@ -130,6 +133,9 @@ def __init__(
self.prec = PRECISION_DICT[precision]
self.pre_ln = pre_ln
self.only_e_ln = only_e_ln
self.pre_bn = pre_bn
self.only_e_bn = only_e_bn
self.bn_moment = bn_moment
self.a_norm_use_max_v = a_norm_use_max_v
self.e_norm_use_max_v = e_norm_use_max_v
self.e_a_reduce_use_sqrt = e_a_reduce_use_sqrt
Expand Down Expand Up @@ -175,6 +181,41 @@ def __init__(
self.angle_layernorm = None
self.h1_layernorm = None

if self.pre_bn:
self.node_batchnorm = nn.BatchNorm1d(
self.n_dim,
affine=False,
device=env.DEVICE,
dtype=self.prec,
momentum=self.bn_moment,
)
self.edge_batchnorm = nn.BatchNorm1d(
self.e_dim,
affine=False,
device=env.DEVICE,
dtype=self.prec,
momentum=self.bn_moment,
)
self.angle_batchnorm = nn.BatchNorm1d(
self.a_dim,
affine=False,
device=env.DEVICE,
dtype=self.prec,
momentum=self.bn_moment,
)
self.h1_batchnorm = nn.BatchNorm1d(
self.h1_dim,
affine=False,
device=env.DEVICE,
dtype=self.prec,
momentum=self.bn_moment,
)
else:
self.node_batchnorm = None
self.edge_batchnorm = None
self.angle_batchnorm = None
self.h1_batchnorm = None

self.update_residual = update_residual
self.update_residual_init = update_residual_init
self.n_residual = []
Expand Down Expand Up @@ -642,6 +683,22 @@ def forward(
angle_ebd = self.angle_layernorm(angle_ebd)
edge_ebd = self.edge_layernorm(edge_ebd)

if self.pre_bn:
assert self.node_batchnorm is not None
assert self.edge_batchnorm is not None
assert self.angle_batchnorm is not None
if not self.only_e_bn:
node_ebd_ext = self.node_batchnorm(
node_ebd_ext.view(nb * nall, self.n_dim)
).view(nb, nall, self.n_dim)
node_ebd, _ = torch.split(node_ebd_ext, [nloc, nall - nloc], dim=1)
angle_ebd = self.angle_batchnorm(
angle_ebd.view(nb * nloc * self.a_sel * self.a_sel, self.a_dim)
).view(nb, nloc, self.a_sel, self.a_sel, self.a_dim)
edge_ebd = self.edge_batchnorm(
edge_ebd.view(nb * nloc * self.nnei, self.e_dim)
).view(nb, nloc, self.nnei, self.e_dim)

# only norm angle with max absolute value
if self.a_norm_use_max_v:
angle_ebd = angle_ebd / (angle_ebd.abs().max(-1)[0] + 1e-5).unsqueeze(-1)
Expand Down
31 changes: 20 additions & 11 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def __init__(
skip_stat: bool = True,
pre_ln: bool = False,
only_e_ln: bool = False,
pre_bn: bool = False,
only_e_bn: bool = False,
bn_moment: float = 0.1,
a_norm_use_max_v: bool = False,
e_norm_use_max_v: bool = False,
e_a_reduce_use_sqrt: bool = True,
Expand Down Expand Up @@ -262,14 +265,17 @@ def __init__(
self.seed = seed
self.pre_ln = pre_ln
self.only_e_ln = only_e_ln
self.out_ln = None
if self.pre_ln:
self.out_ln = torch.nn.LayerNorm(
self.n_dim,
device=env.DEVICE,
dtype=self.prec,
elementwise_affine=False,
)
self.pre_bn = pre_bn
self.only_e_bn = only_e_bn
self.bn_moment = bn_moment
# self.out_ln = None
# if self.pre_ln:
# self.out_ln = torch.nn.LayerNorm(
# self.n_dim,
# device=env.DEVICE,
# dtype=self.prec,
# elementwise_affine=False,
# )

self.edge_embd = MLPLayer(
1, self.e_dim, precision=precision, seed=child_seed(seed, 0)
Expand Down Expand Up @@ -334,6 +340,9 @@ def __init__(
precision=precision,
pre_ln=self.pre_ln,
only_e_ln=self.only_e_ln,
pre_bn=self.pre_bn,
only_e_bn=self.only_e_bn,
bn_moment=self.bn_moment,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down Expand Up @@ -634,9 +643,9 @@ def forward(
h2g2 = RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw)
# (nb x nloc) x e_dim x 3
rot_mat = torch.permute(h2g2, (0, 1, 3, 2))
if self.pre_ln:
assert self.out_ln is not None
node_ebd = self.out_ln(node_ebd)
# if self.pre_ln:
# assert self.out_ln is not None
# node_ebd = self.out_ln(node_ebd)

return node_ebd, edge_ebd, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw

Expand Down
18 changes: 18 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,24 @@ def dpa3_repflow_args():
optional=True,
default=False,
),
Argument(
"pre_bn",
bool,
optional=True,
default=False,
),
Argument(
"only_e_bn",
bool,
optional=True,
default=False,
),
Argument(
"bn_moment",
float,
optional=True,
default=0.1,
),
Argument(
"a_norm_use_max_v",
bool,
Expand Down

0 comments on commit 0d17776

Please sign in to comment.