Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
Manifest.toml
Manifest.toml
samples/
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604"
[compat]
DataGraphs = "0.2.5"
Graphs = "1.12.0"
JuliaFormatter = "1"
NamedGraphs = "0.6.4"
Random = "1.10"
SimpleTensorNetworks = "0.1.0"
Expand Down
105 changes: 88 additions & 17 deletions src/pivotcandidateproposer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,101 @@ Default strategy that uses kronecker product and union with extra indices
"""
struct DefaultPivotCandidateProposer <: AbstractPivotCandidateProposer end

"""
Truncated default strategy that uses kronecker product and union with extra indices
"""
struct TruncatedDefaultPivotCandidateProposer <: AbstractPivotCandidateProposer end

"""
Simple strategy that uses kronecker product and union with extra indices
"""
struct SimplePivotCandidateProposer <: AbstractPivotCandidateProposer end

"""
Default strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors
"""
function generate_pivot_candidates(
::DefaultPivotCandidateProposer,
tci::SimpleTCI{ValueType},
edge::NamedEdge,
extraIJset::Dict{SubTreeVertex,Vector{MultiIndex}},
) where {ValueType}

vp, vq = separatevertices(tci.g, edge)
Ikey, subIkey = subtreevertices(tci.g, vq => vp), vp
Jkey, subJkey = subtreevertices(tci.g, vp => vq), vq
Ikey = subtreevertices(tci.g, vq => vp)
Jkey = subtreevertices(tci.g, vp => vq)

adjacent_edges_vp = adjacentedges(tci.g, vp; combinededges = edge)
InIkeys = edgeInIJkeys(tci.g, vp, adjacent_edges_vp)
Ipivots = pivotset(tci.IJset, InIkeys, Ikey, tci.localdims[vp])
Isite_index = findfirst(==(vp), Ikey)

adjacent_edges_vq = adjacentedges(tci.g, vq; combinededges = edge)
InJkeys = edgeInIJkeys(tci.g, vq, adjacent_edges_vq)
Jpivots = pivotset(tci.IJset, InJkeys, Jkey, tci.localdims[vq])
Jsite_index = findfirst(==(vq), Jkey)

Iset = kronecker(Ipivots, Isite_index, tci.localdims[vp])
Jset = kronecker(Jpivots, Jsite_index, tci.localdims[vq])

# Generate base index sets for both sides
Iset = kronecker(tci.IJset, Ikey, InIkeys, vp, tci.localdims[vp])
Jset = kronecker(tci.IJset, Jkey, InJkeys, vq, tci.localdims[vq])
extraIJset = if length(tci.IJset_history) > 0
extraIJset = tci.IJset_history[end]
else
Dict(key => MultiIndex[] for key in keys(tci.IJset))
end

# Combine with extra indices if available
Icombined = union(Iset, extraIJset[Ikey])
Jcombined = union(Jset, extraIJset[Jkey])
return (Ikey => Jkey), Dict(Ikey => Icombined, Jkey => Jcombined)
return Dict(Ikey => Icombined, Jkey => Jcombined)
end

function kronecker(
function generate_pivot_candidates(
::TruncatedDefaultPivotCandidateProposer,
tci::SimpleTCI{ValueType},
edge::NamedEdge,
) where {ValueType}
vp, vq = separatevertices(tci.g, edge)

Ikey = subtreevertices(tci.g, vq => vp)
Jkey = subtreevertices(tci.g, vp => vq)
chis = Dict(Ikey => tci.localdims[vp] * length(tci.IJset[Ikey]), Jkey => tci.localdims[vq] * length(tci.IJset[Jkey]))

IJcombined = generate_pivot_candidates(DefaultPivotCandidateProposer(), tci, edge)
IJcombined = Dict(
key => sample_ordered_pivots(IJcombined[key], chis[key]) for
key in keys(IJcombined)
)
return IJcombined
end

function generate_pivot_candidates(
::SimplePivotCandidateProposer,
tci::SimpleTCI{ValueType},
edge::NamedEdge,
) where {ValueType}
vp, vq = separatevertices(tci.g, edge)

Ikey = subtreevertices(tci.g, vq => vp)
Ichi = tci.localdims[vp] * length(tci.IJset[Ikey])
Iset = [[rand(1:tci.localdims[i]) for i in Ikey] for _ = 1:Ichi]

Jkey = subtreevertices(tci.g, vp => vq)
Jchi = tci.localdims[vq] * length(tci.IJset[Jkey])
Jset = [[rand(1:tci.localdims[j]) for j in Jkey] for _ = 1:Jchi]
extraIJset = if length(tci.IJset_history) > 0
extraIJset = tci.IJset_history[end]
else
Dict(key => MultiIndex[] for key in keys(tci.IJset))
end
Icombined = union(Iset, extraIJset[Ikey])
Jcombined = union(Jset, extraIJset[Jkey])
return Dict(Ikey => Icombined, Jkey => Jcombined)
end


function pivotset(
IJset::Dict{SubTreeVertex,Vector{MultiIndex}},
Inkeys::Vector{SubTreeVertex},
Outkey::SubTreeVertex, # original subregions order
Inkeys::Vector{SubTreeVertex}, # original subregions order
site::Int, # direct connected site
localdim::Int,
)
pivotset = MultiIndex[]
Expand All @@ -56,16 +116,27 @@ function kronecker(
end
push!(pivotset, indexset)
end
return pivotset
end

site_index = findfirst(==(site), Outkey)
filtered_subregions = filter(x -> x ≠ Set([site]), Outkey)

if site_index === nothing
return MultiIndex[]
function sample_ordered_pivots(pivots::Vector{MultiIndex}, maxsize::Int)
n = length(pivots)
@show n, maxsize
if n ≤ maxsize
return pivots
end
selected_indices = shuffle(1:n)[1:maxsize]
return pivots[sort(selected_indices)]
end

function kronecker(
pivotset::Vector{MultiIndex},
site_index::Union{Int,Nothing},
localdims::Int,
)
isnothing(site_index) && return MultiIndex[]
return MultiIndex[
[is[1:site_index-1]..., j, is[site_index+1:end]...] for is in pivotset,
j = 1:localdim
j = 1:localdims
][:]
end
67 changes: 50 additions & 17 deletions src/simpletci.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,58 @@
MultiIndex = Vector{Int}
SubTreeVertex = Vector{Int}

@doc """
SimpleTCI{ValueType}

Tree tensor cross interpolation (TCI) for tree tensor networks.

# Fields
- `IJset::Dict{SubTreeVertex,Vector{MultiIndex}}`: Pivots sets for each subtrees
- `localdims::Vector{Int}`: Local dimensions for each vertex tensor
- `g::NamedGraph`: Tree graph structure
- `bonderrors::Dict{NamedEdge,Float64}`: Error estimate per bond by 2-site sweep
- `pivoterrors::Vector{Float64}`: Error estimate for backtruncation of bonds
- `maxsamplevalue::Float64`: Maximum sample value for error normalization
- `IJset_history::Vector{Dict{SubTreeVertex,Vector{MultiIndex}}}`: History of pivots sets for each sweep

# Example
```julia
# Create a simple tree graph
g = NamedGraph([1, 2, 3])
add_edge!(g, 1 => 2)
add_edge!(g, 2 => 3)

# Define local dimensions
localdims = [2, 2, 2]

# Create a SimpleTCI instance
tci = SimpleTCI{Float64}(localdims, g)

# Add initial pivots
addglobalpivots!(tci, [[1,1,1], [2,1,1]])
```
"""
mutable struct SimpleTCI{ValueType}
IJset::Dict{SubTreeVertex,Vector{MultiIndex}}
localdims::Vector{Int}
g::NamedGraph
#"Error estimate per bond by 2site sweep."
bonderrors::Dict{NamedEdge,Float64} # key is the bond id
# "Error estimate for backtruncation of bonds."
pivoterrors::Vector{Float64} # key is the bond id
#"Maximum sample for error normalization."
bonderrors::Dict{NamedEdge,Float64}
pivoterrors::Vector{Float64}
maxsamplevalue::Float64
IJset_history::Vector{Dict{SubTreeVertex,Vector{MultiIndex}}}

function SimpleTCI{ValueType}(localdims::Vector{Int}, g::NamedGraph) where {ValueType}
length(localdims) > 1 || error("localdims should have at least 2 elements!")
n = length(localdims)
n > 1 || error("localdims should have at least 2 elements!")
n == length(vertices(g)) || error(
"The number of vertices in the graph must be equal to the length of localdims.",
)
!Graphs.is_cyclic(g) ||
error("SimpleTCI is not supported for loopy tensor network.")

# assign the key for each bond
bonderrors = Dict(e => 0.0 for e in edges(g))

!Graphs.is_cyclic(g) ||
error("TreeTensorNetwork is not supported for loopy tensor network.")

new{ValueType}(
Dict{SubTreeVertex,Vector{MultiIndex}}(), # IJset
localdims,
Expand All @@ -35,6 +65,10 @@ mutable struct SimpleTCI{ValueType}
end
end

"""
Initialize a SimpleTCI instance with a function, local dimensions, and graph.
The initial grobal pivots are set to ones(Int, length(localdims)).
"""
function SimpleTCI{ValueType}(
func::F,
localdims::Vector{Int},
Expand All @@ -49,14 +83,14 @@ function SimpleTCI{ValueType}(
return tci
end

@doc """
Add global pivots to index sets
"""
"""
Add global pivots to IJset.
"""
function addglobalpivots!(
tci::SimpleTCI{ValueType},
pivots::Vector{MultiIndex},
) where {ValueType}
if any(length(tci.localdims) .!= length.(pivots)) # AbstructTreeTensorNetworkをから引き継ぎlength(tci)ができると良い
if any(length(tci.localdims) .!= length.(pivots))
throw(DimensionMismatch("Please specify a pivot as one index per leg of the TTN."))
end
for pivot in pivots
Expand All @@ -71,17 +105,16 @@ function addglobalpivots!(
if !haskey(tci.IJset, Jset_key)
tci.IJset[Jset_key] = Vector{MultiIndex}()
end
pushunique!(tci.IJset[Iset_key], [pivot[i] for i in Iset_key])
pushunique!(tci.IJset[Jset_key], [pivot[j] for j in Jset_key])
pushunique!(tci.IJset[Iset_key], MultiIndex([pivot[i] for i in Iset_key]))
pushunique!(tci.IJset[Jset_key], MultiIndex([pivot[j] for j in Jset_key]))
end
end

tci.IJset[[i for i = 1:length(tci.localdims)]] = Int[]
tci.IJset[[i for i = 1:length(tci.localdims)]] = MultiIndex[]

nothing
end


function pushunique!(collection, item)
if !(item in collection)
push!(collection, item)
Expand Down
Loading
Loading