diff --git a/dptb/entrypoints/emp_sk.py b/dptb/entrypoints/emp_sk.py index d5a684b8..be9c531b 100644 --- a/dptb/entrypoints/emp_sk.py +++ b/dptb/entrypoints/emp_sk.py @@ -8,12 +8,14 @@ import os from dptb.utils.gen_inputs import gen_inputs import json +from collections import OrderedDict log = logging.getLogger(__name__) def to_empsk( INPUT, output='./', basemodel='poly2', + soc= None, **kwargs): """ Convert the model to empirical SK parameters. @@ -23,7 +25,7 @@ def to_empsk( with open(INPUT, 'r') as f: input = json.load(f) common_options = input['common_options'] - EmpSK(common_options, basemodel=basemodel).to_json(outdir=output) + EmpSK(common_options, basemodel=basemodel).to_json(outdir=output, soc=soc) class EmpSK(object): """ @@ -45,17 +47,56 @@ def __init__(self, common_options, basemodel='poly2'): self.model = build_model(model_ckpt, common_options=common_options, no_check=True) - def to_json(self, outdir='./'): + def to_json(self, outdir='./', soc=None): """ Convert the model to json format. """ # 判断是否存在输出目录 if not os.path.exists(outdir): os.makedirs(outdir, exist_ok=True) - json_dict = self.model.to_json(basisref=self.basisref) - with open(os.path.join(outdir,'sktb.json'), 'w') as f: - json.dump(json_dict, f, indent=4) + json_dict = self.model.to_json(basisref=self.basisref)\ + if soc is not None: + mp = json_dict.setdefault("model_params", {}) + onsite = mp.get("onsite", {}) + + # build soc block based on onsite + soc_block = {} + for key, val in onsite.items(): + parts = key.split("-") + if len(parts) < 3: + continue + elem, orb = parts[0], parts[1] + # s and * orbitals -> 0, others -> soc value + if orb.lower() == "s" or "*" in orb: + v = 0.0 + else: + v = float(soc) + soc_block[key] = [v] + + # insert soc block after overlap + if isinstance(mp, dict): + new_mp = OrderedDict() + inserted = False + for k, v in mp.items(): + new_mp[k] = v + if k == "overlap": + new_mp["soc"] = soc_block + inserted = True + if not inserted: + new_mp["soc"] = soc_block + json_dict["model_params"] = new_mp + + # update model_options for nnsk.soc.method + mo = json_dict.setdefault("model_options", {}) + nnsk = mo.setdefault("nnsk", {}) + soc_opt = nnsk.setdefault("soc", {}) + soc_opt["method"] = "uniform_noref" + + # write final file + with open(os.path.join(outdir, 'sktb.json'), 'w') as f: + json.dump(json_dict, f, indent=4) + # save input template # input_template = gen_inputs(model=self.model, task='train', mode=mode) diff --git a/dptb/entrypoints/main.py b/dptb/entrypoints/main.py index 528e45eb..337e5ac8 100644 --- a/dptb/entrypoints/main.py +++ b/dptb/entrypoints/main.py @@ -435,6 +435,14 @@ def main_parser() -> argparse.ArgumentParser: default="poly2", help="The base model type can be poly2 or poly4." ) + parser_esk.add_argument( + "--soc", + "--soc_onsite", + nargs="?", + const=0.2, + type=float, + help="Enable SOC, default 0.2 if no value is given. Example: --soc or --soc=0.5" + ) return parser def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: diff --git a/dptb/nn/hr2hk.py b/dptb/nn/hr2hk.py index 7b443364..9239d9cf 100644 --- a/dptb/nn/hr2hk.py +++ b/dptb/nn/hr2hk.py @@ -45,6 +45,9 @@ def __init__( self.node_field = node_field self.out_field = out_field + + + def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # construct bond wise hamiltonian block from obital pair wise node/edge features @@ -65,13 +68,14 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: if isinstance(soc, torch.Tensor): soc = soc.all() if soc: - if self.overlap: - raise NotImplementedError("Overlap is not implemented for SOC.") - + # if self.overlap: + # print("Overlap for SOC is realized by kronecker product.") + orbpair_soc = data[AtomicDataDict.NODE_SOC_KEY] soc_upup_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.ctype, device=self.device) soc_updn_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.ctype, device=self.device) + ist = 0 for i,iorb in enumerate(self.idp.full_basis): jst = 0 @@ -96,6 +100,12 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: # onsite_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = factor * torch.eye(2*li+1, dtype=self.dtype, device=self.device).reshape(1, 2*li+1, 2*lj+1).repeat(onsite_block.shape[0], 1, 1) if i <= j: onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1) + + if soc and i == j: + soc_updn_tmp = orbpair_soc[:, self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1)) + # j==i -> 2*lj+1 == 2*li+1 + soc_upup_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, :2*lj+1] + soc_updn_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, 2*lj+1:] else: if i <= j: onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1) @@ -110,6 +120,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: self.onsite_block = onsite_block self.bondwise_hopping = bondwise_hopping if soc: + # 先保存已有的 self.soc_upup_block = soc_upup_block self.soc_updn_block = soc_updn_block @@ -159,30 +170,41 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: block = block.contiguous() if soc: - HK_SOC = torch.zeros(kpoints.shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device) - #HK_SOC[:,:all_norb,:all_norb] = block + block_uu - #HK_SOC[:,:all_norb,all_norb:] = block_ud - #HK_SOC[:,all_norb:,:all_norb] = block_ud.conj() - #HK_SOC[:,all_norb:,all_norb:] = block + block_uu.conj() - ist = 0 - assert len(soc_upup_block) == len(soc_updn_block) - for i in range(len(soc_upup_block)): - assert soc_upup_block[i].shape == soc_updn_block[i].shape - mask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[i]] - masked_soc_upup_block = soc_upup_block[i][mask][:,mask] - masked_soc_updn_block = soc_updn_block[i][mask][:,mask] - HK_SOC[:,ist:ist+masked_soc_upup_block.shape[0],ist:ist+masked_soc_upup_block.shape[1]] = masked_soc_upup_block.squeeze(0) - HK_SOC[:,ist:ist+masked_soc_updn_block.shape[0],ist+all_norb:ist+all_norb+masked_soc_updn_block.shape[1]] = masked_soc_updn_block.squeeze(0) - assert masked_soc_upup_block.shape[0] == masked_soc_upup_block.shape[1] - assert masked_soc_upup_block.shape[0] == masked_soc_updn_block.shape[0] - - ist += masked_soc_upup_block.shape[0] + if self.overlap: + # ========== S_soc = S ⊗ I₂ : N×N S(k) to 2N×2N kronecker product ========== + S_soc = torch.zeros(kpoints.shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device) + S_soc[:, :all_norb, :all_norb] = block + S_soc[:, all_norb:, all_norb:] = block + # Enforce strict Hermitian form to avoid non-positive-definite errors during training by torch._C._LinAlgError: linalg.cholesky. + # This issue only occurs when SOC+overlap is active and "overlap" is not frozen. It can be avoided by setting "freeze": ["overlap"]. + S_soc = 0.5 * (S_soc + S_soc.transpose(1, 2).conj()) + + data[self.out_field] = S_soc + else: + HK_SOC = torch.zeros(kpoints.shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device) + #HK_SOC[:,:all_norb,:all_norb] = block + block_uu + #HK_SOC[:,:all_norb,all_norb:] = block_ud + #HK_SOC[:,all_norb:,:all_norb] = block_ud.conj() + #HK_SOC[:,all_norb:,all_norb:] = block + block_uu.conj() + ist = 0 + assert len(soc_upup_block) == len(soc_updn_block) + for i in range(len(soc_upup_block)): + assert soc_upup_block[i].shape == soc_updn_block[i].shape + mask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[i]] + masked_soc_upup_block = soc_upup_block[i][mask][:,mask] + masked_soc_updn_block = soc_updn_block[i][mask][:,mask] + HK_SOC[:,ist:ist+masked_soc_upup_block.shape[0],ist:ist+masked_soc_upup_block.shape[1]] = masked_soc_upup_block.squeeze(0) + HK_SOC[:,ist:ist+masked_soc_updn_block.shape[0],ist+all_norb:ist+all_norb+masked_soc_updn_block.shape[1]] = masked_soc_updn_block.squeeze(0) + assert masked_soc_upup_block.shape[0] == masked_soc_upup_block.shape[1] + assert masked_soc_upup_block.shape[0] == masked_soc_updn_block.shape[0] + + ist += masked_soc_upup_block.shape[0] - HK_SOC[:,all_norb:,:all_norb] = HK_SOC[:,:all_norb,all_norb:].conj() - HK_SOC[:,all_norb:,all_norb:] = HK_SOC[:,:all_norb,:all_norb].conj() + block - HK_SOC[:,:all_norb,:all_norb] = HK_SOC[:,:all_norb,:all_norb] + block + HK_SOC[:,all_norb:,:all_norb] = HK_SOC[:,:all_norb,all_norb:].conj() + HK_SOC[:,all_norb:,all_norb:] = HK_SOC[:,:all_norb,:all_norb].conj() + block + HK_SOC[:,:all_norb,:all_norb] = HK_SOC[:,:all_norb,:all_norb] + block - data[self.out_field] = HK_SOC + data[self.out_field] = HK_SOC else: data[self.out_field] = block diff --git a/dptb/nn/nnsk.py b/dptb/nn/nnsk.py index 4f097c3c..45bdb780 100644 --- a/dptb/nn/nnsk.py +++ b/dptb/nn/nnsk.py @@ -979,6 +979,22 @@ def to_json(self, version=2, basisref=None): to_uniform = True else: print("The basisref is not used. since the onsite method is not uniform_noref.") + # add the support for soc uniform_noref when use ['s', 'p', 'd', 'f'] in soc case + if basisref is not None: + if self.model_options['nnsk']['soc']['method'] in ['uniform_noref']: + for atom, orb in self.basis.items(): + new_basis[atom] = [] + if atom not in basisref: + raise ValueError("The atom in the model basis should be in the basisref.") + for o in orb: + if o not in ['s', 'p', 'd', 'f']: + raise ValueError("For uniform_noref mode, the orb in the model basis should be in ['s', 'p', 'd', 'f'].") + if o not in list(basisref[atom].keys()): + raise ValueError("The orb in the model basis should be in the basisref.") + new_basis[atom].append(basisref[atom][o]) + else: + print("The basisref is not used. since the soc method is not uniform_noref.") + ckpt = {} # load hopping params diff --git a/dptb/nn/sktb/soc.py b/dptb/nn/sktb/soc.py index 5e0e8464..a8b50603 100644 --- a/dptb/nn/sktb/soc.py +++ b/dptb/nn/sktb/soc.py @@ -27,6 +27,7 @@ def get_socLs(self, **kwargs): class SOCFormula(BaseSOC): num_paras_dict = { 'uniform': 1, + 'uniform_noref': 1, "none": 0, "custom": None, } @@ -42,6 +43,8 @@ def __init__( pass elif functype == 'uniform': assert hasattr(self, 'uniform') + elif functype == 'uniform_noref': + assert hasattr(self, 'uniform_noref') elif functype == 'custom': assert hasattr(self, 'custom') else: @@ -64,6 +67,8 @@ def get_socLs(self, **kwargs): return self.none(**kwargs) elif self.functype == 'uniform': return self.uniform(**kwargs) + elif self.functype == 'uniform_noref': + return self.uniform_noref(**kwargs) elif self.functype == 'custom': return self.custom(**kwargs) else: @@ -114,4 +119,26 @@ def uniform(self, atomic_numbers: torch.Tensor, nn_soc_paras: torch.Tensor, **kw idx = self.idp.transform_atom(atomic_numbers) return nn_soc_paras[idx] + self.none(atomic_numbers=atomic_numbers) - \ No newline at end of file + + def uniform_noref(self, atomic_numbers: torch.Tensor, nn_soc_paras: torch.Tensor, **kwargs): + """The uniform soc function with no reference , that have the same onsite energies for one specific orbital of a atom type. + + Parameters + ---------- + atomic_numbers : torch.Tensor(N) or torch.Tensor(N,1) + The atomic number list. + nn_onsite_paras : torch.Tensor(N_atom_type, n_orb) + The nn fitted parameters for onsite energies. + + Returns + ------- + torch.Tensor(N, n_orb) + the onsite energies by composing results from nn and ones from database. + """ + atomic_numbers = atomic_numbers.reshape(-1) + if nn_soc_paras.shape[-1] == 1: + nn_soc_paras = nn_soc_paras.squeeze(-1) + + idx = self.idp.transform_atom(atomic_numbers) + + return nn_soc_paras[idx] \ No newline at end of file