Skip to content

Commit 9583324

Browse files
SUSYUSTCJiaceSun
andauthored
Add functionality of MPO with block diagonal form (#154)
* stack mpo * add test * add comments * update --------- Co-authored-by: JiaceSun <[email protected]>
1 parent 9706a29 commit 9583324

File tree

4 files changed

+110
-29
lines changed

4 files changed

+110
-29
lines changed

renormalizer/mps/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from renormalizer.mps.backend import backend
2-
from renormalizer.mps.mpo import Mpo
2+
from renormalizer.mps.mpo import Mpo, StackedMpo
33
from renormalizer.mps.mps import Mps, BraKetPair
44
from renormalizer.mps.mpdm import MpDm
55
from renormalizer.mps.thermalprop import ThermalProp, load_thermal_state

renormalizer/mps/gs.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from renormalizer.mps.matrix import multi_tensor_contract, tensordot, asnumpy, asxp
1919
from renormalizer.mps.hop_expr import hop_expr
2020
from renormalizer.mps.svd_qn import get_qn_mask
21-
from renormalizer.mps import Mpo, Mps
21+
from renormalizer.mps import Mpo, Mps, StackedMpo
2222
from renormalizer.mps.lib import Environ, cvec2cmat
2323
from renormalizer.utils import Quantity, CompressConfig, CompressCriteria
2424

@@ -45,14 +45,14 @@ def construct_mps_mpo(model, mmax, nexciton, offset=Quantity(0)):
4545
return mps, mpo
4646

4747

48-
def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
48+
def optimize_mps(mps: Mps, mpo: Union[Mpo, StackedMpo], omega: float = None) -> Tuple[List, Mps]:
4949
r"""DMRG ground state algorithm and state-averaged excited states algorithm
5050
5151
Parameters
5252
----------
5353
mps : renormalizer.mps.Mps
5454
initial guess of mps. The MPS is overwritten during the optimization.
55-
mpo : renormalizer.mps.Mpo
55+
mpo : Union[renormalizer.mps.Mpo, renormalizer.mps.StackedMpo]
5656
mpo of Hamiltonian
5757
omega: float, optional
5858
target the eigenpair near omega with special variational function
@@ -67,7 +67,7 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
6767
mps : renormalizer.mps.Mps
6868
optimized ground state MPS.
6969
Note it's not the same with the overwritten input MPS.
70-
70+
7171
See Also
7272
--------
7373
renormalizer.utils.configs.OptimizeConfig : The optimization configuration.
@@ -95,14 +95,19 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
9595
env = "L"
9696

9797
compress_config_bk = mps.compress_config
98-
98+
9999
# construct the environment matrix
100100
if omega is not None:
101+
if isinstance(mpo, StackedMpo):
102+
raise NotImplementedError("StackedMPO + omega is not implemented yet")
101103
identity = Mpo.identity(mpo.model)
102104
mpo = mpo.add(identity.scale(-omega))
103105
environ = Environ(mps, [mpo, mpo], env)
104106
else:
105-
environ = Environ(mps, mpo, env)
107+
if isinstance(mpo, StackedMpo):
108+
environ = [Environ(mps, item, env) for item in mpo.mpos]
109+
else:
110+
environ = Environ(mps, mpo, env)
106111

107112
macro_iteration_result = []
108113
# Idx of the active site with lowest energy for each sweep
@@ -111,7 +116,7 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
111116
res_mps: Union[Mps, List[Mps]] = None
112117
for isweep, (compress_config, percent) in enumerate(mps.optimize_config.procedure):
113118
logger.debug(f"isweep: {isweep}")
114-
119+
115120
if isinstance(compress_config, CompressConfig):
116121
mps.compress_config = compress_config
117122
elif isinstance(compress_config, int):
@@ -156,19 +161,19 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
156161
for res in res_mps:
157162
res.compress_config = compress_config_bk
158163
logger.info(f"{res_mps[0]}")
159-
164+
160165
return macro_iteration_result, res_mps
161166

162167

163168
def single_sweep(
164169
mps: Mps,
165-
mpo: Mpo,
170+
mpo: Union[Mpo, StackedMpo],
166171
environ: Environ,
167172
omega: float,
168173
percent: float,
169174
last_opt_e_idx: int
170175
):
171-
176+
172177
method = mps.optimize_config.method
173178
nroots = mps.optimize_config.nroots
174179

@@ -210,18 +215,26 @@ def single_sweep(
210215
if omega is None:
211216
operator = mpo
212217
else:
218+
assert isinstance(mpo, Mpo)
213219
operator = [mpo, mpo]
214220

215-
ltensor = environ.GetLR("L", lidx, mps, operator, itensor=None, method=lmethod)
216-
rtensor = environ.GetLR("R", ridx, mps, operator, itensor=None, method=rmethod)
221+
if isinstance(mpo, StackedMpo):
222+
ltensor = [environ_item.GetLR("L", lidx, mps, operator_item, itensor=None, method=lmethod) for environ_item, operator_item in zip(environ, operator.mpos)]
223+
rtensor = [environ_item.GetLR("R", ridx, mps, operator_item, itensor=None, method=rmethod) for environ_item, operator_item in zip(environ, operator.mpos)]
224+
else:
225+
ltensor = environ.GetLR("L", lidx, mps, operator, itensor=None, method=lmethod)
226+
rtensor = environ.GetLR("R", ridx, mps, operator, itensor=None, method=rmethod)
217227

218228
# get the quantum number pattern
219229
qnbigl, qnbigr, qnmat = mps._get_big_qn(cidx)
220230
qn_mask = get_qn_mask(qnmat, mps.qntot)
221231
cshape = qn_mask.shape
222232

223233
# center mo
224-
cmo = [asxp(mpo[idx]) for idx in cidx]
234+
if isinstance(mpo, StackedMpo):
235+
cmo = [[asxp(mpo_item[idx]) for idx in cidx] for mpo_item in mpo.mpos]
236+
else:
237+
cmo = [asxp(mpo[idx]) for idx in cidx]
225238

226239
use_direct_eigh = np.prod(cshape) < 1000 or mps.optimize_config.algo == "direct"
227240
if use_direct_eigh:
@@ -285,15 +298,15 @@ def single_sweep(
285298
return micro_iteration_result, res_mps, mpo
286299

287300

288-
def eigh_direct(
301+
def get_ham_direct(
289302
mps: Mps,
290303
qn_mask: np.ndarray,
291-
ltensor: xp.ndarray,
292-
rtensor: xp.ndarray,
304+
ltensor: Union[xp.ndarray, List[xp.ndarray]],
305+
rtensor: Union[xp.ndarray, List[xp.ndarray]],
293306
cmo: List[xp.ndarray],
294307
omega: float,
295308
):
296-
logger.debug(f"use direct eigensolver")
309+
logger.debug("use direct eigensolver")
297310

298311
# direct algorithm
299312
if omega is None:
@@ -347,6 +360,23 @@ def eigh_direct(
347360
)
348361
ham = ham[:, :, :, :, qn_mask][qn_mask, :]
349362

363+
return ham
364+
365+
366+
def eigh_direct(
367+
mps: Mps,
368+
qn_mask: np.ndarray,
369+
ltensor: Union[xp.ndarray, List[xp.ndarray]],
370+
rtensor: Union[xp.ndarray, List[xp.ndarray]],
371+
cmo: List[xp.ndarray],
372+
omega: float,
373+
):
374+
if isinstance(ltensor, list):
375+
assert isinstance(rtensor, list)
376+
assert len(ltensor) == len(rtensor)
377+
ham = sum([get_ham_direct(mps, qn_mask, ltensor_item, rtensor_item, cmo_item, omega) for ltensor_item, rtensor_item, cmo_item in zip(ltensor, rtensor, cmo)])
378+
else:
379+
ham = get_ham_direct(mps, qn_mask, ltensor, rtensor, cmo, omega)
350380
inverse = mps.optimize_config.inverse
351381
w, v = scipy.linalg.eigh(asnumpy(ham) * inverse)
352382

@@ -360,14 +390,13 @@ def eigh_direct(
360390
return e, c
361391

362392

363-
def eigh_iterative(
393+
def get_ham_iterative(
364394
mps: Mps,
365395
qn_mask: np.ndarray,
366-
ltensor: xp.ndarray,
367-
rtensor: xp.ndarray,
396+
ltensor: Union[xp.ndarray, List[xp.ndarray]],
397+
rtensor: Union[xp.ndarray, List[xp.ndarray]],
368398
cmo: List[xp.ndarray],
369399
omega: float,
370-
cguess: List[np.ndarray],
371400
):
372401
# iterative algorithm
373402
method = mps.optimize_config.method
@@ -428,6 +457,34 @@ def eigh_iterative(
428457
# contraction expression
429458
cshape = qn_mask.shape
430459
expr = hop_expr(ltensor, rtensor, cmo, cshape, omega is not None)
460+
return hdiag, expr
461+
462+
463+
def func_sum(funcs):
464+
def new_func(*args, **kwargs):
465+
return sum([func(*args, **kwargs) for func in funcs])
466+
return new_func
467+
468+
469+
def eigh_iterative(
470+
mps: Mps,
471+
qn_mask: np.ndarray,
472+
ltensor: Union[xp.ndarray, List[xp.ndarray]],
473+
rtensor: Union[xp.ndarray, List[xp.ndarray]],
474+
cmo: List[xp.ndarray],
475+
omega: float,
476+
cguess: List[np.ndarray],
477+
):
478+
# iterative algorithm
479+
inverse = mps.optimize_config.inverse
480+
if isinstance(ltensor, list):
481+
assert isinstance(rtensor, list)
482+
assert len(ltensor) == len(rtensor)
483+
ham = [get_ham_iterative(mps, qn_mask, ltensor_item, rtensor_item, cmo_item, omega) for ltensor_item, rtensor_item, cmo_item in zip(ltensor, rtensor, cmo)]
484+
hdiag = sum([hdiag_item for hdiag_item, expr_item in ham])
485+
expr = func_sum([expr_item for hdiag_item, expr_item in ham])
486+
else:
487+
hdiag, expr = get_ham_iterative(mps, qn_mask, ltensor, rtensor, cmo, omega)
431488

432489
count = 0
433490

renormalizer/mps/mpo.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def ph_onsite(cls, model: HolsteinModel, opera: str, mol_idx:int, ph_idx=0):
127127
def intersite(cls, model: HolsteinModel, e_opera: dict, ph_opera: dict, scale:
128128
Quantity=Quantity(1.)):
129129
r""" construct the inter site MPO
130-
130+
131131
Parameters
132132
----------
133133
model : HolsteinModel
@@ -142,7 +142,7 @@ def intersite(cls, model: HolsteinModel, e_opera: dict, ph_opera: dict, scale:
142142
Note
143143
-----
144144
the operator index starts from 0,1,2...
145-
145+
146146
"""
147147

148148
ops = []
@@ -330,7 +330,7 @@ def apply(self, mp: MatrixProduct, canonicalise: bool=False) -> MatrixProduct:
330330
# todo: use meta copy to save time, could be subtle when complex type is involved
331331
# todo: inplace version (saved memory and can be used in `hybrid_exact_propagator`)
332332
# the model is the same as the mps.model
333-
333+
334334
assert self.site_num == mp.site_num
335335
new_mps = self.promote_mt_type(mp.copy())
336336
if mp.is_mps:
@@ -388,14 +388,14 @@ def apply(self, mp: MatrixProduct, canonicalise: bool=False) -> MatrixProduct:
388388

389389
def contract(self, mps, algo="svd"):
390390
r""" an approximation of mpo @ mps/mpdm/mpo
391-
391+
392392
Parameters
393393
----------
394394
mps : `Mps`, `Mpo`, `MpDm`
395395
algo: str, optional
396396
The algorithm to compress mpo @ mps/mpdm/mpo. It could be ``svd``
397-
(default) and ``variational``.
398-
397+
(default) and ``variational``.
398+
399399
Returns
400400
-------
401401
new_mps : `Mps`
@@ -476,3 +476,16 @@ def is_hermitian(self):
476476
def __matmul__(self, other):
477477
return self.apply(other)
478478

479+
480+
class StackedMpo:
481+
"""
482+
An effective sparse representation of MPO in the block diagonal form.
483+
When it enters into the optimization, the Hamiltonian is calculated as
484+
the sum of Hamiltonians generated by each MPO, then the Hamiltonian is
485+
diagonalized and the MPS is updated.
486+
487+
Usage:
488+
optimize_mps(mps, StackedMpo([mpo1, mpo2, ...]))
489+
"""
490+
def __init__(self, mpos: List[Mpo]):
491+
self.mpos = mpos

renormalizer/mps/tests/test_gs.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from renormalizer.model import Model, h_qc
1010
from renormalizer.mps.backend import primme
1111
from renormalizer.mps.gs import construct_mps_mpo, optimize_mps
12-
from renormalizer.mps import Mpo, Mps
12+
from renormalizer.mps import Mpo, Mps, StackedMpo
1313
from renormalizer.tests.parameter import holstein_model
1414
from renormalizer.utils.configs import OFS
1515
from renormalizer.mps.tests import cur_dir
@@ -136,3 +136,14 @@ def test_qc(with_ofs):
136136
print(mpo)
137137
gs_e = min(energies)
138138
assert np.allclose(gs_e, fci_e, atol=5e-3)
139+
140+
141+
def test_stackedmpo():
142+
scheme = 1
143+
method = '1site'
144+
mps, mpo = construct_mps_mpo(holstein_model.switch_scheme(scheme), procedure[0][0], nexciton)
145+
mps.optimize_config.procedure = procedure
146+
mps.optimize_config.method = method
147+
energies1, _ = optimize_mps(mps.copy(), mpo)
148+
energies2, _ = optimize_mps(mps.copy(), StackedMpo([mpo, mpo]))
149+
assert np.all(np.abs(np.array(energies2) - np.array(energies1) * 2) < 1e-8)

0 commit comments

Comments
 (0)