Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions dptb/entrypoints/emp_sk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import os
from dptb.utils.gen_inputs import gen_inputs
import json
from collections import OrderedDict
log = logging.getLogger(__name__)

def to_empsk(
INPUT,
output='./',
basemodel='poly2',
soc= None,
**kwargs):
"""
Convert the model to empirical SK parameters.
Expand All @@ -23,7 +25,7 @@ def to_empsk(
with open(INPUT, 'r') as f:
input = json.load(f)
common_options = input['common_options']
EmpSK(common_options, basemodel=basemodel).to_json(outdir=output)
EmpSK(common_options, basemodel=basemodel).to_json(outdir=output, soc=soc)

class EmpSK(object):
"""
Expand All @@ -45,17 +47,56 @@ def __init__(self, common_options, basemodel='poly2'):

self.model = build_model(model_ckpt, common_options=common_options, no_check=True)

def to_json(self, outdir='./'):
def to_json(self, outdir='./', soc=None):
"""
Convert the model to json format.
"""
# 判断是否存在输出目录
if not os.path.exists(outdir):
os.makedirs(outdir, exist_ok=True)
json_dict = self.model.to_json(basisref=self.basisref)
with open(os.path.join(outdir,'sktb.json'), 'w') as f:
json.dump(json_dict, f, indent=4)
json_dict = self.model.to_json(basisref=self.basisref)\

if soc is not None:
mp = json_dict.setdefault("model_params", {})
onsite = mp.get("onsite", {})

# build soc block based on onsite
soc_block = {}
for key, val in onsite.items():
parts = key.split("-")
if len(parts) < 3:
continue
elem, orb = parts[0], parts[1]
# s and * orbitals -> 0, others -> soc value
if orb.lower() == "s" or "*" in orb:
v = 0.0
else:
v = float(soc)
soc_block[key] = [v]

# insert soc block after overlap
if isinstance(mp, dict):
new_mp = OrderedDict()
inserted = False
for k, v in mp.items():
new_mp[k] = v
if k == "overlap":
new_mp["soc"] = soc_block
inserted = True
if not inserted:
new_mp["soc"] = soc_block
json_dict["model_params"] = new_mp

# update model_options for nnsk.soc.method
mo = json_dict.setdefault("model_options", {})
nnsk = mo.setdefault("nnsk", {})
soc_opt = nnsk.setdefault("soc", {})
soc_opt["method"] = "uniform_noref"

# write final file
with open(os.path.join(outdir, 'sktb.json'), 'w') as f:
json.dump(json_dict, f, indent=4)

# save input template
# input_template = gen_inputs(model=self.model, task='train', mode=mode)

Expand Down
8 changes: 8 additions & 0 deletions dptb/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,14 @@ def main_parser() -> argparse.ArgumentParser:
default="poly2",
help="The base model type can be poly2 or poly4."
)
parser_esk.add_argument(
"--soc",
"--soc_onsite",
nargs="?",
const=0.2,
type=float,
help="Enable SOC, default 0.2 if no value is given. Example: --soc or --soc=0.5"
)
return parser

def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
Expand Down
72 changes: 47 additions & 25 deletions dptb/nn/hr2hk.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(
self.node_field = node_field
self.out_field = out_field




def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:

# construct bond wise hamiltonian block from obital pair wise node/edge features
Expand All @@ -65,13 +68,14 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
if isinstance(soc, torch.Tensor):
soc = soc.all()
if soc:
if self.overlap:
raise NotImplementedError("Overlap is not implemented for SOC.")
# if self.overlap:
# print("Overlap for SOC is realized by kronecker product.")

orbpair_soc = data[AtomicDataDict.NODE_SOC_KEY]
soc_upup_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.ctype, device=self.device)
soc_updn_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.ctype, device=self.device)


ist = 0
for i,iorb in enumerate(self.idp.full_basis):
jst = 0
Expand All @@ -96,6 +100,12 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
# onsite_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = factor * torch.eye(2*li+1, dtype=self.dtype, device=self.device).reshape(1, 2*li+1, 2*lj+1).repeat(onsite_block.shape[0], 1, 1)
if i <= j:
onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)

if soc and i == j:
soc_updn_tmp = orbpair_soc[:, self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1))
# j==i -> 2*lj+1 == 2*li+1
soc_upup_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, :2*lj+1]
soc_updn_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, 2*lj+1:]
else:
if i <= j:
onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
Expand All @@ -110,6 +120,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
self.onsite_block = onsite_block
self.bondwise_hopping = bondwise_hopping
if soc:
# 先保存已有的
self.soc_upup_block = soc_upup_block
self.soc_updn_block = soc_updn_block

Expand Down Expand Up @@ -159,30 +170,41 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
block = block.contiguous()

if soc:
HK_SOC = torch.zeros(kpoints.shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device)
#HK_SOC[:,:all_norb,:all_norb] = block + block_uu
#HK_SOC[:,:all_norb,all_norb:] = block_ud
#HK_SOC[:,all_norb:,:all_norb] = block_ud.conj()
#HK_SOC[:,all_norb:,all_norb:] = block + block_uu.conj()
ist = 0
assert len(soc_upup_block) == len(soc_updn_block)
for i in range(len(soc_upup_block)):
assert soc_upup_block[i].shape == soc_updn_block[i].shape
mask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[i]]
masked_soc_upup_block = soc_upup_block[i][mask][:,mask]
masked_soc_updn_block = soc_updn_block[i][mask][:,mask]
HK_SOC[:,ist:ist+masked_soc_upup_block.shape[0],ist:ist+masked_soc_upup_block.shape[1]] = masked_soc_upup_block.squeeze(0)
HK_SOC[:,ist:ist+masked_soc_updn_block.shape[0],ist+all_norb:ist+all_norb+masked_soc_updn_block.shape[1]] = masked_soc_updn_block.squeeze(0)
assert masked_soc_upup_block.shape[0] == masked_soc_upup_block.shape[1]
assert masked_soc_upup_block.shape[0] == masked_soc_updn_block.shape[0]

ist += masked_soc_upup_block.shape[0]
if self.overlap:
# ========== S_soc = S ⊗ I₂ : N×N S(k) to 2N×2N kronecker product ==========
S_soc = torch.zeros(kpoints.shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device)
S_soc[:, :all_norb, :all_norb] = block
S_soc[:, all_norb:, all_norb:] = block
# Enforce strict Hermitian form to avoid non-positive-definite errors during training by torch._C._LinAlgError: linalg.cholesky.
# This issue only occurs when SOC+overlap is active and "overlap" is not frozen. It can be avoided by setting "freeze": ["overlap"].
S_soc = 0.5 * (S_soc + S_soc.transpose(1, 2).conj())

data[self.out_field] = S_soc
else:
HK_SOC = torch.zeros(kpoints.shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device)
#HK_SOC[:,:all_norb,:all_norb] = block + block_uu
#HK_SOC[:,:all_norb,all_norb:] = block_ud
#HK_SOC[:,all_norb:,:all_norb] = block_ud.conj()
#HK_SOC[:,all_norb:,all_norb:] = block + block_uu.conj()
ist = 0
assert len(soc_upup_block) == len(soc_updn_block)
for i in range(len(soc_upup_block)):
assert soc_upup_block[i].shape == soc_updn_block[i].shape
mask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[i]]
masked_soc_upup_block = soc_upup_block[i][mask][:,mask]
masked_soc_updn_block = soc_updn_block[i][mask][:,mask]
HK_SOC[:,ist:ist+masked_soc_upup_block.shape[0],ist:ist+masked_soc_upup_block.shape[1]] = masked_soc_upup_block.squeeze(0)
HK_SOC[:,ist:ist+masked_soc_updn_block.shape[0],ist+all_norb:ist+all_norb+masked_soc_updn_block.shape[1]] = masked_soc_updn_block.squeeze(0)
assert masked_soc_upup_block.shape[0] == masked_soc_upup_block.shape[1]
assert masked_soc_upup_block.shape[0] == masked_soc_updn_block.shape[0]

ist += masked_soc_upup_block.shape[0]

HK_SOC[:,all_norb:,:all_norb] = HK_SOC[:,:all_norb,all_norb:].conj()
HK_SOC[:,all_norb:,all_norb:] = HK_SOC[:,:all_norb,:all_norb].conj() + block
HK_SOC[:,:all_norb,:all_norb] = HK_SOC[:,:all_norb,:all_norb] + block
HK_SOC[:,all_norb:,:all_norb] = HK_SOC[:,:all_norb,all_norb:].conj()
HK_SOC[:,all_norb:,all_norb:] = HK_SOC[:,:all_norb,:all_norb].conj() + block
HK_SOC[:,:all_norb,:all_norb] = HK_SOC[:,:all_norb,:all_norb] + block

data[self.out_field] = HK_SOC
data[self.out_field] = HK_SOC
else:
data[self.out_field] = block

Expand Down
16 changes: 16 additions & 0 deletions dptb/nn/nnsk.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,22 @@ def to_json(self, version=2, basisref=None):
to_uniform = True
else:
print("The basisref is not used. since the onsite method is not uniform_noref.")
# add the support for soc uniform_noref when use ['s', 'p', 'd', 'f'] in soc case
if basisref is not None:
if self.model_options['nnsk']['soc']['method'] in ['uniform_noref']:
for atom, orb in self.basis.items():
new_basis[atom] = []
if atom not in basisref:
raise ValueError("The atom in the model basis should be in the basisref.")
for o in orb:
if o not in ['s', 'p', 'd', 'f']:
raise ValueError("For uniform_noref mode, the orb in the model basis should be in ['s', 'p', 'd', 'f'].")
if o not in list(basisref[atom].keys()):
raise ValueError("The orb in the model basis should be in the basisref.")
new_basis[atom].append(basisref[atom][o])
else:
print("The basisref is not used. since the soc method is not uniform_noref.")


ckpt = {}
# load hopping params
Expand Down
29 changes: 28 additions & 1 deletion dptb/nn/sktb/soc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_socLs(self, **kwargs):
class SOCFormula(BaseSOC):
num_paras_dict = {
'uniform': 1,
'uniform_noref': 1,
"none": 0,
"custom": None,
}
Expand All @@ -42,6 +43,8 @@ def __init__(
pass
elif functype == 'uniform':
assert hasattr(self, 'uniform')
elif functype == 'uniform_noref':
assert hasattr(self, 'uniform_noref')
elif functype == 'custom':
assert hasattr(self, 'custom')
else:
Expand All @@ -64,6 +67,8 @@ def get_socLs(self, **kwargs):
return self.none(**kwargs)
elif self.functype == 'uniform':
return self.uniform(**kwargs)
elif self.functype == 'uniform_noref':
return self.uniform_noref(**kwargs)
elif self.functype == 'custom':
return self.custom(**kwargs)
else:
Expand Down Expand Up @@ -114,4 +119,26 @@ def uniform(self, atomic_numbers: torch.Tensor, nn_soc_paras: torch.Tensor, **kw
idx = self.idp.transform_atom(atomic_numbers)

return nn_soc_paras[idx] + self.none(atomic_numbers=atomic_numbers)


def uniform_noref(self, atomic_numbers: torch.Tensor, nn_soc_paras: torch.Tensor, **kwargs):
"""The uniform soc function with no reference , that have the same onsite energies for one specific orbital of a atom type.

Parameters
----------
atomic_numbers : torch.Tensor(N) or torch.Tensor(N,1)
The atomic number list.
nn_onsite_paras : torch.Tensor(N_atom_type, n_orb)
The nn fitted parameters for onsite energies.

Returns
-------
torch.Tensor(N, n_orb)
the onsite energies by composing results from nn and ones from database.
"""
atomic_numbers = atomic_numbers.reshape(-1)
if nn_soc_paras.shape[-1] == 1:
nn_soc_paras = nn_soc_paras.squeeze(-1)

idx = self.idp.transform_atom(atomic_numbers)

return nn_soc_paras[idx]