Skip to content
Closed

mad #4983

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .ener import (
EnergyHessianStdLoss,
EnergyStdLoss,
EnergyStdLossMAD,
)
from .ener_spin import (
EnergySpinLoss,
Expand All @@ -28,6 +29,7 @@
"EnergyHessianStdLoss",
"EnergySpinLoss",
"EnergyStdLoss",
"EnergyStdLossMAD",
"PropertyLoss",
"TaskLoss",
"TensorLoss",
Expand Down
152 changes: 152 additions & 0 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,155 @@ def label_requirement(self) -> list[DataRequirementItem]:
)
)
return label_requirement

# new added
class EnergyStdLossMAD(EnergyStdLoss):
def __init__(
self,
starter_learning_rate=1.0,
start_pref_e=0.0,
limit_pref_e=0.0,
start_pref_f=0.0,
limit_pref_f=0.0,
start_pref_v=0.0,
limit_pref_v=0.0,
start_pref_ae: float = 0.0,
limit_pref_ae: float = 0.0,
start_pref_pf: float = 0.0,
limit_pref_pf: float = 0.0,
relative_f: Optional[float] = None,
enable_atom_ener_coeff: bool = False,
start_pref_gf: float = 0.0,
limit_pref_gf: float = 0.0,
numb_generalized_coord: int = 0,
use_l1_all: bool = False,
inference=False,
use_huber=False,
huber_delta=0.01,
mad_reg_coeff: float = 0.0, # new added
**kwargs,
) -> None:
r"""Construct a layer to compute loss on energy, force and virial with MAD regularization.

Parameters
----------
mad_reg_coeff : float
The coefficient for MAD (Mean Average Distance) regularization. Set to 0.0 to disable MAD regularization.
**kwargs
Other keyword arguments passed to EnergyStdLoss.
"""
super().__init__(
starter_learning_rate=starter_learning_rate,
start_pref_e=start_pref_e,
limit_pref_e=limit_pref_e,
start_pref_f=start_pref_f,
limit_pref_f=limit_pref_f,
start_pref_v=start_pref_v,
limit_pref_v=limit_pref_v,
start_pref_ae=start_pref_ae,
limit_pref_ae=limit_pref_ae,
start_pref_pf=start_pref_pf,
limit_pref_pf=limit_pref_pf,
relative_f=relative_f,
enable_atom_ener_coeff=enable_atom_ener_coeff,
start_pref_gf=start_pref_gf,
limit_pref_gf=limit_pref_gf,
numb_generalized_coord=numb_generalized_coord,
use_l1_all=use_l1_all,
inference=inference,
use_huber=use_huber,
huber_delta=huber_delta,
**kwargs,
)
self.mad_reg_coeff = mad_reg_coeff

def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
"""Return loss on energy and force with MAD regularization.

Parameters
----------
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
# 调用父类方法获取基础损失
model_pred, loss, more_loss = super().forward(
input_dict, model, label, natoms, learning_rate, mae=mae
)
# 获取基础损失后,添加MAD正则化
# 添加MAD正则化项
if self.mad_reg_coeff > 0:
descriptor = None
descriptor = model.get_descriptor()
mad_value = descriptor.last_mad_gap # 虽然变量名还是mad_gap,但现在存储的是MAD值
#print("MAD value in loss:", mad_value.item() if mad_value is not None else "None")

# 设置目标MAD值 - 余弦距离为1表示正交,是比较理想的状态
target_mad = 1.0

# 方案1: 目标MAD正则化(推荐)- 鼓励MAD接近目标值
mad_reg_loss = self.mad_reg_coeff * torch.abs(mad_value - target_mad)

# 方案2: 防止over-smoothing - 只惩罚过小的MAD(可选)
# min_mad = 0.5
# mad_reg_loss = self.mad_reg_coeff * torch.relu(min_mad - mad_value)

loss += mad_reg_loss

# 总是添加MAD相关的损失信息(训练和验证时都需要)
more_loss["mad_reg_loss"] = self.display_if_exist(
mad_reg_loss.detach(), 1.0
)
more_loss["mad_value"] = self.display_if_exist(
mad_value.detach(), 1.0
)

return model_pred, loss, more_loss

def serialize(self) -> dict:
"""Serialize the loss module.

Returns
-------
dict
The serialized loss module
"""
data = super().serialize()
data.update({
"@class": "EnergyLossMAD",
"mad_reg_coeff": self.mad_reg_coeff,
})
return data

@classmethod
def deserialize(cls, data: dict) -> "TaskLoss":
"""Deserialize the loss module.

Parameters
----------
data : dict
The serialized loss module

Returns
-------
Loss
The deserialized loss module
"""
data = data.copy()
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
return cls(**data)
85 changes: 75 additions & 10 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ class DescrptDPA3(BaseDescriptor, torch.nn.Module):
When True, local indexing and mapping are applied to neighbor lists and embeddings during descriptor computation.
type_map : list[str], Optional
A list of strings. Give the name to each type of atoms.
enable_mad : bool, Optional
Whether to enable MAD (Mean Average Distance) computation. Set to True to compute MAD values for regularization use.
mad_cutoff_ratio : float, Optional
The ratio to distinguish neighbor and remote nodes for MAD calculation. (Reserved for future extensions)

References
----------
Expand All @@ -119,6 +123,8 @@ def __init__(
use_tebd_bias: bool = False,
use_loc_mapping: bool = True,
type_map: Optional[list[str]] = None,
enable_mad: bool = False, # new added
mad_cutoff_ratio: float = 0.5, # new added (保留以便后续扩展)
) -> None:
super().__init__()

Expand All @@ -134,7 +140,7 @@ def init_subclass_params(sub_data, sub_class):

self.repflow_args = init_subclass_params(repflow, RepFlowArgs)
self.activation_function = activation_function

# here defined the repflows
self.repflows = DescrptBlockRepflows(
self.repflow_args.e_rcut,
self.repflow_args.e_rcut_smth,
Expand Down Expand Up @@ -211,6 +217,12 @@ def init_subclass_params(sub_data, sub_class):
param.requires_grad = trainable
self.compress = False

# MAD相关参数存储
self.enable_mad = enable_mad
self.mad_cutoff_ratio = mad_cutoff_ratio
# 存储MAD值供损失函数使用 (变量名保持last_mad_gap以兼容损失函数)
self.last_mad_gap = None

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut
Expand Down Expand Up @@ -392,6 +404,8 @@ def serialize(self) -> dict:
"use_loc_mapping": self.use_loc_mapping,
"type_map": self.type_map,
"type_embedding": self.type_embedding.embedding.serialize(),
"enable_mad": self.enable_mad, # new added
"mad_cutoff_ratio": self.mad_cutoff_ratio,
}
repflow_variable = {
"edge_embd": repflows.edge_embd.serialize(),
Expand Down Expand Up @@ -492,8 +506,8 @@ def forward(
if not parallel_mode and self.use_loc_mapping:
node_ebd_ext = self.type_embedding(extended_atype[:, :nloc])
else:
node_ebd_ext = self.type_embedding(extended_atype)
node_ebd_inp = node_ebd_ext[:, :nloc, :]
node_ebd_ext = self.type_embedding(extended_atype) # 节点嵌入表征 [nf, nall, tebd_dim]
node_ebd_inp = node_ebd_ext[:, :nloc, :] # 初始类型嵌入 [nframes, nloc, n_dim] (n_dim=128)
# repflows
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
nlist,
Expand All @@ -503,16 +517,67 @@ def forward(
mapping,
comm_dict=comm_dict,
)
if self.concat_output_tebd:
node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1)
if self.concat_output_tebd: # 控制是否在输出时拼接初始类型嵌入
node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1) # 保留原始信息:确保初始的原子类型信息不会在多层RepFlow处理中完全丢失
# 同时提供原始类型特征和经过环境学习的特征,这是一种残差连接的思想,类似于ResNet中跳跃连接,防止深层网络丢失重要的基础信息。

# MAD计算(在启用时总是计算,不仅仅是训练时)
if self.enable_mad:
#print("Computing MAD for node_ebd shape:", node_ebd.shape)
self.last_mad_gap = self._compute_mad(node_ebd)
#print("MAD value:", self.last_mad_gap.item())
else:
self.last_mad_gap = None
return (
node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), # # 1. 节点嵌入表征 [nframes, nloc, n_dim] (n_dim=128)
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), # 2. 旋转等变矩阵 [nframes, nloc, e_dim, 3] (e_dim=128)
edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), # 3. 边嵌入表征 [nframes, nloc, nnei, e_dim] (e_dim=128)
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), # 4. 方向向量 [nframes, nloc, nnei, 3] (3=xyz)
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), # 5. 平滑开关函数 [nframes, nloc, nnei]
)

# 2. 添加简化的 _compute_mad 方法
def _compute_mad(self, node_ebd: torch.Tensor) -> torch.Tensor:
"""计算基础MAD (Mean Average Distance) 用于正则化

MAD使用余弦距离衡量节点嵌入表征之间的平均距离:
余弦距离 = 1 - 余弦相似度 = 1 - (Hi · Hj) / (|Hi| · |Hj|)

Parameters
----------
node_ebd : torch.Tensor
节点嵌入表征,形状 [nframes, nloc, embed_dim]

Returns
-------
torch.Tensor
所有节点对之间的平均余弦距离
"""
import torch.nn.functional as F
nframes, nloc, embed_dim = node_ebd.shape
device = node_ebd.device
if nloc <= 1:
return torch.tensor(0.0, device=node_ebd.device)
node_ebd_norm = F.normalize(node_ebd, p=2, dim=-1) # [nf, nloc, embed_dim]

# 计算余弦相似度矩阵
cosine_sim = torch.bmm(node_ebd_norm, node_ebd_norm.transpose(-1, -2))

# 余弦距离 = 1 - 余弦相似度
cosine_dist = 1.0 - cosine_sim
# 不相似 --> cosin_sim --> 0
# Global MAD
#global_mad = cosine_dist.mean()
# 排除对角线(自己与自己的距离为0)
#eye_mask = torch.eye(nloc, dtype=torch.bool, device=device).unsqueeze(0).expand(nframes, -1, -1)
#valid_mask = ~eye_mask

# 计算所有有效节点对的平均距离
#valid_distances = cosine_dist[valid_mask]
mad_global = cosine_dist.sum() / (nframes * nloc * (nloc - 1))

return mad_global

@classmethod
def update_sel(
cls,
Expand Down
6 changes: 3 additions & 3 deletions deepmd/pt/model/descriptor/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _make_env_mat(
coord_pad = torch.concat([coord, coord[:, -1:, :] + rcut], dim=1)
coord_r = torch.gather(coord_pad, 1, index)
coord_r = coord_r.view(bsz, natoms, nnei, 3)
diff = coord_r - coord_l
diff = coord_r - coord_l # 相对位移
length = torch.linalg.norm(diff, dim=-1, keepdim=True)
# for index 0 nloc atom
length = length + ~mask.unsqueeze(-1)
Expand All @@ -39,8 +39,8 @@ def _make_env_mat(
compute_smooth_weight(length, ruct_smth, rcut)
if not use_exp_switch
else compute_exp_sw(length, ruct_smth, rcut)
)
weight = weight * mask.unsqueeze(-1)
) # 权重计算
weight = weight * mask.unsqueeze(-1) # 权重应用
if radial_only:
env_mat = t0 * weight
else:
Expand Down
Loading