18
18
from renormalizer .mps .matrix import multi_tensor_contract , tensordot , asnumpy , asxp
19
19
from renormalizer .mps .hop_expr import hop_expr
20
20
from renormalizer .mps .svd_qn import get_qn_mask
21
- from renormalizer .mps import Mpo , Mps
21
+ from renormalizer .mps import Mpo , Mps , StackedMpo
22
22
from renormalizer .mps .lib import Environ , cvec2cmat
23
23
from renormalizer .utils import Quantity , CompressConfig , CompressCriteria
24
24
@@ -45,14 +45,14 @@ def construct_mps_mpo(model, mmax, nexciton, offset=Quantity(0)):
45
45
return mps , mpo
46
46
47
47
48
- def optimize_mps (mps : Mps , mpo : Mpo , omega : float = None ) -> Tuple [List , Mps ]:
48
+ def optimize_mps (mps : Mps , mpo : Union [ Mpo , StackedMpo ] , omega : float = None ) -> Tuple [List , Mps ]:
49
49
r"""DMRG ground state algorithm and state-averaged excited states algorithm
50
50
51
51
Parameters
52
52
----------
53
53
mps : renormalizer.mps.Mps
54
54
initial guess of mps. The MPS is overwritten during the optimization.
55
- mpo : renormalizer.mps.Mpo
55
+ mpo : Union[ renormalizer.mps.Mpo, renormalizer.mps.StackedMpo]
56
56
mpo of Hamiltonian
57
57
omega: float, optional
58
58
target the eigenpair near omega with special variational function
@@ -67,7 +67,7 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
67
67
mps : renormalizer.mps.Mps
68
68
optimized ground state MPS.
69
69
Note it's not the same with the overwritten input MPS.
70
-
70
+
71
71
See Also
72
72
--------
73
73
renormalizer.utils.configs.OptimizeConfig : The optimization configuration.
@@ -95,14 +95,19 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
95
95
env = "L"
96
96
97
97
compress_config_bk = mps .compress_config
98
-
98
+
99
99
# construct the environment matrix
100
100
if omega is not None :
101
+ if isinstance (mpo , StackedMpo ):
102
+ raise NotImplementedError ("StackedMPO + omega is not implemented yet" )
101
103
identity = Mpo .identity (mpo .model )
102
104
mpo = mpo .add (identity .scale (- omega ))
103
105
environ = Environ (mps , [mpo , mpo ], env )
104
106
else :
105
- environ = Environ (mps , mpo , env )
107
+ if isinstance (mpo , StackedMpo ):
108
+ environ = [Environ (mps , item , env ) for item in mpo .mpos ]
109
+ else :
110
+ environ = Environ (mps , mpo , env )
106
111
107
112
macro_iteration_result = []
108
113
# Idx of the active site with lowest energy for each sweep
@@ -111,7 +116,7 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
111
116
res_mps : Union [Mps , List [Mps ]] = None
112
117
for isweep , (compress_config , percent ) in enumerate (mps .optimize_config .procedure ):
113
118
logger .debug (f"isweep: { isweep } " )
114
-
119
+
115
120
if isinstance (compress_config , CompressConfig ):
116
121
mps .compress_config = compress_config
117
122
elif isinstance (compress_config , int ):
@@ -156,19 +161,19 @@ def optimize_mps(mps: Mps, mpo: Mpo, omega: float = None) -> Tuple[List, Mps]:
156
161
for res in res_mps :
157
162
res .compress_config = compress_config_bk
158
163
logger .info (f"{ res_mps [0 ]} " )
159
-
164
+
160
165
return macro_iteration_result , res_mps
161
166
162
167
163
168
def single_sweep (
164
169
mps : Mps ,
165
- mpo : Mpo ,
170
+ mpo : Union [ Mpo , StackedMpo ] ,
166
171
environ : Environ ,
167
172
omega : float ,
168
173
percent : float ,
169
174
last_opt_e_idx : int
170
175
):
171
-
176
+
172
177
method = mps .optimize_config .method
173
178
nroots = mps .optimize_config .nroots
174
179
@@ -210,18 +215,26 @@ def single_sweep(
210
215
if omega is None :
211
216
operator = mpo
212
217
else :
218
+ assert isinstance (mpo , Mpo )
213
219
operator = [mpo , mpo ]
214
220
215
- ltensor = environ .GetLR ("L" , lidx , mps , operator , itensor = None , method = lmethod )
216
- rtensor = environ .GetLR ("R" , ridx , mps , operator , itensor = None , method = rmethod )
221
+ if isinstance (mpo , StackedMpo ):
222
+ ltensor = [environ_item .GetLR ("L" , lidx , mps , operator_item , itensor = None , method = lmethod ) for environ_item , operator_item in zip (environ , operator .mpos )]
223
+ rtensor = [environ_item .GetLR ("R" , ridx , mps , operator_item , itensor = None , method = rmethod ) for environ_item , operator_item in zip (environ , operator .mpos )]
224
+ else :
225
+ ltensor = environ .GetLR ("L" , lidx , mps , operator , itensor = None , method = lmethod )
226
+ rtensor = environ .GetLR ("R" , ridx , mps , operator , itensor = None , method = rmethod )
217
227
218
228
# get the quantum number pattern
219
229
qnbigl , qnbigr , qnmat = mps ._get_big_qn (cidx )
220
230
qn_mask = get_qn_mask (qnmat , mps .qntot )
221
231
cshape = qn_mask .shape
222
232
223
233
# center mo
224
- cmo = [asxp (mpo [idx ]) for idx in cidx ]
234
+ if isinstance (mpo , StackedMpo ):
235
+ cmo = [[asxp (mpo_item [idx ]) for idx in cidx ] for mpo_item in mpo .mpos ]
236
+ else :
237
+ cmo = [asxp (mpo [idx ]) for idx in cidx ]
225
238
226
239
use_direct_eigh = np .prod (cshape ) < 1000 or mps .optimize_config .algo == "direct"
227
240
if use_direct_eigh :
@@ -285,15 +298,15 @@ def single_sweep(
285
298
return micro_iteration_result , res_mps , mpo
286
299
287
300
288
- def eigh_direct (
301
+ def get_ham_direct (
289
302
mps : Mps ,
290
303
qn_mask : np .ndarray ,
291
- ltensor : xp .ndarray ,
292
- rtensor : xp .ndarray ,
304
+ ltensor : Union [ xp .ndarray , List [ xp . ndarray ]] ,
305
+ rtensor : Union [ xp .ndarray , List [ xp . ndarray ]] ,
293
306
cmo : List [xp .ndarray ],
294
307
omega : float ,
295
308
):
296
- logger .debug (f "use direct eigensolver" )
309
+ logger .debug ("use direct eigensolver" )
297
310
298
311
# direct algorithm
299
312
if omega is None :
@@ -347,6 +360,23 @@ def eigh_direct(
347
360
)
348
361
ham = ham [:, :, :, :, qn_mask ][qn_mask , :]
349
362
363
+ return ham
364
+
365
+
366
+ def eigh_direct (
367
+ mps : Mps ,
368
+ qn_mask : np .ndarray ,
369
+ ltensor : Union [xp .ndarray , List [xp .ndarray ]],
370
+ rtensor : Union [xp .ndarray , List [xp .ndarray ]],
371
+ cmo : List [xp .ndarray ],
372
+ omega : float ,
373
+ ):
374
+ if isinstance (ltensor , list ):
375
+ assert isinstance (rtensor , list )
376
+ assert len (ltensor ) == len (rtensor )
377
+ ham = sum ([get_ham_direct (mps , qn_mask , ltensor_item , rtensor_item , cmo_item , omega ) for ltensor_item , rtensor_item , cmo_item in zip (ltensor , rtensor , cmo )])
378
+ else :
379
+ ham = get_ham_direct (mps , qn_mask , ltensor , rtensor , cmo , omega )
350
380
inverse = mps .optimize_config .inverse
351
381
w , v = scipy .linalg .eigh (asnumpy (ham ) * inverse )
352
382
@@ -360,14 +390,13 @@ def eigh_direct(
360
390
return e , c
361
391
362
392
363
- def eigh_iterative (
393
+ def get_ham_iterative (
364
394
mps : Mps ,
365
395
qn_mask : np .ndarray ,
366
- ltensor : xp .ndarray ,
367
- rtensor : xp .ndarray ,
396
+ ltensor : Union [ xp .ndarray , List [ xp . ndarray ]] ,
397
+ rtensor : Union [ xp .ndarray , List [ xp . ndarray ]] ,
368
398
cmo : List [xp .ndarray ],
369
399
omega : float ,
370
- cguess : List [np .ndarray ],
371
400
):
372
401
# iterative algorithm
373
402
method = mps .optimize_config .method
@@ -428,6 +457,34 @@ def eigh_iterative(
428
457
# contraction expression
429
458
cshape = qn_mask .shape
430
459
expr = hop_expr (ltensor , rtensor , cmo , cshape , omega is not None )
460
+ return hdiag , expr
461
+
462
+
463
+ def func_sum (funcs ):
464
+ def new_func (* args , ** kwargs ):
465
+ return sum ([func (* args , ** kwargs ) for func in funcs ])
466
+ return new_func
467
+
468
+
469
+ def eigh_iterative (
470
+ mps : Mps ,
471
+ qn_mask : np .ndarray ,
472
+ ltensor : Union [xp .ndarray , List [xp .ndarray ]],
473
+ rtensor : Union [xp .ndarray , List [xp .ndarray ]],
474
+ cmo : List [xp .ndarray ],
475
+ omega : float ,
476
+ cguess : List [np .ndarray ],
477
+ ):
478
+ # iterative algorithm
479
+ inverse = mps .optimize_config .inverse
480
+ if isinstance (ltensor , list ):
481
+ assert isinstance (rtensor , list )
482
+ assert len (ltensor ) == len (rtensor )
483
+ ham = [get_ham_iterative (mps , qn_mask , ltensor_item , rtensor_item , cmo_item , omega ) for ltensor_item , rtensor_item , cmo_item in zip (ltensor , rtensor , cmo )]
484
+ hdiag = sum ([hdiag_item for hdiag_item , expr_item in ham ])
485
+ expr = func_sum ([expr_item for hdiag_item , expr_item in ham ])
486
+ else :
487
+ hdiag , expr = get_ham_iterative (mps , qn_mask , ltensor , rtensor , cmo , omega )
431
488
432
489
count = 0
433
490
0 commit comments