From a9bb0be8cdf690e18ac9f1b2603ffe7736e93953 Mon Sep 17 00:00:00 2001 From: xke123502 Date: Wed, 17 Sep 2025 14:36:45 +0000 Subject: [PATCH] mad --- deepmd/pt/loss/__init__.py | 2 + deepmd/pt/loss/ener.py | 152 +++++ deepmd/pt/model/descriptor/dpa3.py | 85 ++- deepmd/pt/model/descriptor/env_mat.py | 6 +- deepmd/pt/model/descriptor/repflow_layer.py | 706 ++++++++++++++------ deepmd/pt/model/descriptor/repflows.py | 312 ++++++--- deepmd/pt/model/model/__init__.py | 8 +- deepmd/pt/model/network/mlp.py | 255 +++++-- deepmd/pt/model/network/utils.py | 171 +++-- deepmd/pt/train/training.py | 11 +- deepmd/utils/argcheck.py | 187 +++++- examples/water/dpa3/input_torch.json | 2 +- 12 files changed, 1468 insertions(+), 429 deletions(-) diff --git a/deepmd/pt/loss/__init__.py b/deepmd/pt/loss/__init__.py index 1d25c1e52f..3111834117 100644 --- a/deepmd/pt/loss/__init__.py +++ b/deepmd/pt/loss/__init__.py @@ -8,6 +8,7 @@ from .ener import ( EnergyHessianStdLoss, EnergyStdLoss, + EnergyStdLossMAD, ) from .ener_spin import ( EnergySpinLoss, @@ -28,6 +29,7 @@ "EnergyHessianStdLoss", "EnergySpinLoss", "EnergyStdLoss", + "EnergyStdLossMAD", "PropertyLoss", "TaskLoss", "TensorLoss", diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 10e2bf9971..64d163755b 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -595,3 +595,155 @@ def label_requirement(self) -> list[DataRequirementItem]: ) ) return label_requirement + +# new added +class EnergyStdLossMAD(EnergyStdLoss): + def __init__( + self, + starter_learning_rate=1.0, + start_pref_e=0.0, + limit_pref_e=0.0, + start_pref_f=0.0, + limit_pref_f=0.0, + start_pref_v=0.0, + limit_pref_v=0.0, + start_pref_ae: float = 0.0, + limit_pref_ae: float = 0.0, + start_pref_pf: float = 0.0, + limit_pref_pf: float = 0.0, + relative_f: Optional[float] = None, + enable_atom_ener_coeff: bool = False, + start_pref_gf: float = 0.0, + limit_pref_gf: float = 0.0, + numb_generalized_coord: int = 0, + use_l1_all: bool = False, + inference=False, + use_huber=False, + huber_delta=0.01, + mad_reg_coeff: float = 0.0, # new added + **kwargs, + ) -> None: + r"""Construct a layer to compute loss on energy, force and virial with MAD regularization. + + Parameters + ---------- + mad_reg_coeff : float + The coefficient for MAD (Mean Average Distance) regularization. Set to 0.0 to disable MAD regularization. + **kwargs + Other keyword arguments passed to EnergyStdLoss. + """ + super().__init__( + starter_learning_rate=starter_learning_rate, + start_pref_e=start_pref_e, + limit_pref_e=limit_pref_e, + start_pref_f=start_pref_f, + limit_pref_f=limit_pref_f, + start_pref_v=start_pref_v, + limit_pref_v=limit_pref_v, + start_pref_ae=start_pref_ae, + limit_pref_ae=limit_pref_ae, + start_pref_pf=start_pref_pf, + limit_pref_pf=limit_pref_pf, + relative_f=relative_f, + enable_atom_ener_coeff=enable_atom_ener_coeff, + start_pref_gf=start_pref_gf, + limit_pref_gf=limit_pref_gf, + numb_generalized_coord=numb_generalized_coord, + use_l1_all=use_l1_all, + inference=inference, + use_huber=use_huber, + huber_delta=huber_delta, + **kwargs, + ) + self.mad_reg_coeff = mad_reg_coeff + + def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): + """Return loss on energy and force with MAD regularization. + + Parameters + ---------- + input_dict : dict[str, torch.Tensor] + Model inputs. + model : torch.nn.Module + Model to be used to output the predictions. + label : dict[str, torch.Tensor] + Labels. + natoms : int + The local atom number. + + Returns + ------- + model_pred: dict[str, torch.Tensor] + Model predictions. + loss: torch.Tensor + Loss for model to minimize. + more_loss: dict[str, torch.Tensor] + Other losses for display. + """ + # 调用父类方法获取基础损失 + model_pred, loss, more_loss = super().forward( + input_dict, model, label, natoms, learning_rate, mae=mae + ) + # 获取基础损失后,添加MAD正则化 + # 添加MAD正则化项 + if self.mad_reg_coeff > 0: + descriptor = None + descriptor = model.get_descriptor() + mad_value = descriptor.last_mad_gap # 虽然变量名还是mad_gap,但现在存储的是MAD值 + #print("MAD value in loss:", mad_value.item() if mad_value is not None else "None") + + # 设置目标MAD值 - 余弦距离为1表示正交,是比较理想的状态 + target_mad = 1.0 + + # 方案1: 目标MAD正则化(推荐)- 鼓励MAD接近目标值 + mad_reg_loss = self.mad_reg_coeff * torch.abs(mad_value - target_mad) + + # 方案2: 防止over-smoothing - 只惩罚过小的MAD(可选) + # min_mad = 0.5 + # mad_reg_loss = self.mad_reg_coeff * torch.relu(min_mad - mad_value) + + loss += mad_reg_loss + + # 总是添加MAD相关的损失信息(训练和验证时都需要) + more_loss["mad_reg_loss"] = self.display_if_exist( + mad_reg_loss.detach(), 1.0 + ) + more_loss["mad_value"] = self.display_if_exist( + mad_value.detach(), 1.0 + ) + + return model_pred, loss, more_loss + + def serialize(self) -> dict: + """Serialize the loss module. + + Returns + ------- + dict + The serialized loss module + """ + data = super().serialize() + data.update({ + "@class": "EnergyLossMAD", + "mad_reg_coeff": self.mad_reg_coeff, + }) + return data + + @classmethod + def deserialize(cls, data: dict) -> "TaskLoss": + """Deserialize the loss module. + + Parameters + ---------- + data : dict + The serialized loss module + + Returns + ------- + Loss + The deserialized loss module + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 2, 1) + data.pop("@class") + return cls(**data) \ No newline at end of file diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index dd2da9a3c8..c3cae56841 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -94,6 +94,10 @@ class DescrptDPA3(BaseDescriptor, torch.nn.Module): When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation. type_map : list[str], Optional A list of strings. Give the name to each type of atoms. + enable_mad : bool, Optional + Whether to enable MAD (Mean Average Distance) computation. Set to True to compute MAD values for regularization use. + mad_cutoff_ratio : float, Optional + The ratio to distinguish neighbor and remote nodes for MAD calculation. (Reserved for future extensions) References ---------- @@ -119,6 +123,8 @@ def __init__( use_tebd_bias: bool = False, use_loc_mapping: bool = True, type_map: Optional[list[str]] = None, + enable_mad: bool = False, # new added + mad_cutoff_ratio: float = 0.5, # new added (保留以便后续扩展) ) -> None: super().__init__() @@ -134,7 +140,7 @@ def init_subclass_params(sub_data, sub_class): self.repflow_args = init_subclass_params(repflow, RepFlowArgs) self.activation_function = activation_function - +# here defined the repflows self.repflows = DescrptBlockRepflows( self.repflow_args.e_rcut, self.repflow_args.e_rcut_smth, @@ -211,6 +217,12 @@ def init_subclass_params(sub_data, sub_class): param.requires_grad = trainable self.compress = False + # MAD相关参数存储 + self.enable_mad = enable_mad + self.mad_cutoff_ratio = mad_cutoff_ratio + # 存储MAD值供损失函数使用 (变量名保持last_mad_gap以兼容损失函数) + self.last_mad_gap = None + def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.rcut @@ -392,6 +404,8 @@ def serialize(self) -> dict: "use_loc_mapping": self.use_loc_mapping, "type_map": self.type_map, "type_embedding": self.type_embedding.embedding.serialize(), + "enable_mad": self.enable_mad, # new added + "mad_cutoff_ratio": self.mad_cutoff_ratio, } repflow_variable = { "edge_embd": repflows.edge_embd.serialize(), @@ -492,8 +506,8 @@ def forward( if not parallel_mode and self.use_loc_mapping: node_ebd_ext = self.type_embedding(extended_atype[:, :nloc]) else: - node_ebd_ext = self.type_embedding(extended_atype) - node_ebd_inp = node_ebd_ext[:, :nloc, :] + node_ebd_ext = self.type_embedding(extended_atype) # 节点嵌入表征 [nf, nall, tebd_dim] + node_ebd_inp = node_ebd_ext[:, :nloc, :] # 初始类型嵌入 [nframes, nloc, n_dim] (n_dim=128) # repflows node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows( nlist, @@ -503,16 +517,67 @@ def forward( mapping, comm_dict=comm_dict, ) - if self.concat_output_tebd: - node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1) + if self.concat_output_tebd: # 控制是否在输出时拼接初始类型嵌入 + node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1) # 保留原始信息:确保初始的原子类型信息不会在多层RepFlow处理中完全丢失 + # 同时提供原始类型特征和经过环境学习的特征,这是一种残差连接的思想,类似于ResNet中跳跃连接,防止深层网络丢失重要的基础信息。 + + # MAD计算(在启用时总是计算,不仅仅是训练时) + if self.enable_mad: + #print("Computing MAD for node_ebd shape:", node_ebd.shape) + self.last_mad_gap = self._compute_mad(node_ebd) + #print("MAD value:", self.last_mad_gap.item()) + else: + self.last_mad_gap = None return ( - node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), - rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), - edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), - h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), - sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), # # 1. 节点嵌入表征 [nframes, nloc, n_dim] (n_dim=128) + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), # 2. 旋转等变矩阵 [nframes, nloc, e_dim, 3] (e_dim=128) + edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), # 3. 边嵌入表征 [nframes, nloc, nnei, e_dim] (e_dim=128) + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), # 4. 方向向量 [nframes, nloc, nnei, 3] (3=xyz) + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), # 5. 平滑开关函数 [nframes, nloc, nnei] ) + # 2. 添加简化的 _compute_mad 方法 + def _compute_mad(self, node_ebd: torch.Tensor) -> torch.Tensor: + """计算基础MAD (Mean Average Distance) 用于正则化 + + MAD使用余弦距离衡量节点嵌入表征之间的平均距离: + 余弦距离 = 1 - 余弦相似度 = 1 - (Hi · Hj) / (|Hi| · |Hj|) + + Parameters + ---------- + node_ebd : torch.Tensor + 节点嵌入表征,形状 [nframes, nloc, embed_dim] + + Returns + ------- + torch.Tensor + 所有节点对之间的平均余弦距离 + """ + import torch.nn.functional as F + nframes, nloc, embed_dim = node_ebd.shape + device = node_ebd.device + if nloc <= 1: + return torch.tensor(0.0, device=node_ebd.device) + node_ebd_norm = F.normalize(node_ebd, p=2, dim=-1) # [nf, nloc, embed_dim] + + # 计算余弦相似度矩阵 + cosine_sim = torch.bmm(node_ebd_norm, node_ebd_norm.transpose(-1, -2)) + + # 余弦距离 = 1 - 余弦相似度 + cosine_dist = 1.0 - cosine_sim + # 不相似 --> cosin_sim --> 0 + # Global MAD + #global_mad = cosine_dist.mean() + # 排除对角线(自己与自己的距离为0) + #eye_mask = torch.eye(nloc, dtype=torch.bool, device=device).unsqueeze(0).expand(nframes, -1, -1) + #valid_mask = ~eye_mask + + # 计算所有有效节点对的平均距离 + #valid_distances = cosine_dist[valid_mask] + mad_global = cosine_dist.sum() / (nframes * nloc * (nloc - 1)) + + return mad_global + @classmethod def update_sel( cls, diff --git a/deepmd/pt/model/descriptor/env_mat.py b/deepmd/pt/model/descriptor/env_mat.py index c57ae209fd..9c721dc70d 100644 --- a/deepmd/pt/model/descriptor/env_mat.py +++ b/deepmd/pt/model/descriptor/env_mat.py @@ -29,7 +29,7 @@ def _make_env_mat( coord_pad = torch.concat([coord, coord[:, -1:, :] + rcut], dim=1) coord_r = torch.gather(coord_pad, 1, index) coord_r = coord_r.view(bsz, natoms, nnei, 3) - diff = coord_r - coord_l + diff = coord_r - coord_l # 相对位移 length = torch.linalg.norm(diff, dim=-1, keepdim=True) # for index 0 nloc atom length = length + ~mask.unsqueeze(-1) @@ -39,8 +39,8 @@ def _make_env_mat( compute_smooth_weight(length, ruct_smth, rcut) if not use_exp_switch else compute_exp_sw(length, ruct_smth, rcut) - ) - weight = weight * mask.unsqueeze(-1) + ) # 权重计算 + weight = weight * mask.unsqueeze(-1) # 权重应用 if radial_only: env_mat = t0 * weight else: diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 37d4f07bb4..3e05e263b0 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -1,4 +1,34 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +""" +RepFlow层实现模块 + +RepFlow (Representation Flow) 是DPA3模型的核心消息传递层,实现了: +1. 节点、边、角度的表征更新 +2. 旋转等变和置换不变的消息传递 +3. 多体相互作用的建模 +4. 物理约束的保持 + +本模块包含RepFlowLayer类,是DPA3描述符的基础构建块。 + +残差连接策略说明: +RepFlow层支持三种残差连接策略,通过update_style参数控制: + +1. "res_avg" (残差平均): + u = (u₀ + u₁ + u₂ + ... + uₙ) / √(n+1) + - 所有更新项等权重相加 + - 简单稳定,适合更新项重要性相近的情况 + +2. "res_incr" (残差增量): + u = u₀ + (u₁ + u₂ + ... + uₙ) / √n + - 原始表征保持完整权重,更新项作为增量 + - 与ResNet思想最接近,适合更新项作为修正的情况 + +3. "res_residual" (残差权重,默认): + u = u₀ + r₁*u₁ + r₂*u₂ + ... + rₙ*uₙ + - 每个更新项有独立的可学习权重 + - 提供最大灵活性,模型自动学习权重分配 + - 需要更多参数,适合复杂任务 +""" from typing import ( Optional, Union, @@ -11,59 +41,68 @@ child_seed, ) from deepmd.pt.model.descriptor.repformer_layer import ( - _apply_nlist_mask, - _apply_switch, - _make_nei_g1, - get_residual, + _apply_nlist_mask, # 应用邻居列表掩码 + _apply_switch, # 应用开关函数 + _make_nei_g1, # 构建邻居节点特征 + get_residual, # 获取残差连接权重 ) from deepmd.pt.model.network.mlp import ( - MLPLayer, + MLPLayer, # 基础MLP层 ) from deepmd.pt.model.network.utils import ( - aggregate, + aggregate, # 聚合函数 ) from deepmd.pt.utils.env import ( - PRECISION_DICT, + PRECISION_DICT, # 精度字典 ) from deepmd.pt.utils.utils import ( - ActivationFn, - to_numpy_array, - to_torch_tensor, + ActivationFn, # 激活函数 + to_numpy_array, # 转换为numpy数组 + to_torch_tensor, # 转换为torch张量 ) from deepmd.utils.version import ( - check_version_compatibility, + check_version_compatibility, # 版本兼容性检查 ) class RepFlowLayer(torch.nn.Module): + """RepFlow层:DPA3模型的核心消息传递层 + + RepFlow层实现了图神经网络中的消息传递机制,包括: + 1. 节点表征更新:通过自更新、对称化、边消息传递 + 2. 边表征更新:通过自更新、角度消息传递 + 3. 角度表征更新:通过自更新(如果启用) + + 该层保证旋转等变性和置换不变性,是DPA3描述符的基础构建块。 + """ def __init__( self, - e_rcut: float, - e_rcut_smth: float, - e_sel: int, - a_rcut: float, - a_rcut_smth: float, - a_sel: int, - ntypes: int, - n_dim: int = 128, - e_dim: int = 16, - a_dim: int = 64, - a_compress_rate: int = 0, - a_compress_use_split: bool = False, - a_compress_e_rate: int = 1, - n_multi_edge_message: int = 1, - axis_neuron: int = 4, - update_angle: bool = True, - optim_update: bool = True, - use_dynamic_sel: bool = False, - sel_reduce_factor: float = 10.0, - smooth_edge_update: bool = False, - activation_function: str = "silu", - update_style: str = "res_residual", - update_residual: float = 0.1, - update_residual_init: str = "const", - precision: str = "float64", - seed: Optional[Union[int, list[int]]] = None, + e_rcut: float, # 边截断半径 + e_rcut_smth: float, # 边平滑截断半径 + e_sel: int, # 边邻居选择数量 + a_rcut: float, # 角度截断半径 + a_rcut_smth: float, # 角度平滑截断半径 + a_sel: int, # 角度邻居选择数量 + ntypes: int, # 原子类型数量 + n_dim: int = 128, # 节点表征维度 + e_dim: int = 16, # 边表征维度 + a_dim: int = 64, # 角度表征维度 + a_compress_rate: int = 0, # 角度压缩率 + a_compress_use_split: bool = False, # 是否使用分割压缩 + a_compress_e_rate: int = 1, # 角度边压缩率 + n_multi_edge_message: int = 1, # 多头边消息数量 + axis_neuron: int = 4, # 轴神经元数量 + update_angle: bool = True, # 是否更新角度表征 + optim_update: bool = True, # 是否使用优化更新 + use_dynamic_sel: bool = False, # 是否使用动态选择 + sel_reduce_factor: float = 10.0, # 选择减少因子 + smooth_edge_update: bool = False, # 是否平滑边更新 + activation_function: str = "silu", # 激活函数 + update_style: str = "res_residual", # 更新风格 + update_residual: float = 0.1, # 残差更新参数 + update_residual_init: str = "const", # 残差初始化方式 + precision: str = "float64", # 数值精度 + seed: Optional[Union[int, list[int]]] = None, # 随机种子 ) -> None: super().__init__() self.epsilon = 1e-4 # protection of 1./nnei @@ -115,6 +154,7 @@ def __init__( self.update_residual = update_residual self.update_residual_init = update_residual_init + # 残差连接权重列表 self.n_residual = [] self.e_residual = [] self.a_residual = [] @@ -127,6 +167,7 @@ def __init__( precision=precision, seed=child_seed(seed, 0), ) + # 如果使用残差连接,添加残差权重 if self.update_style == "res_residual": self.n_residual.append( get_residual( @@ -138,7 +179,8 @@ def __init__( ) ) - # node sym (grrg + drrd) + # node sym (grrg + drrd) # 节点对称化MLP:处理GRRG不变量 + # 输入维度:节点×axis + 边×axis = 128×4 + 64×4 = 768 self.n_sym_dim = n_dim * self.axis_neuron + e_dim * self.axis_neuron self.node_sym_linear = MLPLayer( self.n_sym_dim, @@ -156,14 +198,16 @@ def __init__( seed=child_seed(seed, 3), ) ) - + # 节点-边消息传递MLP + # 输入:边信息(320维) → 输出:多头节点更新(1×128或多头) # node edge message self.node_edge_linear = MLPLayer( self.edge_info_dim, self.n_multi_edge_message * n_dim, precision=precision, seed=child_seed(seed, 4), - ) + ) + # 如果使用残差连接,添加残差权重 if self.update_style == "res_residual": for head_index in range(self.n_multi_edge_message): self.n_residual.append( @@ -175,7 +219,11 @@ def __init__( seed=child_seed(child_seed(seed, 5), head_index), ) ) - + # ============================================================================= + # 3. 边相关的神经网络层 + # ============================================================================= + + # 边自更新MLP:边信息(320维) → 边表征(64维) # edge self message self.edge_self_linear = MLPLayer( self.edge_info_dim, @@ -193,33 +241,36 @@ def __init__( seed=child_seed(seed, 7), ) ) - + # ============================================================================= + # 4. 角度相关的神经网络层(如果启用角度更新) + # ============================================================================= + # 如果启用角度更新,则构建角度相关的神经网络层 if self.update_angle: - self.angle_dim = self.a_dim - if self.a_compress_rate == 0: + self.angle_dim = self.a_dim # 角度维度 + if self.a_compress_rate == 0: # 无压缩模式 # angle + node + edge * 2 - self.angle_dim += self.n_dim + 2 * self.e_dim + self.angle_dim += self.n_dim + 2 * self.e_dim # 角度维度 = 角度维度 + 节点维度 + 边维度 × 2 self.a_compress_n_linear = None self.a_compress_e_linear = None self.e_a_compress_dim = e_dim self.n_a_compress_dim = n_dim - else: + else: # 压缩模式 # angle + a_dim/c + a_dim/2c * 2 * e_rate self.angle_dim += (1 + self.a_compress_e_rate) * ( self.a_dim // self.a_compress_rate - ) + ) # 角度维度 = 角度维度 + 角度维度 / 压缩率 × 压缩率 self.e_a_compress_dim = ( self.a_dim // (2 * self.a_compress_rate) * self.a_compress_e_rate - ) - self.n_a_compress_dim = self.a_dim // self.a_compress_rate - if not self.a_compress_use_split: + ) # 边维度 = 角度维度 / 压缩率 × 压缩率 + self.n_a_compress_dim = self.a_dim // self.a_compress_rate # 节点维度 = 角度维度 / 压缩率 + if not self.a_compress_use_split: # 不使用分割模式 self.a_compress_n_linear = MLPLayer( self.n_dim, self.n_a_compress_dim, precision=precision, bias=False, seed=child_seed(seed, 8), - ) + ) # self.a_compress_e_linear = MLPLayer( self.e_dim, self.e_a_compress_dim, @@ -232,13 +283,15 @@ def __init__( self.a_compress_e_linear = None # edge angle message + # 边-角度消息传递:两层MLP设计 + # 第一层:角度信息 → 边表征 self.edge_angle_linear1 = MLPLayer( self.angle_dim, self.e_dim, precision=precision, seed=child_seed(seed, 10), ) - self.edge_angle_linear2 = MLPLayer( + self.edge_angle_linear2 = MLPLayer( # 第二层:边表征 → 边表征(进一步处理) self.e_dim, self.e_dim, precision=precision, @@ -255,7 +308,7 @@ def __init__( ) ) - # angle self message + # angle self message # 角度自更新MLP self.angle_self_linear = MLPLayer( self.angle_dim, self.a_dim, @@ -279,10 +332,11 @@ def __init__( self.a_compress_n_linear = None self.a_compress_e_linear = None self.angle_dim = 0 - - self.n_residual = nn.ParameterList(self.n_residual) - self.e_residual = nn.ParameterList(self.e_residual) - self.a_residual = nn.ParameterList(self.a_residual) + # 将残差权重转换为PyTorch参数列表 + # 这些权重用于"res_residual"策略,每个更新项都有独立的可学习权重 + self.n_residual = nn.ParameterList(self.n_residual) # 节点残差权重 + self.e_residual = nn.ParameterList(self.e_residual) # 边残差权重 + self.a_residual = nn.ParameterList(self.a_residual) # 角度残差权重 @staticmethod def _cal_hg( @@ -291,39 +345,50 @@ def _cal_hg( nlist_mask: torch.Tensor, sw: torch.Tensor, ) -> torch.Tensor: - """ - Calculate the transposed rotation matrix. + """计算转置旋转矩阵 + + 这个函数计算用于对称化操作的转置旋转矩阵,是GRRG不变量计算的关键步骤。 + 通过边嵌入和旋转等变张量的矩阵乘法,生成旋转等变的几何信息。 Parameters ---------- - edge_ebd - Neighbor-wise/Pair-wise edge embeddings, with shape nb x nloc x nnei x e_dim. - h2 - Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. - nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. - sw - The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nb x nloc x nnei. + edge_ebd : torch.Tensor + 邻居/原子对边嵌入,形状为 nb x nloc x nnei x e_dim + h2 : torch.Tensor + 邻居/原子对等变表征张量,形状为 nb x nloc x nnei x 3 + nlist_mask : torch.Tensor + 邻居列表掩码,0表示无邻居,形状为 nb x nloc x nnei + sw : torch.Tensor + 开关函数,在rcut_smth范围内为1,在rcut_smth到rcut之间平滑衰减到0, + 在rcut之外为0,形状为 nb x nloc x nnei Returns ------- - hg - The transposed rotation matrix, with shape nb x nloc x 3 x e_dim. + hg : torch.Tensor + 转置旋转矩阵,形状为 nb x nloc x 3 x e_dim """ + # 获取张量形状信息 # edge_ebd: nb x nloc x nnei x e_dim # h2: nb x nloc x nnei x 3 # msk: nb x nloc x nnei nb, nloc, nnei, _ = edge_ebd.shape e_dim = edge_ebd.shape[-1] - # nb x nloc x nnei x e_dim - edge_ebd = _apply_nlist_mask(edge_ebd, nlist_mask) - edge_ebd = _apply_switch(edge_ebd, sw) + + # 应用邻居列表掩码和开关函数 + # 形状: nb x nloc x nnei x e_dim + edge_ebd = _apply_nlist_mask(edge_ebd, nlist_mask) # 将无效邻居的边嵌入置零 + edge_ebd = _apply_switch(edge_ebd, sw) # 应用开关函数进行平滑截断 + + # 计算邻居数量的逆平方根,用于归一化 invnnei = torch.rsqrt( float(nnei) * torch.ones((nb, nloc, 1, 1), dtype=edge_ebd.dtype, device=edge_ebd.device) ) - # nb x nloc x 3 x e_dim + + # 计算转置旋转矩阵:h2^T * edge_ebd + # h2转置后形状: nb x nloc x 3 x nnei + # edge_ebd形状: nb x nloc x nnei x e_dim + # 结果形状: nb x nloc x 3 x e_dim h2g2 = torch.matmul(torch.transpose(h2, -1, -2), edge_ebd) * invnnei return h2g2 @@ -338,47 +403,55 @@ def _cal_hg_dynamic( nloc: int, scale_factor: float, ) -> torch.Tensor: - """ - Calculate the transposed rotation matrix. + """计算转置旋转矩阵(动态选择版本) + + 这是_cal_hg函数的动态选择版本,用于处理变长邻居列表的情况。 + 通过聚合函数将扁平化的边信息聚合到节点,生成旋转等变的几何信息。 Parameters ---------- - flat_edge_ebd - Flatted neighbor-wise/pair-wise invariant rep tensors, with shape n_edge x e_dim. - flat_h2 - Flatted neighbor-wise/pair-wise equivariant rep tensors, with shape n_edge x 3. - flat_sw - Flatted switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape n_edge. - owner - The owner index of the neighbor to reduce on. + flat_edge_ebd : torch.Tensor + 扁平化的邻居/原子对不变表征张量,形状为 n_edge x e_dim + flat_h2 : torch.Tensor + 扁平化的邻居/原子对等变表征张量,形状为 n_edge x 3 + flat_sw : torch.Tensor + 扁平化的开关函数,在rcut_smth范围内为1,在rcut_smth到rcut之间平滑衰减到0, + 在rcut之外为0,形状为 n_edge + owner : torch.Tensor + 邻居归约的所有者索引 num_owner : int - The total number of the owner. + 所有者的总数 nb : int - The number of batches. + 批次数 nloc : int - The number of local atoms. + 局部原子数 scale_factor : float - The scale factor to apply after reduce. + 归约后应用的缩放因子 Returns ------- - hg - The transposed rotation matrix, with shape nf x nloc x 3 x e_dim. + hg : torch.Tensor + 转置旋转矩阵,形状为 nf x nloc x 3 x e_dim """ n_edge, e_dim = flat_edge_ebd.shape - # n_edge x e_dim + + # 应用开关函数到边嵌入 + # 形状: n_edge x e_dim flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1) - # n_edge x 3 x e_dim + + # 计算外积:h2[:, None, :] * edge_ebd[:, :, None] + # 形状: n_edge x 3 x e_dim flat_h2g2 = (flat_h2[..., None] * flat_edge_ebd[:, None, :]).reshape( -1, 3 * e_dim ) - # nf x nloc x 3 x e_dim + + # 使用聚合函数将边信息聚合到节点 + # 形状: nf x nloc x 3 x e_dim h2g2 = ( aggregate(flat_h2g2, owner, average=False, num_owner=num_owner).reshape( nb, nloc, 3, e_dim ) - * scale_factor + * scale_factor # 应用缩放因子 ) return h2g2 @@ -697,88 +770,113 @@ def forward( edge_index: torch.Tensor, # n_edge x 2 angle_index: torch.Tensor, # n_angle x 3 ): - """ + """RepFlow层的前向传播函数 + + 这是RepFlow层的核心函数,实现了完整的消息传递机制: + 1. 节点表征更新:自更新 + 对称化 + 边消息传递 + 2. 边表征更新:自更新 + 角度消息传递 + 3. 角度表征更新:自更新(如果启用) + + 整个过程保证旋转等变性和置换不变性。 + Parameters ---------- node_ebd_ext : nf x nall x n_dim - Extended node embedding. + 扩展节点嵌入,包含所有原子的表征 edge_ebd : nf x nloc x nnei x e_dim - Edge embedding. + 边嵌入,表示原子对之间的相互作用 h2 : nf x nloc x nnei x 3 - Pair-atom channel, equivariant. + 旋转等变的原子对通道,用于保持旋转等变性 angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim - Angle embedding. + 角度嵌入,表示三体相互作用 nlist : nf x nloc x nnei - Neighbor list. (padded neis are set to 0) + 邻居列表,填充的邻居设为0 nlist_mask : nf x nloc x nnei - Masks of the neighbor list. real nei 1 otherwise 0 + 邻居列表掩码,真实邻居为1,否则为0 sw : nf x nloc x nnei - Switch function. + 开关函数,用于平滑截断 a_nlist : nf x nloc x a_nnei - Neighbor list for angle. (padded neis are set to 0) + 角度邻居列表,填充的邻居设为0 a_nlist_mask : nf x nloc x a_nnei - Masks of the neighbor list for angle. real nei 1 otherwise 0 + 角度邻居列表掩码,真实邻居为1,否则为0 a_sw : nf x nloc x a_nnei - Switch function for angle. - edge_index : Optional for dynamic sel, n_edge x 2 - n2e_index : n_edge - Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). - n_ext2e_index : n_edge - Broadcast indices from extended node(j) to edge(ij). - angle_index : Optional for dynamic sel, n_angle x 3 - n2a_index : n_angle - Broadcast indices from extended node(j) to angle(ijk). - eij2a_index : n_angle - Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). - eik2a_index : n_angle - Broadcast indices from extended edge(ik) to angle(ijk). + 角度开关函数,用于平滑截断 + edge_index : n_edge x 2 (动态选择时使用) + n2e_index : n_edge - 从节点(i)到边(ij)的广播索引 + n_ext2e_index : n_edge - 从扩展节点(j)到边(ij)的广播索引 + angle_index : n_angle x 3 (动态选择时使用) + n2a_index : n_angle - 从扩展节点(j)到角度(ijk)的广播索引 + eij2a_index : n_angle - 从边(ij)到角度(ijk)的广播索引 + eik2a_index : n_angle - 从边(ik)到角度(ijk)的广播索引 Returns ------- - n_updated: nf x nloc x n_dim - Updated node embedding. - e_updated: nf x nloc x nnei x e_dim - Updated edge embedding. - a_updated : nf x nloc x a_nnei x a_nnei x a_dim - Updated angle embedding. + n_updated: nf x nloc x n_dim + 更新后的节点表征 + e_updated: nf x nloc x nnei x e_dim + 更新后的边表征 + a_updated: nf x nloc x a_nnei x a_nnei x a_dim + 更新后的角度表征 """ - nb, nloc, nnei = nlist.shape - nall = node_ebd_ext.shape[1] - node_ebd = node_ebd_ext[:, :nloc, :] - n_edge = int(nlist_mask.sum().item()) + # ============================================================================= + # 1. 输入预处理和形状检查 + # ============================================================================= + nb, nloc, nnei = nlist.shape # 批次数、局部原子数、近邻数 + nall = node_ebd_ext.shape[1] # 扩展节点数 + node_ebd = node_ebd_ext[:, :nloc, :] # 局部节点表征 + n_edge = int(nlist_mask.sum().item()) # 实际边数量 assert (nb, nloc) == node_ebd.shape[:2] - if not self.use_dynamic_sel: + + # 检查h2张量的形状(根据是否使用动态选择) + if not self.use_dynamic_sel: # 不使用动态选择 assert (nb, nloc, nnei, 3) == h2.shape - else: + else: # 使用动态选择 assert (n_edge, 3) == h2.shape - del a_nlist # may be used in the future + del a_nlist # 可能在未来使用 - n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1] + # 提取索引信息 + n2e_index, n_ext2e_index = edge_index[:, 0], edge_index[:, 1] # 节点到边的索引、扩展节点到边的索引 n2a_index, eij2a_index, eik2a_index = ( - angle_index[:, 0], - angle_index[:, 1], - angle_index[:, 2], + angle_index[:, 0], # 节点到角度的索引 + angle_index[:, 1], # 边ij到角度的索引 + angle_index[:, 2], # 边ik到角度的索引 ) - # nb x nloc x nnei x n_dim [OR] n_edge x n_dim + # ============================================================================= + # 2. 构建邻居节点嵌入 + # ============================================================================= + # 构建近邻节点嵌入:每个边对应的邻居节点特征 + # 形状: nb x nloc x nnei x n_dim [OR] n_edge x n_dim nei_node_ebd = ( - _make_nei_g1(node_ebd_ext, nlist) + _make_nei_g1(node_ebd_ext, nlist) # 标准模式:通过邻居列表构建 if not self.use_dynamic_sel - else torch.index_select( + else torch.index_select( # 动态模式:通过索引选择 node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index ) ) - n_update_list: list[torch.Tensor] = [node_ebd] - e_update_list: list[torch.Tensor] = [edge_ebd] - a_update_list: list[torch.Tensor] = [angle_ebd] - - # node self mlp + # ============================================================================= + # 3. 初始化更新列表(用于残差连接) + # ============================================================================= + n_update_list: list[torch.Tensor] = [node_ebd] # 节点更新列表,包含原始节点 + e_update_list: list[torch.Tensor] = [edge_ebd] # 边更新列表,包含原始边 + a_update_list: list[torch.Tensor] = [angle_ebd] # 角度更新列表,包含原始角度 + + # ============================================================================= + # 4. 节点表征更新 + # ============================================================================= + + # 4.1 节点自更新:通过MLP处理节点特征 node_self_mlp = self.act(self.node_self_mlp(node_ebd)) n_update_list.append(node_self_mlp) - - # node sym (grrg + drrd) + + # 4.2 节点对称化更新:基于GRRG不变量的几何信息 + # (从边信息聚合到节点) + # 这部分实现了旋转不变性,通过对称化操作处理几何信息 node_sym_list: list[torch.Tensor] = [] + + # 计算边嵌入的GRRG不变量 + # symmetrization_op: 边嵌入 → 对称化不变量 [nf, nloc, axis*e_dim] node_sym_list.append( self.symmetrization_op( edge_ebd, @@ -788,7 +886,7 @@ def forward( self.axis_neuron, ) if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( + else self.symmetrization_op_dynamic( # 动态选择模式 edge_ebd, h2, sw, @@ -800,6 +898,8 @@ def forward( axis_neuron=self.axis_neuron, ) ) + + # 计算邻居节点的GRRG不变量 node_sym_list.append( self.symmetrization_op( nei_node_ebd, @@ -809,7 +909,7 @@ def forward( self.axis_neuron, ) if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( + else self.symmetrization_op_dynamic( # 动态选择模式 nei_node_ebd, h2, sw, @@ -821,43 +921,53 @@ def forward( axis_neuron=self.axis_neuron, ) ) + + # 将两个GRRG不变量拼接并通过MLP处理 node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) n_update_list.append(node_sym) + # ============================================================================= + # 5. 节点-边消息传递 + # ============================================================================= + + # 5.1 构建边信息(用于消息传递) if not self.optim_update: if not self.use_dynamic_sel: - # nb x nloc x nnei x (n_dim * 2 + e_dim) + # 标准模式:拼接节点、邻居节点、边信息 + # 形状: nb x nloc x nnei x (n_dim * 2 + e_dim) edge_info = torch.cat( [ - torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), - nei_node_ebd, - edge_ebd, + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), # 中心节点特征 + nei_node_ebd, # 邻居节点特征 + edge_ebd, # 边特征 ], dim=-1, ) else: - # n_edge x (n_dim * 2 + e_dim) + # 动态模式:通过索引选择 + # 形状: n_edge x (n_dim * 2 + e_dim) edge_info = torch.cat( [ torch.index_select( node_ebd.reshape(-1, self.n_dim), 0, n2e_index - ), - nei_node_ebd, - edge_ebd, + ), # 中心节点特征 + nei_node_ebd, # 邻居节点特征 + edge_ebd, # 边特征 ], dim=-1, ) else: - edge_info = None + edge_info = None # 优化模式不需要预构建边信息 - # node edge message - # nb x nloc x nnei x (h * n_dim) + # 5.2 节点-边消息传递 + # 通过边信息更新节点表征,实现消息传递 if not self.optim_update: assert edge_info is not None node_edge_update = self.act( self.node_edge_linear(edge_info) - ) * sw.unsqueeze(-1) + ) * sw.unsqueeze(-1) # 应用开关函数 else: + # 优化模式:直接计算,避免构建大型中间张量 node_edge_update = self.act( self.optim_edge_update( node_ebd, @@ -875,11 +985,14 @@ def forward( n_ext2e_index, "node", ) - ) * sw.unsqueeze(-1) + ) * sw.unsqueeze(-1) # 应用开关函数 + + # 5.3 聚合边消息到节点 + # 将来自所有邻居的消息聚合到中心节点 node_edge_update = ( - (torch.sum(node_edge_update, dim=-2) / self.nnei) + (torch.sum(node_edge_update, dim=-2) / self.nnei) # 标准模式:平均聚合 if not self.use_dynamic_sel - else ( + else ( # 动态模式:使用聚合函数 aggregate( node_edge_update, n2e_index, @@ -890,23 +1003,32 @@ def forward( ) ) + # 5.4 处理多头边消息(如果启用) if self.n_multi_edge_message > 1: - # nb x nloc x h x n_dim + # 将边消息重塑为多头格式 + # 形状: nb x nloc x h x n_dim node_edge_update_mul_head = node_edge_update.view( nb, nloc, self.n_multi_edge_message, self.n_dim ) + # 将每个头作为独立的更新项 for head_index in range(self.n_multi_edge_message): n_update_list.append(node_edge_update_mul_head[..., head_index, :]) else: n_update_list.append(node_edge_update) - # update node_ebd + + # 5.5 更新节点表征(使用残差连接) n_updated = self.list_update(n_update_list, "node") - # edge self message + # ============================================================================= + # 6. 边表征更新 G2 + # ============================================================================= + + # 6.1 边自更新:通过边信息更新边表征 if not self.optim_update: assert edge_info is not None edge_self_update = self.act(self.edge_self_linear(edge_info)) else: + # 优化模式:直接计算边更新 edge_self_update = self.act( self.optim_edge_update( node_ebd, @@ -927,34 +1049,48 @@ def forward( ) e_update_list.append(edge_self_update) + # ============================================================================= + # 7. 角度表征更新(如果启用) + # ============================================================================= + if self.update_angle: assert self.angle_self_linear is not None assert self.edge_angle_linear1 is not None assert self.edge_angle_linear2 is not None - # get angle info + + # 7.1 角度信息压缩(如果启用) + # 为了减少计算量,可以对节点和边特征进行压缩 if self.a_compress_rate != 0: if not self.a_compress_use_split: + # 使用MLP进行压缩 assert self.a_compress_n_linear is not None assert self.a_compress_e_linear is not None node_ebd_for_angle = self.a_compress_n_linear(node_ebd) edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) else: - # use the first a_compress_dim dim for node and edge + # 使用分割方式:取前几个维度 node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] else: + # 不压缩:直接使用原始特征 node_ebd_for_angle = node_ebd edge_ebd_for_angle = edge_ebd + # 7.2 处理角度边特征 if not self.use_dynamic_sel: - # nb x nloc x a_nnei x e_dim + # 标准模式:截取角度邻居数量并应用掩码 + # 形状: nb x nloc x a_nnei x e_dim edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] - # nb x nloc x a_nnei x e_dim edge_ebd_for_angle = torch.where( a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 ) + + # 7.3 构建角度信息(用于角度消息传递) if not self.optim_update: - # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim + # 标准模式:构建角度信息张量 + + # 节点信息:为每个角度对复制节点特征 + # 形状: nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim node_for_angle_info = ( torch.tile( node_ebd_for_angle.unsqueeze(2).unsqueeze(2), @@ -968,7 +1104,8 @@ def forward( ) ) - # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim + # 边信息k:为每个角度对复制边ik特征 + # 形状: nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim edge_for_angle_k = ( torch.tile( edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) @@ -976,7 +1113,9 @@ def forward( if not self.use_dynamic_sel else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) ) - # nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim + + # 边信息j:为每个角度对复制边ij特征 + # 形状: nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim edge_for_angle_j = ( torch.tile( edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) @@ -984,26 +1123,32 @@ def forward( if not self.use_dynamic_sel else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) ) - # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim) + + # 拼接边信息:边ik和边ij的特征 + # 形状: nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim) edge_for_angle_info = torch.cat( [edge_for_angle_k, edge_for_angle_j], dim=-1 ) + + # 构建完整的角度信息列表 angle_info_list = [angle_ebd] angle_info_list.append(node_for_angle_info) angle_info_list.append(edge_for_angle_info) - # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) - # [OR] - # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) + + # 拼接所有角度信息 + # 形状: nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) + # [OR] n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) angle_info = torch.cat(angle_info_list, dim=-1) else: - angle_info = None + angle_info = None # 优化模式不需要预构建角度信息 - # edge angle message - # nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim + # 7.4 边-角度消息传递 G2 MP + # 通过角度信息更新边表征,实现三体相互作用建模 if not self.optim_update: assert angle_info is not None edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) else: + # 优化模式:直接计算角度消息 edge_angle_update = self.act( self.optim_angle_update( angle_ebd, @@ -1023,16 +1168,22 @@ def forward( ) ) + # 7.5 处理角度消息的权重和聚合 if not self.use_dynamic_sel: - # nb x nloc x a_nnei x a_nnei x e_dim + # 标准模式:应用角度开关函数并聚合 + # 形状: nb x nloc x a_nnei x a_nnei x e_dim weighted_edge_angle_update = ( a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update ) - # nb x nloc x a_nnei x e_dim + + # 沿角度维度聚合:从角度对聚合到边 + # 形状: nb x nloc x a_nnei x e_dim reduced_edge_angle_update = torch.sum( weighted_edge_angle_update, dim=-2 ) / (self.a_sel**0.5) - # nb x nloc x nnei x e_dim + + # 填充到完整的边维度:将角度边扩展到所有边 + # 形状: nb x nloc x nnei x e_dim padding_edge_angle_update = torch.concat( [ reduced_edge_angle_update, @@ -1045,9 +1196,12 @@ def forward( dim=2, ) else: - # n_angle x e_dim + # 动态模式:使用聚合函数 + # 形状: n_angle x e_dim weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) - # n_edge x e_dim + + # 聚合角度消息到边 + # 形状: n_edge x e_dim padding_edge_angle_update = aggregate( weighted_edge_angle_update, eij2a_index, @@ -1055,13 +1209,15 @@ def forward( num_owner=n_edge, ) / (self.dynamic_a_sel**0.5) + # 7.6 平滑边更新处理(向后兼容) if not self.smooth_edge_update: - # will be deprecated in the future - # not support dynamic index, will pass anyway + # 注意:此功能将在未来版本中弃用 + # 不支持动态索引,但会通过检查 if self.use_dynamic_sel: raise NotImplementedError( "smooth_edge_update must be True when use_dynamic_sel is True!" ) + # 构建完整的掩码 full_mask = torch.concat( [ a_nlist_mask, @@ -1073,21 +1229,26 @@ def forward( ], dim=-1, ) + # 应用掩码:在非角度边位置使用原始边特征 padding_edge_angle_update = torch.where( full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd ) + + # 7.7 边角度消息的进一步处理 e_update_list.append( self.act(self.edge_angle_linear2(padding_edge_angle_update)) ) - # update edge_ebd + + # 7.8 更新边表征(使用残差连接) e_updated = self.list_update(e_update_list, "edge") - # angle self message - # nb x nloc x a_nnei x a_nnei x dim_a + # 7.9 角度自更新消息 + # 通过角度信息更新角度表征 if not self.optim_update: assert angle_info is not None angle_self_update = self.act(self.angle_self_linear(angle_info)) else: + # 优化模式:直接计算角度自更新 angle_self_update = self.act( self.optim_angle_update( angle_ebd, @@ -1108,11 +1269,17 @@ def forward( ) a_update_list.append(angle_self_update) else: - # update edge_ebd + # 如果未启用角度更新,只更新边表征 e_updated = self.list_update(e_update_list, "edge") - # update angle_ebd + # ============================================================================= + # 8. 最终更新和返回 + # ============================================================================= + + # 更新角度表征(使用残差连接) a_updated = self.list_update(a_update_list, "angle") + + # 返回更新后的所有表征 return n_updated, e_updated, a_updated @torch.jit.export @@ -1120,45 +1287,180 @@ def list_update_res_avg( self, update_list: list[torch.Tensor], ) -> torch.Tensor: + """残差平均更新策略 + + 这是三种残差连接策略之一,实现方式为: + u = (u₀ + u₁ + u₂ + ... + uₙ) / √(n+1) + + 其中: + - u₀: 原始表征(第0个元素) + - u₁, u₂, ..., uₙ: 各种更新项 + - n+1: 总项数(包括原始项) + + 这种策略的特点: + 1. 所有更新项等权重相加 + 2. 使用√(n+1)进行归一化,保持数值稳定性 + 3. 适合更新项重要性相近的情况 + + Parameters + ---------- + update_list : list[torch.Tensor] + 更新列表,第一个元素是原始表征,后续是各种更新项 + 例如:[node_ebd, node_self_update, node_sym_update, node_edge_update] + + Returns + ------- + torch.Tensor + 更新后的表征,形状与原始表征相同 + """ nitem = len(update_list) - uu = update_list[0] + uu = update_list[0] # 从原始表征开始 for ii in range(1, nitem): - uu = uu + update_list[ii] - return uu / (float(nitem) ** 0.5) + uu = uu + update_list[ii] # 累加所有更新项 + return uu / (float(nitem) ** 0.5) # 归一化:除以√(n+1) @torch.jit.export def list_update_res_incr(self, update_list: list[torch.Tensor]) -> torch.Tensor: + """残差增量更新策略 + + 这是三种残差连接策略之一,实现方式为: + u = u₀ + (u₁ + u₂ + ... + uₙ) / √n + + 其中: + - u₀: 原始表征(第0个元素) + - u₁, u₂, ..., uₙ: 各种更新项 + - n: 更新项数量(不包括原始项) + + 这种策略的特点: + 1. 原始表征保持完整权重 + 2. 更新项作为增量,按√n归一化 + 3. 适合更新项作为对原始表征的修正的情况 + 4. 与ResNet的残差连接思想最接近 + + Parameters + ---------- + update_list : list[torch.Tensor] + 更新列表,第一个元素是原始表征,后续是各种更新项 + 例如:[node_ebd, node_self_update, node_sym_update, node_edge_update] + + Returns + ------- + torch.Tensor + 更新后的表征,形状与原始表征相同 + """ nitem = len(update_list) - uu = update_list[0] + uu = update_list[0] # 从原始表征开始 + # 计算更新项的归一化因子:1/√(n-1),其中n-1是更新项数量 scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 for ii in range(1, nitem): - uu = uu + scale * update_list[ii] + uu = uu + scale * update_list[ii] # 添加归一化的更新项 return uu @torch.jit.export def list_update_res_residual( self, update_list: list[torch.Tensor], update_name: str = "node" ) -> torch.Tensor: + """残差权重更新策略(最灵活的策略) + + 这是三种残差连接策略中最灵活的一种,实现方式为: + u = u₀ + r₁*u₁ + r₂*u₂ + ... + rₙ*uₙ + + 其中: + - u₀: 原始表征(第0个元素) + - u₁, u₂, ..., uₙ: 各种更新项 + - r₁, r₂, ..., rₙ: 可学习的残差权重参数 + + 这种策略的特点: + 1. 每个更新项都有独立的可学习权重 + 2. 模型可以自动学习哪些更新项更重要 + 3. 权重参数在训练过程中自适应调整 + 4. 提供最大的表达能力和灵活性 + 5. 需要更多的参数和计算资源 + + 权重初始化: + - 初始标准差由update_residual参数控制(默认0.1) + - 初始化方式由update_residual_init控制("const"或"norm") + + Parameters + ---------- + update_list : list[torch.Tensor] + 更新列表,第一个元素是原始表征,后续是各种更新项 + 例如:[node_ebd, node_self_update, node_sym_update, node_edge_update] + update_name : str + 更新类型,决定使用哪组残差权重: + - "node": 使用self.n_residual权重 + - "edge": 使用self.e_residual权重 + - "angle": 使用self.a_residual权重 + + Returns + ------- + torch.Tensor + 更新后的表征,形状与原始表征相同 + """ nitem = len(update_list) - uu = update_list[0] - # make jit happy + uu = update_list[0] # 从原始表征开始 + + # 根据更新类型选择对应的残差权重 + # 注意:这里使用"make jit happy"的写法,避免动态属性访问 if update_name == "node": + # 使用节点残差权重:n_residual = [r₁, r₂, r₃, ...] for ii, vv in enumerate(self.n_residual): - uu = uu + vv * update_list[ii + 1] + uu = uu + vv * update_list[ii + 1] # u₀ + r₁*u₁ + r₂*u₂ + ... elif update_name == "edge": + # 使用边残差权重:e_residual = [r₁, r₂, r₃, ...] for ii, vv in enumerate(self.e_residual): - uu = uu + vv * update_list[ii + 1] + uu = uu + vv * update_list[ii + 1] # u₀ + r₁*u₁ + r₂*u₂ + ... elif update_name == "angle": + # 使用角度残差权重:a_residual = [r₁, r₂, r₃, ...] for ii, vv in enumerate(self.a_residual): - uu = uu + vv * update_list[ii + 1] + uu = uu + vv * update_list[ii + 1] # u₀ + r₁*u₁ + r₂*u₂ + ... else: - raise NotImplementedError + raise NotImplementedError(f"Unknown update_name: {update_name}") return uu @torch.jit.export def list_update( self, update_list: list[torch.Tensor], update_name: str = "node" ) -> torch.Tensor: + """残差更新策略的统一入口函数 + + 根据配置的update_style参数,选择相应的残差连接策略: + + 1. "res_avg": 残差平均策略 + - 公式: u = (u₀ + u₁ + u₂ + ... + uₙ) / √(n+1) + - 特点: 所有项等权重,简单稳定 + - 适用: 更新项重要性相近的情况 + + 2. "res_incr": 残差增量策略 + - 公式: u = u₀ + (u₁ + u₂ + ... + uₙ) / √n + - 特点: 原始项保持完整权重,更新项作为增量 + - 适用: 更新项作为对原始表征的修正 + + 3. "res_residual": 残差权重策略(默认) + - 公式: u = u₀ + r₁*u₁ + r₂*u₂ + ... + rₙ*uₙ + - 特点: 每个更新项有独立可学习权重 + - 适用: 需要最大灵活性的情况 + + 策略选择建议: + - 简单任务: 使用"res_avg"或"res_incr" + - 复杂任务: 使用"res_residual"(默认) + - 计算资源受限: 避免"res_residual" + + Parameters + ---------- + update_list : list[torch.Tensor] + 更新列表,第一个元素是原始表征,后续是各种更新项 + update_name : str + 更新类型,用于"res_residual"策略选择对应的权重组 + - "node": 节点更新 + - "edge": 边更新 + - "angle": 角度更新 + + Returns + ------- + torch.Tensor + 更新后的表征,形状与原始表征相同 + """ if self.update_style == "res_avg": return self.list_update_res_avg(update_list) elif self.update_style == "res_incr": diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 5408c49482..3782e09dbb 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -1,4 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +""" +RepFlow描述符块实现模块 + +本模块实现了DescrptBlockRepflows类,是DPA3描述符的核心组件,负责: +1. 管理多个RepFlow层的堆叠 +2. 处理环境矩阵的构建和统计 +3. 协调节点、边、角度的嵌入计算 +4. 实现动态邻居选择和图索引管理 + +RepFlow描述符块是DPA3模型的主要描述符实现,通过多层RepFlow层 +实现复杂的图神经网络消息传递机制。 +""" from typing import ( Callable, Optional, @@ -8,47 +20,47 @@ import torch from deepmd.dpmodel.utils.seed import ( - child_seed, + child_seed, # 子种子生成器 ) from deepmd.pt.model.descriptor.descriptor import ( - DescriptorBlock, + DescriptorBlock, # 描述符块基类 ) from deepmd.pt.model.descriptor.env_mat import ( - prod_env_mat, + prod_env_mat, # 环境矩阵生成函数 ) from deepmd.pt.model.network.mlp import ( - MLPLayer, + MLPLayer, # MLP层 ) from deepmd.pt.model.network.utils import ( - get_graph_index, + get_graph_index, # 图索引生成函数 ) from deepmd.pt.utils import ( - env, + env, # 环境配置 ) from deepmd.pt.utils.env import ( - PRECISION_DICT, + PRECISION_DICT, # 精度字典 ) from deepmd.pt.utils.env_mat_stat import ( - EnvMatStatSe, + EnvMatStatSe, # 环境矩阵统计 ) from deepmd.pt.utils.exclude_mask import ( - PairExcludeMask, + PairExcludeMask, # 原子对排除掩码 ) from deepmd.pt.utils.spin import ( - concat_switch_virtual, + concat_switch_virtual, # 虚拟原子拼接 ) from deepmd.pt.utils.utils import ( - ActivationFn, + ActivationFn, # 激活函数 ) from deepmd.utils.env_mat_stat import ( - StatItem, + StatItem, # 统计项 ) from deepmd.utils.path import ( - DPPath, + DPPath, # 路径处理 ) from .repflow_layer import ( - RepFlowLayer, + RepFlowLayer, # RepFlow层 ) if not hasattr(torch.ops.deepmd, "border_op"): @@ -75,41 +87,41 @@ def border_op( @DescriptorBlock.register("se_repflow") class DescrptBlockRepflows(DescriptorBlock): - r""" - The repflow descriptor block. + """RepFlow描述符块 + + 这是DPA3模型的核心描述符块,实现了基于RepFlow的图神经网络描述符。 + 通过堆叠多个RepFlow层,实现复杂的消息传递机制,用于建模原子间的相互作用。 Parameters ---------- n_dim : int, optional - The dimension of node representation. + 节点表征的维度,默认为128 e_dim : int, optional - The dimension of edge representation. + 边表征的维度,默认为64 a_dim : int, optional - The dimension of angle representation. + 角度表征的维度,默认为64 nlayers : int, optional - Number of repflow layers. + RepFlow层的数量,默认为6 e_rcut : float, optional - The edge cut-off radius. + 边的截断半径 e_rcut_smth : float, optional - Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth. + 边平滑截断的起始位置,例如1/r项从rcut到rcut_smth平滑 e_sel : int, optional - Maximally possible number of selected edge neighbors. + 边邻居的最大选择数量 a_rcut : float, optional - The angle cut-off radius. + 角度的截断半径 a_rcut_smth : float, optional - Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. + 角度平滑截断的起始位置 a_sel : int, optional - Maximally possible number of selected angle neighbors. + 角度邻居的最大选择数量 a_compress_rate : int, optional - The compression rate for angular messages. The default value is 0, indicating no compression. - If a non-zero integer c is provided, the node and edge dimensions will be compressed - to a_dim/c and a_dim/2c, respectively, within the angular message. + 角度消息的压缩率,默认为0表示无压缩。 + 如果提供非零整数c,节点和边维度将在角度消息中分别压缩到a_dim/c和a_dim/2c a_compress_e_rate : int, optional - The extra compression rate for edge in angular message compression. The default value is 1. - When using angular message compression with a_compress_rate c and a_compress_e_rate c_e, - the edge dimension will be compressed to (c_e * a_dim / 2c) within the angular message. + 角度消息压缩中边的额外压缩率,默认为1。 + 当使用角度消息压缩时,边维度将压缩到(c_e * a_dim / 2c) a_compress_use_split : bool, optional - Whether to split first sub-vectors instead of linear mapping during angular message compression. + 在角度消息压缩期间是否分割第一个子向量而不是线性映射 The default value is False. n_multi_edge_message : int, optional The head number of multiple edge messages to update node feature. @@ -284,10 +296,10 @@ def __init__( self.edge_embd = MLPLayer( 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) - ) + ) # 创建边嵌入 self.angle_embd = MLPLayer( 1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1) - ) + ) # 创建角度嵌入 layers = [] for ii in range(nlayers): layers.append( @@ -319,7 +331,7 @@ def __init__( smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), ) - ) + ) # 创建RepFlow层 self.layers = torch.nn.ModuleList(layers) wanted_shape = (self.ntypes, self.nnei, 4) @@ -426,16 +438,64 @@ def forward( mapping: Optional[torch.Tensor] = None, comm_dict: Optional[dict[str, torch.Tensor]] = None, ): + """RepFlow描述符块的前向传播函数 + + 这是RepFlow描述符块的核心函数,负责: + 1. 构建环境矩阵和邻居信息 + 2. 计算初始的边和角度嵌入 + 3. 通过多层RepFlow层进行消息传递 + 4. 生成最终的节点、边、角度表征 + + Parameters + ---------- + nlist : torch.Tensor + 邻居列表,形状为 nf x nloc x nnei + extended_coord : torch.Tensor + 扩展坐标,形状为 nf x (nall*3) + extended_atype : torch.Tensor + 扩展原子类型,形状为 nf x nall + extended_atype_embd : Optional[torch.Tensor] + 扩展原子类型嵌入,形状为 nf x nall x n_dim + mapping : Optional[torch.Tensor] + 索引映射,将扩展区域索引映射到局部区域 + comm_dict : Optional[dict[str, torch.Tensor]] + 并行推理所需的通信数据 + + Returns + ------- + node_ebd : torch.Tensor + 节点嵌入,形状为 nf x nloc x n_dim + edge_ebd : torch.Tensor + 边嵌入,形状为 nf x nloc x nnei x e_dim + h2 : torch.Tensor + 旋转等变表征,形状为 nf x nloc x nnei x 3 + rot_mat : torch.Tensor + 旋转矩阵,形状为 nf x nloc x e_dim x 3 + sw : torch.Tensor + 开关函数,形状为 nf x nloc x nnei + """ + # ============================================================================= + # 1. 输入预处理和模式检测 + # ============================================================================= parallel_mode = comm_dict is not None if not parallel_mode: assert mapping is not None nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] - # nb x nloc x nnei + + # ============================================================================= + # 2. 处理排除的原子对 + # ============================================================================= + # 应用排除掩码:将排除的原子对设为-1 exclude_mask = self.emask(nlist, extended_atype) nlist = torch.where(exclude_mask != 0, nlist, -1) - # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 + + # ============================================================================= + # 3. 构建环境矩阵 + # ============================================================================= + # 生成环境矩阵:包含距离矩阵、方向向量、开关函数 + # 形状: nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 dmatrix, diff, sw = prod_env_mat( extended_coord, nlist, @@ -447,17 +507,24 @@ def forward( protection=self.env_protection, use_exp_switch=self.use_exp_switch, ) - nlist_mask = nlist != -1 + + # 处理邻居列表掩码和开关函数 + nlist_mask = nlist != -1 # nlist_mask 是邻居列表掩码,真实邻居为1,否则为0, -1是填充的邻居 sw = torch.squeeze(sw, -1) - # beyond the cutoff sw should be 0.0 + # 在截断半径之外,开关函数应该为0.0 sw = sw.masked_fill(~nlist_mask, 0.0) - # get angle nlist (maybe smaller) + # ============================================================================= + # 4. 构建角度环境矩阵 + # ============================================================================= + # 获取角度邻居列表(可能比边邻居列表小) a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[ :, :, : self.a_sel ] a_nlist = nlist[:, :, : self.a_sel] a_nlist = torch.where(a_dist_mask, a_nlist, -1) + + # 为角度计算生成环境矩阵:包含距离矩阵、方向向量、开关函数 _, a_diff, a_sw = prod_env_mat( extended_coord, a_nlist, @@ -469,52 +536,69 @@ def forward( protection=self.env_protection, use_exp_switch=self.use_exp_switch, ) + + # 处理角度邻居列表掩码和开关函数 a_nlist_mask = a_nlist != -1 a_sw = torch.squeeze(a_sw, -1) - # beyond the cutoff sw should be 0.0 + # 在截断半径之外,开关函数应该为0.0 a_sw = a_sw.masked_fill(~a_nlist_mask, 0.0) - # set all padding positions to index of 0 - # if the a neighbor is real or not is indicated by nlist_mask + + # 将所有填充位置设为索引0 + # 邻居是否真实由nlist_mask指示 nlist[nlist == -1] = 0 a_nlist[a_nlist == -1] = 0 - # get node embedding - # [nframes, nloc, tebd_dim] + # ============================================================================= + # 5. 获取节点嵌入 + # ============================================================================= + # 从扩展原子类型嵌入中提取局部原子嵌入 + # 形状: [nframes, nloc, tebd_dim] assert extended_atype_embd is not None atype_embd = extended_atype_embd[:, :nloc, :] assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] - assert isinstance(atype_embd, torch.Tensor) # for jit - node_ebd = self.act(atype_embd) + assert isinstance(atype_embd, torch.Tensor) # 用于jit编译 + node_ebd = self.act(atype_embd) # 应用激活函数 n_dim = node_ebd.shape[-1] - # get edge and angle embedding input - # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + # ============================================================================= + # 6. 获取边和角度嵌入输入 + # ============================================================================= + # 从环境矩阵中分离边输入和旋转等变表征 + # 形状: nb x nloc x nnei x 1, nb x nloc x nnei x 3 edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) + + # 如果使用直接距离初始化边特征, nb = nframe if self.edge_init_use_dist: - # nb x nloc x nnei x 1 + # 形状: nb x nloc x nnei x 1 edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True) - # nf x nloc x a_nnei x 3 + # 计算角度输入:归一化的方向向量 + # 形状: nf x nloc x a_nnei x 3 normalized_diff_i = a_diff / ( torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6 ) - # nf x nloc x 3 x a_nnei + # 形状: nf x nloc x 3 x a_nnei normalized_diff_j = torch.transpose(normalized_diff_i, 2, 3) - # nf x nloc x a_nnei x a_nnei - # 1 - 1e-6 for torch.acos stability + # 计算角度输入:原子对之间的余弦值 + # 形状: nf x nloc x a_nnei x a_nnei + # 1 - 1e-6 用于torch.acos的数值稳定性 cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6) angle_input = cosine_ij.unsqueeze(-1) / (torch.pi**0.5) + # ============================================================================= + # 7. 处理索引映射和动态选择 + # ============================================================================= if not parallel_mode and self.use_loc_mapping: assert mapping is not None - # convert nlist from nall to nloc index + # 将邻居列表从nall索引转换为nloc索引 nlist = torch.gather( mapping, 1, index=nlist.reshape(nframes, -1), ).reshape(nlist.shape) + if self.use_dynamic_sel: - # get graph index + # 获取图索引:用于动态邻居选择 edge_index, angle_index = get_graph_index( nlist, nlist_mask, @@ -522,40 +606,56 @@ def forward( nall, use_loc_mapping=self.use_loc_mapping, ) - # flat all the tensors - # n_edge x 1 + + # 扁平化所有张量以适应动态选择 + # 形状: n_edge x 1 edge_input = edge_input[nlist_mask] - # n_edge x 3 + # 形状: n_edge x 3 h2 = h2[nlist_mask] - # n_edge x 1 + # 形状: n_edge x 1 sw = sw[nlist_mask] - # nb x nloc x a_nnei x a_nnei + + # 计算角度掩码:两个邻居都有效 + # 形状: nb x nloc x a_nnei x a_nnei a_nlist_mask = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] - # n_angle x 1 + # 形状: n_angle x 1 angle_input = angle_input[a_nlist_mask] - # n_angle x 1 + # 形状: n_angle x 1 a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] else: - # avoid jit assertion + # 避免jit断言错误 edge_index = angle_index = torch.zeros( [1, 3], device=nlist.device, dtype=nlist.dtype ) - # get edge and angle embedding - # nb x nloc x nnei x e_dim [OR] n_edge x e_dim + + # ============================================================================= + # 8. 计算边和角度嵌入 + # ============================================================================= + # 计算边嵌入 + # 形状: nb x nloc x nnei x e_dim [OR] n_edge x e_dim if not self.edge_init_use_dist: - edge_ebd = self.act(self.edge_embd(edge_input)) + edge_ebd = self.act(self.edge_embd(edge_input)) # 应用激活函数 else: - edge_ebd = self.edge_embd(edge_input) - # nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim + edge_ebd = self.edge_embd(edge_input) # 直接使用距离,不应用激活函数 + + # 计算角度嵌入 + # 形状: nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim angle_ebd = self.angle_embd(angle_input) - # nb x nall x n_dim + # ============================================================================= + # 9. 通过多层RepFlow层进行消息传递 + # ============================================================================= + # 准备映射张量(非并行模式) if not parallel_mode: assert mapping is not None mapping = ( mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim) ) + + # 遍历所有RepFlow层进行消息传递 for idx, ll in enumerate(self.layers): + # n_prev, e_prev, a_prev = node_ebd, edge_ebd, angle_ebd + # 准备扩展节点嵌入 # node_ebd: nb x nloc x n_dim # node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parallel_mode if not parallel_mode: @@ -564,11 +664,13 @@ def forward( torch.gather(node_ebd, 1, mapping) if not self.use_loc_mapping else node_ebd - ) + ) # node_ebd → 派生 node_ebd_ext → RepFlowLayer → 得到新 node_ebd → 再派生下一层的 node_ebd_ext → … else: + # 并行模式:处理通信和自旋 assert comm_dict is not None has_spin = "has_spin" in comm_dict if not has_spin: + # 无自旋:简单填充 n_padding = nall - nloc node_ebd = torch.nn.functional.pad( node_ebd.squeeze(0), (0, 0, 0, n_padding), value=0.0 @@ -576,26 +678,30 @@ def forward( real_nloc = nloc real_nall = nall else: - # for spin + # 有自旋:处理实部和虚部 real_nloc = nloc // 2 real_nall = nall // 2 real_n_padding = real_nall - real_nloc node_ebd_real, node_ebd_virtual = torch.split( node_ebd, [real_nloc, real_nloc], dim=1 ) - # mix_node_ebd: nb x real_nloc x (n_dim * 2) + # 混合节点嵌入:拼接实部和虚部 + # 形状: nb x real_nloc x (n_dim * 2) mix_node_ebd = torch.cat([node_ebd_real, node_ebd_virtual], dim=2) - # nb x real_nall x (n_dim * 2) + # 形状: nb x real_nall x (n_dim * 2) node_ebd = torch.nn.functional.pad( mix_node_ebd.squeeze(0), (0, 0, 0, real_n_padding), value=0.0 ) + # 检查并行通信所需的字典键 assert "send_list" in comm_dict assert "send_proc" in comm_dict assert "recv_proc" in comm_dict assert "send_num" in comm_dict assert "recv_num" in comm_dict assert "communicator" in comm_dict + + # 执行并行通信操作 ret = torch.ops.deepmd.border_op( comm_dict["send_list"], comm_dict["send_proc"], @@ -608,14 +714,16 @@ def forward( real_nloc, dtype=torch.int32, device=torch.device("cpu"), - ), # should be int of c++, placed on cpu + ), # 应该是c++的int,放在cpu上 torch.tensor( real_nall - real_nloc, dtype=torch.int32, device=torch.device("cpu"), - ), # should be int of c++, placed on cpu + ), # 应该是c++的int,放在cpu上 ) node_ebd_ext = ret[0].unsqueeze(0) + + # 如果有自旋,分离实部和虚部 if has_spin: node_ebd_real_ext, node_ebd_virtual_ext = torch.split( node_ebd_ext, [n_dim, n_dim], dim=2 @@ -623,22 +731,34 @@ def forward( node_ebd_ext = concat_switch_virtual( node_ebd_real_ext, node_ebd_virtual_ext, real_nloc ) + + # 调用RepFlow层的前向传播函数 --- 这里输出了最终的node_ebd, edge_ebd, angle_ebd --- from repflow_layer.py node_ebd, edge_ebd, angle_ebd = ll.forward( - node_ebd_ext, - edge_ebd, - h2, - angle_ebd, - nlist, - nlist_mask, - sw, - a_nlist, - a_nlist_mask, - a_sw, - edge_index=edge_index, - angle_index=angle_index, - ) - - # nb x nloc x 3 x e_dim + node_ebd_ext, # node 嵌入 + edge_ebd, # edge 嵌入,距离embd + h2, # 旋转等变表征, dmatrix的后面三维 + angle_ebd, # 角度嵌入 + nlist, # 邻居列表 + nlist_mask, # 邻居列表掩码 + sw, # 开关函数 + a_nlist, # 角度邻居列表 + a_nlist_mask, # 角度邻居列表掩码 + a_sw, # 角度开关函数 + edge_index=edge_index, # 边索引 + angle_index=angle_index, # 角度索引 + ) # 返回:node_ebd, edge_ebd, angle_ebd + ''' + if self.use_inter_layer_res: + node_ebd = n_prev + self.alpha_n[idx] * (node_ebd - n_prev) + edge_ebd = e_prev + self.alpha_e[idx] * (edge_ebd - e_prev) + if self.update_angle: + angle_ebd = a_prev + self.alpha_a[idx] * (angle_ebd - a_prev) + ''' + # ============================================================================= + # 10. 计算最终的旋转矩阵 + # ============================================================================= + # 计算转置旋转矩阵:用于生成旋转等变的几何信息 + # 形状: nb x nloc x 3 x e_dim h2g2 = ( RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) if not self.use_dynamic_sel @@ -653,11 +773,15 @@ def forward( scale_factor=(self.nnei / self.sel_reduce_factor) ** (-0.5), ) ) - # (nb x nloc) x e_dim x 3 + + # 转置旋转矩阵:从 (nb x nloc) x 3 x e_dim 到 (nb x nloc) x e_dim x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) + # ============================================================================= + # 11. 返回最终结果 + # ============================================================================= return node_ebd, edge_ebd, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw - +# def compute_input_stats( self, merged: Union[Callable[[], list[dict]], list[dict]], diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 8d451f087f..df1854a417 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -83,9 +83,9 @@ def _get_standard_model_components(model_params, ntypes): # descriptor model_params["descriptor"]["ntypes"] = ntypes model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) - descriptor = BaseDescriptor(**model_params["descriptor"]) + descriptor = BaseDescriptor(**model_params["descriptor"]) # here is descriptor_dpa3 return node_ebd 等 # fitting - fitting_net = model_params.get("fitting_net", {}) + fitting_net = model_params.get("fitting_net", {}) # read it directly from config fitting_net["type"] = fitting_net.get("type", "ener") fitting_net["ntypes"] = descriptor.get_ntypes() fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) @@ -98,8 +98,8 @@ def _get_standard_model_components(model_params, ntypes): fitting_net["out_dim"] = descriptor.get_dim_emb() if "ener" in fitting_net["type"]: fitting_net["return_energy"] = True - fitting = BaseFitting(**fitting_net) - return descriptor, fitting, fitting_net["type"] + fitting = BaseFitting(**fitting_net) # here is fitting_dpa3 + return descriptor, fitting, fitting_net["type"] # return the descriptor, fitting, fitting_net["type"] def get_spin_model(model_params): diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 22675d6163..704f804410 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -1,4 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +""" +多层感知机(MLP)网络模块 + +本模块实现了DeePMD-kit中使用的各种神经网络层和网络集合: +1. MLPLayer: 单个MLP层,支持激活函数、残差连接、时间步等 +2. MLP: 多层MLP网络 +3. EmbeddingNet: 嵌入网络,用于原子类型嵌入 +4. FittingNet: 拟合网络,用于将描述符映射为原子性质 +5. NetworkCollection: 网络集合,管理多个网络实例 + +这些网络是DPA3模型中RepFlow层和Fitting层的基础组件。 +""" from typing import ( ClassVar, Optional, @@ -44,10 +56,16 @@ def empty_t(shape, precision): + """创建指定形状和精度的空张量""" return torch.empty(shape, dtype=precision, device=device) class Identity(nn.Module): + """恒等映射层 + + 这是一个简单的恒等映射层,输入什么就输出什么。 + 主要用于网络结构中的占位符或跳过连接。 + """ def __init__(self) -> None: super().__init__() @@ -55,10 +73,11 @@ def forward( self, xx: torch.Tensor, ) -> torch.Tensor: - """The Identity operation layer.""" + """恒等映射操作:直接返回输入""" return xx def serialize(self) -> dict: + """序列化层参数""" return { "@class": "Identity", "@version": 1, @@ -66,73 +85,113 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "Identity": + """从序列化数据反序列化层""" return Identity() class MLPLayer(nn.Module): + """单个多层感知机(MLP)层 + + 这是DeePMD-kit中最基础的神经网络层,实现了: + 1. 线性变换: y = Wx + b + 2. 激活函数: y = activation(y) + 3. 残差连接: y = y + x (当resnet=True时) + 4. 时间步缩放: y = y * idt (当use_timestep=True时) + + 在DPA3模型中,这些层被用于: + - RepFlow层中的消息传递网络 + - Fitting层中的能量/力预测网络 + - 类型嵌入网络 + """ def __init__( self, - num_in, - num_out, - bias: bool = True, - use_timestep: bool = False, - activation_function: Optional[str] = None, - resnet: bool = False, - bavg: float = 0.0, - stddev: float = 1.0, - precision: str = DEFAULT_PRECISION, - init: str = "default", - seed: Optional[Union[int, list[int]]] = None, + num_in, # 输入维度 + num_out, # 输出维度 + bias: bool = True, # 是否使用偏置 + use_timestep: bool = False, # 是否使用时间步缩放 + activation_function: Optional[str] = None, # 激活函数名称 + resnet: bool = False, # 是否使用残差连接 + bavg: float = 0.0, # 偏置初始化均值 + stddev: float = 1.0, # 权重初始化标准差 + precision: str = DEFAULT_PRECISION, # 数值精度 + init: str = "default", # 初始化方法 + seed: Optional[Union[int, list[int]]] = None, # 随机种子 ) -> None: super().__init__() - # only use_timestep when skip connection is established. + + # 时间步缩放:只有在残差连接建立时才使用 + # 要求输出维度等于输入维度或输入维度的2倍 self.use_timestep = use_timestep and ( num_out == num_in or num_out == num_in * 2 ) + + # 基本参数 self.num_in = num_in self.num_out = num_out self.activate_name = activation_function - self.activate = ActivationFn(self.activate_name) + self.activate = ActivationFn(self.activate_name) # 激活函数对象 self.precision = precision self.prec = PRECISION_DICT[self.precision] + + # 权重矩阵: [num_in, num_out] self.matrix = nn.Parameter(data=empty_t((num_in, num_out), self.prec)) + + # 随机数生成器 random_generator = get_generator(seed) + + # 偏置参数: [num_out] if bias: self.bias = nn.Parameter( data=empty_t([num_out], self.prec), ) else: self.bias = None + + # 时间步参数: [num_out] (用于ResNet中的时间步缩放) if self.use_timestep: - self.idt = nn.Parameter(data=empty_t([num_out], self.prec)) + self.idt = nn.Parameter(data=empty_t([num_out], self.prec)) # 定义一个可学习的参数 else: self.idt = None + self.resnet = resnet + + # ============================================================================= + # 参数初始化 + # ============================================================================= if init == "default": + # 默认正态分布初始化 self._default_normal_init( bavg=bavg, stddev=stddev, generator=random_generator ) elif init == "trunc_normal": + # 截断正态分布初始化 self._trunc_normal_init(1.0, generator=random_generator) elif init == "relu": + # ReLU激活函数的截断正态分布初始化 self._trunc_normal_init(2.0, generator=random_generator) elif init == "glorot": + # Glorot均匀分布初始化 self._glorot_uniform_init(generator=random_generator) elif init == "gating": + # 门控机制的零初始化 self._zero_init(self.use_bias) elif init == "kaiming_normal": + # Kaiming正态分布初始化 self._normal_init(generator=random_generator) elif init == "final": + # 最终层的零初始化 self._zero_init(False) else: raise ValueError(f"Unknown initialization method: {init}") def check_type_consistency(self) -> None: + """检查所有参数的数据类型一致性""" precision = self.precision def check_var(var) -> None: if var is not None: - # assertion "float64" == "double" would fail + # 检查参数的数据类型是否与指定的精度一致 + # 注意:断言 "float64" == "double" 会失败,所以使用PRECISION_DICT assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision] check_var(self.matrix) @@ -140,9 +199,11 @@ def check_var(var) -> None: check_var(self.idt) def dim_in(self) -> int: + """返回输入维度""" return self.matrix.shape[0] def dim_out(self) -> int: + """返回输出维度""" return self.matrix.shape[1] def _default_normal_init( @@ -151,6 +212,12 @@ def _default_normal_init( stddev: float = 1.0, generator: Optional[torch.Generator] = None, ) -> None: + """默认正态分布初始化 + + 权重矩阵使用Xavier初始化:std = stddev / sqrt(fan_in + fan_out) + 偏置使用正态分布:mean=bavg, std=stddev + 时间步参数使用小方差正态分布:mean=0.1, std=0.001 + """ normal_( self.matrix.data, std=stddev / np.sqrt(self.num_out + self.num_in), @@ -164,7 +231,12 @@ def _default_normal_init( def _trunc_normal_init( self, scale=1.0, generator: Optional[torch.Generator] = None ) -> None: - # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + """截断正态分布初始化 + + 使用截断正态分布初始化权重矩阵,有助于避免梯度爆炸问题。 + 常数来自scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + """ + # 截断正态分布的标准差因子(来自scipy.stats.truncnorm) TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 _, fan_in = self.matrix.shape scale = scale / max(1, fan_in) @@ -172,9 +244,15 @@ def _trunc_normal_init( trunc_normal_(self.matrix, mean=0.0, std=std, generator=generator) def _glorot_uniform_init(self, generator: Optional[torch.Generator] = None) -> None: + """Glorot均匀分布初始化(Xavier均匀分布)""" xavier_uniform_(self.matrix, gain=1, generator=generator) def _zero_init(self, use_bias=True) -> None: + """零初始化 + + 权重矩阵初始化为0,偏置初始化为1(如果use_bias=True) + 用于门控机制或最终层 + """ with torch.no_grad(): self.matrix.fill_(0.0) if use_bias and self.bias is not None: @@ -182,126 +260,201 @@ def _zero_init(self, use_bias=True) -> None: self.bias.fill_(1.0) def _normal_init(self, generator: Optional[torch.Generator] = None) -> None: + """Kaiming正态分布初始化 + + 适用于线性激活函数,有助于保持前向传播时的方差 + """ kaiming_normal_(self.matrix, nonlinearity="linear", generator=generator) def forward( self, xx: torch.Tensor, ) -> torch.Tensor: - """One MLP layer used by DP model. + """MLP层的前向传播 + + 实现以下计算流程: + 1. 线性变换: y = W^T * x + b + 2. 激活函数: y = activation(y) + 3. 时间步缩放: y = y * idt (如果启用) + 4. 残差连接: y = y + x (如果启用) Parameters ---------- xx : torch.Tensor - The input. + 输入张量,形状为 [..., num_in] Returns ------- yy: torch.Tensor - The output. + 输出张量,形状为 [..., num_out] """ + # 保存原始精度,用于后续恢复 ori_prec = xx.dtype + + # 精度转换(如果允许) if not env.DP_DTYPE_PROMOTION_STRICT: xx = xx.to(self.prec) + + # 1. 线性变换: y = W^T * x + b + # 注意:这里使用matrix.t()进行转置,因为PyTorch的linear函数期望权重为[out_features, in_features] yy = F.linear(xx, self.matrix.t(), self.bias) + + # 2. 激活函数: y = activation(y) yy = self.activate(yy) + + # 3. 时间步缩放: y = y * idt (用于ResNet中的时间步控制) yy = yy * self.idt if self.idt is not None else yy + + # 4. 残差连接: y = y + x (ResNet跳过连接) if self.resnet: if xx.shape[-1] == yy.shape[-1]: + # 维度匹配:直接相加 yy = yy + xx elif 2 * xx.shape[-1] == yy.shape[-1]: + # 输出维度是输入维度的2倍:将输入重复后相加 yy = yy + torch.concat([xx, xx], dim=-1) else: + # 维度不匹配:不进行残差连接 yy = yy + + # 恢复原始精度 if not env.DP_DTYPE_PROMOTION_STRICT: yy = yy.to(ori_prec) return yy def serialize(self) -> dict: - """Serialize the layer to a dict. + """序列化层参数到字典 + + 将MLP层的所有参数(权重、偏置、时间步参数)序列化为字典格式, + 用于模型保存和加载。 Returns ------- dict - The serialized layer. + 序列化后的层参数字典 """ + # 创建NativeLayer对象,包含层的基本信息 nl = NativeLayer( - self.matrix.shape[0], - self.matrix.shape[1], - bias=self.bias is not None, - use_timestep=self.idt is not None, - activation_function=self.activate_name, - resnet=self.resnet, - precision=self.precision, + self.matrix.shape[0], # 输入维度 + self.matrix.shape[1], # 输出维度 + bias=self.bias is not None, # 是否有偏置 + use_timestep=self.idt is not None, # 是否使用时间步 + activation_function=self.activate_name, # 激活函数名称 + resnet=self.resnet, # 是否使用残差连接 + precision=self.precision, # 数值精度 ) + + # 将PyTorch张量转换为numpy数组并赋值 nl.w, nl.b, nl.idt = ( - to_numpy_array(self.matrix), - to_numpy_array(self.bias), - to_numpy_array(self.idt), + to_numpy_array(self.matrix), # 权重矩阵 + to_numpy_array(self.bias), # 偏置向量 + to_numpy_array(self.idt), # 时间步参数 ) return nl.serialize() @classmethod def deserialize(cls, data: dict) -> "MLPLayer": - """Deserialize the layer from a dict. + """从字典反序列化层参数 + + 从序列化的字典中恢复MLP层的所有参数。 Parameters ---------- data : dict - The dict to deserialize from. + 包含序列化层参数的字典 + + Returns + ------- + MLPLayer + 恢复的MLP层实例 """ + # 从字典反序列化NativeLayer nl = NativeLayer.deserialize(data) + + # 创建MLPLayer实例 obj = cls( - nl["matrix"].shape[0], - nl["matrix"].shape[1], - bias=nl["bias"] is not None, - use_timestep=nl["idt"] is not None, - activation_function=nl["activation_function"], - resnet=nl["resnet"], - precision=nl["precision"], + nl["matrix"].shape[0], # 输入维度 + nl["matrix"].shape[1], # 输出维度 + bias=nl["bias"] is not None, # 是否有偏置 + use_timestep=nl["idt"] is not None, # 是否使用时间步 + activation_function=nl["activation_function"], # 激活函数 + resnet=nl["resnet"], # 是否使用残差连接 + precision=nl["precision"], # 数值精度 ) + + # 获取精度类型 prec = PRECISION_DICT[obj.precision] def check_load_param(ss): + """检查并加载参数""" return ( nn.Parameter(data=to_torch_tensor(nl[ss])) if nl[ss] is not None else None ) - obj.matrix = check_load_param("matrix") - obj.bias = check_load_param("bias") - obj.idt = check_load_param("idt") + # 加载所有参数 + obj.matrix = check_load_param("matrix") # 权重矩阵 + obj.bias = check_load_param("bias") # 偏置向量 + obj.idt = check_load_param("idt") # 时间步参数 return obj +# ============================================================================= +# 多层网络定义 +# ============================================================================= + +# 使用make_multilayer_network创建多层MLP网络 MLP_ = make_multilayer_network(MLPLayer, nn.Module) class MLP(MLP_): + """多层感知机(MLP)网络 + + 由多个MLPLayer组成的深度神经网络,用于复杂的非线性映射。 + 在DPA3模型中用于: + - RepFlow层中的消息传递网络 + - Fitting层中的能量/力预测网络 + """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + # 将层列表转换为PyTorch的ModuleList,确保参数注册 self.layers = torch.nn.ModuleList(self.layers) - forward = MLP_.call + forward = MLP_.call # 使用父类的前向传播方法 + +# ============================================================================= +# 专用网络类型定义 +# ============================================================================= +# 嵌入网络:用于原子类型嵌入,将原子类型ID映射为高维向量 EmbeddingNet = make_embedding_network(MLP, MLPLayer) +# 拟合网络:用于将描述符映射为原子性质(能量、力等) FittingNet = make_fitting_network(EmbeddingNet, MLP, MLPLayer) class NetworkCollection(DPNetworkCollection, nn.Module): - """PyTorch implementation of NetworkCollection.""" - + """网络集合类 + + PyTorch实现的网络集合,用于管理多个网络实例。 + 在DPA3模型中用于: + - 管理不同原子类型的拟合网络 + - 管理RepFlow层中的多个消息传递网络 + - 支持混合类型和分离类型的网络配置 + """ + + # 网络类型映射字典 NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { - "network": MLP, - "embedding_network": EmbeddingNet, - "fitting_network": FittingNet, + "network": MLP, # 通用MLP网络 + "embedding_network": EmbeddingNet, # 嵌入网络 + "fitting_network": FittingNet, # 拟合网络 } def __init__(self, *args, **kwargs) -> None: - # init both two base classes + # 初始化两个基类 DPNetworkCollection.__init__(self, *args, **kwargs) nn.Module.__init__(self) + # 将网络列表转换为PyTorch的ModuleList,确保参数注册 self.networks = self._networks = torch.nn.ModuleList(self._networks) diff --git a/deepmd/pt/model/network/utils.py b/deepmd/pt/model/network/utils.py index 2047efec2b..53b9befe70 100644 --- a/deepmd/pt/model/network/utils.py +++ b/deepmd/pt/model/network/utils.py @@ -14,26 +14,39 @@ def aggregate( num_owner: Optional[int] = None, ) -> torch.Tensor: """ - Aggregate rows in data by specifying the owners. + 根据所有者索引聚合数据行 + + 在DPA3的RepFlow层中,这个函数用于将边或角度的特征聚合到对应的节点上。 + 例如:将多条边的特征聚合到中心原子节点上。 Parameters ---------- - data : data tensor to aggregate [n_row, feature_dim] - owners : specify the owner of each row [n_row, 1] - average : if True, average the rows, if False, sum the rows. - Default = True - num_owner : the number of owners, this is needed if the - max idx of owner is not presented in owners tensor - Default = None + data : torch.Tensor + 要聚合的数据张量 [n_row, feature_dim] + 例如:边特征 [n_edge, e_dim] 或角度特征 [n_angle, a_dim] + owners : torch.Tensor + 指定每行数据的所有者索引 [n_row] + 例如:边特征中每个边对应的中心原子索引 + average : bool, optional + 如果为True,对行进行平均;如果为False,对行进行求和 + 默认 = True + num_owner : Optional[int], optional + 所有者的数量,当owners张量中不包含最大索引时需要指定 + 默认 = None Returns ------- - output: [num_owner, feature_dim] + torch.Tensor + 聚合后的输出 [num_owner, feature_dim] + 例如:聚合后的节点特征 [n_atoms, feature_dim] """ + # 计算每个所有者的数据行数(用于平均化) if num_owner is None or average: - # requires bincount + # 使用bincount统计每个所有者拥有的数据行数 bin_count = torch.bincount(owners) + # 避免除零错误:将0替换为1 bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1)) + # 如果指定的num_owner与bin_count长度不匹配,进行填充 if (num_owner is not None) and (bin_count.shape[0] != num_owner): difference = num_owner - bin_count.shape[0] bin_count = torch.cat([bin_count, bin_count.new_ones(difference)]) @@ -42,8 +55,11 @@ def aggregate( else: bin_count = None + # 初始化输出张量 output = data.new_zeros([num_owner, data.shape[1]]) + # 使用index_add_将数据按所有者索引累加 output = output.index_add_(0, owners, data) + # 如果需要平均化,除以每个所有者的数据行数 if average: assert bin_count is not None output = (output.T / bin_count).T @@ -59,85 +75,116 @@ def get_graph_index( use_loc_mapping: bool = True, ): """ - Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`. + 获取边图和角度图的索引映射,用于`aggregate`或`index_select`操作 + + 在DPA3的动态选择模式下,这个函数构建了图神经网络所需的索引映射: + 1. 边图索引:连接中心原子和邻居原子的边 + 2. 角度图索引:连接中心原子和两个邻居原子形成的角度 + + 这些索引用于高效的消息传递和特征聚合操作。 Parameters ---------- - nlist : nf x nloc x nnei - Neighbor list. (padded neis are set to 0) - nlist_mask : nf x nloc x nnei - Masks of the neighbor list. real nei 1 otherwise 0 - a_nlist_mask : nf x nloc x a_nnei - Masks of the neighbor list for angle. real nei 1 otherwise 0 - nall - The number of extended atoms. + nlist : torch.Tensor + 邻居列表 [nf, nloc, nnei] + 填充的邻居设置为0 + nlist_mask : torch.Tensor + 邻居列表的掩码 [nf, nloc, nnei] + 真实邻居为1,否则为0 + a_nlist_mask : torch.Tensor + 角度邻居列表的掩码 [nf, nloc, a_nnei] + 用于角度计算的真实邻居为1,否则为0 + nall : int + 扩展原子的总数 + use_loc_mapping : bool, optional + 是否使用局部索引映射,默认 = True Returns ------- - edge_index : n_edge x 2 - n2e_index : n_edge - Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). - n_ext2e_index : n_edge - Broadcast indices from extended node(j) to edge(ij). - angle_index : n_angle x 3 - n2a_index : n_angle - Broadcast indices from extended node(j) to angle(ijk). - eij2a_index : n_angle - Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). - eik2a_index : n_angle - Broadcast indices from extended edge(ik) to angle(ijk). + edge_index : torch.Tensor + 边图索引 [n_edge, 2] + - n2e_index: 从节点(i)到边(ij)的广播索引,或从边(ij)到节点(i)的归约索引 + - n_ext2e_index: 从扩展节点(j)到边(ij)的广播索引 + angle_index : torch.Tensor + 角度图索引 [n_angle, 3] + - n2a_index: 从扩展节点(j)到角度(ijk)的广播索引 + - eij2a_index: 从扩展边(ij)到角度(ijk)的广播索引,或从角度(ijk)到边(ij)的归约索引 + - eik2a_index: 从扩展边(ik)到角度(ijk)的广播索引 """ + # 获取张量维度信息 nf, nloc, nnei = nlist.shape _, _, a_nnei = a_nlist_mask.shape - # nf x nloc x nnei x nnei - # nlist_mask_3d = nlist_mask[:, :, :, None] & nlist_mask[:, :, None, :] + + # 构建角度掩码:a_nnei x a_nnei 的3D掩码,用于角度计算 + # 只有两个邻居都存在时,才形成有效的角度 a_nlist_mask_3d = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] + + # 计算有效边和角度的数量 n_edge = nlist_mask.sum().item() - # n_angle = a_nlist_mask_3d.sum().item() - - # following: get n2e_index, n_ext2e_index, n2a_index, eij2a_index, eik2a_index - - # 1. atom graph - # node(i) to edge(ij) index_select; edge(ij) to node aggregate + # n_angle = a_nlist_mask_3d.sum().item() # 注释掉,因为后面会重新计算 + + # ============================================================================= + # 1. 构建边图索引 (atom graph) + # ============================================================================= + + # 1.1 节点(i)到边(ij)的索引映射 + # 创建局部原子索引:每个帧的每个局部原子都有唯一索引 nlist_loc_index = torch.arange(0, nf * nloc, dtype=nlist.dtype, device=nlist.device) - # nf x nloc x nnei + # 扩展为 [nf, nloc, nnei] 形状,每个邻居都对应同一个中心原子 n2e_index = nlist_loc_index.reshape(nf, nloc, 1).expand(-1, -1, nnei) - # n_edge - n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0] + # 只保留真实邻居对应的索引 + n2e_index = n2e_index[nlist_mask] # 形状: [n_edge] - # node_ext(j) to edge(ij) index_select + # 1.2 扩展节点(j)到边(ij)的索引映射 + # 计算帧偏移量:每帧的原子索引需要加上帧偏移 frame_shift = torch.arange(0, nf, dtype=nlist.dtype, device=nlist.device) * ( nall if not use_loc_mapping else nloc ) + # 将邻居列表转换为全局索引 shifted_nlist = nlist + frame_shift[:, None, None] - # n_edge - n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1] - - # 2. edge graph - # node(i) to angle(ijk) index_select + # 只保留真实邻居对应的索引 + n_ext2e_index = shifted_nlist[nlist_mask] # 形状: [n_edge] + + # ============================================================================= + # 2. 构建角度图索引 (angle graph) + # ============================================================================= + + # 2.1 节点(i)到角度(ijk)的索引映射 + # 扩展为 [nf, nloc, a_nnei, a_nnei] 形状 n2a_index = nlist_loc_index.reshape(nf, nloc, 1, 1).expand(-1, -1, a_nnei, a_nnei) - # n_angle - n2a_index = n2a_index[a_nlist_mask_3d] + # 只保留有效角度对应的索引 + n2a_index = n2a_index[a_nlist_mask_3d] # 形状: [n_angle] - # edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate + # 2.2 边(ij)到角度(ijk)的索引映射 + # 为每条边分配唯一ID edge_id = torch.arange(0, n_edge, dtype=nlist.dtype, device=nlist.device) - # nf x nloc x nnei + # 创建边索引张量,形状与nlist相同 edge_index = torch.zeros([nf, nloc, nnei], dtype=nlist.dtype, device=nlist.device) edge_index[nlist_mask] = edge_id - # only cut a_nnei neighbors, to avoid nnei x nnei + # 只取前a_nnei个邻居,避免nnei x nnei的复杂度 edge_index = edge_index[:, :, :a_nnei] + + # 2.3 边(ij)到角度(ijk)的索引:j边 edge_index_ij = edge_index.unsqueeze(-1).expand(-1, -1, -1, a_nnei) - # n_angle - eij2a_index = edge_index_ij[a_nlist_mask_3d] - - # edge(ik) to angle(ijk) index_select + eij2a_index = edge_index_ij[a_nlist_mask_3d] # 形状: [n_angle] + + # 2.4 边(ik)到角度(ijk)的索引:k边 edge_index_ik = edge_index.unsqueeze(-2).expand(-1, -1, a_nnei, -1) - # n_angle - eik2a_index = edge_index_ik[a_nlist_mask_3d] - - return torch.cat( + eik2a_index = edge_index_ik[a_nlist_mask_3d] # 形状: [n_angle] + + # ============================================================================= + # 3. 返回索引张量 + # ============================================================================= + + # 边图索引:[n_edge, 2] - [n2e_index, n_ext2e_index] + edge_index = torch.cat( [n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], dim=-1 - ), torch.cat( + ) + + # 角度图索引:[n_angle, 3] - [n2a_index, eij2a_index, eik2a_index] + angle_index = torch.cat( [n2a_index.unsqueeze(-1), eij2a_index.unsqueeze(-1), eik2a_index.unsqueeze(-1)], dim=-1, ) + + return edge_index, angle_index diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 7a6ff0ebde..a3e9584fa3 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -31,6 +31,7 @@ EnergyHessianStdLoss, EnergySpinLoss, EnergyStdLoss, + EnergyStdLossMAD, # new added PropertyLoss, TaskLoss, TensorLoss, @@ -1178,7 +1179,12 @@ def print_on_training( if valid_results: prop_fmt = " %11.2e %11.2e" for k in train_keys: - print_str += prop_fmt % (valid_results[k], train_results[k]) + if k in valid_results: # 只打印验证结果中也存在的键 + print_str += prop_fmt % (valid_results[k], train_results[k]) + else: + # 如果验证结果中没有该键,只打印训练结果 + prop_fmt_single = " %11.2e %11s" + print_str += prop_fmt_single % (train_results[k], "N/A") else: prop_fmt = " %11.2e" for k in train_keys: @@ -1241,6 +1247,9 @@ def get_loss(loss_params, start_lr, _ntypes, _model): elif loss_type == "ener": loss_params["starter_learning_rate"] = start_lr return EnergyStdLoss(**loss_params) + elif loss_type == "ener_mad": + loss_params["starter_learning_rate"] = start_lr + return EnergyStdLossMAD(**loss_params) elif loss_type == "dos": loss_params["starter_learning_rate"] = start_lr loss_params["numb_dos"] = _model.model_output_def()["dos"].output_size diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index fb911550dd..e782170d4f 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1355,7 +1355,7 @@ def dpa2_repformer_args(): ), ] - +# modified it here @descrpt_args_plugin.register("dpa3", doc=doc_only_pt_supported) def descrpt_dpa3_args(): # repflow args @@ -1376,6 +1376,8 @@ def descrpt_dpa3_args(): "Whether to use local atom index mapping in training or non-parallel inference. " "When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation." ) + doc_enable_mad = "Whether to enable MADGap computation. Set to True to compute MADGap values for regularization use." + doc_mad_cutoff_ratio = "The ratio to distinguish neighbor and remote nodes for MADGap calculation." return [ # doc_repflow args Argument("repflow", dict, dpa3_repflow_args(), doc=doc_repflow), @@ -1432,6 +1434,21 @@ def descrpt_dpa3_args(): default=True, doc=doc_use_loc_mapping, ), + # MADGap计算参数 + Argument( + "enable_mad", + bool, + optional=True, + default=False, + doc=doc_enable_mad, + ), + Argument( + "mad_cutoff_ratio", + float, + optional=True, + default=0.5, + doc=doc_mad_cutoff_ratio, + ), ] @@ -2653,6 +2670,174 @@ def loss_ener(): ] +@loss_args_plugin.register("ener_mad") +def loss_ener_mad(): + doc_start_pref_e = start_pref("energy", abbr="e") + doc_limit_pref_e = limit_pref("energy") + doc_start_pref_f = start_pref("force", abbr="f") + doc_limit_pref_f = limit_pref("force") + doc_start_pref_v = start_pref("virial", abbr="v") + doc_limit_pref_v = limit_pref("virial") + doc_start_pref_h = start_pref("hessian", abbr="h") # prefactor of hessian + doc_limit_pref_h = limit_pref("hessian") + doc_start_pref_ae = start_pref("atomic energy", label="atom_ener", abbr="ae") + doc_limit_pref_ae = limit_pref("atomic energy") + doc_start_pref_pf = start_pref( + "atomic prefactor force", label="atom_pref", abbr="pf" + ) + doc_limit_pref_pf = limit_pref("atomic prefactor force") + doc_start_pref_gf = start_pref("generalized force", label="drdq", abbr="gf") + doc_limit_pref_gf = limit_pref("generalized force") + doc_numb_generalized_coord = "The dimension of generalized coordinates. Required when generalized force loss is used." + doc_relative_f = "If provided, relative force error will be used in the loss. The difference of force will be normalized by the magnitude of the force in the label with a shift given by `relative_f`, i.e. DF_i / ( || F || + relative_f ) with DF denoting the difference between prediction and label and || F || denoting the L2 norm of the label." + doc_enable_atom_ener_coeff = "If true, the energy will be computed as \\sum_i c_i E_i. c_i should be provided by file atom_ener_coeff.npy in each data system, otherwise it's 1." + doc_use_huber = ( + "Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D). " + "The loss function smoothly transitions between L2 and L1 loss: \n\n" + "- For absolute prediction errors within D: quadratic loss 0.5 * (error**2) \n\n" + "- For absolute errors exceeding D: linear loss D * (\\|error\\| - 0.5 * D) \n\n" + "Formula: loss = 0.5 * (error**2) if \\|error\\| <= D else D * (\\|error\\| - 0.5 * D). " + ) + doc_huber_delta = "The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss." + doc_mad_reg_coeff = "The coefficient for MADGap regularization. Set to 0.0 to disable MADGap. MADGap is a regularization method that helps improve the representation learning by encouraging the model to distinguish between neighbor and remote atoms in the embedding space." + return [ + Argument( + "start_pref_e", + [float, int], + optional=True, + default=0.02, + doc=doc_start_pref_e, + ), + Argument( + "limit_pref_e", + [float, int], + optional=True, + default=1.00, + doc=doc_limit_pref_e, + ), + Argument( + "start_pref_f", + [float, int], + optional=True, + default=1000, + doc=doc_start_pref_f, + ), + Argument( + "limit_pref_f", + [float, int], + optional=True, + default=1.00, + doc=doc_limit_pref_f, + ), + Argument( + "start_pref_v", + [float, int], + optional=True, + default=0.00, + doc=doc_start_pref_v, + ), + Argument( + "limit_pref_v", + [float, int], + optional=True, + default=0.00, + doc=doc_limit_pref_v, + ), + Argument( + "start_pref_h", + [float, int], + optional=True, + default=0.00, + doc=doc_start_pref_h, + ), + Argument( + "limit_pref_h", + [float, int], + optional=True, + default=0.00, + doc=doc_limit_pref_h, + ), + Argument( + "start_pref_ae", + [float, int], + optional=True, + default=0.00, + doc=doc_start_pref_ae, + ), + Argument( + "limit_pref_ae", + [float, int], + optional=True, + default=0.00, + doc=doc_limit_pref_ae, + ), + Argument( + "start_pref_pf", + [float, int], + optional=True, + default=0.00, + doc=doc_start_pref_pf, + ), + Argument( + "limit_pref_pf", + [float, int], + optional=True, + default=0.00, + doc=doc_limit_pref_pf, + ), + Argument("relative_f", [float, None], optional=True, doc=doc_relative_f), + Argument( + "enable_atom_ener_coeff", + [bool], + optional=True, + default=False, + doc=doc_enable_atom_ener_coeff, + ), + Argument( + "start_pref_gf", + float, + optional=True, + default=0.0, + doc=doc_start_pref_gf, + ), + Argument( + "limit_pref_gf", + float, + optional=True, + default=0.0, + doc=doc_limit_pref_gf, + ), + Argument( + "numb_generalized_coord", + int, + optional=True, + default=0, + doc=doc_numb_generalized_coord, + ), + Argument( + "use_huber", + bool, + optional=True, + default=False, + doc=doc_use_huber, + ), + Argument( + "huber_delta", + float, + optional=True, + default=0.01, + doc=doc_huber_delta, + ), + Argument( + "mad_reg_coeff", + float, + optional=True, + default=0.0, + doc=doc_mad_reg_coeff, + ), + ] + + @loss_args_plugin.register("ener_spin") def loss_ener_spin(): doc_start_pref_e = start_pref("energy") diff --git a/examples/water/dpa3/input_torch.json b/examples/water/dpa3/input_torch.json index 90e81b5403..1a8da9304a 100644 --- a/examples/water/dpa3/input_torch.json +++ b/examples/water/dpa3/input_torch.json @@ -85,7 +85,7 @@ "batch_size": 1, "_comment": "that's all" }, - "numb_steps": 1000000, + "numb_steps": 200, "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10,