Skip to content

Commit

Permalink
add unet
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 27, 2025
1 parent cb7f995 commit 04f03ae
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 4 deletions.
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(
only_e_ln: bool = False,
pre_bn: bool = False,
only_e_bn: bool = False,
use_unet: bool = False,
use_unet_n: bool = True,
use_unet_e: bool = True,
use_unet_a: bool = True,
bn_moment: float = 0.1,
n_update_has_a: bool = False,
n_update_has_a_first_sum: bool = False,
Expand Down Expand Up @@ -135,6 +139,10 @@ def __init__(
self.e_a_reduce_use_sqrt = e_a_reduce_use_sqrt
self.n_update_has_a = n_update_has_a
self.n_update_has_a_first_sum = n_update_has_a_first_sum
self.use_unet = use_unet
self.use_unet_n = use_unet_n
self.use_unet_e = use_unet_e
self.use_unet_a = use_unet_a

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def init_subclass_params(sub_data, sub_class):
only_e_ln=self.repflow_args.only_e_ln,
pre_bn=self.repflow_args.pre_bn,
only_e_bn=self.repflow_args.only_e_bn,
use_unet=self.repflow_args.use_unet,
use_unet_n=self.repflow_args.use_unet_n,
use_unet_e=self.repflow_args.use_unet_e,
use_unet_a=self.repflow_args.use_unet_a,
bn_moment=self.repflow_args.bn_moment,
skip_stat=self.repflow_args.skip_stat,
exclude_types=exclude_types,
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ def __init__(
seed=child_seed(seed, 9),
)
else:
# use split
assert self.n_a_compress_dim <= self.n_dim
assert self.e_a_compress_dim <= self.e_dim
self.a_compress_n_linear = None
self.a_compress_e_linear = None

Expand Down
69 changes: 65 additions & 4 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def __init__(
only_e_ln: bool = False,
pre_bn: bool = False,
only_e_bn: bool = False,
use_unet: bool = False,
use_unet_n: bool = True,
use_unet_e: bool = True,
use_unet_a: bool = True,
bn_moment: float = 0.1,
a_norm_use_max_v: bool = False,
e_norm_use_max_v: bool = False,
Expand Down Expand Up @@ -268,6 +272,10 @@ def __init__(
self.pre_bn = pre_bn
self.only_e_bn = only_e_bn
self.bn_moment = bn_moment
self.use_unet = use_unet
self.use_unet_n = use_unet_n
self.use_unet_e = use_unet_e
self.use_unet_a = use_unet_a
# self.out_ln = None
# if self.pre_ln:
# self.out_ln = torch.nn.LayerNorm(
Expand Down Expand Up @@ -296,6 +304,15 @@ def __init__(
else:
self.h1_embd = None
layers = []
self.unet_scale = [1.0 for _ in range(self.nlayers)]
self.unet_first_half = int((self.nlayers + 1) / 2)
self.unet_rest_half = int(self.nlayers / 2)
if self.use_unet:
self.unet_scale = [(0.5**i) for i in range(self.unet_first_half)] + [
(0.5 ** (self.unet_rest_half - 1 - i))
for i in range(self.unet_rest_half)
]

for ii in range(nlayers):
layers.append(
RepFlowLayer(
Expand All @@ -306,9 +323,15 @@ def __init__(
a_rcut_smth=self.a_rcut_smth,
a_sel=self.a_sel,
ntypes=self.ntypes,
n_dim=self.n_dim,
e_dim=self.e_dim,
a_dim=self.a_dim,
n_dim=self.n_dim
if (not self.use_unet or not self.use_unet_n)
else int(self.n_dim * self.unet_scale[ii]),
e_dim=self.e_dim
if (not self.use_unet or not self.use_unet_e)
else int(self.e_dim * self.unet_scale[ii]),
a_dim=self.a_dim
if (not self.use_unet or not self.use_unet_a)
else int(self.a_dim * self.unet_scale[ii]),
a_compress_rate=self.a_compress_rate,
a_mess_has_n=self.a_mess_has_n,
a_use_e_mess=self.a_use_e_mess,
Expand Down Expand Up @@ -555,13 +578,19 @@ def forward(
else:
mapping3 = None

unet_list_node = []
unet_list_edge = []
unet_list_angle = []

for idx, ll in enumerate(self.layers):
# node_ebd: nb x nloc x n_dim
# node_ebd_ext: nb x nall x n_dim
if comm_dict is None:
assert mapping is not None
assert mapping3 is not None
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
node_ebd_ext = torch.gather(
node_ebd, 1, mapping[:, :, : node_ebd.shape[-1]]
)
if self.has_h1:
assert h1 is not None
h1_ext = torch.gather(h1, 1, mapping3)
Expand Down Expand Up @@ -639,6 +668,38 @@ def forward(
h1_ext,
)

if self.use_unet:
if idx < self.unet_first_half - 1:
# stack half output
tmp_n_dim = int(self.n_dim * self.unet_scale[idx + 1])
tmp_e_dim = int(self.e_dim * self.unet_scale[idx + 1])
tmp_a_dim = int(self.a_dim * self.unet_scale[idx + 1])
if self.use_unet_n:
stack_node_ebd, node_ebd = torch.split(
node_ebd, [tmp_n_dim, tmp_n_dim], dim=-1
)
unet_list_node.append(stack_node_ebd)
if self.use_unet_e:
stack_edge_ebd, edge_ebd = torch.split(
edge_ebd, [tmp_e_dim, tmp_e_dim], dim=-1
)
unet_list_edge.append(stack_edge_ebd)
if self.use_unet_a:
stack_angle_ebd, angle_ebd = torch.split(
angle_ebd, [tmp_a_dim, tmp_a_dim], dim=-1
)
unet_list_angle.append(stack_angle_ebd)
elif self.unet_rest_half - 1 < idx < self.nlayers - 1:
# skip connection, concat the half output
if self.use_unet_n:
stack_node_ebd = unet_list_node.pop()
node_ebd = torch.cat([stack_node_ebd, node_ebd], dim=-1)
if self.use_unet_e:
stack_edge_ebd = unet_list_edge.pop()
edge_ebd = torch.cat([stack_edge_ebd, edge_ebd], dim=-1)
if self.use_unet_a:
stack_angle_ebd = unet_list_angle.pop()
angle_ebd = torch.cat([stack_angle_ebd, angle_ebd], dim=-1)
# nb x nloc x 3 x e_dim
h2g2 = RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw)
# (nb x nloc) x e_dim x 3
Expand Down
24 changes: 24 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,6 +1673,30 @@ def dpa3_repflow_args():
optional=True,
default=False,
),
Argument(
"use_unet",
bool,
optional=True,
default=False,
),
Argument(
"use_unet_n",
bool,
optional=True,
default=True,
),
Argument(
"use_unet_e",
bool,
optional=True,
default=True,
),
Argument(
"use_unet_a",
bool,
optional=True,
default=True,
),
]


Expand Down

0 comments on commit 04f03ae

Please sign in to comment.