Skip to content

Commit 830d0df

Browse files
committed
fix adaptive_patching logic
1 parent 1072bf4 commit 830d0df

File tree

2 files changed

+70
-25
lines changed

2 files changed

+70
-25
lines changed

src/patching.jl

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -81,29 +81,18 @@ If the bond dimension of a SubDomainMPS exceeds `maxdim`, perform patching.
8181
function adaptive_patching(
8282
subdmps::SubDomainMPS, patchorder; cutoff=0.0, maxdim=typemax(Int)
8383
)::Vector{SubDomainMPS}
84-
if maxbonddim(subdmps) <= maxdim
84+
prjidx = _next_projindex(subdmps.projector, patchorder)
85+
if maxbonddim(subdmps) <= maxdim || prjidx === nothing
8586
return [subdmps]
8687
end
8788

88-
# If the bond dimension exceeds maxdim, perform patching
89-
refined_subdmpss = SubDomainMPS[]
90-
nextprjidx = _next_projindex(subdmps.projector, patchorder)
91-
if nextprjidx === nothing
92-
return [subdmps]
93-
end
94-
95-
for prjval in 1:ITensors.dim(nextprjidx)
96-
prj_ = subdmps.projector & Projector(nextprjidx => prjval)
97-
subdmps_ = truncate(project(subdmps, prj_); cutoff, maxdim)
98-
if maxbonddim(subdmps_) <= maxdim
99-
push!(refined_subdmpss, subdmps_)
100-
else
101-
append!(
102-
refined_subdmpss, adaptive_patching(subdmps_, patchorder; cutoff, maxdim)
103-
)
104-
end
89+
# The bond dimension exceeds maxdim and there's an index to be projected, so perform patching
90+
prjvals = 1:ITensors.dim(prjidx)
91+
return mapreduce(vcat, prjvals) do prjval
92+
prj_ = subdmps.projector & Projector(prjidx => prjval)
93+
subdmps_ = truncate(project(subdmps, prj_); cutoff)
94+
adaptive_patching(subdmps_, patchorder; cutoff, maxdim)
10595
end
106-
return refined_subdmpss
10796
end
10897

10998
"""
@@ -113,14 +102,14 @@ Do patching recursively to reduce the bond dimension.
113102
If the bond dimension of a SubDomainMPS exceeds `maxdim`, perform patching.
114103
"""
115104
function adaptive_patching(
116-
prjmpss::PartitionedMPS, patchorder; cutoff=0.0, maxdim=typemax(Int)
105+
prjmps::PartitionedMPS, patchorder; cutoff=0.0, maxdim=typemax(Int)
117106
)::PartitionedMPS
118107
return PartitionedMPS(
119108
collect(
120-
Iterators.flatten((
121-
apdaptive_patching(prjmps; cutoff, maxdim, patchorder) for
122-
prjmps in values(prjmpss)
123-
)),
109+
Iterators.flatten(
110+
adaptive_patching(subdmps, patchorder; cutoff, maxdim) for
111+
subdmps in values(prjmps)
112+
),
124113
),
125114
)
126115
end

test/patching_tests.jl

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,55 @@
11
using Test
22
import PartitionedMPSs:
3-
PartitionedMPSs, Projector, project, SubDomainMPS, adaptive_patching, PartitionedMPS
3+
PartitionedMPSs,
4+
Projector,
5+
project,
6+
SubDomainMPS,
7+
adaptive_patching,
8+
PartitionedMPS,
9+
maxbonddim
410
import FastMPOContractions as FMPOC
511
using ITensors
12+
using ITensorMPS
613
using Random
14+
import TensorCrossInterpolation as TCI
15+
import QuanticsGrids as QG
16+
17+
include("_util.jl")
18+
19+
# first example from the patching paper
20+
function examplefunc(r)
21+
r1 = (0.2, 0.2)
22+
r2 = (0.8, 0.8)
23+
A = 10^4
24+
σ = 1e-1
25+
k = 10^3
26+
return A * exp(-norm(r .- r1)^2 / (2σ^2)) * sin(k * norm(r .- r1)) +
27+
A / 2 * exp(-norm(r .- r2)^2 /^2 / 2)) * sin(k * norm(r .- r2) / 2)
28+
end
29+
30+
function tci_examplefunc(R)
31+
grid = QG.DiscretizedGrid{2}(R; unfoldingscheme=:interleaved)
32+
qf(x) = examplefunc(QG.quantics_to_origcoord(grid, x))
33+
localdims = fill(2, 2 * R)
34+
35+
Npivots = 5
36+
Random.seed!(1)
37+
firstpivots = [[rand(1:dim) for dim in localdims] for _ in 1:Npivots]
38+
firstpivots = map(p -> TCI.optfirstpivot(qf, localdims, p), firstpivots)
39+
tci_func, _ = TCI.crossinterpolate2(Float64, qf, localdims, firstpivots; tolerance=1e-7)
40+
41+
sitesx = [Index(2, "Qubit,x=$n") for n in 1:R]
42+
sitesy = [Index(2, "Qubit,y=$n") for n in 1:R]
43+
44+
sites_ = collect.(zip(sitesx, sitesy))
45+
sites = collect(Iterators.flatten(sites_))
46+
47+
mps = MPS(TCI.TensorTrain(tci_func); sites)
48+
sdmps = SubDomainMPS(mps)
49+
pmps = PartitionedMPS([sdmps])
50+
51+
return sites, pmps
52+
end
753

854
@testset "patching.jl" begin
955
@testset "adaptive_patching" begin
@@ -26,4 +72,14 @@ using Random
2672

2773
@test MPS(partmps) MPS(subdmps) rtol = 1e-12
2874
end
75+
76+
@testset "adaptive_patching PartitionedMPS" begin
77+
R = 15
78+
sites, pmps = tci_examplefunc(R)
79+
pmps_partitioned = adaptive_patching(pmps, sites; maxdim=100, cutoff=1e-10)
80+
81+
@test maximum(maxbonddim, values(pmps_partitioned)) <= 100
82+
@test length(values(pmps_partitioned)) > 1
83+
@test MPS(pmps_partitioned) MPS(pmps) rtol = 6e-5
84+
end
2985
end

0 commit comments

Comments
 (0)