Skip to content

Commit 2703cfb

Browse files
committed
Rename SimpleTensorNetwork => TensorNetwork
1 parent bf7a420 commit 2703cfb

File tree

2 files changed

+20
-23
lines changed

2 files changed

+20
-23
lines changed

src/simpletensornetwork.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
"""
22
Generic tensor network data structure
33
"""
4-
mutable struct SimpleTensorNetwork <: AbstractDataGraph{Int,IndexedArray,IndexedArray}
4+
mutable struct TensorNetwork <: AbstractDataGraph{Int,IndexedArray,IndexedArray}
55
# data_graph: (undirected) graph of the tensor network
66
# An integer is assigned to each vertex (starting from 1 and increasing one by one).
77
# We can place an IndexedArray at each vertex of the graph, and an edge between two vertices.
88
# But, the latter is not supported by the current implementation of SimpleTensorNetworks.jl.
99
# This may be useful for supporting the Vidal notation.
1010
data_graph::DataGraph{Int,IndexedArray,IndexedArray,NamedGraph{Int},NamedEdge{Int}}
1111

12-
function SimpleTensorNetwork(
12+
function TensorNetwork(
1313
dg::DataGraph{Int,IndexedArray,IndexedArray,NamedGraph{Int},NamedEdge{Int}},
1414
)
15-
is_connected(dg) ||
16-
error("SimpleTensorNetwork is only supported for a connected graph.")
15+
is_connected(dg) || error("TensorNetwork is only supported for a connected graph.")
1716
new(dg)
1817
end
1918
end
2019

21-
function SimpleTensorNetwork(ts::AbstractVector{<:AbstractIndexedArray})
20+
function TensorNetwork(ts::AbstractVector{<:AbstractIndexedArray})
2221
g = NamedGraph(collect(eachindex(ts)))
2322
dg = DataGraph{Int,IndexedArray,IndexedArray,NamedGraph{Int},NamedEdge{Int}}(g)
2423

@@ -32,34 +31,33 @@ function SimpleTensorNetwork(ts::AbstractVector{<:AbstractIndexedArray})
3231
end
3332
end
3433
end
35-
tn = SimpleTensorNetwork(dg)
34+
tn = TensorNetwork(dg)
3635
return tn
3736
end
3837

39-
data_graph(tn::SimpleTensorNetwork) = getfield(tn, :data_graph)
40-
data_graph_type(TN::Type{<:SimpleTensorNetwork}) = fieldtype(TN, :data_graph)
41-
DataGraphs.underlying_graph(tn::SimpleTensorNetwork) = underlying_graph(data_graph(tn))
42-
DataGraphs.underlying_graph_type(TN::Type{<:SimpleTensorNetwork}) =
38+
data_graph(tn::TensorNetwork) = getfield(tn, :data_graph)
39+
data_graph_type(TN::Type{<:TensorNetwork}) = fieldtype(TN, :data_graph)
40+
DataGraphs.underlying_graph(tn::TensorNetwork) = underlying_graph(data_graph(tn))
41+
DataGraphs.underlying_graph_type(TN::Type{<:TensorNetwork}) =
4342
fieldtype(data_graph_type(TN), :underlying_graph)
44-
DataGraphs.vertex_data(graph::SimpleTensorNetwork, args...) =
43+
DataGraphs.vertex_data(graph::TensorNetwork, args...) =
4544
vertex_data(data_graph(graph), args...)
46-
DataGraphs.edge_data(graph::SimpleTensorNetwork, args...) =
47-
edge_data(data_graph(graph), args...)
45+
DataGraphs.edge_data(graph::TensorNetwork, args...) = edge_data(data_graph(graph), args...)
4846

49-
function Base.setindex!(tn::SimpleTensorNetwork, t::AbstractIndexedArray, v::Int)
47+
function Base.setindex!(tn::TensorNetwork, t::AbstractIndexedArray, v::Int)
5048
tn.data_graph[v] = t
5149
end
5250

53-
Base.getindex(tn::SimpleTensorNetwork, v::Int) = tn.data_graph[v]
51+
Base.getindex(tn::TensorNetwork, v::Int) = tn.data_graph[v]
5452

5553
"""
5654
Return if a tensor network `tn` has a cycle. If it has not a cycle, `tn` is a tree tensor network.
5755
"""
58-
Graphs.is_cyclic(tn::SimpleTensorNetwork) =
56+
Graphs.is_cyclic(tn::TensorNetwork) =
5957
Graphs.is_cyclic(tn.data_graph.underlying_graph.position_graph)
6058

6159

62-
Graphs.has_edge(tn::SimpleTensorNetwork, e::NamedEdge) = Graphs.has_edge(tn.data_graph, e)
60+
Graphs.has_edge(tn::TensorNetwork, e::NamedEdge) = Graphs.has_edge(tn.data_graph, e)
6361

6462
"""
6563
Contract all the tensors in a tensor network `tn` and return the result.
@@ -68,7 +66,7 @@ This function works only for tree tensor networks, i.e., `is_cyclic(tn) == false
6866
6967
root_vertex: The vertex to start the contraction. The default is 1.
7068
"""
71-
function complete_contraction(tn::SimpleTensorNetwork; root_vertex::Int = 1)
69+
function complete_contraction(tn::TensorNetwork; root_vertex::Int = 1)
7270
!Graphs.is_cyclic(tn) ||
7371
error("complete_contraction is not supported only for a tree tensor network.")
7472
res = tn[root_vertex]
@@ -84,7 +82,7 @@ Contract all the tensors in a subtree of a tensor network `tn` and return the re
8482
The subtree is defined by a vertex `v` and its parent vertex `parent_v`.
8583
Note that `parent_v` is not included in the subtree.
8684
"""
87-
function _contract_subtree(tn::SimpleTensorNetwork, v::Int, parent_v::Union{Int,Nothing})
85+
function _contract_subtree(tn::TensorNetwork, v::Int, parent_v::Union{Int,Nothing})
8886
res = tn[v]
8987
for nv in neighbors(tn.data_graph, v)
9088
if nv != parent_v

test/simpletensornetwork_test.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
@testitem "simpletensornetwork.jl" begin
2-
import SimpleTensorNetworks:
3-
Index, dim, IndexedArray, indices, permute, SimpleTensorNetwork
2+
import SimpleTensorNetworks: Index, dim, IndexedArray, indices, permute, TensorNetwork
43
import Graphs: is_connected, has_edge
54

65
@testset "Construction from IndexedArray objects" begin
@@ -14,7 +13,7 @@
1413
t3 = IndexedArray(rand(2, 2), [c, d])
1514

1615

17-
tn = SimpleTensorNetwork([t1, t2, t3])
16+
tn = TensorNetwork([t1, t2, t3])
1817

1918
@test has_edge(tn, 1 => 2)
2019
@test has_edge(tn, 2 => 1)
@@ -34,7 +33,7 @@
3433
t3 = IndexedArray(rand(2), [b])
3534
t4 = IndexedArray(rand(2), [c])
3635

37-
tn = SimpleTensorNetwork([t1, t2, t3, t4])
36+
tn = TensorNetwork([t1, t2, t3, t4])
3837
@test only(SimpleTensorNetworks.complete_contraction(tn; root_vertex = 1))
3938
only(t1 * t2 * t3 * t4)
4039
end

0 commit comments

Comments
 (0)