Skip to content

Commit

Permalink
add e_norm_use_max_v
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 20, 2025
1 parent 340c56a commit c1aa456
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
skip_stat: bool = False,
a_compress_use_split: bool = False,
a_norm_use_max_v: bool = False,
e_norm_use_max_v: bool = False,
update_n_has_attn: bool = False,
n_attn_hidden: int = 64,
n_attn_head: int = 4,
Expand Down Expand Up @@ -119,6 +120,7 @@ def __init__(
self.n_attn_head = n_attn_head
self.pre_ln = pre_ln
self.a_norm_use_max_v = a_norm_use_max_v
self.e_norm_use_max_v = e_norm_use_max_v

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def init_subclass_params(sub_data, sub_class):
n_attn_hidden=self.repflow_args.n_attn_hidden,
n_attn_head=self.repflow_args.n_attn_head,
a_norm_use_max_v=self.repflow_args.a_norm_use_max_v,
e_norm_use_max_v=self.repflow_args.e_norm_use_max_v,
h1_dim=self.repflow_args.h1_dim,
pre_ln=self.repflow_args.pre_ln,
skip_stat=self.repflow_args.skip_stat,
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
n_attn_hidden: int = 64,
n_attn_head: int = 4,
a_norm_use_max_v: bool = False,
e_norm_use_max_v: bool = False,
pre_ln: bool = False,
activation_function: str = "silu",
update_style: str = "res_residual",
Expand Down Expand Up @@ -125,6 +126,7 @@ def __init__(
self.prec = PRECISION_DICT[precision]
self.pre_ln = pre_ln
self.a_norm_use_max_v = a_norm_use_max_v
self.e_norm_use_max_v = e_norm_use_max_v

assert update_residual_init in [
"norm",
Expand Down Expand Up @@ -614,6 +616,10 @@ def forward(
if self.a_norm_use_max_v:
angle_ebd = angle_ebd / (angle_ebd.abs().max(-1)[0] + 1e-4).unsqueeze(-1)

# only norm edge with max absolute value
if self.e_norm_use_max_v:
edge_ebd = edge_ebd / (edge_ebd.abs().max(-1)[0] + 1e-4).unsqueeze(-1)

# node self mlp
node_self_mlp = self.act(self.node_self_mlp(node_ebd))
n_update_list.append(node_self_mlp)
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
skip_stat: bool = True,
pre_ln: bool = False,
a_norm_use_max_v: bool = False,
e_norm_use_max_v: bool = False,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
r"""
Expand Down Expand Up @@ -230,6 +231,7 @@ def __init__(
self.act = ActivationFn(activation_function)
self.prec = PRECISION_DICT[precision]
self.a_norm_use_max_v = a_norm_use_max_v
self.e_norm_use_max_v = e_norm_use_max_v

# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
Expand Down Expand Up @@ -297,6 +299,7 @@ def __init__(
n_attn_hidden=self.n_attn_hidden,
n_attn_head=self.n_attn_head,
a_norm_use_max_v=self.a_norm_use_max_v,
e_norm_use_max_v=self.e_norm_use_max_v,
activation_function=self.activation_function,
update_style=self.update_style,
update_residual=self.update_residual,
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,12 @@ def dpa3_repflow_args():
optional=True,
default=False,
),
Argument(
"e_norm_use_max_v",
bool,
optional=True,
default=False,
),
]


Expand Down

0 comments on commit c1aa456

Please sign in to comment.