Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/TreeTCI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import SimpleTensorNetworks:
import Random: shuffle
include("treegraph_utils.jl")
include("simpletci.jl")
include("pivotcandidateproper.jl")
include("sweep2sitepathproper.jl")
include("pivotcandidateproposer.jl")
include("sweep2sitepathproposer.jl")
include("simpletci_optimize.jl")
include("simpletci_tensors.jl")
include("treetensornetwork.jl")
Expand Down
10 changes: 7 additions & 3 deletions src/pivotcandidateproper.jl → src/pivotcandidateproposer.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""
Abstract type for pivot candidate generation strategies
"""
abstract type AbstractPivotCandidateProper end
abstract type AbstractPivotCandidateProposer end

"""
Default strategy that uses kronecker product and union with extra indices
"""
struct DefaultPivotCandidateProper <: AbstractPivotCandidateProper end
struct DefaultPivotCandidateProposer <: AbstractPivotCandidateProposer end

"""
Default strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors
"""
function generate_pivot_candidates(
::DefaultPivotCandidateProper,
::DefaultPivotCandidateProposer,
tci::SimpleTCI{ValueType},
edge::NamedEdge,
extraIJset::Dict{SubTreeVertex,Vector{MultiIndex}},
Expand Down Expand Up @@ -60,6 +60,10 @@ function kronecker(
site_index = findfirst(==(site), Outkey)
filtered_subregions = filter(x -> x ≠ Set([site]), Outkey)

if site_index === nothing
return MultiIndex[]
end

return MultiIndex[
[is[1:site_index-1]..., j, is[site_index+1:end]...] for is in pivotset,
j = 1:localdim
Expand Down
1 change: 0 additions & 1 deletion src/simpletci.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,3 @@ function pushunique!(collection, items...)
pushunique!(collection, item)
end
end

62 changes: 8 additions & 54 deletions src/simpletci_optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,13 @@

# Keywords
- `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence
- `pivottolerance::Union{Float64,Nothing} = nothing`: Deprecated, use tolerance instead
- `maxbonddim::Int = typemax(Int)`: Maximum bond dimension
- `maxiter::Int = 20`: Maximum number of iterations
- `sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper()`: Strategy for sweeping
- `pivotsearch::Symbol = :full`: Strategy for pivot search
- `sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer()`: Strategy for sweeping
- `verbosity::Int = 0`: Verbosity level
- `loginterval::Int = 10`: Interval for logging
- `normalizeerror::Bool = true`: Whether to normalize errors
- `ncheckhistory::Int = 3`: Number of history steps to check
- `maxnglobalpivot::Int = 5`: Maximum number of global pivots
- `nsearchglobalpivot::Int = 5`: Number of global pivots to search
- `tolmarginglobalsearch::Float64 = 10.0`: Tolerance margin for global search
- `strictlynested::Bool = false`: Whether to enforce strict nesting
- `checkbatchevaluatable::Bool = false`: Whether to check if function is batch evaluatable

# Returns
- `ranks`: Vector of ranks at each iteration
Expand All @@ -31,56 +24,21 @@
function optimize!(
tci::SimpleTCI{ValueType},
f;
tolerance::Union{Float64,Nothing} = nothing,
pivottolerance::Union{Float64,Nothing} = nothing,
tolerance::Float64 = 1e-8,
maxbonddim::Int = typemax(Int),
maxiter::Int = 20,
sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper(),
pivotsearch::Symbol = :full,
sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer(),
verbosity::Int = 0,
loginterval::Int = 10,
normalizeerror::Bool = true,
ncheckhistory::Int = 3,
maxnglobalpivot::Int = 5,
nsearchglobalpivot::Int = 5,
tolmarginglobalsearch::Float64 = 10.0,
strictlynested::Bool = false,
checkbatchevaluatable::Bool = false,
) where {ValueType}
errors = Float64[]
ranks = Int[]
nglobalpivots = Int[]
local tol::Float64

if checkbatchevaluatable && !(f isa BatchEvaluator)
error("Function `f` is not batch evaluatable")
end

if nsearchglobalpivot > 0 && nsearchglobalpivot < maxnglobalpivot
error("nsearchglobalpivot < maxnglobalpivot!")
end

# Deprecate the pivottolerance option
if !isnothing(pivottolerance)
if !isnothing(tolerance) && (tolerance != pivottolerance)
throw(
ArgumentError(
"Got different values for pivottolerance and tolerance in optimize!(TCI2). For TCI2, both of these options have the same meaning. Please assign only `tolerance`.",
),
)
else
@warn "The option `pivottolerance` of `optimize!(tci::TensorCI2, f)` is deprecated. Please update your code to use `tolerance`, as `pivottolerance` will be removed in the future."
tol = pivottolerance
end
elseif !isnothing(tolerance)
tol = tolerance
else # pivottolerance == tolerance == nothing, therefore set tol to default value
tol = 1e-8
end

tstart = time_ns()

if maxbonddim >= typemax(Int) && tol <= 0
if maxbonddim >= typemax(Int) && tolerance <= 0
throw(
ArgumentError(
"Specify either tolerance > 0 or some maxbonddim; otherwise, the convergence criterion is not reachable!",
Expand All @@ -91,7 +49,7 @@ function optimize!(
globalpivots = MultiIndex[]
for iter = 1:maxiter
errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0
abstol = tol * errornormalization
abstol = tolerance * errornormalization

if verbosity > 1
println(" Walltime $(1e-9*(time_ns() - tstart)) sec: starting 2site sweep")
Expand All @@ -102,10 +60,8 @@ function optimize!(
tci,
f,
2;
iter1 = 1,
abstol = abstol,
maxbonddim = maxbonddim,
pivotsearch = pivotsearch,
verbosity = verbosity,
sweepstrategy = sweepstrategy,
)
Expand Down Expand Up @@ -142,17 +98,15 @@ function sweep2site!(
tci::SimpleTCI{ValueType},
f,
niter::Int;
iter1::Int = 1,
abstol::Float64 = 1e-8,
maxbonddim::Int = typemax(Int),
sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper(),
pivotsearch::Symbol = :full,
sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer(),
verbosity::Int = 0,
) where {ValueType}

edge_path = generate_sweep2site_path(sweepstrategy, tci)

for iter = iter1:iter1+niter-1
for _ = 1:niter
extraIJset = Dict(key => MultiIndex[] for key in keys(tci.IJset))
if length(tci.IJset_history) > 0
extraIJset = tci.IJset_history[end]
Expand Down Expand Up @@ -205,7 +159,7 @@ function updatepivots!(
N = length(tci.localdims)

(IJkey, combinedIJset) =
generate_pivot_candidates(DefaultPivotCandidateProper(), tci, edge, extraIJset)
generate_pivot_candidates(DefaultPivotCandidateProposer(), tci, edge, extraIJset)
Ikey, Jkey = first(IJkey), last(IJkey)

t1 = time_ns()
Expand Down
14 changes: 7 additions & 7 deletions src/sweep2sitepathproper.jl → src/sweep2sitepathproposer.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
"""
Abstract type for pivot candidate generation strategies
"""
abstract type AbstractSweep2sitePathProper end
abstract type AbstractSweep2sitePathProposer end

"""
Default strategy
"""
struct DefaultSweep2sitePathProper <: AbstractSweep2sitePathProper end
struct DefaultSweep2sitePathProposer <: AbstractSweep2sitePathProposer end

"""
Random strategy
"""
struct RandomSweep2sitePathProper <: AbstractSweep2sitePathProper end
struct RandomSweep2sitePathProposer <: AbstractSweep2sitePathProposer end

"""
LocalAdjacent strategy
"""
struct LocalAdjacentSweep2sitePathProper <: AbstractSweep2sitePathProper end
struct LocalAdjacentSweep2sitePathProposer <: AbstractSweep2sitePathProposer end

"""
Default strategy that return the sequence path defined by the edges(g)
"""
function generate_sweep2site_path(
::DefaultSweep2sitePathProper,
::DefaultSweep2sitePathProposer,
tci::SimpleTCI{ValueType},
) where {ValueType}
return collect(edges(tci.g))
Expand All @@ -32,7 +32,7 @@ end
Random strategy that returns a random sequence of edges
"""
function generate_sweep2site_path(
::RandomSweep2sitePathProper,
::RandomSweep2sitePathProposer,
tci::SimpleTCI{ValueType},
) where {ValueType}
return shuffle(collect(edges(tci.g)))
Expand All @@ -42,7 +42,7 @@ end
LocalAdjacent strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors
"""
function generate_sweep2site_path(
::LocalAdjacentSweep2sitePathProper,
::LocalAdjacentSweep2sitePathProposer,
tci::SimpleTCI{ValueType};
origin_edge = undef,
) where {ValueType}
Expand Down
5 changes: 5 additions & 0 deletions src/treetensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,8 @@ end
function (ttn::TreeTensorNetwork{V})(indexset) where {V}
return evaluate(ttn, indexset)
end

# Add length method for TreeTensorNetwork
function Base.length(ttn::TreeTensorNetwork)
return length(vertices(ttn.tensornetwork.data_graph))
end
17 changes: 2 additions & 15 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,4 @@
# using ReTestItems: runtests, @testitem
# using TreeTCI: TreeTCI

# runtests(TreeTCI)

using ReTestItems: runtests, @testitem
using TreeTCI: TreeTCI
using Test

@testset verbose = true "TreeTCI tests" begin
# @testset "Code quality (Aqua.jl)" begin
# Aqua.test_all(TreeTCI; unbound_args = false, deps_compat = false)
# end

@testset verbose = true "Actual tests" begin
include("simpletci_test.jl")
end
end
runtests(TreeTCI)
61 changes: 0 additions & 61 deletions test/simpletci_test.jl

This file was deleted.

25 changes: 25 additions & 0 deletions test/simpletci_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@testitem "SimpleTCI" begin
using Test
using TreeTCI
import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge

# make graph
g = NamedGraph(7)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
add_edge!(g, 2, 4)
add_edge!(g, 4, 5)
add_edge!(g, 5, 6)
add_edge!(g, 5, 7)

localdims = fill(2, length(vertices(g)))
f(v) = 1 / (1 + v' * v)

ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g)
@test ttn([1, 1, 1, 1, 1, 1, 1]) ≈ f([1, 1, 1, 1, 1, 1, 1])
@test ttn([1, 1, 1, 1, 2, 2, 2]) ≈ f([1, 1, 1, 1, 2, 2, 2])
@test ttn([2, 2, 2, 2, 1, 1, 1]) ≈ f([2, 2, 2, 2, 1, 1, 1])
@test ttn([1, 2, 1, 2, 1, 2, 1]) ≈ f([1, 2, 1, 2, 1, 2, 1])
@test ttn([2, 1, 2, 1, 2, 1, 2]) ≈ f([2, 1, 2, 1, 2, 1, 2])
@test ttn([2, 2, 2, 2, 2, 2, 2]) ≈ f([2, 2, 2, 2, 2, 2, 2])
end
43 changes: 43 additions & 0 deletions test/treegraph_utils_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
@testitem "TreeGraphUtils" begin
using Test
using TreeTCI
import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge

# make graph
g = NamedGraph(7)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
add_edge!(g, 2, 4)
add_edge!(g, 4, 5)
add_edge!(g, 5, 6)
add_edge!(g, 5, 7)

e = NamedEdge(2 => 4)
v1, v2 = TreeTCI.separatevertices(g, e)
@test v1 == 2
@test v2 == 4

Ivertices = TreeTCI.subtreevertices(g, v2 => v1) # 4 -> 2
Jvertices = TreeTCI.subtreevertices(g, v1 => v2) # 2 -> 4

@test Ivertices == [1, 2, 3]
@test Jvertices == [4, 5, 6, 7]

subregions = TreeTCI.subregionvertices(g, e)
@test first(subregions) == [1, 2, 3]
@test last(subregions) == [4, 5, 6, 7]

@test Set(TreeTCI.adjacentedges(g, 4)) == Set([NamedEdge(2 => 4), NamedEdge(4 => 5)])

@test Set(TreeTCI.candidateedges(g, NamedEdge(2 => 4))) ==
Set([NamedEdge(1 => 2), NamedEdge(2 => 3), NamedEdge(4 => 5)])

@test TreeTCI.distanceedges(g, NamedEdge(2 => 4)) == Dict(
NamedEdge(2 => 4) => 0,
NamedEdge(1 => 2) => 1,
NamedEdge(2 => 3) => 1,
NamedEdge(4 => 5) => 1,
NamedEdge(5 => 6) => 2,
NamedEdge(5 => 7) => 2,
)
end
Loading