From 7e547d70b18a50836c88abee7e4909bfffef4e42 Mon Sep 17 00:00:00 2001 From: watayo Date: Thu, 13 Mar 2025 15:00:57 +0900 Subject: [PATCH 1/8] format and simple path --- Manifest.toml | 379 +++++++++++++++++++++++++++++++ Project.toml | 2 - docs/make.jl | 5 +- samples/devtci.jl | 3 +- src/TreeTCI.jl | 14 +- src/abstracttreetensornetwork.jl | 40 ---- src/pivotcandidateproper.jl | 4 +- src/simpletci.jl | 141 ++++++------ src/simpletci_utils.jl | 37 +-- src/sweep2sitepathproper.jl | 19 +- src/tree_utils.jl | 21 +- src/treetensornetwork.jl | 17 +- test/simpletci_utils_test.jl | 10 +- 13 files changed, 515 insertions(+), 177 deletions(-) create mode 100644 Manifest.toml delete mode 100644 src/abstracttreetensornetwork.jl diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..6c454af --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,379 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.11.1" +manifest_format = "2.0" +project_hash = "aa27f51668b9cb208b591bf6f6b28585e35c8ef7" + +[[deps.AbstractTrees]] +git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.5" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "cd8b948862abee8f3d3e9b73a102a9ca924debb0" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.2.0" +weakdeps = ["SparseArrays", "StaticArrays"] + + [deps.Adapt.extensions] + AdaptSparseArraysExt = "SparseArrays" + AdaptStaticArraysExt = "StaticArrays" + +[[deps.ArnoldiMethod]] +deps = ["LinearAlgebra", "Random", "StaticArrays"] +git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" +uuid = "ec485272-7323-5ecc-a04f-4719b315124d" +version = "0.4.0" + +[[deps.ArrayInterface]] +deps = ["Adapt", "LinearAlgebra"] +git-tree-sha1 = "017fcb757f8e921fb44ee063a7aafe5f89b86dd1" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "7.18.0" + + [deps.ArrayInterface.extensions] + ArrayInterfaceBandedMatricesExt = "BandedMatrices" + ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" + ArrayInterfaceCUDAExt = "CUDA" + ArrayInterfaceCUDSSExt = "CUDSS" + ArrayInterfaceChainRulesCoreExt = "ChainRulesCore" + ArrayInterfaceChainRulesExt = "ChainRules" + ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" + ArrayInterfaceReverseDiffExt = "ReverseDiff" + ArrayInterfaceSparseArraysExt = "SparseArrays" + ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" + ArrayInterfaceTrackerExt = "Tracker" + + [deps.ArrayInterface.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" + ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" + +[[deps.CommonWorldInvalidations]] +git-tree-sha1 = "ae52d1c52048455e85a387fbee9be553ec2b68d0" +uuid = "f70d9fcc-98c5-4d4a-abd7-e4cdeebd8ca8" +version = "1.0.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.16.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.DataGraphs]] +deps = ["Dictionaries", "Graphs", "NamedGraphs", "PackageExtensionCompat", "SimpleTraits"] +git-tree-sha1 = "ba80c479f54904ea2d574eba7e03e434c83d71c8" +uuid = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" +version = "0.2.5" + + [deps.DataGraphs.extensions] + DataGraphsGraphsFlowsExt = "GraphsFlows" + + [deps.DataGraphs.weakdeps] + GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" + +[[deps.Dictionaries]] +deps = ["Indexing", "Random", "Serialization"] +git-tree-sha1 = "1cdab237b6e0d0960d5dcbd2c0ebfa15fa6573d9" +uuid = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" +version = "0.4.4" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +version = "1.11.0" + +[[deps.EllipsisNotation]] +deps = ["StaticArrayInterface"] +git-tree-sha1 = "3507300d4343e8e4ad080ad24e335274c2e297a9" +uuid = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" +version = "1.8.0" + +[[deps.Graphs]] +deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] +git-tree-sha1 = "1dc470db8b1131cfc7fb4c115de89fe391b9e780" +uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" +version = "1.12.0" + +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[deps.Indexing]] +git-tree-sha1 = "ce1566720fd6b19ff3411404d4b977acd4814f9f" +uuid = "313cdc1a-70c2-5d6a-ae34-0150d3930a38" +version = "1.1.1" + +[[deps.Inflate]] +git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" +uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +version = "0.1.5" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.11.0" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" + +[[deps.MacroTools]] +git-tree-sha1 = "72aebe0b5051e5143a079a4685a46da330a40472" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.15" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "1.11.0" + +[[deps.NamedGraphs]] +deps = ["AbstractTrees", "Dictionaries", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "SimpleTraits", "SparseArrays", "SplitApplyCombine", "Suppressor"] +git-tree-sha1 = "c520ef2017b7c4f1bf7f60025b831babb8d3eaed" +uuid = "678767b0-92e7-4007-89e4-4527a8725b19" +version = "0.6.4" + + [deps.NamedGraphs.extensions] + NamedGraphsGraphsFlowsExt = "GraphsFlows" + NamedGraphsKaHyParExt = "KaHyPar" + NamedGraphsMetisExt = "Metis" + NamedGraphsSymRCMExt = "SymRCM" + + [deps.NamedGraphs.weakdeps] + GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889" + KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" + Metis = "2679e427-3c69-5b7f-982b-ece356f1e94b" + SymRCM = "286e6d88-80af-4590-acc9-0001b223b9bd" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.27+1" + +[[deps.OrderedCollections]] +git-tree-sha1 = "cc4054e898b852042d7b503313f7ad03de99c3dd" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.8.0" + +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" + +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "9da16da70037ba9d701192e27befedefb91ec284" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.11.2" + + [deps.QuadGK.extensions] + QuadGKEnzymeExt = "Enzyme" + + [deps.QuadGK.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "62389eeff14780bfe55195b7204c0d8738436d64" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.1" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" + +[[deps.SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" +version = "1.11.0" + +[[deps.SimpleTensorNetworks]] +deps = ["DataGraphs", "Graphs", "NamedGraphs"] +git-tree-sha1 = "c425260a467bb77990a523e3b44d29d687bed4e5" +uuid = "3075f829-f72e-4896-a859-7fe0a9cabb9b" +version = "0.1.0" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +version = "1.11.0" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.11.0" + +[[deps.SplitApplyCombine]] +deps = ["Dictionaries", "Indexing"] +git-tree-sha1 = "c06d695d51cfb2187e6848e98d6252df9101c588" +uuid = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" +version = "1.2.3" + +[[deps.Static]] +deps = ["CommonWorldInvalidations", "IfElse", "PrecompileTools"] +git-tree-sha1 = "f737d444cb0ad07e61b3c1bef8eb91203c321eff" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "1.2.0" + +[[deps.StaticArrayInterface]] +deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Static"] +git-tree-sha1 = "96381d50f1ce85f2663584c8e886a6ca97e60554" +uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" +version = "1.8.0" + + [deps.StaticArrayInterface.extensions] + StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" + StaticArrayInterfaceStaticArraysExt = "StaticArrays" + + [deps.StaticArrayInterface.weakdeps] + OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "0feb6b9031bd5c51f9072393eb5ab3efd31bf9e4" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.13" + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + + [deps.StaticArrays.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.3" + +[[deps.Statistics]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.11.1" +weakdeps = ["SparseArrays"] + + [deps.Statistics.extensions] + SparseArraysExt = ["SparseArrays"] + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.7.0+0" + +[[deps.Suppressor]] +deps = ["Logging"] +git-tree-sha1 = "6dbb5b635c5437c68c28c2ac9e39b87138f37c0a" +uuid = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +version = "0.2.8" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TensorCrossInterpolation]] +deps = ["EllipsisNotation", "LinearAlgebra", "QuadGK"] +git-tree-sha1 = "378bca655cc4596c9db7adba9ac3e089bac9e3c5" +uuid = "b261b2ec-6378-4871-b32e-9173bb050604" +version = "0.9.14" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.11.0+0" diff --git a/Project.toml b/Project.toml index 716a020..ae6cfe6 100644 --- a/Project.toml +++ b/Project.toml @@ -6,9 +6,7 @@ version = "0.1.0" [deps] DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" -JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" SimpleTensorNetworks = "3075f829-f72e-4896-a859-7fe0a9cabb9b" TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" diff --git a/docs/make.jl b/docs/make.jl index 6f10cfa..d1a07aa 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -12,10 +12,7 @@ makedocs(; edit_link = "main", assets = String[], ), - pages = [ - "Home" => "index.md", - "API Reference" => "api.md" - ], + pages = ["Home" => "index.md", "API Reference" => "api.md"], ) deploydocs(; repo = "github.com/tensor4all/TreeTCI.jl.git", devbranch = "main") diff --git a/samples/devtci.jl b/samples/devtci.jl index cfb77d6..f5cc157 100644 --- a/samples/devtci.jl +++ b/samples/devtci.jl @@ -15,7 +15,8 @@ function main() f(v) = 1 / (1 + v' * v) tolerance = 1e-8 - mpn, ranks, errors = TreeTCI.TCI.crossinterpolate2(Float64, f, localdims; tolerance = tolerance) + mpn, ranks, errors = + TreeTCI.TCI.crossinterpolate2(Float64, f, localdims; tolerance = tolerance) ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g) @show f([1, 1, 1, 1, 2, 1, 1]), f([1, 2, 1, 2, 2, 1, 1]), f([2, 2, 2, 2, 2, 2, 2]) @show mpn([1, 1, 1, 1, 2, 1, 1]), mpn([1, 2, 1, 2, 2, 1, 1]), mpn([2, 2, 2, 2, 2, 2, 2]) diff --git a/src/TreeTCI.jl b/src/TreeTCI.jl index aab7804..d251a0c 100644 --- a/src/TreeTCI.jl +++ b/src/TreeTCI.jl @@ -2,9 +2,19 @@ module TreeTCI import Graphs import NamedGraphs: - NamedGraph, NamedEdge, is_directed, outneighbors, has_edge, edges, vertices, src, dst, namedgraph_dijkstra_shortest_paths + NamedGraph, + NamedEdge, + is_directed, + outneighbors, + has_edge, + edges, + vertices, + src, + dst, + namedgraph_dijkstra_shortest_paths import TensorCrossInterpolation as TCI -import SimpleTensorNetworks: TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract +import SimpleTensorNetworks: + TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract include("tree_utils.jl") include("simpletci.jl") include("simpletci_utils.jl") diff --git a/src/abstracttreetensornetwork.jl b/src/abstracttreetensornetwork.jl deleted file mode 100644 index 0a83d47..0000000 --- a/src/abstracttreetensornetwork.jl +++ /dev/null @@ -1,40 +0,0 @@ -abstract type AbstractTreeTensorNetwork{V} <: Function end - -""" - function evaluate( - ttn::TreeTensorNetwork{V}, - indexset::Union{AbstractVector{Int}, NTuple{N, Int}} - )::V where {V} - -Evaluates the tensor train `tt` at indices given by `indexset`. -""" -function evaluate( - ttn::AbstractTreeTensorNetwork{V}, - indexset::Union{AbstractVector{Int},NTuple{N,Int}}, -)::V where {N,V} - if length(indexset) != length(ttn.sitetensors) - throw( - ArgumentError( - "To evaluate a tt of length $(length(ttn)), you have to provide $(length(ttn)) indices, but there were $(length(indexset)).", - ), - ) - end - sitetensors = IndexedArray[] - # TODO: site tensorを作る関数を作成 - for (Tinfo, i) in zip(ttn.sitetensors, indexset) - T, edges = Tinfo - inds = (i, ntuple(_ -> :, ndims(T) - 1)...) - T = T[inds...] - indexs = [ - Index(size(T)[j], "$(src(edges[j]))=>$(dst(edges[j]))") for j = 1:length(edges) - ] - t = IndexedArray(T, indexs) - push!(sitetensors, t) - end - tn = TensorNetwork(sitetensors) - return only(complete_contraction(tn)) -end - -function (ttn::AbstractTreeTensorNetwork{V})(indexset) where {V} - return evaluate(ttn, indexset) -end diff --git a/src/pivotcandidateproper.jl b/src/pivotcandidateproper.jl index 9c4db9c..826fe8f 100644 --- a/src/pivotcandidateproper.jl +++ b/src/pivotcandidateproper.jl @@ -22,10 +22,10 @@ function generate_pivot_candidates( Ikey, subIkey = subtreevertices(tci.g, vq => vp), vp Jkey, subJkey = subtreevertices(tci.g, vp => vq), vq - adjacent_edges_vp = adjacentedges(tci.g, vp; combinededges=edge) + adjacent_edges_vp = adjacentedges(tci.g, vp; combinededges = edge) InIkeys = edgeInIJkeys(tci.g, vp, adjacent_edges_vp) - adjacent_edges_vq = adjacentedges(tci.g, vq; combinededges=edge) + adjacent_edges_vq = adjacentedges(tci.g, vq; combinededges = edge) InJkeys = edgeInIJkeys(tci.g, vq, adjacent_edges_vq) # Generate base index sets for both sides diff --git a/src/simpletci.jl b/src/simpletci.jl index 0299bad..08f3281 100644 --- a/src/simpletci.jl +++ b/src/simpletci.jl @@ -51,9 +51,9 @@ function SimpleTCI{ValueType}( return tci end -@doc""" -Add global pivots to index sets -""" +@doc """ + Add global pivots to index sets + """ function addglobalpivots!( tci::SimpleTCI{ValueType}, pivots::Vector{MultiIndex}, @@ -83,36 +83,36 @@ function addglobalpivots!( nothing end -@doc""" - optimize!(tci::SimpleTCI{ValueType}, f; kwargs...) - -Optimize the tensor cross interpolation (TCI) by iteratively updating pivots. - -# Arguments -- `tci`: The SimpleTCI object to optimize -- `f`: The function to interpolate - -# Keywords -- `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence -- `pivottolerance::Union{Float64,Nothing} = nothing`: Deprecated, use tolerance instead -- `maxbonddim::Int = typemax(Int)`: Maximum bond dimension -- `maxiter::Int = 20`: Maximum number of iterations -- `sweepstrategy::Symbol = :backandforth`: Strategy for sweeping -- `pivotsearch::Symbol = :full`: Strategy for pivot search -- `verbosity::Int = 0`: Verbosity level -- `loginterval::Int = 10`: Interval for logging -- `normalizeerror::Bool = true`: Whether to normalize errors -- `ncheckhistory::Int = 3`: Number of history steps to check -- `maxnglobalpivot::Int = 5`: Maximum number of global pivots -- `nsearchglobalpivot::Int = 5`: Number of global pivots to search -- `tolmarginglobalsearch::Float64 = 10.0`: Tolerance margin for global search -- `strictlynested::Bool = false`: Whether to enforce strict nesting -- `checkbatchevaluatable::Bool = false`: Whether to check if function is batch evaluatable - -# Returns -- `ranks`: Vector of ranks at each iteration -- `errors`: Vector of normalized errors at each iteration -""" +@doc """ + optimize!(tci::SimpleTCI{ValueType}, f; kwargs...) + + Optimize the tensor cross interpolation (TCI) by iteratively updating pivots. + + # Arguments + - `tci`: The SimpleTCI object to optimize + - `f`: The function to interpolate + + # Keywords + - `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence + - `pivottolerance::Union{Float64,Nothing} = nothing`: Deprecated, use tolerance instead + - `maxbonddim::Int = typemax(Int)`: Maximum bond dimension + - `maxiter::Int = 20`: Maximum number of iterations + - `sweepstrategy::Symbol = :backandforth`: Strategy for sweeping + - `pivotsearch::Symbol = :full`: Strategy for pivot search + - `verbosity::Int = 0`: Verbosity level + - `loginterval::Int = 10`: Interval for logging + - `normalizeerror::Bool = true`: Whether to normalize errors + - `ncheckhistory::Int = 3`: Number of history steps to check + - `maxnglobalpivot::Int = 5`: Maximum number of global pivots + - `nsearchglobalpivot::Int = 5`: Number of global pivots to search + - `tolmarginglobalsearch::Float64 = 10.0`: Tolerance margin for global search + - `strictlynested::Bool = false`: Whether to enforce strict nesting + - `checkbatchevaluatable::Bool = false`: Whether to check if function is batch evaluatable + + # Returns + - `ranks`: Vector of ranks at each iteration + - `errors`: Vector of normalized errors at each iteration + """ function optimize!( tci::SimpleTCI{ValueType}, f; @@ -120,7 +120,7 @@ function optimize!( pivottolerance::Union{Float64,Nothing} = nothing, maxbonddim::Int = typemax(Int), maxiter::Int = 20, - sweepstrategy::Symbol = :backandforth, # TODO: Implement for Tree structure + sweepstrategy::Symbol = :default, pivotsearch::Symbol = :full, verbosity::Int = 0, loginterval::Int = 10, @@ -218,11 +218,11 @@ function optimize!( return ranks, errors ./ errornormalization end -@doc""" -Perform 2site sweeps on a SimpleTCI. -!TODO: Implement for Tree structure +@doc """ + Perform 2site sweeps on a SimpleTCI. + !TODO: Implement for Tree structure -""" + """ function sweep2site!( tci::SimpleTCI{ValueType}, f, @@ -230,12 +230,18 @@ function sweep2site!( iter1::Int = 1, abstol::Float64 = 1e-8, maxbonddim::Int = typemax(Int), - sweepstrategy::Symbol = :backandforth, + sweepstrategy::Symbol = :default, pivotsearch::Symbol = :full, verbosity::Int = 0, ) where {ValueType} - edge_path = generate_sweep2site_path(DefaultSweep2sitePathProper(), tci) + if sweepstrategy == :default + edge_path = generate_sweep2site_path(DefaultSweep2sitePathProper(), tci) + elseif sweepstrategy == :localadjacent + edge_path = generate_sweep2site_path(LocalAdjacentSweep2sitePathProper(), tci) + else + error("Invalid sweep strategy: $sweepstrategy") + end for iter = iter1:iter1+niter-1 @@ -285,30 +291,18 @@ function updatepivots!( extraIJset::Dict{SubTreeVertex,Vector{MultiIndex}} = Dict{ SubTreeVertex, Vector{MultiIndex}, - }(), - ) where {F,ValueType} + }(), +) where {F,ValueType} - N = length(tci.localdims) + N = length(tci.localdims) - (IJkey, combinedIJset) = generate_pivot_candidates( - DefaultPivotCandidateProper(), - tci, - edge, - extraIJset, - ) - Ikey, Jkey = first(IJkey), last(IJkey) - - t1 = time_ns() - Pi = reshape( - filltensor( - ValueType, - f, - tci.localdims, - combinedIJset, - [Ikey], - [Jkey], - Val(0), - ), + (IJkey, combinedIJset) = + generate_pivot_candidates(DefaultPivotCandidateProper(), tci, edge, extraIJset) + Ikey, Jkey = first(IJkey), last(IJkey) + + t1 = time_ns() + Pi = reshape( + filltensor(ValueType, f, tci.localdims, combinedIJset, [Ikey], [Jkey], Val(0)), length(combinedIJset[Ikey]), length(combinedIJset[Jkey]), ) @@ -324,18 +318,19 @@ function updatepivots!( t3 = time_ns() if verbosity > 2 - x, y = length(combinedIJset[Ikey]), length(combinedIJset[Jkey]), + x, y = length(combinedIJset[Ikey]), + length(combinedIJset[Jkey]), println( " Computing Pi ($x x $y) at bond $b: $(1e-9*(t2-t1)) sec, LU: $(1e-9*(t3-t2)) sec", - ) - end + ) + end - tci.IJset[Ikey] = combinedIJset[Ikey][TCI.rowindices(luci)] - tci.IJset[Jkey] = combinedIJset[Jkey][TCI.colindices(luci)] + tci.IJset[Ikey] = combinedIJset[Ikey][TCI.rowindices(luci)] + tci.IJset[Jkey] = combinedIJset[Jkey][TCI.colindices(luci)] - updateerrors!(tci, edge, TCI.pivoterrors(luci)) - nothing - end + updateerrors!(tci, edge, TCI.pivoterrors(luci)) + nothing +end function updatemaxsample!(tci::SimpleTCI{V}, samples::Array{V}) where {V} @@ -352,11 +347,7 @@ function updateerrors!( nothing end -function updateedgeerror!( - tci::SimpleTCI{T}, - edge::NamedEdge, - error::Float64, -) where {T} +function updateedgeerror!(tci::SimpleTCI{T}, edge::NamedEdge, error::Float64) where {T} tci.bonderrors[edge] = error nothing end diff --git a/src/simpletci_utils.jl b/src/simpletci_utils.jl index af26b85..313cbdf 100644 --- a/src/simpletci_utils.jl +++ b/src/simpletci_utils.jl @@ -4,7 +4,8 @@ function fillsitetensors( center_vertex::Int = 0, ) where {ValueType} - sitetensors = Vector{Pair{Array{ValueType},Vector{NamedEdge}}}(undef, length(vertices(tci.g))) + sitetensors = + Vector{Pair{Array{ValueType},Vector{NamedEdge}}}(undef, length(vertices(tci.g))) if center_vertex ∉ vertices(tci.g) center_vertex = first(vertices(tci.g)) @@ -18,17 +19,21 @@ function fillsitetensors( for child in children # adjacent_edges = adjacentedges(tci.g, child) parent = state.parents[child] - edge = filter(e -> src(e) == parent && dst(e) == child || dst(e) == parent && src(e) == child, edges(tci.g)) + edge = filter( + e -> + src(e) == parent && dst(e) == child || + dst(e) == parent && src(e) == child, + edges(tci.g), + ) edge = isempty(edge) ? nothing : only(edge) incomingedges = setdiff(adjacentedges(tci.g, child), Set([edge])) - InKeys = !isempty(incomingedges) ? edgeInIJkeys(tci.g, child, incomingedges) : SubTreeVertex[] + InKeys = + !isempty(incomingedges) ? edgeInIJkeys(tci.g, child, incomingedges) : + SubTreeVertex[] OutKeys = edge != nothing ? edgeInIJkeys(tci.g, child, edge) : SubTreeVertex[] if d != 0 T = sitetensor(tci, child, edge, InKeys => OutKeys, f) - sitetensors[child] = T => vcat( - incomingedges, - [edge], - ) + sitetensors[child] = T => vcat(incomingedges, [edge]) else T = sitetensor(tci, child, edge, InKeys => OutKeys, f, core = true) sitetensors[child] = T => incomingedges @@ -91,11 +96,11 @@ function sitetensor( length(tci.IJset[I1key]) == sum([length(tci.IJset[key]) for key in Outkeys]) || error("Pivot matrix at bond $(site) is not square!") Tmat = transpose(transpose(P) \ transpose(Pi1)) T = reshape( - Tmat, - tci.localdims[site], - [length(tci.IJset[key]) for key in Inkeys]..., - [length(tci.IJset[key]) for key in Outkeys]..., - ) + Tmat, + tci.localdims[site], + [length(tci.IJset[key]) for key in Inkeys]..., + [length(tci.IJset[key]) for key in Outkeys]..., + ) return T end @@ -186,11 +191,7 @@ function _call( return result end -function edgeInIJkeys( - g::NamedGraph, - v::Int, - combinededges -) +function edgeInIJkeys(g::NamedGraph, v::Int, combinededges) if combinededges isa NamedEdge combinededges = [combinededges] end @@ -204,4 +205,4 @@ function edgeInIJkeys( end end return keys -end \ No newline at end of file +end diff --git a/src/sweep2sitepathproper.jl b/src/sweep2sitepathproper.jl index 2351cbe..8cf1779 100644 --- a/src/sweep2sitepathproper.jl +++ b/src/sweep2sitepathproper.jl @@ -9,10 +9,20 @@ Default strategy that uses kronecker product and union with extra indices struct DefaultSweep2sitePathProper <: Sweep2sitePathProper end """ -Default strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors +Default strategy that return the sequence path defined by the edges(g) """ function generate_sweep2site_path( ::DefaultSweep2sitePathProper, + tci::SimpleTCI{ValueType}, +) where {ValueType} + return collect(edges(tci.g)) +end + +""" +LocalAdjacent strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors +""" +function generate_sweep2site_path( + ::LocalAdjacentSweep2sitePathProper, tci::SimpleTCI{ValueType}; origin_edge = undef, ) where {ValueType} @@ -42,10 +52,7 @@ function generate_sweep2site_path( while true candidates = candidateedges(tci.g, center_edge) - candidates = filter( - e -> flags[e] == 0, - candidates - ) + candidates = filter(e -> flags[e] == 0, candidates) # If candidates is empty, exit while loop if isempty(candidates) @@ -74,4 +81,4 @@ function generate_sweep2site_path( end return edge_path -end \ No newline at end of file +end diff --git a/src/tree_utils.jl b/src/tree_utils.jl index bb688f1..04158d7 100644 --- a/src/tree_utils.jl +++ b/src/tree_utils.jl @@ -35,8 +35,8 @@ end function adjacentedges( g::NamedGraph, vertex::Int; - combinededges::Union{NamedEdge, Vector{NamedEdge}} = Vector{NamedEdge}() -) ::Vector{NamedEdge} + combinededges::Union{NamedEdge,Vector{NamedEdge}} = Vector{NamedEdge}(), +)::Vector{NamedEdge} if combinededges isa NamedEdge combinededges = [combinededges] end @@ -50,19 +50,15 @@ function adjacentedges( return adjedges end -function candidateedges( - g::NamedGraph, - edge::NamedEdge, -)::Vector{NamedEdge} +function candidateedges(g::NamedGraph, edge::NamedEdge)::Vector{NamedEdge} p, q = separatevertices(g, edge) - candidates = adjacentedges(g, p; combinededges=edge) ∪ adjacentedges(g, q; combinededges=edge) + candidates = + adjacentedges(g, p; combinededges = edge) ∪ + adjacentedges(g, q; combinededges = edge) return candidates end -function distanceedges( - g::NamedGraph, - edge::NamedEdge, -)::Dict{NamedEdge,Int} +function distanceedges(g::NamedGraph, edge::NamedEdge)::Dict{NamedEdge,Int} p, q = separatevertices(g, edge) distances = Dict{NamedEdge,Int}() distances[edge] = 0 @@ -80,8 +76,7 @@ function distanceBFSedge( candidates = filter(cand -> cand ∉ keys(distances), candidates) for cand in candidates distances[cand] = distances[edge] + 1 - distances = - merge!(distances, distanceBFSedge(g, cand, distances)) + distances = merge!(distances, distanceBFSedge(g, cand, distances)) end return distances end diff --git a/src/treetensornetwork.jl b/src/treetensornetwork.jl index 1e189b4..1109d46 100644 --- a/src/treetensornetwork.jl +++ b/src/treetensornetwork.jl @@ -2,7 +2,8 @@ mutable struct TreeTensorNetwork{ValueType} tensornetwork::TensorNetwork function TreeTensorNetwork( - g::NamedGraph, sitetensors::Vector{Pair{Array{ValueType},Vector{NamedEdge}}}, + g::NamedGraph, + sitetensors::Vector{Pair{Array{ValueType},Vector{NamedEdge}}}, ) where {ValueType} !Graphs.is_cyclic(g) || error("TreeTensorNetwork is not supported for loopy tensor network.") @@ -11,8 +12,9 @@ mutable struct TreeTensorNetwork{ValueType} indexs = vcat( Index(size(T)[1], "s$i"), [ - Index(size(T)[j+1], "$(src(edges[j]))=>$(dst(edges[j]))") for j = 1:length(edges) - ] + Index(size(T)[j+1], "$(src(edges[j]))=>$(dst(edges[j]))") for + j = 1:length(edges) + ], ) t = IndexedArray(T, indexs) push!(ttntensors, t) @@ -25,7 +27,7 @@ end function crossinterpolate( ::Type{ValueType}, f, - localdims::Union{Vector{Int}, NTuple{N,Int}}, + localdims::Union{Vector{Int},NTuple{N,Int}}, g::NamedGraph, initialpivots::Vector{MultiIndex} = [ones(Int, length(localdims))]; kwargs..., @@ -39,7 +41,7 @@ end function evaluate( ttn::TreeTensorNetwork{ValueType}, indexset::Union{AbstractVector{Int},NTuple{N,Int}}, - ) where {N, ValueType} +) where {N,ValueType} tn = deepcopy(ttn.tensornetwork) if length(indexset) != length(vertices(tn.data_graph)) throw( @@ -51,8 +53,8 @@ function evaluate( for i = 1:length(vertices(tn.data_graph)) t = tn[i] site = IndexedArray( - [j == indexset[i] ? 1.0 : 0.0 for j in 1:t.indices[1].dim], - [t.indices[1]] + [j == indexset[i] ? 1.0 : 0.0 for j = 1:t.indices[1].dim], + [t.indices[1]], ) tn[i] = contract(t, site) end @@ -62,4 +64,3 @@ end function (ttn::TreeTensorNetwork{V})(indexset) where {V} return evaluate(ttn, indexset) end - diff --git a/test/simpletci_utils_test.jl b/test/simpletci_utils_test.jl index 8d2a467..5eda4ee 100644 --- a/test/simpletci_utils_test.jl +++ b/test/simpletci_utils_test.jl @@ -30,13 +30,11 @@ import NamedGraphs: NamedGraph, NamedEdge, add_edge!, edges, has_edge @test last(subregions) == [4, 5, 6, 7] - @test Set(TreeTCI.adjacentedges(g, 4)) == Set( - [NamedEdge(2 => 4), NamedEdge(4 => 5)] - ) + @test Set(TreeTCI.adjacentedges(g, 4)) == + Set([NamedEdge(2 => 4), NamedEdge(4 => 5)]) - @test Set(TreeTCI.candidateedges(g, NamedEdge(2 => 4))) == Set( - [NamedEdge(1 => 2), NamedEdge(2 => 3), NamedEdge(4 => 5)] - ) + @test Set(TreeTCI.candidateedges(g, NamedEdge(2 => 4))) == + Set([NamedEdge(1 => 2), NamedEdge(2 => 3), NamedEdge(4 => 5)]) @test TreeTCI.distanceedges(g, NamedEdge(2 => 4)) == Dict( NamedEdge(2 => 4) => 0, From 27671bccaccb2532611c12ccaffef7eb5ded4866 Mon Sep 17 00:00:00 2001 From: watayo Date: Thu, 13 Mar 2025 15:56:14 +0900 Subject: [PATCH 2/8] revise distanceedges --- Manifest.toml | 2 +- Project.toml | 2 ++ src/TreeTCI.jl | 3 +- src/simpletci.jl | 5 +++ src/sweep2sitepathproper.jl | 22 ++++++++++++- src/{tree_utils.jl => treegraph_utils.jl} | 31 +++++++++---------- test/runtests.jl | 2 +- ..._utils_test.jl => treegraph_utils_test.jl} | 0 8 files changed, 47 insertions(+), 20 deletions(-) rename src/{tree_utils.jl => treegraph_utils.jl} (78%) rename test/{simpletci_utils_test.jl => treegraph_utils_test.jl} (100%) diff --git a/Manifest.toml b/Manifest.toml index 6c454af..eb379a2 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.11.1" manifest_format = "2.0" -project_hash = "aa27f51668b9cb208b591bf6f6b28585e35c8ef7" +project_hash = "7b99dad25e08c441325ada1fca8f1debd62d95c4" [[deps.AbstractTrees]] git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" diff --git a/Project.toml b/Project.toml index ae6cfe6..e714596 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.1.0" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SimpleTensorNetworks = "3075f829-f72e-4896-a859-7fe0a9cabb9b" TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" @@ -14,6 +15,7 @@ TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" DataGraphs = "0.2.5" Graphs = "1.12.0" NamedGraphs = "0.6.4" +Random = "1.11.0" SimpleTensorNetworks = "0.1.0" TensorCrossInterpolation = "0.9.13" diff --git a/src/TreeTCI.jl b/src/TreeTCI.jl index d251a0c..97f91ac 100644 --- a/src/TreeTCI.jl +++ b/src/TreeTCI.jl @@ -15,7 +15,8 @@ import NamedGraphs: import TensorCrossInterpolation as TCI import SimpleTensorNetworks: TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract -include("tree_utils.jl") +import Random: shuffle +include("treegraph_utils.jl") include("simpletci.jl") include("simpletci_utils.jl") include("pivotcandidateproper.jl") diff --git a/src/simpletci.jl b/src/simpletci.jl index 08f3281..ddfacbb 100644 --- a/src/simpletci.jl +++ b/src/simpletci.jl @@ -237,8 +237,13 @@ function sweep2site!( if sweepstrategy == :default edge_path = generate_sweep2site_path(DefaultSweep2sitePathProper(), tci) + elseif sweepstrategy == :localadjacent edge_path = generate_sweep2site_path(LocalAdjacentSweep2sitePathProper(), tci) + + elseif sweepstrategy == :random + edge_path = generate_sweep2site_path(RandomSweep2sitePathProper(), tci) + else error("Invalid sweep strategy: $sweepstrategy") end diff --git a/src/sweep2sitepathproper.jl b/src/sweep2sitepathproper.jl index 8cf1779..a793a9b 100644 --- a/src/sweep2sitepathproper.jl +++ b/src/sweep2sitepathproper.jl @@ -4,10 +4,20 @@ Abstract type for pivot candidate generation strategies abstract type Sweep2sitePathProper end """ -Default strategy that uses kronecker product and union with extra indices +Default strategy """ struct DefaultSweep2sitePathProper <: Sweep2sitePathProper end +""" +Random strategy +""" +struct RandomSweep2sitePathProper <: Sweep2sitePathProper end + +""" +LocalAdjacent strategy +""" +struct LocalAdjacentSweep2sitePathProper <: Sweep2sitePathProper end + """ Default strategy that return the sequence path defined by the edges(g) """ @@ -18,6 +28,16 @@ function generate_sweep2site_path( return collect(edges(tci.g)) end +""" +Random strategy that returns a random sequence of edges +""" +function generate_sweep2site_path( + ::RandomSweep2sitePathProper, + tci::SimpleTCI{ValueType}, +) where {ValueType} + return shuffle(collect(edges(tci.g))) +end + """ LocalAdjacent strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors """ diff --git a/src/tree_utils.jl b/src/treegraph_utils.jl similarity index 78% rename from src/tree_utils.jl rename to src/treegraph_utils.jl index 04158d7..d9cf68a 100644 --- a/src/tree_utils.jl +++ b/src/treegraph_utils.jl @@ -61,22 +61,21 @@ end function distanceedges(g::NamedGraph, edge::NamedEdge)::Dict{NamedEdge,Int} p, q = separatevertices(g, edge) distances = Dict{NamedEdge,Int}() - distances[edge] = 0 - distances = distanceBFSedge(g, edge, distances) - return distances -end -function distanceBFSedge( - g::NamedGraph, - edge::NamedEdge, - distances::Dict{NamedEdge,Int}, -)::Dict{NamedEdge,Int} - - candidates = candidateedges(g, edge) - candidates = filter(cand -> cand ∉ keys(distances), candidates) - for cand in candidates - distances[cand] = distances[edge] + 1 - distances = merge!(distances, distanceBFSedge(g, cand, distances)) + function compute_distances(root, opposite, state) + for subvertex in filter(v -> v != root, subtreevertices(g, opposite => root)) + parent = state.parents[subvertex] + e = NamedEdge(parent, subvertex) + distances[e ∈ edges(g) ? e : reverse(e)] = state.dists[subvertex] + end end + + state_p = namedgraph_dijkstra_shortest_paths(g, p) + compute_distances(p, q, state_p) + + state_q = namedgraph_dijkstra_shortest_paths(g, q) + compute_distances(q, p, state_q) + + distances[edge] = 0 return distances -end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 149b069..8022863 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,6 @@ using Test # end @testset verbose = true "Actual tests" begin - include("simpletci_utils_test.jl") + include("treegraph_utils_test.jl") end end diff --git a/test/simpletci_utils_test.jl b/test/treegraph_utils_test.jl similarity index 100% rename from test/simpletci_utils_test.jl rename to test/treegraph_utils_test.jl From 70adff32b5ff0d849b46d0f272c528000ccac24c Mon Sep 17 00:00:00 2001 From: watayo Date: Thu, 13 Mar 2025 16:05:30 +0900 Subject: [PATCH 3/8] complete test --- samples/devtci.jl | 27 ------------------- test/runtests.jl | 2 +- ...egraph_utils_test.jl => simpletci_test.jl} | 17 +++++++++--- 3 files changed, 15 insertions(+), 31 deletions(-) delete mode 100644 samples/devtci.jl rename test/{treegraph_utils_test.jl => simpletci_test.jl} (63%) diff --git a/samples/devtci.jl b/samples/devtci.jl deleted file mode 100644 index f5cc157..0000000 --- a/samples/devtci.jl +++ /dev/null @@ -1,27 +0,0 @@ -using Revise -using TreeTCI -using NamedGraphs: NamedGraph, add_edge!, edges - -function main() - localdims = fill(2, 7) - g = NamedGraph(7) - add_edge!(g, 1, 2) - add_edge!(g, 2, 3) - add_edge!(g, 2, 4) - add_edge!(g, 4, 5) - add_edge!(g, 5, 6) - add_edge!(g, 5, 7) - - f(v) = 1 / (1 + v' * v) - tolerance = 1e-8 - - mpn, ranks, errors = - TreeTCI.TCI.crossinterpolate2(Float64, f, localdims; tolerance = tolerance) - ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g) - @show f([1, 1, 1, 1, 2, 1, 1]), f([1, 2, 1, 2, 2, 1, 1]), f([2, 2, 2, 2, 2, 2, 2]) - @show mpn([1, 1, 1, 1, 2, 1, 1]), mpn([1, 2, 1, 2, 2, 1, 1]), mpn([2, 2, 2, 2, 2, 2, 2]) - @show ttn([1, 1, 1, 1, 2, 1, 1]), ttn([1, 2, 1, 2, 2, 1, 1]), ttn([2, 2, 2, 2, 2, 2, 2]) - nothing -end - -main() diff --git a/test/runtests.jl b/test/runtests.jl index 8022863..236fbbd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,6 @@ using Test # end @testset verbose = true "Actual tests" begin - include("treegraph_utils_test.jl") + include("simpletci_test.jl") end end diff --git a/test/treegraph_utils_test.jl b/test/simpletci_test.jl similarity index 63% rename from test/treegraph_utils_test.jl rename to test/simpletci_test.jl index 5eda4ee..a073ff1 100644 --- a/test/treegraph_utils_test.jl +++ b/test/simpletci_test.jl @@ -1,6 +1,6 @@ using Test using TreeTCI -import NamedGraphs: NamedGraph, NamedEdge, add_edge!, edges, has_edge +import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge @testset "simpletci.jl" begin # make graph @@ -12,8 +12,7 @@ import NamedGraphs: NamedGraph, NamedEdge, add_edge!, edges, has_edge add_edge!(g, 5, 6) add_edge!(g, 5, 7) - - @testset "SubTreeVertex" begin + @testset "TreeGraphUtils" begin e = NamedEdge(2 => 4) v1, v2 = TreeTCI.separatevertices(g, e) @test v1 == 2 @@ -47,4 +46,16 @@ import NamedGraphs: NamedGraph, NamedEdge, add_edge!, edges, has_edge end + @testset "SimpleTCI" begin + localdims = fill(2, length(vertices(g))) + f(v) = 1 / (1 + v' * v) + + ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g) + @test ttn([1, 1, 1, 1, 1, 1, 1]) ≈ f([1, 1, 1, 1, 1, 1, 1]) + @test ttn([1, 1, 1, 1, 2, 2, 2]) ≈ f([1, 1, 1, 1, 2, 2, 2]) + @test ttn([2, 2, 2, 2, 1, 1, 1]) ≈ f([2, 2, 2, 2, 1, 1, 1]) + @test ttn([1, 2, 1, 2, 1, 2, 1]) ≈ f([1, 2, 1, 2, 1, 2, 1]) + @test ttn([2, 1, 2, 1, 2, 1, 2]) ≈ f([2, 1, 2, 1, 2, 1, 2]) + @test ttn([2, 2, 2, 2, 2, 2, 2]) ≈ f([2, 2, 2, 2, 2, 2, 2]) + end end From 71fe2472a546300516ba0c4fc8d8d95aa7b45f5c Mon Sep 17 00:00:00 2001 From: watayo Date: Thu, 13 Mar 2025 16:07:48 +0900 Subject: [PATCH 4/8] format --- src/treegraph_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/treegraph_utils.jl b/src/treegraph_utils.jl index d9cf68a..a01ace9 100644 --- a/src/treegraph_utils.jl +++ b/src/treegraph_utils.jl @@ -78,4 +78,4 @@ function distanceedges(g::NamedGraph, edge::NamedEdge)::Dict{NamedEdge,Int} distances[edge] = 0 return distances -end \ No newline at end of file +end From 33ee39f1a9d316c23f7d0d98eb925226cee53a7d Mon Sep 17 00:00:00 2001 From: watayo Date: Thu, 13 Mar 2025 16:15:47 +0900 Subject: [PATCH 5/8] Remove Manifest.toml from repository --- .gitignore | 1 + Manifest.toml | 379 -------------------------------------------------- 2 files changed, 1 insertion(+), 379 deletions(-) create mode 100644 .gitignore delete mode 100644 Manifest.toml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2251642 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +Manifest.toml \ No newline at end of file diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index eb379a2..0000000 --- a/Manifest.toml +++ /dev/null @@ -1,379 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.11.1" -manifest_format = "2.0" -project_hash = "7b99dad25e08c441325ada1fca8f1debd62d95c4" - -[[deps.AbstractTrees]] -git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.4.5" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "cd8b948862abee8f3d3e9b73a102a9ca924debb0" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.2.0" -weakdeps = ["SparseArrays", "StaticArrays"] - - [deps.Adapt.extensions] - AdaptSparseArraysExt = "SparseArrays" - AdaptStaticArraysExt = "StaticArrays" - -[[deps.ArnoldiMethod]] -deps = ["LinearAlgebra", "Random", "StaticArrays"] -git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" -uuid = "ec485272-7323-5ecc-a04f-4719b315124d" -version = "0.4.0" - -[[deps.ArrayInterface]] -deps = ["Adapt", "LinearAlgebra"] -git-tree-sha1 = "017fcb757f8e921fb44ee063a7aafe5f89b86dd1" -uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.18.0" - - [deps.ArrayInterface.extensions] - ArrayInterfaceBandedMatricesExt = "BandedMatrices" - ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" - ArrayInterfaceCUDAExt = "CUDA" - ArrayInterfaceCUDSSExt = "CUDSS" - ArrayInterfaceChainRulesCoreExt = "ChainRulesCore" - ArrayInterfaceChainRulesExt = "ChainRules" - ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" - ArrayInterfaceReverseDiffExt = "ReverseDiff" - ArrayInterfaceSparseArraysExt = "SparseArrays" - ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" - ArrayInterfaceTrackerExt = "Tracker" - - [deps.ArrayInterface.weakdeps] - BandedMatrices = "aae01518-5342-5314-be14-df237901396f" - BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" - ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" -version = "1.11.0" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -version = "1.11.0" - -[[deps.CommonWorldInvalidations]] -git-tree-sha1 = "ae52d1c52048455e85a387fbee9be553ec2b68d0" -uuid = "f70d9fcc-98c5-4d4a-abd7-e4cdeebd8ca8" -version = "1.0.0" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.16.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" - -[[deps.DataGraphs]] -deps = ["Dictionaries", "Graphs", "NamedGraphs", "PackageExtensionCompat", "SimpleTraits"] -git-tree-sha1 = "ba80c479f54904ea2d574eba7e03e434c83d71c8" -uuid = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" -version = "0.2.5" - - [deps.DataGraphs.extensions] - DataGraphsGraphsFlowsExt = "GraphsFlows" - - [deps.DataGraphs.weakdeps] - GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" -version = "1.11.0" - -[[deps.Dictionaries]] -deps = ["Indexing", "Random", "Serialization"] -git-tree-sha1 = "1cdab237b6e0d0960d5dcbd2c0ebfa15fa6573d9" -uuid = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" -version = "0.4.4" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" -version = "1.11.0" - -[[deps.EllipsisNotation]] -deps = ["StaticArrayInterface"] -git-tree-sha1 = "3507300d4343e8e4ad080ad24e335274c2e297a9" -uuid = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" -version = "1.8.0" - -[[deps.Graphs]] -deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "1dc470db8b1131cfc7fb4c115de89fe391b9e780" -uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.12.0" - -[[deps.IfElse]] -git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" -uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" -version = "0.1.1" - -[[deps.Indexing]] -git-tree-sha1 = "ce1566720fd6b19ff3411404d4b977acd4814f9f" -uuid = "313cdc1a-70c2-5d6a-ae34-0150d3930a38" -version = "1.1.1" - -[[deps.Inflate]] -git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" -uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.5" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -version = "1.11.0" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" -version = "1.11.0" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -version = "1.11.0" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -version = "1.11.0" - -[[deps.MacroTools]] -git-tree-sha1 = "72aebe0b5051e5143a079a4685a46da330a40472" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.15" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -version = "1.11.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" -version = "1.11.0" - -[[deps.NamedGraphs]] -deps = ["AbstractTrees", "Dictionaries", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "SimpleTraits", "SparseArrays", "SplitApplyCombine", "Suppressor"] -git-tree-sha1 = "c520ef2017b7c4f1bf7f60025b831babb8d3eaed" -uuid = "678767b0-92e7-4007-89e4-4527a8725b19" -version = "0.6.4" - - [deps.NamedGraphs.extensions] - NamedGraphsGraphsFlowsExt = "GraphsFlows" - NamedGraphsKaHyParExt = "KaHyPar" - NamedGraphsMetisExt = "Metis" - NamedGraphsSymRCMExt = "SymRCM" - - [deps.NamedGraphs.weakdeps] - GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889" - KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" - Metis = "2679e427-3c69-5b7f-982b-ece356f1e94b" - SymRCM = "286e6d88-80af-4590-acc9-0001b223b9bd" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.27+1" - -[[deps.OrderedCollections]] -git-tree-sha1 = "cc4054e898b852042d7b503313f7ad03de99c3dd" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.8.0" - -[[deps.PackageExtensionCompat]] -git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" -uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" -version = "1.0.2" -weakdeps = ["Requires", "TOML"] - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -version = "1.11.0" - -[[deps.QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "9da16da70037ba9d701192e27befedefb91ec284" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.11.2" - - [deps.QuadGK.extensions] - QuadGKEnzymeExt = "Enzyme" - - [deps.QuadGK.weakdeps] - Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -version = "1.11.0" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "62389eeff14780bfe55195b7204c0d8738436d64" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.1" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -version = "1.11.0" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" -version = "1.11.0" - -[[deps.SimpleTensorNetworks]] -deps = ["DataGraphs", "Graphs", "NamedGraphs"] -git-tree-sha1 = "c425260a467bb77990a523e3b44d29d687bed4e5" -uuid = "3075f829-f72e-4896-a859-7fe0a9cabb9b" -version = "0.1.0" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" -version = "1.11.0" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.11.0" - -[[deps.SplitApplyCombine]] -deps = ["Dictionaries", "Indexing"] -git-tree-sha1 = "c06d695d51cfb2187e6848e98d6252df9101c588" -uuid = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" -version = "1.2.3" - -[[deps.Static]] -deps = ["CommonWorldInvalidations", "IfElse", "PrecompileTools"] -git-tree-sha1 = "f737d444cb0ad07e61b3c1bef8eb91203c321eff" -uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "1.2.0" - -[[deps.StaticArrayInterface]] -deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Static"] -git-tree-sha1 = "96381d50f1ce85f2663584c8e886a6ca97e60554" -uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" -version = "1.8.0" - - [deps.StaticArrayInterface.extensions] - StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" - StaticArrayInterfaceStaticArraysExt = "StaticArrays" - - [deps.StaticArrayInterface.weakdeps] - OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "0feb6b9031bd5c51f9072393eb5ab3efd31bf9e4" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.13" - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - - [deps.StaticArrays.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.3" - -[[deps.Statistics]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.11.1" -weakdeps = ["SparseArrays"] - - [deps.Statistics.extensions] - SparseArraysExt = ["SparseArrays"] - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.7.0+0" - -[[deps.Suppressor]] -deps = ["Logging"] -git-tree-sha1 = "6dbb5b635c5437c68c28c2ac9e39b87138f37c0a" -uuid = "fd094767-a336-5f1f-9728-57cf17d0bbfb" -version = "0.2.8" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TensorCrossInterpolation]] -deps = ["EllipsisNotation", "LinearAlgebra", "QuadGK"] -git-tree-sha1 = "378bca655cc4596c9db7adba9ac3e089bac9e3c5" -uuid = "b261b2ec-6378-4871-b32e-9173bb050604" -version = "0.9.14" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" -version = "1.11.0" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" -version = "1.11.0" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.11.0+0" From 7f63756406fbde7e7503dec3f2006940c132dd2c Mon Sep 17 00:00:00 2001 From: watayo Date: Thu, 13 Mar 2025 16:17:32 +0900 Subject: [PATCH 6/8] rem --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e714596..aa7e57b 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,7 @@ TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" DataGraphs = "0.2.5" Graphs = "1.12.0" NamedGraphs = "0.6.4" -Random = "1.11.0" +Random = "1.11" SimpleTensorNetworks = "0.1.0" TensorCrossInterpolation = "0.9.13" From 70f3942f5a3ae8325d11371ef421caf77c22cd5a Mon Sep 17 00:00:00 2001 From: watayo Date: Thu, 13 Mar 2025 16:38:10 +0900 Subject: [PATCH 7/8] Abstract+ --- Project.toml | 2 +- src/TreeTCI.jl | 3 +- src/pivotcandidateproper.jl | 4 +- src/simpletci.jl | 307 ------------------ src/simpletci_optimize.jl | 276 ++++++++++++++++ ...impletci_utils.jl => simpletci_tensors.jl} | 0 src/sweep2sitepathproper.jl | 8 +- 7 files changed, 285 insertions(+), 315 deletions(-) create mode 100644 src/simpletci_optimize.jl rename src/{simpletci_utils.jl => simpletci_tensors.jl} (100%) diff --git a/Project.toml b/Project.toml index aa7e57b..bccfef1 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,7 @@ TensorCrossInterpolation = "b261b2ec-6378-4871-b32e-9173bb050604" DataGraphs = "0.2.5" Graphs = "1.12.0" NamedGraphs = "0.6.4" -Random = "1.11" +Random = "1.10" SimpleTensorNetworks = "0.1.0" TensorCrossInterpolation = "0.9.13" diff --git a/src/TreeTCI.jl b/src/TreeTCI.jl index 97f91ac..8fc82a2 100644 --- a/src/TreeTCI.jl +++ b/src/TreeTCI.jl @@ -18,8 +18,9 @@ import SimpleTensorNetworks: import Random: shuffle include("treegraph_utils.jl") include("simpletci.jl") -include("simpletci_utils.jl") include("pivotcandidateproper.jl") include("sweep2sitepathproper.jl") +include("simpletci_optimize.jl") +include("simpletci_tensors.jl") include("treetensornetwork.jl") end diff --git a/src/pivotcandidateproper.jl b/src/pivotcandidateproper.jl index 826fe8f..7ec8659 100644 --- a/src/pivotcandidateproper.jl +++ b/src/pivotcandidateproper.jl @@ -1,12 +1,12 @@ """ Abstract type for pivot candidate generation strategies """ -abstract type PivotCandidateProper end +abstract type AbstractPivotCandidateProper end """ Default strategy that uses kronecker product and union with extra indices """ -struct DefaultPivotCandidateProper <: PivotCandidateProper end +struct DefaultPivotCandidateProper <: AbstractPivotCandidateProper end """ Default strategy that runs through within all indices of site tensor according to the bond and connect them with IJSet from neighbors diff --git a/src/simpletci.jl b/src/simpletci.jl index ddfacbb..8efa38d 100644 --- a/src/simpletci.jl +++ b/src/simpletci.jl @@ -1,8 +1,6 @@ MultiIndex = Vector{Int} SubTreeVertex = Vector{Int} -using Base: SimpleLogger - mutable struct SimpleTCI{ValueType} IJset::Dict{SubTreeVertex,Vector{MultiIndex}} localdims::Vector{Int} @@ -82,308 +80,3 @@ function addglobalpivots!( nothing end - -@doc """ - optimize!(tci::SimpleTCI{ValueType}, f; kwargs...) - - Optimize the tensor cross interpolation (TCI) by iteratively updating pivots. - - # Arguments - - `tci`: The SimpleTCI object to optimize - - `f`: The function to interpolate - - # Keywords - - `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence - - `pivottolerance::Union{Float64,Nothing} = nothing`: Deprecated, use tolerance instead - - `maxbonddim::Int = typemax(Int)`: Maximum bond dimension - - `maxiter::Int = 20`: Maximum number of iterations - - `sweepstrategy::Symbol = :backandforth`: Strategy for sweeping - - `pivotsearch::Symbol = :full`: Strategy for pivot search - - `verbosity::Int = 0`: Verbosity level - - `loginterval::Int = 10`: Interval for logging - - `normalizeerror::Bool = true`: Whether to normalize errors - - `ncheckhistory::Int = 3`: Number of history steps to check - - `maxnglobalpivot::Int = 5`: Maximum number of global pivots - - `nsearchglobalpivot::Int = 5`: Number of global pivots to search - - `tolmarginglobalsearch::Float64 = 10.0`: Tolerance margin for global search - - `strictlynested::Bool = false`: Whether to enforce strict nesting - - `checkbatchevaluatable::Bool = false`: Whether to check if function is batch evaluatable - - # Returns - - `ranks`: Vector of ranks at each iteration - - `errors`: Vector of normalized errors at each iteration - """ -function optimize!( - tci::SimpleTCI{ValueType}, - f; - tolerance::Union{Float64,Nothing} = nothing, - pivottolerance::Union{Float64,Nothing} = nothing, - maxbonddim::Int = typemax(Int), - maxiter::Int = 20, - sweepstrategy::Symbol = :default, - pivotsearch::Symbol = :full, - verbosity::Int = 0, - loginterval::Int = 10, - normalizeerror::Bool = true, - ncheckhistory::Int = 3, - maxnglobalpivot::Int = 5, - nsearchglobalpivot::Int = 5, - tolmarginglobalsearch::Float64 = 10.0, - strictlynested::Bool = false, - checkbatchevaluatable::Bool = false, -) where {ValueType} - errors = Float64[] - ranks = Int[] - nglobalpivots = Int[] - local tol::Float64 - - if checkbatchevaluatable && !(f isa BatchEvaluator) - error("Function `f` is not batch evaluatable") - end - - if nsearchglobalpivot > 0 && nsearchglobalpivot < maxnglobalpivot - error("nsearchglobalpivot < maxnglobalpivot!") - end - - # Deprecate the pivottolerance option - if !isnothing(pivottolerance) - if !isnothing(tolerance) && (tolerance != pivottolerance) - throw( - ArgumentError( - "Got different values for pivottolerance and tolerance in optimize!(TCI2). For TCI2, both of these options have the same meaning. Please assign only `tolerance`.", - ), - ) - else - @warn "The option `pivottolerance` of `optimize!(tci::TensorCI2, f)` is deprecated. Please update your code to use `tolerance`, as `pivottolerance` will be removed in the future." - tol = pivottolerance - end - elseif !isnothing(tolerance) - tol = tolerance - else # pivottolerance == tolerance == nothing, therefore set tol to default value - tol = 1e-8 - end - - tstart = time_ns() - - if maxbonddim >= typemax(Int) && tol <= 0 - throw( - ArgumentError( - "Specify either tolerance > 0 or some maxbonddim; otherwise, the convergence criterion is not reachable!", - ), - ) - end - - globalpivots = MultiIndex[] - for iter = 1:maxiter - errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 - abstol = tol * errornormalization - - if verbosity > 1 - println(" Walltime $(1e-9*(time_ns() - tstart)) sec: starting 2site sweep") - flush(stdout) - end - - sweep2site!( - tci, - f, - 2; - iter1 = 1, - abstol = abstol, - maxbonddim = maxbonddim, - pivotsearch = pivotsearch, - verbosity = verbosity, - sweepstrategy = sweepstrategy, - ) - if verbosity > 0 && length(globalpivots) > 0 && mod(iter, loginterval) == 0 - abserr = [abs(evaluate(tci, p) - f(p)) for p in globalpivots] - nrejections = length(abserr .> abstol) - if nrejections > 0 - println( - " Rejected $(nrejections) global pivots added in the previous iteration, errors are $(abserr)", - ) - flush(stdout) - end - end - push!(errors, last(pivoterror(tci))) - - if verbosity > 1 - println( - " Walltime $(1e-9*(time_ns() - tstart)) sec: start searching global pivots", - ) - flush(stdout) - end - end - - errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 - return ranks, errors ./ errornormalization -end - -@doc """ - Perform 2site sweeps on a SimpleTCI. - !TODO: Implement for Tree structure - - """ -function sweep2site!( - tci::SimpleTCI{ValueType}, - f, - niter::Int; - iter1::Int = 1, - abstol::Float64 = 1e-8, - maxbonddim::Int = typemax(Int), - sweepstrategy::Symbol = :default, - pivotsearch::Symbol = :full, - verbosity::Int = 0, -) where {ValueType} - - if sweepstrategy == :default - edge_path = generate_sweep2site_path(DefaultSweep2sitePathProper(), tci) - - elseif sweepstrategy == :localadjacent - edge_path = generate_sweep2site_path(LocalAdjacentSweep2sitePathProper(), tci) - - elseif sweepstrategy == :random - edge_path = generate_sweep2site_path(RandomSweep2sitePathProper(), tci) - - else - error("Invalid sweep strategy: $sweepstrategy") - end - - - for iter = iter1:iter1+niter-1 - extraIJset = Dict(key => MultiIndex[] for key in keys(tci.IJset)) - if length(tci.IJset_history) > 0 - extraIJset = tci.IJset_history[end] - end - - push!(tci.IJset_history, deepcopy(tci.IJset)) - - flushpivoterror!(tci) - - for edge in edge_path - updatepivots!( - tci, - edge, - f; - abstol = abstol, - maxbonddim = maxbonddim, - verbosity = verbosity, - extraIJset = extraIJset, - ) - end - end - - nothing -end - - -function flushpivoterror!(tci::SimpleTCI{ValueType}) where {ValueType} - tci.pivoterrors = Float64[] - nothing -end - -""" -Update pivots at bond `b` of `tci` using the TCI2 algorithm. -Site tensors will be invalidated. -""" -function updatepivots!( - tci::SimpleTCI{ValueType}, - edge::NamedEdge, - f::F; - reltol::Float64 = 1e-14, - abstol::Float64 = 0.0, - maxbonddim::Int = typemax(Int), - verbosity::Int = 0, - extraIJset::Dict{SubTreeVertex,Vector{MultiIndex}} = Dict{ - SubTreeVertex, - Vector{MultiIndex}, - }(), -) where {F,ValueType} - - N = length(tci.localdims) - - (IJkey, combinedIJset) = - generate_pivot_candidates(DefaultPivotCandidateProper(), tci, edge, extraIJset) - Ikey, Jkey = first(IJkey), last(IJkey) - - t1 = time_ns() - Pi = reshape( - filltensor(ValueType, f, tci.localdims, combinedIJset, [Ikey], [Jkey], Val(0)), - length(combinedIJset[Ikey]), - length(combinedIJset[Jkey]), - ) - t2 = time_ns() - - updatemaxsample!(tci, Pi) - - luci = TCI.MatrixLUCI(Pi, reltol = reltol, abstol = abstol, maxrank = maxbonddim) - # TODO: we will implement luci according to optimal index subsets by following step - # 1. Compute the optimal index subsets (We also need the indices to set new pivots) - # 2. Reshape the Pi matrix by the optimal index subsets - # 3. Compute the LUCI by the reshaped Pi matrix - - t3 = time_ns() - if verbosity > 2 - x, y = length(combinedIJset[Ikey]), - length(combinedIJset[Jkey]), - println( - " Computing Pi ($x x $y) at bond $b: $(1e-9*(t2-t1)) sec, LU: $(1e-9*(t3-t2)) sec", - ) - end - - tci.IJset[Ikey] = combinedIJset[Ikey][TCI.rowindices(luci)] - tci.IJset[Jkey] = combinedIJset[Jkey][TCI.colindices(luci)] - - updateerrors!(tci, edge, TCI.pivoterrors(luci)) - nothing -end - - -function updatemaxsample!(tci::SimpleTCI{V}, samples::Array{V}) where {V} - tci.maxsamplevalue = TCI.maxabs(tci.maxsamplevalue, samples) -end - -function updateerrors!( - tci::SimpleTCI{T}, - edge::NamedEdge, - errors::AbstractVector{Float64}, -) where {T} - updateedgeerror!(tci, edge, last(errors)) - updatepivoterror!(tci, errors) - nothing -end - -function updateedgeerror!(tci::SimpleTCI{T}, edge::NamedEdge, error::Float64) where {T} - tci.bonderrors[edge] = error - nothing -end - -function updatepivoterror!(tci::SimpleTCI{T}, errors::AbstractVector{Float64}) where {T} - erroriter = Iterators.map(max, TCI.padzero(tci.pivoterrors), TCI.padzero(errors)) - tci.pivoterrors = - Iterators.take(erroriter, max(length(tci.pivoterrors), length(errors))) |> collect - nothing -end - -function pivoterror(tci::SimpleTCI{T}) where {T} - return maxbonderror(tci) -end - -function maxbonderror(tci::SimpleTCI{T}) where {T} - return maximum(values(tci.bonderrors)) -end - -""" -Return if site tensors are available -""" - -function pushunique!(collection, item) - if !(item in collection) - push!(collection, item) - end -end - -function pushunique!(collection, items...) - for item in items - pushunique!(collection, item) - end -end diff --git a/src/simpletci_optimize.jl b/src/simpletci_optimize.jl new file mode 100644 index 0000000..7c4d77a --- /dev/null +++ b/src/simpletci_optimize.jl @@ -0,0 +1,276 @@ +@doc """ + optimize!(tci::SimpleTCI{ValueType}, f; kwargs...) + + Optimize the tensor cross interpolation (TCI) by iteratively updating pivots. + + # Arguments + - `tci`: The SimpleTCI object to optimize + - `f`: The function to interpolate + + # Keywords + - `tolerance::Union{Float64,Nothing} = nothing`: Error tolerance for convergence + - `pivottolerance::Union{Float64,Nothing} = nothing`: Deprecated, use tolerance instead + - `maxbonddim::Int = typemax(Int)`: Maximum bond dimension + - `maxiter::Int = 20`: Maximum number of iterations + - `sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper()`: Strategy for sweeping + - `pivotsearch::Symbol = :full`: Strategy for pivot search + - `verbosity::Int = 0`: Verbosity level + - `loginterval::Int = 10`: Interval for logging + - `normalizeerror::Bool = true`: Whether to normalize errors + - `ncheckhistory::Int = 3`: Number of history steps to check + - `maxnglobalpivot::Int = 5`: Maximum number of global pivots + - `nsearchglobalpivot::Int = 5`: Number of global pivots to search + - `tolmarginglobalsearch::Float64 = 10.0`: Tolerance margin for global search + - `strictlynested::Bool = false`: Whether to enforce strict nesting + - `checkbatchevaluatable::Bool = false`: Whether to check if function is batch evaluatable + + # Returns + - `ranks`: Vector of ranks at each iteration + - `errors`: Vector of normalized errors at each iteration + """ +function optimize!( + tci::SimpleTCI{ValueType}, + f; + tolerance::Union{Float64,Nothing} = nothing, + pivottolerance::Union{Float64,Nothing} = nothing, + maxbonddim::Int = typemax(Int), + maxiter::Int = 20, + sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper(), + pivotsearch::Symbol = :full, + verbosity::Int = 0, + loginterval::Int = 10, + normalizeerror::Bool = true, + ncheckhistory::Int = 3, + maxnglobalpivot::Int = 5, + nsearchglobalpivot::Int = 5, + tolmarginglobalsearch::Float64 = 10.0, + strictlynested::Bool = false, + checkbatchevaluatable::Bool = false, +) where {ValueType} + errors = Float64[] + ranks = Int[] + nglobalpivots = Int[] + local tol::Float64 + + if checkbatchevaluatable && !(f isa BatchEvaluator) + error("Function `f` is not batch evaluatable") + end + + if nsearchglobalpivot > 0 && nsearchglobalpivot < maxnglobalpivot + error("nsearchglobalpivot < maxnglobalpivot!") + end + + # Deprecate the pivottolerance option + if !isnothing(pivottolerance) + if !isnothing(tolerance) && (tolerance != pivottolerance) + throw( + ArgumentError( + "Got different values for pivottolerance and tolerance in optimize!(TCI2). For TCI2, both of these options have the same meaning. Please assign only `tolerance`.", + ), + ) + else + @warn "The option `pivottolerance` of `optimize!(tci::TensorCI2, f)` is deprecated. Please update your code to use `tolerance`, as `pivottolerance` will be removed in the future." + tol = pivottolerance + end + elseif !isnothing(tolerance) + tol = tolerance + else # pivottolerance == tolerance == nothing, therefore set tol to default value + tol = 1e-8 + end + + tstart = time_ns() + + if maxbonddim >= typemax(Int) && tol <= 0 + throw( + ArgumentError( + "Specify either tolerance > 0 or some maxbonddim; otherwise, the convergence criterion is not reachable!", + ), + ) + end + + globalpivots = MultiIndex[] + for iter = 1:maxiter + errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 + abstol = tol * errornormalization + + if verbosity > 1 + println(" Walltime $(1e-9*(time_ns() - tstart)) sec: starting 2site sweep") + flush(stdout) + end + + sweep2site!( + tci, + f, + 2; + iter1 = 1, + abstol = abstol, + maxbonddim = maxbonddim, + pivotsearch = pivotsearch, + verbosity = verbosity, + sweepstrategy = sweepstrategy, + ) + if verbosity > 0 && length(globalpivots) > 0 && mod(iter, loginterval) == 0 + abserr = [abs(evaluate(tci, p) - f(p)) for p in globalpivots] + nrejections = length(abserr .> abstol) + if nrejections > 0 + println( + " Rejected $(nrejections) global pivots added in the previous iteration, errors are $(abserr)", + ) + flush(stdout) + end + end + push!(errors, last(pivoterror(tci))) + + if verbosity > 1 + println( + " Walltime $(1e-9*(time_ns() - tstart)) sec: start searching global pivots", + ) + flush(stdout) + end + end + + errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 + return ranks, errors ./ errornormalization +end + +@doc """ + Perform 2site sweeps on a SimpleTCI. + !TODO: Implement for Tree structure + + """ +function sweep2site!( + tci::SimpleTCI{ValueType}, + f, + niter::Int; + iter1::Int = 1, + abstol::Float64 = 1e-8, + maxbonddim::Int = typemax(Int), + sweepstrategy::AbstractSweep2sitePathProper = DefaultSweep2sitePathProper(), + pivotsearch::Symbol = :full, + verbosity::Int = 0, +) where {ValueType} + + edge_path = generate_sweep2site_path(sweepstrategy, tci) + + for iter = iter1:iter1+niter-1 + extraIJset = Dict(key => MultiIndex[] for key in keys(tci.IJset)) + if length(tci.IJset_history) > 0 + extraIJset = tci.IJset_history[end] + end + + push!(tci.IJset_history, deepcopy(tci.IJset)) + + flushpivoterror!(tci) + + for edge in edge_path + updatepivots!( + tci, + edge, + f; + abstol = abstol, + maxbonddim = maxbonddim, + verbosity = verbosity, + extraIJset = extraIJset, + ) + end + end + + nothing +end + + +function flushpivoterror!(tci::SimpleTCI{ValueType}) where {ValueType} + tci.pivoterrors = Float64[] + nothing +end + +""" +Update pivots at bond `b` of `tci` using the TCI2 algorithm. +Site tensors will be invalidated. +""" +function updatepivots!( + tci::SimpleTCI{ValueType}, + edge::NamedEdge, + f::F; + reltol::Float64 = 1e-14, + abstol::Float64 = 0.0, + maxbonddim::Int = typemax(Int), + verbosity::Int = 0, + extraIJset::Dict{SubTreeVertex,Vector{MultiIndex}} = Dict{ + SubTreeVertex, + Vector{MultiIndex}, + }(), +) where {F,ValueType} + + N = length(tci.localdims) + + (IJkey, combinedIJset) = + generate_pivot_candidates(DefaultPivotCandidateProper(), tci, edge, extraIJset) + Ikey, Jkey = first(IJkey), last(IJkey) + + t1 = time_ns() + Pi = reshape( + filltensor(ValueType, f, tci.localdims, combinedIJset, [Ikey], [Jkey], Val(0)), + length(combinedIJset[Ikey]), + length(combinedIJset[Jkey]), + ) + t2 = time_ns() + + updatemaxsample!(tci, Pi) + + luci = TCI.MatrixLUCI(Pi, reltol = reltol, abstol = abstol, maxrank = maxbonddim) + # TODO: we will implement luci according to optimal index subsets by following step + # 1. Compute the optimal index subsets (We also need the indices to set new pivots) + # 2. Reshape the Pi matrix by the optimal index subsets + # 3. Compute the LUCI by the reshaped Pi matrix + + t3 = time_ns() + if verbosity > 2 + x, y = length(combinedIJset[Ikey]), + length(combinedIJset[Jkey]), + println( + " Computing Pi ($x x $y) at bond $b: $(1e-9*(t2-t1)) sec, LU: $(1e-9*(t3-t2)) sec", + ) + end + + tci.IJset[Ikey] = combinedIJset[Ikey][TCI.rowindices(luci)] + tci.IJset[Jkey] = combinedIJset[Jkey][TCI.colindices(luci)] + + updateerrors!(tci, edge, TCI.pivoterrors(luci)) + nothing +end + + +function updatemaxsample!(tci::SimpleTCI{V}, samples::Array{V}) where {V} + tci.maxsamplevalue = TCI.maxabs(tci.maxsamplevalue, samples) +end + +function updateerrors!( + tci::SimpleTCI{T}, + edge::NamedEdge, + errors::AbstractVector{Float64}, +) where {T} + updateedgeerror!(tci, edge, last(errors)) + updatepivoterror!(tci, errors) + nothing +end + +function updateedgeerror!(tci::SimpleTCI{T}, edge::NamedEdge, error::Float64) where {T} + tci.bonderrors[edge] = error + nothing +end + +function updatepivoterror!(tci::SimpleTCI{T}, errors::AbstractVector{Float64}) where {T} + erroriter = Iterators.map(max, TCI.padzero(tci.pivoterrors), TCI.padzero(errors)) + tci.pivoterrors = + Iterators.take(erroriter, max(length(tci.pivoterrors), length(errors))) |> collect + nothing +end + +function pivoterror(tci::SimpleTCI{T}) where {T} + return maxbonderror(tci) +end + +function maxbonderror(tci::SimpleTCI{T}) where {T} + return maximum(values(tci.bonderrors)) +end diff --git a/src/simpletci_utils.jl b/src/simpletci_tensors.jl similarity index 100% rename from src/simpletci_utils.jl rename to src/simpletci_tensors.jl diff --git a/src/sweep2sitepathproper.jl b/src/sweep2sitepathproper.jl index a793a9b..7b0e930 100644 --- a/src/sweep2sitepathproper.jl +++ b/src/sweep2sitepathproper.jl @@ -1,22 +1,22 @@ """ Abstract type for pivot candidate generation strategies """ -abstract type Sweep2sitePathProper end +abstract type AbstractSweep2sitePathProper end """ Default strategy """ -struct DefaultSweep2sitePathProper <: Sweep2sitePathProper end +struct DefaultSweep2sitePathProper <: AbstractSweep2sitePathProper end """ Random strategy """ -struct RandomSweep2sitePathProper <: Sweep2sitePathProper end +struct RandomSweep2sitePathProper <: AbstractSweep2sitePathProper end """ LocalAdjacent strategy """ -struct LocalAdjacentSweep2sitePathProper <: Sweep2sitePathProper end +struct LocalAdjacentSweep2sitePathProper <: AbstractSweep2sitePathProper end """ Default strategy that return the sequence path defined by the edges(g) From 91d0da04fdd126b00806ba960c7e71feefbabc36 Mon Sep 17 00:00:00 2001 From: watayo Date: Thu, 13 Mar 2025 16:42:00 +0900 Subject: [PATCH 8/8] pushunique --- src/simpletci.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/simpletci.jl b/src/simpletci.jl index 8efa38d..a4c41d6 100644 --- a/src/simpletci.jl +++ b/src/simpletci.jl @@ -80,3 +80,17 @@ function addglobalpivots!( nothing end + + +function pushunique!(collection, item) + if !(item in collection) + push!(collection, item) + end +end + +function pushunique!(collection, items...) + for item in items + pushunique!(collection, item) + end +end +