diff --git a/.github/workflows/Julia-tests.yml b/.github/workflows/Julia-tests.yml index 42c2aa0..46f00a8 100644 --- a/.github/workflows/Julia-tests.yml +++ b/.github/workflows/Julia-tests.yml @@ -7,13 +7,13 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - julia-version: [1.2.0, 1.3.0-rc3] + julia-version: [1.2, 1.3] os: [ubuntu-latest, macOS-latest] steps: - uses: actions/checkout@v1.0.0 - name: "Setup Julia environment ${{ matrix.julia-version }}" - uses: julia-actions/setup-julia@v0.2 + uses: julia-actions/setup-julia@v1.0.1 with: version: ${{ matrix.julia-version }} - name: "Run Tests" @@ -21,6 +21,5 @@ jobs: julia --color=yes --project=@. -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); - Pkg.add(PackageSpec(url="https://github.com/PhilipVinc/QuantumLattices.jl")); println("Instantiated");' julia --color=yes --project=@. -e "using Pkg; Pkg.test(coverage=true)" diff --git a/.travis.yml b/.travis.yml index d9a4296..d8373c1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,19 +8,16 @@ julia: - nightly matrix: allow_failures: - - julia: [1.3, nightly] - - os: osx + - julia: [nightly] notifications: email: false script: - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi - travis_wait 20 julia --project=@. --color=yes -e 'using Pkg; Pkg.activate(); - println("Instantiated"); + println("Activate"); Pkg.instantiate(); - Pkg.add(PackageSpec(url="https://github.com/PhilipVinc/QuantumLattices.jl")); - println("Installed [QuantumLattices, Zygote]"); - println("Developed"); + println("Instantiate"); Pkg.test(coverage=true)'; after_success: - julia -e 'using Pkg; cd(Pkg.dir("NeuralQuantum")); Pkg.add("Coverage"); using Coverage; Codecov.submit(Codecov.process_folder())' @@ -32,7 +29,7 @@ jobs: os: linux script: - julia --project=docs/ -e 'using Pkg; - Pkg.add([PackageSpec(url="https://github.com/PhilipVinc/QuantumLattices.jl")]); + Pkg.activate(); Pkg.add("Documenter"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' diff --git a/Manifest.toml b/Manifest.toml index 7e8f046..164fd4c 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,9 +2,9 @@ [[AbstractFFTs]] deps = ["LinearAlgebra"] -git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40" +git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "0.4.1" +version = "0.5.0" [[Adapt]] deps = ["LinearAlgebra"] @@ -28,16 +28,10 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" version = "0.8.10" [[BinaryProvider]] -deps = ["Libdl", "Logging", "SHA"] -git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648" +deps = ["Libdl", "SHA"] +git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.6" - -[[CSTParser]] -deps = ["Tokenize"] -git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b" -uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" -version = "0.6.2" +version = "0.5.8" [[CommonSubexpressions]] deps = ["Test"] @@ -47,9 +41,9 @@ version = "0.2.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f" +git-tree-sha1 = "ed2c4abadf84c53d9e58510b5fc48912c2336fbb" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "2.1.0" +version = "2.2.0" [[Conda]] deps = ["JSON", "VersionParsing"] @@ -57,17 +51,11 @@ git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032" uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" version = "1.3.0" -[[Crayons]] -deps = ["Test"] -git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.0.0" - [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "517ce30aa57cdfae1ab444a7c0aef8bb86345bc2" +git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.17.1" +version = "0.17.6" [[Dates]] deps = ["Printf"] @@ -84,40 +72,50 @@ uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" version = "0.0.4" [[DiffRules]] -deps = ["Random", "Test"] -git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "f734b5f6bc9c909027ef99f6d91d5d9e4b111eed" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "0.0.10" +version = "0.1.0" [[Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[DocStringExtensions]] +deps = ["LibGit2", "Markdown", "Pkg", "Test"] +git-tree-sha1 = "88bb0edb352b16608036faadcc071adda068582a" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.1" + +[[Documenter]] +deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] +git-tree-sha1 = "0be9bf63e854a2408c2ecd3c600d68d4d87d8a73" +uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +version = "0.24.2" + [[FFTW]] -deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"] -git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f" +deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport"] +git-tree-sha1 = "4cfd3d43819228b9e73ab46600d0af0aa5cedceb" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "1.0.1" +version = "1.1.0" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "16974065d5bfa867446d3228bc63f05a440e910b" +git-tree-sha1 = "1a9fe4e1323f38de0ba4da49eafd15b25ec62298" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.7.2" +version = "0.8.2" [[ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] -git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "da46ac97b17793eba44ff366dc6cb70f1238a738" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.3" +version = "0.10.7" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "5bdac064676287838868e3715d44d1129d8f2d26" -repo-rev = "master" -repo-url = "https://github.com/MikeInnes/IRTools.jl.git" +git-tree-sha1 = "72421971e60917b8cd7737f9577c4f0f87eab306" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.2.3" +version = "0.3.0" [[Inflate]] deps = ["Pkg", "Printf", "Random", "Test"] @@ -161,10 +159,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[MacroTools]] -deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"] -git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76" +deps = ["Compat", "DataStructures", "Test"] +git-tree-sha1 = "82921f0e3bde6aebb8e524efc20f4042373c0c06" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.1" +version = "0.5.2" [[Markdown]] deps = ["Base64"] @@ -180,10 +178,9 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.6.0" [[NaNMath]] -deps = ["Compat"] -git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2" +git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.2" +version = "0.3.3" [[OrderedCollections]] deps = ["Random", "Serialization", "Test"] @@ -193,9 +190,9 @@ version = "1.1.0" [[Parsers]] deps = ["Dates", "Test"] -git-tree-sha1 = "ef0af6c8601db18c282d092ccbd2f01f3f0cd70b" +git-tree-sha1 = "0139ba59ce9bc680e2925aec5b7db79065d60556" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "0.3.7" +version = "0.3.10" [[Pkg]] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -275,9 +272,9 @@ version = "0.8.0" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6" +git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "0.11.0" +version = "0.12.1" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -294,21 +291,15 @@ deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[TimerOutputs]] -deps = ["Crayons", "Printf", "Test", "Unicode"] -git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c" +deps = ["Printf"] +git-tree-sha1 = "311765af81bbb48d7bad01fb016d9c328c6ede03" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.0" - -[[Tokenize]] -git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf" -uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" -version = "0.5.6" +version = "0.5.3" [[TupleTools]] -deps = ["Random", "Test"] -git-tree-sha1 = "b006524003142128cc6d36189dce337729aa0050" +git-tree-sha1 = "62a7a6cd5a608ff71cecfdb612e67a0897836069" uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" -version = "1.1.0" +version = "1.2.0" [[URIParser]] deps = ["Test", "Unicode"] @@ -323,6 +314,11 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +[[UnsafeArrays]] +git-tree-sha1 = "1de6ef280110c7ad3c5d2f7a31a360b57a1bde21" +uuid = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6" +version = "1.0.0" + [[VersionParsing]] deps = ["Compat"] git-tree-sha1 = "c9d5aa108588b978bd859554660c8a5c4f2f7669" @@ -331,16 +327,12 @@ version = "1.1.3" [[Zygote]] deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "bc6f84f1ac81f1d1548264dc0c1816252e1c62ef" -repo-rev = "master" -repo-url = "https://github.com/FluxML/Zygote.jl.git" +git-tree-sha1 = "e4245b9c5362346e154b62842a89a18e0210b92b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.3.4" +version = "0.4.1" [[ZygoteRules]] deps = ["MacroTools"] -git-tree-sha1 = "c4c29b30b8ff3be13d4244e78be7df2a42bc54d0" -repo-rev = "master" -repo-url = "https://github.com/FluxML/ZygoteRules.jl.git" +git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8" uuid = "700de1a5-db45-46bc-99cf-38207098b444" version = "0.2.0" diff --git a/Project.toml b/Project.toml index 967d14d..9d454e2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,10 @@ name = "NeuralQuantum" uuid = "eb923273-1014-53d4-802c-abcb7262255a" authors = ["Filippo Vicentini "] -version = "0.1.1" +version = "0.1.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -20,14 +19,16 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" +UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] +Zygote = ">= 0.4" julia = ">= 1.2" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" [targets] test = ["Test"] diff --git a/README.md b/README.md index 98a544b..fcf3bd6 100644 --- a/README.md +++ b/README.md @@ -18,14 +18,28 @@ pkg"add https://github.com/PhilipVinc/NeuralQuantum.jl" `QuantumLattices` is a custom package that allows defining new types of operators on a lattice. It's not needed natively but it is usefull to define hamiltonians on a lattice. +Alternatively you may activate the project included in the manifest that comes with NeuralQuantum. + + ## Example +The basic idea of the package is the following: you create an hamiltonian/lindbladian with QuantumOptics or QuantumLattices (the latter allows you to go to arbitrarily big lattices, but isn't yet very documented...). +Then, you create a `SteadyStateProblem`, which performs some transformations on those operators to put them in the shaped necessary to optimize them efficiently. +You also will probably need to create an `ObservablesProblem`, by providing it all the observables that you wish to monitor during the optimization. + +By default, if you don't provide the precision `Float32` is used. + +Then, you will pick a network, a sampler, and create an iterative sampler to sample the network. +You must write the training loop by yourself. Check the documentation and the examples in the folder `examples/` to better understand how to do this. + +*IMPORTANT:* If you want to use multithreaded samplers (identified by a `MT` at the beginning of their name), you will launch one markov chain per julia thread. As such, you will get much better performance if you set `JULIA_NUM_THREADS` environment variable to the number of physical cores in your computer before launching julia. + ``` # Load dependencies using NeuralQuantum, QuantumLattices -using Printf, ValueHistoriesLogger, Logging, ValueHistories +using Printf, Logging, ValueHistories # Select the numerical precision -T = Float64 +T = Float32 # Select how many sites you want Nsites = 6 diff --git a/examples/dissipative_spins_1d.jl b/examples/dissipative_spins_1d.jl index 9203f49..73e5317 100644 --- a/examples/dissipative_spins_1d.jl +++ b/examples/dissipative_spins_1d.jl @@ -3,7 +3,7 @@ using NeuralQuantum, QuantumLattices using Logging, Printf, ValueHistories # Select the numerical precision -T = Float64 +T = Float32 # Select how many sites you want Nsites = 7 @@ -22,7 +22,7 @@ Sx = QuantumLattices.LocalObservable(lind, sigmax, Nsites) Sy = QuantumLattices.LocalObservable(lind, sigmay, Nsites) Sz = QuantumLattices.LocalObservable(lind, sigmaz, Nsites) # Create the problem object with all the observables to be computed. -oprob = ObservablesProblem(Sx, Sy, Sz) +oprob = ObservablesProblem(T, Sx, Sy, Sz) # Define the Neural Network. A NDM with N visible spins and αa=2 and αh=1 @@ -33,7 +33,7 @@ cnet = cached(net) # Chose a sampler. Options are FullSumSampler() which sums over the whole space # ExactSampler() which does exact sampling or MCMCSamler which does a markov # chain. -sampl = MCMCSampler(Metropolis(Nsites), 3000, burn=50) +sampl = MCMCSampler(Metropolis(Nsites), 1000, burn=50) # Chose a sampler for the observables. osampl = FullSumSampler() @@ -69,6 +69,8 @@ for i=1:110 Optimisers.update!(optimizer, cnet, Δw) end +using QuantumOptics, Plots + # Optional: compute the exact solution ρ = last(steadystate.master(lind)[2]) ESx = real(expect(SparseOperator(Sx), ρ)) diff --git a/examples/spins_1d.jl b/examples/spins_1d.jl index b93d731..dafe347 100644 --- a/examples/spins_1d.jl +++ b/examples/spins_1d.jl @@ -3,7 +3,7 @@ using NeuralQuantum, QuantumLattices using Logging, Printf, ValueHistories # Select the numerical precision -T = Float64 +T = Float32 # Select how many sites you want sites = [3, 3] Nsites = prod(sites) @@ -14,6 +14,7 @@ lattice = SquareLattice(sites, PBC=true) Ĥ = quantum_ising_ham(lattice, g=1.0, V=2.0) # Create the Problem (cost function) for the given hamiltonian # targeting the ground state. +prob = GroundStateProblem(T, Ĥ) #-- Observables # Define the local observables to look at. @@ -21,7 +22,7 @@ Sx = QuantumLattices.LocalObservable(Ĥ, sigmax, Nsites) Sy = QuantumLattices.LocalObservable(Ĥ, sigmay, Nsites) Sz = QuantumLattices.LocalObservable(Ĥ, sigmaz, Nsites) # Create the problem object with all the observables to be computed. -oprob = ObservablesProblem(Sx, Sy, Sz, Ĥ) +oprob = ObservablesProblem(T, Sx, Sy, Sz, Ĥ) # Define the Neural Network. A RBM with N visible spins and α=2 @@ -39,7 +40,7 @@ osampl = FullSumSampler() # for more information on options type ?SR algo = SR(ϵ=T(0.001), use_iterative=true) # Optimizer: how big the steps of the descent should be -optimizer = Optimisers.Descent(0.005) +optimizer = Optimisers.Descent(0.01) # Create a multithreaded Iterative Sampler. is = MTIterativeSampler(cnet, sampl, prob, algo) @@ -69,8 +70,10 @@ for i=1:500 Optimisers.update!(optimizer, cnet, Δw) end +using QuantumOptics, Plots + # Optional: compute the exact solution -en, st = eigenstates(DenseOperator(ham)) +en, st = eigenstates(DenseOperator(Ĥ)) E_gs = real(minimum(en)) ψgs = first(st) ESx = real(expect(SparseOperator(Sx), ψgs)) diff --git a/src/Algorithms/Gradient/Gradient_eval.jl b/src/Algorithms/Gradient/Gradient_eval.jl index 256d377..6f92fc5 100644 --- a/src/Algorithms/Gradient/Gradient_eval.jl +++ b/src/Algorithms/Gradient/Gradient_eval.jl @@ -20,12 +20,13 @@ end ## Matrix whole space function sample_network!(res::MCMCGradientLEvaluationCache, prob::LRhoSquaredProblem, net, σ, wholespace=false) + T = typeof(res.Zave) CLO_i = res.LLO_i lnψ, ∇lnψ = logψ_and_∇logψ!(res.∇lnψ, net, σ) C_loc = compute_Cloc!(CLO_i, res.∇lnψ2, prob, net, σ, lnψ, res.σ) - prob = wholespace ? exp(2*real(lnψ)) : 1.0 + prob = wholespace ? exp(T(2)*real(lnψ)) : one(res.Zave) E = abs(C_loc)^2 res.Zave += prob diff --git a/src/Algorithms/SR/SR_eval.jl b/src/Algorithms/SR/SR_eval.jl index d01fcf6..2a420cd 100644 --- a/src/Algorithms/SR/SR_eval.jl +++ b/src/Algorithms/SR/SR_eval.jl @@ -5,10 +5,11 @@ function sample_network!( prob::HermitianMatrixProblem, net, σ, wholespace=false) + T = real(out_type(net)) lnψ, ∇lnψ = logψ_and_∇logψ!(res.∇lnψ, net, σ) E = compute_Cloc(prob, net, σ, lnψ, res.σ) - prob = wholespace ? exp(2*real(lnψ)) : 1.0 + prob = wholespace ? exp(T(2)*real(lnψ)) : one(T) res.Eave += prob * E res.Zave += prob #1.0 #exp(2*real(lnψ)) push!(res.Evalues, prob*E) @@ -24,12 +25,13 @@ end function sample_network!(res::MCMCSRLEvaluationCache, prob::LRhoSquaredProblem, net, σ, wholespace=false) + T = real(out_type(net)) CLO_i = res.LLO_i lnψ, ∇lnψ = logψ_and_∇logψ!(res.∇lnψ, net, σ) C_loc = compute_Cloc!(CLO_i, res.∇lnψ2, prob, net, σ, lnψ, res.σ) - prob = wholespace ? exp(2*real(lnψ)) : 1.0 + prob = wholespace ? exp(T(2)*real(lnψ)) : one(T) E = abs(C_loc)^2 res.Zave += prob diff --git a/src/IterativeInterface/Batched/Accumulator.jl b/src/IterativeInterface/Batched/Accumulator.jl new file mode 100644 index 0000000..20368f1 --- /dev/null +++ b/src/IterativeInterface/Batched/Accumulator.jl @@ -0,0 +1,12 @@ + +function Accumulator(net, prob, n_tot, batch_sz) + if prob isa HamiltonianGSEnergyProblem + return LocalKetAccumulator(net, state(prob, net), n_tot, batch_sz) + #elseif prob isa HamiltonianGSVarianceProblem + # return LocalGradAccumulator(net, state(net, prob), n_tot, batch_sz) + elseif prob isa LRhoKLocalSOpProblem + return LocalGradAccumulator(net, state(prob, net), n_tot, batch_sz) + else + throw("problem not handled") + end +end diff --git a/src/IterativeInterface/Batched/BatchedSampler.jl b/src/IterativeInterface/Batched/BatchedSampler.jl new file mode 100644 index 0000000..0f73934 --- /dev/null +++ b/src/IterativeInterface/Batched/BatchedSampler.jl @@ -0,0 +1,151 @@ +export BatchedSampler + +mutable struct BatchedSampler{N,BN,P,IC,EC,S,SC,V,Vb,Vi,Pv,Gv,LC} <: AbstractIterativeSampler + net::N + bnet::BN + problem::P + itercache::IC + sampled_values::EC + + sampler::S + sampler_cache::SC + + ψvals::Pv + ∇vals::Gv + ψ_batch::Pv + ∇ψ_batch::Gv + + accum::LC + + v::V + vc::Vb + vi_vec::Vi + batch_sz::Int +end + + +""" + IterativeSampler(net, sampler, algorithm, problem) + +Create a single-thread iterative sampler for the quantities defined +by algorithm. +""" +function BatchedSampler(net, + sampl, + prob, + algo=prob; + batch_sz=2^3) + + !ispow2(batch_sz) && @warn "Batch size is not a power of 2. Bad performance guaranteed." + + cnet = cached(net) + bnet = cached(net, batch_sz) + evaluated_vals = EvaluatedNetwork(algo, net) + itercache = SamplingCache(algo, prob, net) + v = state(prob, net) + sampler_cache = init_sampler!(sampl, net, v) + + # Total length of the markov chain, used to preallocate + N_tot = chain_length(sampl, sampler_cache) + @assert N_tot>0 "Error: chain length not inferred" + + ψvals = similar(trainable_first(net), out_type(net), 1, N_tot) + ∇vals = grad_cache(net, N_tot) + vi_vec = zeros(Int, N_tot) + + ψ_batch = similar(trainable_first(net), out_type(net), 1, batch_sz) + ∇ψ_batch = grad_cache(net, batch_sz) + + accum = Accumulator(net, prob, N_tot, batch_sz) + + vc = preallocate_state_batch(trainable_first(net), + input_type(net), + v, + batch_sz) + + BatchedSampler(cnet, bnet, + prob, + itercache, + evaluated_vals, + sampl, + sampler_cache, + ψvals, + ∇vals, + ψ_batch, + ∇ψ_batch, + accum, + v, + vc, + vi_vec, + batch_sz) +end + +""" + sample!(is::IterativeSampler) + +Samples the quantities accordingly, returning the sampled values and sampling history. +""" +function sample!(is::BatchedSampler) + init_sampler!(is.sampler, is.net, is.v, is.sampler_cache) + #vc_vec = zeros(0) + vi_vec = is.vi_vec .= 0 + for i=1:typemax(Int) + vi_vec[i] = index(is.v) + !samplenext!(is.v, is.sampler, is.net, is.sampler_cache) && break + end + + b_sz = is.batch_sz + Nv = length(vi_vec) + vc = is.vc + + # Those won't work on the cpu + ψvals_data = uview(is.ψvals) + ∇vals_data = uview(first(vec_data(is.∇vals))) + ∇ψ_batch_data = uview(first(vec_data(is.∇ψ_batch))) + + for i=1:b_sz:(Nv-b_sz) + for j=1:b_sz + set_index!(is.v, i+j-1) + store_state!(vc, config(is.v), j) + end + @views logψ_and_∇logψ!(is.∇ψ_batch, ψvals_data[:,i:i+b_sz-1], is.bnet, vc) + ∇vals_data[:,i:i+b_sz-1] .= ∇ψ_batch_data + end + + i = last(1:b_sz:(Nv)); l = Nv-i+1 + for j=1:b_sz + set_index!(is.v, i+j-1) + store_state!(vc, config(is.v), j) + end + @views logψ_and_∇logψ!(is.∇ψ_batch, ψvals_data[i:end], is.bnet, vc) + @views ∇vals_data[:,i:i+l-1] .= ∇ψ_batch_data[:,1:l] + + compute_local_term!(is) + + # Now I have gradients and other stuff + + return is.accum +end + +function compute_local_term!(is::BatchedSampler) + vi_vec = is.vi_vec + ψvals_data = collect(uview(is.ψvals)) + ∇vals_data = uview(first(vec_data(is.∇vals))) + ∇ψ_batch_data = uview(first(vec_data(is.∇ψ_batch))) + + ## Ended sampling those things + accum = is.accum + init!(accum) + for (i, σi)=enumerate(vi_vec) + σ = set_index!(is.v, σi) + @views push!(accum, is.ψvals[i], ∇vals_data[:,i]) + accumulate_connections!(accum, is.problem.L, σ) + end + finalize!(accum) +end + +Base.show(io::IO, is::BatchedSampler) = print(io, + "BatchedSampler for :"* + "\n\tnet\t\t: $(is.net)"* + "\n\tproblem\t: $(is.problem)"* + "\n\tsampler\t: $(is.sampler)") diff --git a/src/IterativeInterface/Batched/GradientBatchAccumulator.jl b/src/IterativeInterface/Batched/GradientBatchAccumulator.jl new file mode 100644 index 0000000..f02acdc --- /dev/null +++ b/src/IterativeInterface/Batched/GradientBatchAccumulator.jl @@ -0,0 +1,101 @@ +mutable struct GradientBatchAccumulator{N,A,B,C,C2,D,D2,E,F,B2,F2} + bnet::N + + in_buf::A # the matrix of Nsites x batchsz used as state + out_buf::B + ∇out_buf::F + + res::B2 + ∇res::F2 + + ψ0_buf_g::C # Buffers alls the <σ|ψ> of the denominator + ψ0_buf_c::C2 # Buffers alls the <σ|ψ> of the denominator + ∇0_buf::E + mel_buf_c::D # ⟨σ|Ô|σ'⟩ in the buffer + mel_buf_g::D2 # ⟨σ|Ô|σ'⟩ in the buffer + + buf_n::Int # Counter for elements in buffer + batch_sz::Int +end + +function GradientBatchAccumulator(net::NeuralNetwork, v::State, batch_sz) + bnet = cached(net, batch_sz) + CT = Complex{real(out_type(net))} + + w = trainable_first(net) + RT = real(eltype(w)) + in_buf = preallocate_state_batch(w, RT, v, batch_sz) + out_buf = similar(w, out_type(bnet), 1, batch_sz) + ∇out_buf = grad_cache(net, batch_sz) + + res = similar(w, CT, 1, batch_sz) + ∇res = grad_cache(CT, net, batch_sz) + + ψ0_buf_g = similar(w, out_type(net), 1, batch_sz) + ψ0_buf_c = zeros(out_type(net), 1, batch_sz) + ∇0_buf = grad_cache(net, batch_sz) + mel_buf_g = similar(w, CT, 1, batch_sz) + mel_buf_c = zeros(CT, 1, batch_sz) + + if typeof(mel_buf_g) == typeof(mel_buf_c) + mel_buf_g = nothing + ψ0_buf_g = nothing + end + + return GradientBatchAccumulator( + bnet, + in_buf, out_buf, ∇out_buf, + res, ∇res, + ψ0_buf_g, ψ0_buf_c, + ∇0_buf, + mel_buf_c, mel_buf_g, + 0, batch_sz) +end + +Base.length(a::GradientBatchAccumulator) = a.batch_sz + +isfull(a::GradientBatchAccumulator) = a.buf_n == length(a) + +init!(c::GradientBatchAccumulator) = c.buf_n = 0 + +function (c::GradientBatchAccumulator)(mel, v, ψ0, ∇0_buf) + # Increase the step in our internal buffer + # this is guaranteed to always be < max_capacity + c.buf_n = c.buf_n + 1 + + c.ψ0_buf_c[ c.buf_n] = ψ0 + c.mel_buf_c[c.buf_n] = mel + + store_state!(c.in_buf, v, c.buf_n) + c∇0_buf = vec_data(c.∇0_buf)[1] + dd = vec_data(∇0_buf)[1] + @uviews c∇0_buf dd begin + c∇0_buf[:,c.buf_n] .= dd + end + return c +end + +function process_accumulator!(c::GradientBatchAccumulator) + out_buf = c.out_buf + ∇out = c.∇out_buf + init!(c) + + logψ_and_∇logψ!(∇out, out_buf, c.bnet, c.in_buf) + #out_buf .-= c.ψ0_buf #logΔ + #out_buf .= exp.(out_buf) #exp(logΔ) + #out_buf .*= c.mel_buf + ψ0_buf = isnothing(c.ψ0_buf_g) ? c.ψ0_buf_c : copy!(c.ψ0_buf_g, c.ψ0_buf_c) + mel_buf = isnothing(c.mel_buf_g) ? c.ψ0_buf_c : copy!(c.mel_buf_g, c.mel_buf_c) + + c.res .= mel_buf .* exp.(out_buf .- ψ0_buf) + #collect ? if using the gpu... need to think about this + + ∇res = vec_data(c.∇res)[1] + ∇out = vec_data(∇out)[1] + ∇0 = vec_data(c.∇0_buf)[1] + + ∇res .= mel_buf .* (∇out .- ∇0) + + return nothing + #return c.out2_buf, c.∇out2_buf +end diff --git a/src/IterativeInterface/Batched/LocalGradAccumulator.jl b/src/IterativeInterface/Batched/LocalGradAccumulator.jl new file mode 100644 index 0000000..bbc7d6e --- /dev/null +++ b/src/IterativeInterface/Batched/LocalGradAccumulator.jl @@ -0,0 +1,146 @@ +mutable struct LocalGradAccumulator{a,b,B,C,D,S,Ac} <: AbstractAccumulator + cur_ψval::a # The last value seen of <σ|ψ> + cur_∇ψ::b + ψ_counter::B # Stores a counter, referring to how many non zero + # elements referring to σ we have found + cur_ψ::Int # Stores a counter, referring to which i-th element + # of the above array we are currently processing + + Oloc::C # ⟨σ|Ô|ψ⟩ computed + ∇Oloc::D # ⟨σ|Ô|ψ⟩ computed + n_tot::Int # total number of σ' matrix element computed + + acc::Ac # The accumulator + σ::S # just a temporary state to perform operation, cached +end + +function LocalGradAccumulator(net, σ, n_tot, batch_sz) + IT = input_type(net) + OT = out_type(net) + CT = Complex{real(out_type(net))} + f = trainable_first(net) + + cur_ψval = zero(OT) + cur_∇ψ = grad_cache(net) + ψ_counter = zeros(Int, n_tot) + Oloc = collect(similar(f, CT, n_tot)) + ∇Oloc = grad_cache(CT, net, n_tot) + + _σ = deepcopy(σ) + accum = GradientBatchAccumulator(net, σ, batch_sz) + + return LocalGradAccumulator(cur_ψval, cur_∇ψ, ψ_counter, 0, + Oloc, ∇Oloc, 0, + accum, _σ) +end + +batch_size(a::LocalGradAccumulator) = length(a.acc) + +function Base.resize!(c::LocalGradAccumulator, n_tot) + resize!(c.Oloc, n_tot) + + CT = eltype(vec_data(c.∇Oloc)[1]) + c.∇Oloc = grad_cache(CT, c.acc.bnet, n_tot) + + resize!(c.ψ_counter, n_tot) + + init!(c) + return c +end + +function init!(c::LocalGradAccumulator) + c.cur_ψ = 0 + c.n_tot = 0 + + c.Oloc .= 0.0 + + for v=vec_data(c.∇Oloc) + v.= 0.0 + end + + c.ψ_counter .= 0 + + init!(c.acc) + return c +end + +function Base.push!(c::LocalGradAccumulator, ψval::Number, ∇val) + c.cur_ψ += 1 + c.ψ_counter[c.cur_ψ] = 0 + c.cur_ψval = ψval + #c.cur_∇ψ.tuple_all_weights[1] .= ∇val + copyto!(c.cur_∇ψ.tuple_all_weights[1], ∇val) + return c +end + +function (c::LocalGradAccumulator)(mel::Number, cngs_l, cngs_r, v::State) + n_cngs_l = isnothing(cngs_l) ? 0 : length(cngs_l) + n_cngs_r = isnothing(cngs_r) ? 0 : length(cngs_r) + + mel == 0.0 && return c + + if n_cngs_l == 0 && n_cngs_r == 0 + c.Oloc[c.cur_ψ] += mel + #c.∇Oloc[i] += 0 + c.n_tot += 1 + else + σ = set_index!(c.σ, index(v)) + apply!(σ, cngs_l, cngs_r) + _send_to_accumulator(c, mel, σ) + end + return c +end + +function (c::LocalGradAccumulator)(mel::Number, cngs::StateChanges, v::State) + mel == 0.0 && return c + + if length(cngs) == 0 + c.Oloc[c.cur_ψ] += mel + #c.∇Oloc[i] += 0 + c.n_tot += 1 + else + σ = set_index!(c.σ, index(v)) + apply!(σ, cngs) + _send_to_accumulator(c, mel, σ) + end + return c +end + +function _send_to_accumulator(c::LocalGradAccumulator, mel, σ) + c.ψ_counter[c.cur_ψ] += 1 + + cσ = config(σ) + c.acc(mel, config(σ), c.cur_ψval, c.cur_∇ψ) + isfull(c.acc) && process_buffer!(c) + return c +end + + +finalize!(c::LocalGradAccumulator) = + process_buffer!(c, c.acc.buf_n) + +function process_buffer!(c::LocalGradAccumulator, k=length(c.acc)) + #out, ∇out = process_accumulator!(c.acc) + process_accumulator!(c.acc) + out = collect(c.acc.res) + ∇out = c.acc.∇res + #collect ? if using the gpu... need to think about this + + # Unsafe stuff can't be returned! + ∇Oloc = uview(vec_data(c.∇Oloc)[1]) + ∇out = uview(vec_data(∇out)[1]) + + i = c.cur_ψ + while k>0 + for j=1:c.ψ_counter[i] + c.Oloc[i] += out[k] + view(∇Oloc, :, i) .+= view(∇out, :,k) + k -= 1 + c.ψ_counter[i] = c.ψ_counter[i] - 1 + c.n_tot += 1 + end + i -= 1 + end + + return c +end diff --git a/src/IterativeInterface/Batched/LocalKetAccumulator.jl b/src/IterativeInterface/Batched/LocalKetAccumulator.jl index b660f8f..841dbba 100644 --- a/src/IterativeInterface/Batched/LocalKetAccumulator.jl +++ b/src/IterativeInterface/Batched/LocalKetAccumulator.jl @@ -1,6 +1,5 @@ -mutable struct LocalKetAccumulator{a,A,B,C,D,E,F,N} <: AbstractAccumulator +mutable struct LocalKetAccumulator{a,B,C,S,Ac} <: AbstractAccumulator cur_ψval::a # The last value seen of <σ|ψ> - ψ_vals::A # Stores all the values <σ|ψ> computed for this chain ψ_counter::B # Stores a counter, referring to how many non zero # elements referring to σ we have found cur_ψ::Int # Stores a counter, referring to which i-th element @@ -9,14 +8,8 @@ mutable struct LocalKetAccumulator{a,A,B,C,D,E,F,N} <: AbstractAccumulator Oloc::C # ⟨σ|Ô|ψ⟩ computed n_tot::Int # total number of σ' matrix element computed - mel_buf::D # ⟨σ|Ô|σ'⟩ in the buffer - ψ0_buf::A # Buffers alls the <σ|ψ> of the denominator - v_buf::F # the matrix of Nsites x batchsz used as state - buf_n::Int # Counter for elements in buffer - - σ::E # just a temporary state to perform operation, cached - batch_sz::Int - bnet::N + acc::Ac # The accumulator + σ::S # just a temporary state to perform operation, cached end function LocalKetAccumulator(net, σ, batch_sz) @@ -25,28 +18,27 @@ function LocalKetAccumulator(net, σ, batch_sz) f = trainable_first(net) cur_ψval = zero(OT) - ψ_vals = similar(f, OT, 1, 2) - ψ0_buf = similar(f, OT, 1, batch_sz) - ψ_counter = zeros(Int, batch_sz) + ψ_counter = zeros(Int, 1) Oloc = similar(f, OT, batch_sz) - mel_buf = zeros(OT, batch_sz) _σ = deepcopy(σ) - v = similar(f, IT, length(config(σ)), batch_sz) + accum = ScalarBatchAccumulator(net, σ, batch_sz) - return LocalKetAccumulator(cur_ψval, ψ_vals, ψ_counter, 0, + return LocalKetAccumulator(cur_ψval, ψ_counter, 0, Oloc, 0, - mel_buf, ψ0_buf, v, 0, - _σ, batch_sz, cached(net, batch_sz)) + accum, _σ) end -function init!(c::LocalKetAccumulator, ψ_vals) - c.ψ_vals = ψ_vals +batch_size(a::LocalKetAccumulator) = length(a.acc) + +function init!(c::LocalKetAccumulator, chain_len::Int) c.cur_ψ = 0 - c.buf_n = 0 - c.Oloc = similar(ψ_vals, length(ψ_vals)) + resize!(c.Oloc, chain_len) c.Oloc .= 0.0 - c.ψ_counter = zeros(Int, length(ψ_vals)) + resize!(c.ψ_counter, chain_len) + c.ψ_counter .= 0 + + init!(c.acc) return c end @@ -57,47 +49,64 @@ function Base.push!(c::LocalKetAccumulator, ψval::Number) return c end -function (c::LocalKetAccumulator)(mel, cngs, v) - i = c.cur_ψ +# I don't really need two versions of this command, +# but Julia is stupid so I need. +function (c::LocalKetAccumulator)(mel::Number, cngs_l, cngs_r, v::State) + n_cngs_l = isnothing(cngs_l) ? 0 : length(cngs_l) + n_cngs_r = isnothing(cngs_r) ? 0 : length(cngs_r) + + mel == 0.0 && return c # If we have no changes, simply add the element to ⟨σ|Ô|ψ⟩ because # exp(logψ(σ)-logψ(σ)) = exp(0) = 1 - if length(cngs) == 0 - c.Oloc[i] += mel + if n_cngs_l == 0 && n_cngs_r == 0 + c.Oloc[c.cur_ψ] += mel c.n_tot += 1 else - # Increase the step in our internal buffer - # this is guaranteed to always be < max_capacity - c.buf_n = c.buf_n + 1 + σ = set_index!(c.σ, index(v)) + apply!(σ, cngs_l, cngs_r) + _send_to_accumulator(c, mel, σ) + end + return c +end - c.ψ_counter[i] += 1 - c.ψ0_buf[c.buf_n] = c.cur_ψ #c.ψ_vals[i] +function (c::LocalKetAccumulator)(mel::Number, cngs::StateChanges, v::State) + mel == 0.0 && return c + # If we have no changes, simply add the element to ⟨σ|Ô|ψ⟩ because + # exp(logψ(σ)-logψ(σ)) = exp(0) = 1 + if length(cngs) == 0 + c.Oloc[c.cur_ψ] += mel + c.n_tot += 1 + else σ = set_index!(c.σ, index(v)) apply!(σ, cngs) - c.mel_buf[c.buf_n] = mel - c.v_buf[:,c.buf_n] .= config(v) - c.buf_n == c.batch_sz && process_buffer!(c) + _send_to_accumulator(c, mel, σ) end + return c +end + +function _send_to_accumulator(c::LocalKetAccumulator, mel, σ) + c.ψ_counter[c.cur_ψ] += 1 + cσ = config(σ) + c.acc(mel, config(σ), c.cur_ψval) + isfull(c.acc) && process_buffer!(c) return c end -finalize!(c::LocalKetAccumulator) = - process_buffer!(c, c.buf_n) -function process_buffer!(c::LocalKetAccumulator, k=c.batch_sz) - net = c.bnet +finalize!(c::LocalKetAccumulator) = + process_buffer!(c, c.acc.buf_n) - out = net(c.v_buf) - out .-= c.ψ0_buf - out .= exp.(out) +function process_buffer!(c::LocalKetAccumulator, k=length(c.acc)) + out = process_accumulator!(c.acc) #collect ? if using the gpu... need to think about this i = c.cur_ψ while k>0 for j=1:c.ψ_counter[i] - c.Oloc[i] += out[k] * c.mel_buf[k] + c.Oloc[i] += out[k] k -= 1 c.ψ_counter[i] = c.ψ_counter[i] - 1 c.n_tot += 1 @@ -105,6 +114,5 @@ function process_buffer!(c::LocalKetAccumulator, k=c.batch_sz) i -= 1 end - c.buf_n = 0 return c end diff --git a/src/IterativeInterface/Batched/ScalarBatchAccumulator.jl b/src/IterativeInterface/Batched/ScalarBatchAccumulator.jl new file mode 100644 index 0000000..c5dbe77 --- /dev/null +++ b/src/IterativeInterface/Batched/ScalarBatchAccumulator.jl @@ -0,0 +1,108 @@ +""" + ScalarBatchAccumulator(net, state, batch_size) + +A ScalarBatchAccumulator is used to evaluate the contribution to +local observables ⟨σ|Ô|σ'⟩ψ(σ')/ψ(σ) , but by computing the neural +network ψ(σ) in batches of size `batch_size`. This is essential +to extrract a speedup when using GPUs. + +This is an internal implementation detail of NeuralQuantum, and should +not be relied upon. + +Once constructed, a ScalarBatchAccumulator is supposed to be used as follows: +- if `isfull(sba) == true` you should not push new elements to it (an error +will be throw otherwise) +- data is pushed as `sba(⟨σ|Ô|σ'⟩, σ', ψ(σ))`. +The configuration should be passed as a vector (if ket) or as a tuple +of two vectors (if density matrix). +""" +mutable struct ScalarBatchAccumulator{N,A,B,Cc,Cg,Dc,Dg} + bnet::N # A batched version of the cached neural network + + in_buf::A # the matrix of Nsites x batchsz used as input + out_buf::B # The row vector of outputs + + ψ0_buf_c::Cc # Buffers alls the <σ|ψ> of the denominator + ψ0_buf_g::Cg # Buffers alls the <σ|ψ> of the denominator + mel_buf_c::Dc # ⟨σ|Ô|σ'⟩ in the buffer + mel_buf_g::Dg # ⟨σ|Ô|σ'⟩ in the buffer + + buf_n::Int # Counter for elements in buffer + batch_sz::Int # batch size +end + +function ScalarBatchAccumulator(net::NeuralNetwork, v::State, batch_sz) + bnet = cached(net, batch_sz) + + w = trainable_first(net) + RT = real(eltype(w)) + in_buf = preallocate_state_batch(w, RT, v, batch_sz) + out_buf = similar(w, out_type(net), 1, batch_sz) + + ψ0_buf_c = zeros(out_type(net), 1, batch_sz) + ψ0_buf_g = similar(w, out_type(net), 1, batch_sz) + mel_buf_c = zeros(out_type(net), 1, batch_sz) + mel_buf_g = similar(w, out_type(net), 1, batch_sz) + + return ScalarBatchAccumulator( + bnet, in_buf, out_buf, + ψ0_buf_c, ψ0_buf_g, mel_buf_c, mel_buf_g, 0, batch_sz) +end + + +Base.length(a::ScalarBatchAccumulator) = a.batch_sz +Base.count(a::ScalarBatchAccumulator) = a.buf_n +isfull(a::ScalarBatchAccumulator) = count(a) == length(a) + +""" + init!(c::ScalarBatchAccumulator) + +Resets the internal counter of the accumulator, deleting +all previously accumulated (but not computed) values. +""" +init!(c::ScalarBatchAccumulator) = c.buf_n = 0 + +function (c::ScalarBatchAccumulator)(mel, v, ψ0) + @assert !isfull(c) "Pushed data to a full accumulator." + + # Increase the step in our internal buffer + # this should be guaranteed to always be < max_capacity + c.buf_n = c.buf_n + 1 + + c.ψ0_buf_c[c.buf_n] = ψ0 + c.mel_buf_c[c.buf_n] = mel + store_state!(c.in_buf, v, c.buf_n) +end + +""" + process_accumulator!(c) + +Processes all states stored in the accumulator, by computing their +relative local contribution. + +It is safe to call this even if the accumulator is not full. In this +case all data beyond the count should be disregarded as it was +not initialized. + +The output will be returned. You should not assume ownership of +the output, as it is preallocated and will be used for further +computations of the accumulator. +""" +function process_accumulator!(c::ScalarBatchAccumulator) + out = c.out_buf + + # Compute the batch of logψ neural networks + logψ!(out, c.bnet, c.in_buf) + + ψ0_buf = isnothing(c.ψ0_buf_g) ? c.ψ0_buf_c : copy!(c.ψ0_buf_g, c.ψ0_buf_c) + mel_buf = isnothing(c.mel_buf_g) ? c.ψ0_buf_c : copy!(c.mel_buf_g, c.mel_buf_c) + + # compute the local contributions + out .= mel_buf .* exp.(out .- ψ0_buf) + #collect ? if using the gpu... need to think about this + + # Reset th ecounter of the batch accumulator + init!(c) + + return out +end diff --git a/src/Networks/ClosedSystems/Chain.jl b/src/Networks/ClosedSystems/Chain.jl index ec2e9e4..9a21f91 100644 --- a/src/Networks/ClosedSystems/Chain.jl +++ b/src/Networks/ClosedSystems/Chain.jl @@ -29,7 +29,15 @@ struct ChainCache{T<:Tuple} <: NNCache{Chain} valid::Bool end -cache(l::Chain) = ChainCache(cache.(l.layers), false) +cache(l::Chain, in_T, in_sz) = begin + caches = [] + for layer = l.layers + c = cache(layer, in_T, in_sz) + in_T, in_sz = layer_out_type_size(layer, in_T, in_sz) + push!(caches, c) + end + ChainCache(Tuple(caches), false) +end # Applychain with caches applychain(::Tuple{}, ::Tuple{}, x) = x diff --git a/src/Networks/ClosedSystems/RBMBatched.jl b/src/Networks/ClosedSystems/RBMBatched.jl index 8c182a0..e50df1b 100644 --- a/src/Networks/ClosedSystems/RBMBatched.jl +++ b/src/Networks/ClosedSystems/RBMBatched.jl @@ -27,17 +27,17 @@ cache(net::RBM, batch_sz) = begin similar(net.b, n_v, batch_sz), false) end +batch_size(c::RBMBatchedCache) = size(c.θ, 2) -(net::RBM)(c::RBMBatchedCache, σ::State) = net(c, config(σ)) -function (net::RBM)(c::RBMBatchedCache, σ_r::AbstractArray) +function logψ!(out::AbstractArray, net::RBM, c::RBMBatchedCache, σ_r::AbstractArray) θ = c.θ θ_tmp = c.θ_tmp logℒθ = c.logℒθ - res = c.res + res = out#c.res T = eltype(θ) # copy the states to complex valued states for the computations. - σ = copyto!(c.σ, σ_r) + σ = c.σ; σ.=σ_r #σ = copy!(c.σ, σ_r) #θ .= net.b .+ net.W * σ mul!(θ, net.W, σ) @@ -49,20 +49,23 @@ function (net::RBM)(c::RBMBatchedCache, σ_r::AbstractArray) conj!(res) Base.mapreducedim!(identity, +, res, logℒθ) - return res + # TODO make this better + #copyto!(out, 1, res, 1, length(out)) + + return out end -function logψ_and_∇logψ!(∇logψ, net::RBM, c::RBMBatchedCache, σ_r) +function logψ_and_∇logψ!(∇logψ, out, net::RBM, c::RBMBatchedCache, σ_r) θ = c.θ θ_tmp = c.θ_tmp logℒθ = c.logℒθ ∂logℒθ = c.∂logℒθ - res = c.res + res = out # c.res res_tmp = c.res_tmp T = eltype(θ) # copy the states to complex valued states for the computations. - σ = copyto!(c.σ, σ_r) + σ = c.σ; σ.=σ_r #σ = copy!(c.σ, σ_r) #θ .= net.b .+ net.W * σ mul!(θ, net.W, σ) @@ -81,5 +84,8 @@ function logψ_and_∇logψ!(∇logψ, net::RBM, c::RBMBatchedCache, σ_r) _batched_outer_prod!(∇logψ.W, ∂logℒθ, σ) - return res + # TODO make this better + #copyto!(out, 1, res, 1, length(out)) + + return out end diff --git a/src/Networks/ClosedSystems/SimpleLayers.jl b/src/Networks/ClosedSystems/SimpleLayers.jl index 00bbca6..f55b68b 100644 --- a/src/Networks/ClosedSystems/SimpleLayers.jl +++ b/src/Networks/ClosedSystems/SimpleLayers.jl @@ -22,13 +22,22 @@ struct DenseCache{Ta,Tb,Tc,Td} valid::Bool end -cache(l::Dense{Ta,Tb}) where {Ta,Tb} = - DenseCache(similar(l.W, size(l.W,2)), +function cache(l::Dense{Ta,Tb}, in_T ,in_sz) where {Ta,Tb} + c = DenseCache(similar(l.W, size(l.W,2)), similar(l.b), similar(l.W, size(l.W,1)), similar(l.b), similar(l.b), false) + return c +end + +function layer_out_type_size(l::Dense, in_T ,in_sz) + T1 = promote_type(in_T, eltype(l.W)) + out_T = promote_type(T1, eltype(l.b)) + out_sz = size(l.b) + return out_T, out_sz +end function (l::Dense)(c::DenseCache, x) # The preallocated caches @@ -81,12 +90,18 @@ mutable struct WSumCache{Ta,Tb,Tc} valid::Bool end -cache(l::WSum) = +cache(l::WSum, in_T, in_sz) = WSumCache(similar(l.c, Complex{real(eltype(l.c))}), zero(Complex{real(eltype(l.c))}), similar(l.c, Complex{real(eltype(l.c))}, 1, length(l.c)), false) +function layer_out_type_size(l::WSum, in_T ,in_sz) + out_T = Complex{real(eltype(l.c))} + return out_T, (1,) +end + + function (l::WSum)(c::WSumCache, x) σ = copyto!(c.σᵢₙ, x) diff --git a/src/Networks/MixedDensityMatrix/NDM.jl b/src/Networks/MixedDensityMatrix/NDM.jl index a5d7685..9213f37 100644 --- a/src/Networks/MixedDensityMatrix/NDM.jl +++ b/src/Networks/MixedDensityMatrix/NDM.jl @@ -141,9 +141,6 @@ cache(net::NDM) = similar(net.d_λ, complex(eltype(net.d_λ))), similar(net.d_λ, eltype(net.d_λ)), similar(net.d_λ, complex(eltype(net.d_λ))), - #VT(T, length(net.h_μ)), zeros(T, length(net.h_μ)), - #VT(T, length(net.h_μ)), zeros(T, length(net.h_μ)), - #zeros(Complex{T}, length(net.d_λ)), zeros(T, length(net.d_λ)), zeros(Complex{T}, length(net.d_λ)), similar(net.b_μ), -1, @@ -296,13 +293,13 @@ function logψ_and_∇logψ!(∇logψ, W::NDM, c::NDMCache, σr,σc) _Π .= _Π_tmp LinearAlgebra.BLAS.gemv!('N', T(0.5), W.u_μ, Δσ, T(0.0), _Π_tmp) _Π .+= T(1.0)im.* _Π_tmp - #@info "_Π diff " maximum(abs.(_Π - (T(0.5) * transpose(transpose(∑σ)*W.u_λ) + T(0.5)im* transpose(transpose(Δσ)*W.u_μ) .+ W.d_λ))) # --- End common terms with computation of ψ --- # # Compute additional terms for derivatives ∂logℒ_λ_σp = c.∂logℒ_λ_σp; ∂logℒ_λ_σp .= ∂logℒ.(θλ_σp) ∂logℒ_μ_σp = c.∂logℒ_μ_σp; ∂logℒ_μ_σp .= ∂logℒ.(θμ_σp) + ∂logℒ_Π = c.∂logℒ_Π; ∂logℒ_Π .= ∂logℒ.(_Π) # Store the derivatives diff --git a/src/Networks/MixedDensityMatrix/NDMBatched.jl b/src/Networks/MixedDensityMatrix/NDMBatched.jl new file mode 100644 index 0000000..f6e12c0 --- /dev/null +++ b/src/Networks/MixedDensityMatrix/NDMBatched.jl @@ -0,0 +1,273 @@ +# Cached version +mutable struct NDMBatchedCache{T,VT,VCT} <: NNBatchedCache{NDM} + θλ_σ::VT + θμ_σ::VT + θλ_σp::VT + θμ_σp::VT + + θλ_σ_tmp::VT + θμ_σ_tmp::VT + θλ_σp_tmp::VT + θμ_σp_tmp::VT + + σr::VT + σc::VT + ∑σ::VT + Δσ::VT + + ∑logℒ_λ_σ::T + ∑logℒ_μ_σ::T + ∑logℒ_λ_σp::T + ∑logℒ_μ_σp::T + ∂logℒ_λ_σ::VT + ∂logℒ_μ_σ::VT + ∂logℒ_λ_σp::VT + ∂logℒ_μ_σp::VT + + Γ_λ::T + Γ_μ::T + Π::VCT + + _Π::VCT + _Π2::VCT + _Π_tmp::VT + ∂logℒ_Π::VCT + + σ_row_cache::VT + i_σ_row_cache::Int + + valid::Bool # = false +end + +cache(net::NDM, batch_sz) = begin + n_h = length(net.h_μ) + n_v = length(net.b_μ) + n_a = length(net.d_λ) + CT = complex(eltype(net.d_λ)) + + NDMBatchedCache(similar(net.h_μ, n_h, batch_sz), + similar(net.h_μ, n_h, batch_sz), + similar(net.h_μ, n_h, batch_sz), + similar(net.h_μ, n_h, batch_sz), + + similar(net.h_μ, n_h, batch_sz), + similar(net.h_μ, n_h, batch_sz), + similar(net.h_μ, n_h, batch_sz), + similar(net.h_μ, n_h, batch_sz), + + similar(net.b_μ, n_v, batch_sz), + similar(net.b_μ, n_v, batch_sz), + similar(net.b_μ, n_v, batch_sz), + similar(net.b_μ, n_v, batch_sz), + + similar(net.b_μ, 1, batch_sz), + similar(net.b_μ, 1, batch_sz), + similar(net.b_μ, 1, batch_sz), + similar(net.b_μ, 1, batch_sz), + similar(net.h_μ, n_h, batch_sz), + similar(net.h_μ, n_h, batch_sz), + similar(net.h_μ, n_h, batch_sz), + similar(net.h_μ, n_h, batch_sz), + + similar(net.b_μ, 1, batch_sz), + similar(net.b_μ, 1, batch_sz), + similar(net.b_μ, CT, 1, batch_sz), + + similar(net.d_λ, CT, n_a, batch_sz), + similar(net.d_λ, CT, n_a, batch_sz), + similar(net.d_λ, n_a, batch_sz), + similar(net.d_λ, CT, n_a, batch_sz), + + similar(net.b_μ, n_v, batch_sz), + -1, + + false) +end + + +function logψ!(out::AbstractArray, W::NDM, c::NDMBatchedCache, σr_r, σc_r) + ∑σ = c.∑σ + Δσ = c.Δσ + θλ_σ = c.θλ_σ + θμ_σ = c.θμ_σ + θλ_σp = c.θλ_σp + θμ_σp = c.θμ_σp + _Π = c._Π + _Π2 = c._Π2 + _Π_tmp = c._Π_tmp + T = eltype(c.θλ_σ) + + σr = c.σr; copy!(σr, σr_r) + σc = c.σc; copy!(σc, σc_r) + + if !c.valid || c.σ_row_cache ≠ σr + c.σ_row_cache .= σr + c.valid = true + + # θλ_σ .= W.h_λ + W.w_λ*σr + mul!(θλ_σ, W.w_λ, σr) + θλ_σ .+= W.h_λ + + # θμ_σ .= W.h_μ + W.w_μ*σr + mul!(θμ_σ, W.w_μ, σr) + θμ_σ .+= W.h_μ + + c.θλ_σ_tmp .= logℒ.(θλ_σ) + c.∑logℒ_λ_σ .= 0.0 + Base.mapreducedim!(identity, +, c.∑logℒ_λ_σ, c.θλ_σ_tmp) + + c.θμ_σ_tmp .= logℒ.(θμ_σ) + c.∑logℒ_μ_σ .= 0.0 + Base.mapreducedim!(identity, +, c.∑logℒ_μ_σ, c.θμ_σ_tmp) + + c.∂logℒ_λ_σ .= ∂logℒ.(θλ_σ) + c.∂logℒ_μ_σ .= ∂logℒ.(θμ_σ) + end + + ∑logℒ_λ_σ = c.∑logℒ_λ_σ + ∑logℒ_μ_σ = c.∑logℒ_μ_σ + ∂logℒ_λ_σ = c.∂logℒ_λ_σ + ∂logℒ_μ_σ = c.∂logℒ_μ_σ + + ∑σ .= σr .+ σc + Δσ .= σr .- σc + + #θλ_σp .= W.h_λ + W.w_λ*σc + #θμ_σp .= W.h_μ + W.w_μ*σc + mul!(θλ_σp, W.w_λ, σc) + θλ_σp .+= W.h_λ + mul!(θμ_σp, W.w_μ, σc) + θμ_σp .+= W.h_μ + + c.θλ_σp_tmp .= logℒ.(θλ_σp) + c.∑logℒ_λ_σp .= 0.0 + Base.mapreducedim!(identity, +, c.∑logℒ_λ_σp, c.θλ_σp_tmp) + + c.θμ_σp_tmp .= logℒ.(θμ_σp) + c.∑logℒ_μ_σp .= 0.0 + Base.mapreducedim!(identity, +, c.∑logℒ_μ_σp, c.θμ_σp_tmp) + + mul!(_Π_tmp, W.u_λ, ∑σ) + _Π .= 0.5 .* _Π_tmp .+ W.d_λ + mul!(_Π_tmp, W.u_μ, Δσ) + _Π .+= T(0.5)im.* _Π_tmp + + Γ_λ = c.Γ_λ .=0.0 + Γ_μ = c.Γ_μ .=0.0 + mul!(Γ_λ, transpose(W.b_λ), ∑σ) + Γ_λ .= T(0.5) .* (Γ_λ .+ c.∑logℒ_λ_σ .+ c.∑logℒ_λ_σp) + mul!(Γ_μ, transpose(W.b_μ), Δσ) + Γ_μ .= T(0.5) .* (Γ_μ .+ c.∑logℒ_μ_σ .- c.∑logℒ_μ_σp) + + _Π .= logℒ.(_Π) + Π = c.Π .=0.0 + Base.mapreducedim!(identity, +, Π, _Π) + + out .= Γ_λ .+ T(1.0)im .* Γ_μ .+ Π + return out +end + +function logψ_and_∇logψ!(∇logψ, out, W::NDM, c::NDMBatchedCache, σr_r, σc_r) + ∑σ = c.∑σ + Δσ = c.Δσ + θλ_σ = c.θλ_σ + θμ_σ = c.θμ_σ + θλ_σp = c.θλ_σp + θμ_σp = c.θμ_σp + _Π = c._Π + _Π2 = c._Π2 + _Π_tmp = c._Π_tmp + T = eltype(c.θλ_σ) + + σr = c.σr; copy!(σr, σr_r) + σc = c.σc; copy!(σc, σc_r) + + if !c.valid || c.σ_row_cache ≠ σr + c.σ_row_cache .= σr + c.valid = true + + # θλ_σ .= W.h_λ + W.w_λ*σr + mul!(θλ_σ, W.w_λ, σr) + θλ_σ .+= W.h_λ + + # θμ_σ .= W.h_μ + W.w_μ*σr + mul!(θμ_σ, W.w_μ, σr) + θμ_σ .+= W.h_μ + + c.θλ_σ_tmp .= logℒ.(θλ_σ) + c.∑logℒ_λ_σ .= 0.0 + Base.mapreducedim!(identity, +, c.∑logℒ_λ_σ, c.θλ_σ_tmp) + + c.θμ_σ_tmp .= logℒ.(θμ_σ) + c.∑logℒ_μ_σ .= 0.0 + Base.mapreducedim!(identity, +, c.∑logℒ_μ_σ, c.θμ_σ_tmp) + + c.∂logℒ_λ_σ .= ∂logℒ.(θλ_σ) + c.∂logℒ_μ_σ .= ∂logℒ.(θμ_σ) + end + + ∑logℒ_λ_σ = c.∑logℒ_λ_σ + ∑logℒ_μ_σ = c.∑logℒ_μ_σ + ∂logℒ_λ_σ = c.∂logℒ_λ_σ + ∂logℒ_μ_σ = c.∂logℒ_μ_σ + + ∑σ .= σr .+ σc + Δσ .= σr .- σc + + #θλ_σp .= W.h_λ + W.w_λ*σc + #θμ_σp .= W.h_μ + W.w_μ*σc + mul!(θλ_σp, W.w_λ, σc) + θλ_σp .+= W.h_λ + mul!(θμ_σp, W.w_μ, σc) + θμ_σp .+= W.h_μ + + c.θλ_σp_tmp .= logℒ.(θλ_σp) + c.∑logℒ_λ_σp .= 0.0 + Base.mapreducedim!(identity, +, c.∑logℒ_λ_σp, c.θλ_σp_tmp) + + c.θμ_σp_tmp .= logℒ.(θμ_σp) + c.∑logℒ_μ_σp .= 0.0 + Base.mapreducedim!(identity, +, c.∑logℒ_μ_σp, c.θμ_σp_tmp) + + mul!(_Π_tmp, W.u_λ, ∑σ) + _Π .= 0.5 .* _Π_tmp .+ W.d_λ + mul!(_Π_tmp, W.u_μ, Δσ) + _Π .+= T(0.5)im.* _Π_tmp + + Γ_λ = c.Γ_λ .=0.0 + Γ_μ = c.Γ_μ .=0.0 + mul!(Γ_λ, transpose(W.b_λ), ∑σ) + Γ_λ .= T(0.5) .* (Γ_λ .+ c.∑logℒ_λ_σ .+ c.∑logℒ_λ_σp) + mul!(Γ_μ, transpose(W.b_μ), Δσ) + Γ_μ .= T(0.5) .* (Γ_μ .+ c.∑logℒ_μ_σ .- c.∑logℒ_μ_σp) + + _Π2 .= logℒ.(_Π) + Π = c.Π .=0.0 + Base.mapreducedim!(identity, +, Π, _Π2) + + out .= Γ_λ .+ T(1.0)im .* Γ_μ .+ Π + # --- End common terms with computation of ψ --- # + ∂logℒ_λ_σp = c.∂logℒ_λ_σp; ∂logℒ_λ_σp .= ∂logℒ.(θλ_σp) + ∂logℒ_μ_σp = c.∂logℒ_μ_σp; ∂logℒ_μ_σp .= ∂logℒ.(θμ_σp) + ∂logℒ_Π = c.∂logℒ_Π; ∂logℒ_Π .= ∂logℒ.(_Π) + + # Store the derivatives + ∇logψ.b_λ .= T(0.5) .* ∑σ + ∇logψ.b_μ .= T(0.5)im .* Δσ + + ∇logψ.h_λ .= T(0.5) .* (∂logℒ_λ_σ .+ ∂logℒ_λ_σp) + ∇logψ.h_μ .= T(0.5)im .* (∂logℒ_μ_σ .- ∂logℒ_μ_σp) + + #∇logψ.w_λ .= T(0.5) .* (∂logℒ_λ_σ.*transpose(σr) .+ ∂logℒ_λ_σp.*transpose(σc)) + #∇logψ.w_μ .= T(0.5)im .* (∂logℒ_μ_σ.*transpose(σr) .- ∂logℒ_μ_σp.*transpose(σc)) + _batched_outer_prod_∑!(∇logψ.w_λ, T(0.5), ∂logℒ_λ_σ, σr, ∂logℒ_λ_σp, σc) + _batched_outer_prod_Δ!(∇logψ.w_μ, T(0.5)im, ∂logℒ_μ_σ, σr, ∂logℒ_μ_σp, σc) + + ∇logψ.d_λ .= ∂logℒ_Π +# ∇logψ.u_λ .= T(0.5) .* ∂logℒ_Π .* transpose(∑σ) +# ∇logψ.u_μ .= T(0.5)im .* ∂logℒ_Π .* transpose(Δσ) + _batched_outer_prod!(∇logψ.u_λ, T(0.5), ∂logℒ_Π, ∑σ) + _batched_outer_prod!(∇logψ.u_μ, T(0.5)im, ∂logℒ_Π, Δσ) + + return out +end diff --git a/src/Networks/MixedDensityMatrix/NDMComplex.jl b/src/Networks/MixedDensityMatrix/NDMComplex.jl index fd962ee..035e85b 100644 --- a/src/Networks/MixedDensityMatrix/NDMComplex.jl +++ b/src/Networks/MixedDensityMatrix/NDMComplex.jl @@ -32,9 +32,9 @@ inita=(dims...)->rescaled_normal(T, 0.005, dims...)) Refs: """ -NDMComplex(args...) = NDMComplex(Complex{STD_REAL_PREC}, args...) -NDMComplex(::Real, ::Int) = throw("NDMComplex needs complex type") -NDMComplex(T::Type{<:Complex}, in, αh, αa, +NDMComplex(args...) = NDMComplex(STD_REAL_PREC, args...) +NDMComplex(T::Type{<:Real}, args...) = _NDMComplex(Complex{T}, args...) +_NDMComplex(T::Type{<:Complex}, in, αh, αa, initW=(dims...)->rescaled_normal(T, 0.01, dims...), initb=(dims...)->rescaled_normal(T, 0.005, dims...), inita=(dims...)->rescaled_normal(T, 0.005, dims...)) = diff --git a/src/Networks/MixedDensityMatrix/RBMSplit.jl b/src/Networks/MixedDensityMatrix/RBMSplit.jl index b2eec06..7f7ef11 100644 --- a/src/Networks/MixedDensityMatrix/RBMSplit.jl +++ b/src/Networks/MixedDensityMatrix/RBMSplit.jl @@ -43,8 +43,9 @@ out_type(net::RBMSplit) = eltype(net.Wr) is_analytic(net::RBMSplit) = true -(net::RBMSplit)(σ::State) = net(config(σ)...) -(net::RBMSplit)(σr, σc) = transpose(net.ar)*σr .+ transpose(net.ac)*σc .+ sum_autobatch(logℒ.(net.b .+ +(net::RBMSplit)(σ::State) = net(config(σ)...) +(net::RBMSplit)(σ::NTuple{N,<:AbstractArray}) where {N} = net(σ...) +(net::RBMSplit)(σr, σc) = transpose(net.ar)*σr .+ transpose(net.ac)*σc .+ sum_autobatch(logℒ.(net.b .+ net.Wr*σr .+ net.Wc*σc)) @@ -108,24 +109,14 @@ function (net::RBMSplit)(c::RBMSplitCache, σr_r, σc_r) end function logψ_and_∇logψ!(∇logψ, net::RBMSplit, c::RBMSplitCache, σr_r, σc_r) - θ = c.θ - θ_tmp = c.θ_tmp - logℒθ = c.logℒθ - ∂logℒθ = c.∂logℒθ - T = eltype(θ) + # Forward pass + lnψ = net(c, σr_r, σc_r) - # copy the states to complex valued states for the computations. - σr = c.σr; copyto!(σr, σr_r) - σc = c.σc; copyto!(σc, σc_r) - - #θ .= net.b .+ - # net.Wr*σr .+ - # net.Wc*σc - mul!(θ, net.Wr, σr) - mul!(θ_tmp, net.Wc, σc) - θ .+= net.b .+ θ_tmp + σr = c.σr; + σc = c.σc; + θ = c.θ + ∂logℒθ = c.∂logℒθ - logℒθ .= logℒ.(θ) ∂logℒθ .= ∂logℒ.(θ) ∇logψ.ar .= σr @@ -134,6 +125,5 @@ function logψ_and_∇logψ!(∇logψ, net::RBMSplit, c::RBMSplitCache, σr_r, ∇logψ.Wr .= ∂logℒθ .* transpose(σr) ∇logψ.Wc .= ∂logℒθ .* transpose(σc) - lnψ = dot(σr,net.ar) + dot(σc,net.ac) + sum(logℒθ) return lnψ end diff --git a/src/Networks/MixedDensityMatrix/RBMSplitBatched.jl b/src/Networks/MixedDensityMatrix/RBMSplitBatched.jl index e158786..1a68663 100644 --- a/src/Networks/MixedDensityMatrix/RBMSplitBatched.jl +++ b/src/Networks/MixedDensityMatrix/RBMSplitBatched.jl @@ -30,17 +30,17 @@ cache(net::RBMSplit, batch_sz) = begin false) end -function (net::RBMSplit)(c::RBMSplitBatchedCache, σr_r, σc_r) +function logψ!(out::AbstractArray, net::RBMSplit, c::RBMSplitBatchedCache, σr_r, σc_r) θ = c.θ θ_tmp = c.θ_tmp logℒθ = c.logℒθ - res = c.res + res = out res_tmp = c.res_tmp T = eltype(θ) # copy the states to complex valued states for the computations. - σr = c.σr; copyto!(σr, σr_r) - σc = c.σc; copyto!(σc, σc_r) + σr = c.σr; σr .= σr_r #copyto!(σr, σr_r) + σc = c.σc; σc .= σc_r #copyto!(σc, σc_r) mul!(θ, net.Wr, σr) mul!(θ_tmp, net.Wc, σc) @@ -48,26 +48,29 @@ function (net::RBMSplit)(c::RBMSplitBatchedCache, σr_r, σc_r) logℒθ .= NeuralQuantum.logℒ.(θ) #res = σr'*net.ar + σc'*net.ac # + sum(logℒθ, dims=1) - mul!(res_tmp, net.ar', σr) - mul!(res, net.ac', σc) + mul!(res_tmp, transpose(net.ar), σr) + mul!(res, transpose(net.ac), σc) res .+= res_tmp Base.mapreducedim!(identity, +, res, logℒθ) - return res + # TODO make this better + #copyto!(out, 1, res, 1, length(out)) + + return out end -function logψ_and_∇logψ!(∇logψ, net::RBMSplit, c::RBMSplitBatchedCache, σr_r, σc_r) +function logψ_and_∇logψ!(∇logψ, out, net::RBMSplit, c::RBMSplitBatchedCache, σr_r, σc_r) θ = c.θ θ_tmp = c.θ_tmp logℒθ = c.logℒθ ∂logℒθ = c.∂logℒθ - res = c.res + res = out res_tmp = c.res_tmp T = eltype(θ) # copy the states to complex valued states for the computations. - σr = c.σr; copyto!(σr, σr_r) - σc = c.σc; copyto!(σc, σc_r) + σr = c.σr; σr .= σr_r #copyto!(σr, σr_r) + σc = c.σc; σc .= σc_r #copyto!(σc, σc_r) mul!(θ, net.Wr, σr) mul!(θ_tmp, net.Wc, σc) @@ -85,11 +88,12 @@ function logψ_and_∇logψ!(∇logψ, net::RBMSplit, c::RBMSplitBatchedCache, ∇logψ.ar .= σr ∇logψ.ac .= σc ∇logψ.b .= ∂logℒθ - #∇logψ.Wr .= ∂logℒθ .* transpose(σr) - #∇logψ.Wc .= ∂logℒθ .* transpose(σc) _batched_outer_prod!(∇logψ.Wr, ∂logℒθ, σr) _batched_outer_prod!(∇logψ.Wc, ∂logℒθ, σc) - return res + # TODO make this better + #copyto!(out, 1, res, 1, length(out)) + + return out end diff --git a/src/Networks/NetworkWrappers.jl b/src/Networks/NetworkWrappers.jl index 56cfd45..59ee4f1 100644 --- a/src/Networks/NetworkWrappers.jl +++ b/src/Networks/NetworkWrappers.jl @@ -1,10 +1,11 @@ export PureStateAnsatz -struct PureStateAnsatz{A,IT,OT,Anal} <: KetNeuralNetwork +struct PureStateAnsatz{A,IT,OT,Anal,S} <: KetNeuralNetwork __ansatz::A + in_size::S end -function PureStateAnsatz(ansatz) +function PureStateAnsatz(ansatz, in_size) # input type in_t = real(eltype(trainable_first(ansatz))) # output_type @@ -12,23 +13,33 @@ function PureStateAnsatz(ansatz) # is_analytic anal = true ansatz = (ansatz,) - return PureStateAnsatz{typeof(ansatz), in_t, out_t, anal}(ansatz) + S=typeof(in_size) + return PureStateAnsatz{typeof(ansatz), in_t, out_t, anal,S}(ansatz, in_size) end @forward PureStateAnsatz.__m_ansatz Base.getindex, Base.length, Base.first, Base.last, - Base.iterate, Base.lastindex, cache + Base.iterate, Base.lastindex ansatz(psa::PureStateAnsatz) = first(getfield(psa, :__ansatz)) -@functor PureStateAnsatz +functor(psa::PureStateAnsatz) = psa.__ansatz, + a -> PureStateAnsatz(a, input_size(psa)) + +cache(psa::PureStateAnsatz) = + cache(psa.__m_ansatz, input_type(psa), input_size(psa)) + +cache(psa::PureStateAnsatz, batch_sz::Int) = + cache(psa.__m_ansatz, input_type(psa), input_size(psa), batch_sz) + (c::PureStateAnsatz)(x::Vararg{N,V}) where {N,V} = ansatz(c)(config.(x)...) (c::PureStateAnsatz)(cache::NNCache, σ) = ansatz(c)(cache, config(σ)) logψ_and_∇logψ!(∇logψ, net::PureStateAnsatz, c::NNCache, σ) = - logψ_and_∇logψ!(∇logψ.__ansatz[1], ansatz(net), c, σ) + logψ_and_∇logψ!(∇logψ[1], ansatz(net), c, σ) input_type(::PureStateAnsatz{A,IT,OT,AN}) where {A,IT,OT,AN} = IT +input_size(psa::PureStateAnsatz) = getfield(psa, :in_size) out_type(::PureStateAnsatz{A,IT,OT,AN}) where {A,IT,OT,AN} = OT is_analytic(::PureStateAnsatz{A,IT,OT,AN}) where {A,IT,OT,AN} = AN diff --git a/src/NeuralQuantum.jl b/src/NeuralQuantum.jl index 4d22d8a..733c544 100644 --- a/src/NeuralQuantum.jl +++ b/src/NeuralQuantum.jl @@ -1,16 +1,16 @@ module NeuralQuantum # Using statements -using Reexport, Requires +using Reexport +using Requires using MacroTools: @forward using QuantumOpticsBase using LightGraphs - using Zygote -using Random: AbstractRNG, MersenneTwister, GLOBAL_RNG -using LinearAlgebra, SparseArrays, Strided using NNlib +using Random: Random, AbstractRNG, MersenneTwister, GLOBAL_RNG, rand! +using LinearAlgebra, SparseArrays, Strided, UnsafeArrays include("IterativeSolvers/minresqlp.jl") using .MinresQlp @@ -24,27 +24,18 @@ using .Optimisers import .Optimisers: update, update! export Optimisers -# Imports -import Base: length, UInt, eltype, copy, deepcopy, iterate -import Random: rand! -import QuantumOpticsBase: basis - # Abstract Types abstract type NeuralNetwork end abstract type State end abstract type FiniteBasisState <: State end -abstract type AbstractProblem end -abstract type AbstractSteadyStateProblem <: AbstractProblem end -abstract type HermitianMatrixProblem <: AbstractSteadyStateProblem end -abstract type LRhoSquaredProblem <: AbstractSteadyStateProblem end -abstract type OpenTimeEvolutionProblem <: AbstractSteadyStateProblem end -abstract type OperatorEstimationProblem <: AbstractProblem end - abstract type Sampler end abstract type AbstractAccumulator end +abstract type AbstractProblem end +include("Problems/base_problems.jl") + # Type describing the parallel backend used by a solver. abstract type ParallelType end struct NotParallel <: ParallelType end @@ -59,7 +50,6 @@ include("base_states.jl") include("base_derivatives.jl") include("base_networks.jl") include("base_cached_networks.jl") -include("base_batched_networks.jl") include("treelike.jl") # from flux include("tuple_logic.jl") @@ -72,20 +62,24 @@ include("States/NAryState.jl") include("States/DoubleState.jl") include("States/PurifiedState.jl") include("States/DiagonalStateWrapper.jl") -export local_index include("States/ModifiedState.jl") -export ModifiedState +export ModifiedState, local_index + +include("base_batched_networks.jl") # Linear Operators -import Base: + +import Base: +, * include("Operators/BaseOperators.jl") include("Operators/OpConnection.jl") include("Operators/OpConnectionIndex.jl") include("Operators/KLocalOperator.jl") include("Operators/KLocalOperatorSum.jl") +include("Operators/KLocalOperatorTensor.jl") +include("Operators/KLocalLiouvillian.jl") + include("Operators/GraphConversion.jl") export OpConnection -export KLocalOperator, KLocalOperatorSum, KLocalOperatorRow, operators +export KLocalOperator, KLocalOperatorTensor, KLocalOperatorSum, KLocalOperatorRow, operators export row_valdiff, row_valdiff_index, col_valdiff, sites, conn_type export duplicate @@ -95,6 +89,7 @@ include("Networks/utils.jl") # Mixed Density Matrices include("Networks/MixedDensityMatrix/NDM.jl") +include("Networks/MixedDensityMatrix/NDMBatched.jl") include("Networks/MixedDensityMatrix/NDMComplex.jl") include("Networks/MixedDensityMatrix/NDMSymm.jl") include("Networks/MixedDensityMatrix/RBMSplit.jl") @@ -117,6 +112,7 @@ export LdagLSparseOpProblem, LRhoSparseSuperopProblem, LdagLProblem, LdagLFullPr include("Problems/SteadyStateLindblad/LdagLSparseOpProblem.jl") include("Problems/SteadyStateLindblad/LdagLSparseSuperopProblem.jl") include("Problems/SteadyStateLindblad/LRhoKLocalOpProblem.jl") +include("Problems/SteadyStateLindblad/LRhoKLocalSOpProblem.jl") include("Problems/SteadyStateLindblad/LRhoSparseOpProblem.jl") include("Problems/SteadyStateLindblad/LRhoSparseSuperopProblem.jl") const LdagLFullProblem = LRhoSparseSuperopProblem @@ -127,6 +123,7 @@ include("Problems/SteadyStateLindblad/build_SteadyStateProblem.jl") # Hamiltonian problems include("Problems/Hamiltonian/HamiltonianGSEnergyProblem.jl") +include("Problems/Hamiltonian/HamiltonianGSVarianceProblem.jl") include("Problems/Hamiltonian/build_GroundStateProblem.jl") # Observables problem @@ -184,20 +181,26 @@ include("IterativeInterface/BaseIterativeSampler.jl") include("IterativeInterface/IterativeSampler.jl") include("IterativeInterface/MTIterativeSampler.jl") +include("IterativeInterface/Batched/BatchedSampler.jl") +include("IterativeInterface/Batched/ScalarBatchAccumulator.jl") +include("IterativeInterface/Batched/GradientBatchAccumulator.jl") include("IterativeInterface/Batched/LocalKetAccumulator.jl") +include("IterativeInterface/Batched/LocalGradAccumulator.jl") +include("IterativeInterface/Batched/Accumulator.jl") + export sample! function __init__() @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin - using .CuArrays + import .CuArrays: CuArrays, @cufunc - CuArrays.@cufunc ℒ(x) = one(x) + exp(x) + #= @cufunc NeuralQuantum.ℒ(x) = one(x) + exp(x) - CuArrays.@cufunc ∂logℒ(x) = one(x)/(one(x)+exp(-x)) + @cufunc NeuralQuantum.∂logℒ(x) = one(x)/(one(x)+exp(-x)) - CuArrays.@cufunc logℒ(x::Real) = log1p(exp(x)) - CuArrays.@cufunc logℒ(x::Complex) = log(one(x) + exp(x)) + @cufunc NeuralQuantum.logℒ(x::Real) = log1p(exp(x)) + @cufunc NeuralQuantum.logℒ(x::Complex) = log(one(x) + exp(x))=# end @require QuantumOptics="6e0679c1-51ea-5a7c-ac74-d61b76210b0c" begin diff --git a/src/Operators/GraphConversion.jl b/src/Operators/GraphConversion.jl index 82c4dee..37abc3b 100644 --- a/src/Operators/GraphConversion.jl +++ b/src/Operators/GraphConversion.jl @@ -17,6 +17,7 @@ function to_linear_operator(ham::GraphOperator, c_ops::Vector, T::Union{Nothing, # default type T = isnothing(T) ? eltype(first(ham_locs).data) : T + T = T<:Real ? Complex{T} : T op_loc = KLocalOperatorRow(T, [1], [length(basis(first(ham_locs)))], first(ham_locs).data) @@ -100,6 +101,7 @@ function to_linear_operator(op::GraphOperator, T::Union{Nothing, Type{<:Number}} # default type T = isnothing(T) ? eltype(first(op_locs).data) : T + T = T<:Real ? Complex{T} : T op_loc = KLocalOperatorRow(T, [1], [length(basis(first(op_locs)))], first(op_locs).data) @@ -133,3 +135,26 @@ function to_linear_operator(op::GraphOperator, T::Union{Nothing, Type{<:Number}} return res_op end + +""" + to_matrix(operator) +Converts to a dense matrix the KLocal Operator +""" +function to_matrix(op::AbsLinearOperator, σ) + N = spacedimension(σ) + mat = zeros(ComplexF64, N, N) + + for i = 1:N + set_index!(σ, i) + fun = (mel, cngs, σ) -> begin + σp = apply(σ, cngs) + j = index(σp) + mat[i, j] += mel + end + + map_connections(fun, op, σ) + end + return mat +end + +Base.Matrix(op::AbsLinearOperator, σ) = to_matrix(op, σ) diff --git a/src/Operators/KLocalLiouvillian.jl b/src/Operators/KLocalLiouvillian.jl new file mode 100644 index 0000000..13e082d --- /dev/null +++ b/src/Operators/KLocalLiouvillian.jl @@ -0,0 +1,44 @@ +struct KLocalLiouvillian{T,A,B,C} <: AbsLinearOperator + sites::T + HnH_l::A + HnH_r::B + LLdag::C +end + +function KLocalLiouvillian(HnH, Lops) + T = eltype(HnH) + HnH_l = T(-1.0im) * KLocalOperatorTensor(HnH, nothing) + HnH_r = T(1.0im) * KLocalOperatorTensor(nothing, HnH') + + LLdag_list = [KLocalOperatorTensor(L, conj(L)) for L=Lops] + LLdag = isempty(LLdag_list) ? [] : sum(LLdag_list) + + return KLocalLiouvillian([], HnH_l, HnH_r, LLdag) +end + +sites(op::KLocalLiouvillian) = op.sites + +conn_type(op::KLocalLiouvillian) = conn_type(op.HnH_l) + +accumulate_connections!(a, b::Vector, c) = nothing + +function accumulate_connections!(acc::AbstractAccumulator, op::KLocalLiouvillian, v::DoubleState) + accumulate_connections!(acc, op.HnH_l, v) + accumulate_connections!(acc, op.HnH_r, v) + accumulate_connections!(acc, op.LLdag, v) + + return acc +end + +function map_connections(fun::Function, op::KLocalLiouvillian, v::DoubleState) + map_connections(fun, op.HnH_l, v) + map_connections(fun, op.HnH_r, v) + map_connections(fun, op.LLdag, v) + return nothing +end + +Base.show(io::IO, m::MIME"text/plain", op::KLocalLiouvillian) = begin + T = eltype(op.HnH_l) + + print(io, "KLocalLiouvillian($T)") +end diff --git a/src/Operators/KLocalOperator.jl b/src/Operators/KLocalOperator.jl index d1c6fa2..4d693cc 100644 --- a/src/Operators/KLocalOperator.jl +++ b/src/Operators/KLocalOperator.jl @@ -103,6 +103,10 @@ function KLocalOperatorRow(T::Type{<:Number}, sites::AbstractVector, hilb_dims:: mel, to_change, new_values, op_conns, new_indices) end +KLocalOperator(op::KLocalOperator, mat::AbstractMatrix) = + KLocalOperatorRow(copy(sites(op)), copy(hilb_dims(op)), mat) + + ## Accessors """ sites(op::KLocalOperator) @@ -112,8 +116,12 @@ acts, in no particular order. """ sites(op::KLocalOperator) = op.sites +hilb_dims(op::KLocalOperator) = op.hilb_dims + operators(op::KLocalOperator) = (op,) +densemat(op::KLocalOperator) = op.mat + conn_type(top::Type{KLocalOperator{SV,M,Vel,Vti,Vtc,Vtv,OC}}) where {SV, M, Vel, Vti, Vtc, Vtv, OC} = OpConnection{Vel, eltype(Vtc), eltype(Vtv)} conn_type(op::KLocalOperator{SV,M,Vel,Vti,Vtc,Vtv,OC}) where {SV, M, Vel, Vti, Vtc, Vtv, OC} = @@ -145,6 +153,16 @@ function row_valdiff_index!(conn::OpConnectionIndex, op::KLocalOperator, v::Stat append!(conn, (mel, ids)) end + +function map_connections(fun::Function, op::KLocalOperator, v::State) + r = local_index(v, sites(op)) + + for (mel, changes) = op.op_conns[r] + fun(mel, changes, v) + end + return nothing +end + function accumulate_connections!(acc::AbstractAccumulator, op::KLocalOperator, v::State) # If it is a doublestate, we are probably computing Operator x densitymatrix, # so we only iterate along the column of v @@ -161,8 +179,9 @@ function accumulate_connections!(acc::AbstractAccumulator, op::KLocalOperator, v return acc end -# sum -function sum_samesite!(op_l::KLocalOperator, op_r::KLocalOperator) +_sum_samesite(op_l::KLocalOperator, op_r::KLocalOperator) = _sum_samesite!(duplicate(op_l), op_r) + +function _sum_samesite!(op_l::KLocalOperator, op_r::KLocalOperator) @assert op_l.sites == op_r.sites @assert length(op_l.mel) == length(op_r.mel) @@ -202,25 +221,130 @@ function sum_samesite!(op_l::KLocalOperator, op_r::KLocalOperator) return op_l end -sum_samesite(op_l::KLocalOperator, op_r::KLocalOperator) = sum_samesite!(duplicate(op_l), op_r) - -Base.transpose(op::KLocalOperator) = KLocalOperatorRow(deepcopy(op.sites), - deepcopy(op.hilb_dims), - transpose(op.mat)|>collect) +Base.transpose(op::KLocalOperator) = + KLocalOperator(op, collect(transpose(op.mat))) function Base.conj!(op::KLocalOperator) conj!(op.mat) - conj!(op.op_conns) + map(conj!, op.op_conns) for el=op.mel - conj!.(el) + conj!(el) end return op end Base.conj(op::KLocalOperator) = conj!(duplicate(op)) +Base.adjoint(op::KLocalOperator) = conj(transpose(op)) + +*(a::Number, b::KLocalOperator) = + _op_alpha_prod(b,a) +*(b::KLocalOperator, a::Number) = + _op_alpha_prod(b,a) +_op_alpha_prod(op::KLocalOperator, a::Number) = + KLocalOperator(op, a*op.mat) + +function *(opl::KLocalOperator, opr::KLocalOperator) + if sites(opl) == sites(opr) + return KLocalOperator(opl, opl.mat * opr.mat) + else + disjoint = true + for s=sites(opr) + if s ∈ sites(opl) + disjoint = false + break + end + end + if disjoint + _kop_kop_disjoint_prod(opl,opr) + else + _kop_kop_joint_prod(opl, opr) + end + end +end + +function _kop_kop_disjoint_prod(opl::KLocalOperator, opr::KLocalOperator) + sl = sites(opl); sr = sites(opr) + if length(sl) == 1 && length(sr) == 1 + # it's commutative + if sl[1] > sr[1] + _op =opl + opl = opr + opr = _op + end + sl = first(sites(opl)) + sr = first(sites(opr)) -Base.adjoint(op::KLocalOperator) = conj(tranpose(op)) + hilb_dim_l = first(hilb_dims(opl)) + hilb_dim_r = first(hilb_dims(opr)) + + # inverted also in QuantumOptics... who knows why + mat = kron(opr.mat, opl.mat) + return KLocalOperatorRow([sl, sr], [hilb_dim_l, hilb_dim_r], mat) + else + sites_new = sort(vcat(sites(opl), sites(opr))) + ids_l = [findfirst(i .==sites_new) for i=sites(opl)] + ids_r = [findfirst(i .==sites_new) for i=sites(opr)] + throw("to implement error $(sites(opl)) and $(sites(opr))") + end +end + +function _kop_kop_joint_prod(opl::KLocalOperator, opr::KLocalOperator) + sl = sites(opl); sr = sites(opr) + if length(sl) == 1 || length(sr) == 1 + reversed = false + if length(sl) == 1 + _op = opl + opl = opr + opr = _op + reversed = true + end + # opl has many dims, opr only 1 + sr = first(sites(opr)) + r_index = findfirst(sr .== sl) + hdim_r = first(hilb_dims(opr)) + + matrices = [Matrix(I, d, d) for d=hilb_dims(opl)] + matrices[r_index] = opr.mat + mat_r = kron(matrices...) + prod_mat = reversed ? mat_r*opl.mat : opl.mat*mat_r + + return KLocalOperator(opl, prod_mat) + else + sites_new = sort(vcat(sites(opl), sites(opr))) + ids_l = [findfirst(i .==sites_new) for i=sites(opl)] + ids_r = [findfirst(i .==sites_new) for i=sites(opr)] + throw("to implement error") + end +end + +function permutesystems(a::AbstractMatrix, h_dims::Vector, perm::Vector{Int}) + #@assert length(a.basis_l.bases) == length(a.basis_r.bases) == length(perm) + #@assert isperm(perm) + data = reshape(a, [h_dims; h_dims]...) + data = permutedims(data, [perm; perm .+ length(perm)]) + data = reshape(data, prod(h_dims), prod(h_dims)) + return data +end Base.eltype(::T) where {T<:KLocalOperator} = eltype(T) Base.eltype(T::Type{KLocalOperator{SV,M,Vel,Vti,Vtc,Vtv,OC}}) where {SV,M,Vel,Vti,Vtc,Vtv,OC} = eltype(eltype(Vel)) + +Base.show(io::IO, op::KLocalOperator) = begin + T = eltype(op) + s = sites(op) + dims = hilb_dims(op) + mat = densemat(op) + + print(io, "KLocalOperatorRowa($T, $s, $dims, $mat)") +end + +Base.show(io::IO, m::MIME"text/plain", op::KLocalOperator) = begin + T = eltype(op) + s = sites(op) + dims = hilb_dims(op) + mat = densemat(op) + + print(io, "KLocalOperator($T)\n sites: $s\n Hilb: $dims\n") + Base.print_array(io, mat) +end diff --git a/src/Operators/KLocalOperatorSum.jl b/src/Operators/KLocalOperatorSum.jl index d51d78e..4f00a7c 100644 --- a/src/Operators/KLocalOperatorSum.jl +++ b/src/Operators/KLocalOperatorSum.jl @@ -46,6 +46,13 @@ function row_valdiff_index!(conn::OpConnectionIndex, op::KLocalOperatorSum, v::S conn end +function map_connections(fun::Function, ∑Ô::KLocalOperatorSum, v::State) + for Ô=operators(∑Ô) + map_connections(fun, Ô, v) + end + return nothing +end + function accumulate_connections!(acc::AbstractAccumulator, ∑Ô::KLocalOperatorSum, v::State) for Ô=operators(∑Ô) accumulate_connections!(acc, Ô, v) @@ -60,7 +67,7 @@ function Base.sum!(op_sum::KLocalOperatorSum, op::AbsLinearOperator) push!(op_sum.sites, sites(op)) push!(op_sum.operators, op) else - sum_samesite!(op_sum.operators[id], op) + _sum_samesite!(op_sum.operators[id], op) end return op_sum @@ -73,8 +80,13 @@ function Base.sum!(op_l::KLocalOperatorSum, op_r::KLocalOperatorSum) op_l end -+(op_l::KLocalOperatorSum, op::KLocalOperator) = sum!(duplicate(op_l), op) -+(op::KLocalOperator, ops::KLocalOperatorSum) = ops + op ++(op_l::KLocalOperatorSum, op::AbsLinearOperator) = sum!(duplicate(op_l), op) ++(op::AbsLinearOperator, ops::KLocalOperatorSum) = ops + op ++(op_l::KLocalOperatorSum, op_r::KLocalOperatorSum) = sum!(duplicate(op_l), op_r) ++(op_l::KLocalOperator, op_r::KLocalOperator) = begin + sites(op_l) == sites(op_r) && return _sum_samesite(op_l, op_r) + return KLocalOperatorSum(op_l) + op_r +end function Base.transpose(ops::KLocalOperatorSum) new_sites = similar(ops.sites) @@ -107,9 +119,39 @@ end Base.adjoint(ops::KLocalOperatorSum) = conj!(transpose(ops)) +function *(opl::KLocalOperatorSum, opr::KLocalOperator) + ∑op = duplicate(opl) + for (i,op)=enumerate(operators(∑op)) + op_new = op*opr + ∑op.operators[i] = op_new + ∑op.sites[i] = sites(op_new) + end + return ∑op +end + +function *(opl::KLocalOperator, opr::KLocalOperatorSum) + ∑op = duplicate(opr) + for (i,op)=enumerate(operators(∑op)) + op_new = opl*op + ∑op.operators[i] = op_new + ∑op.sites[i] = sites(op_new) + end + return ∑op +end + Base.show(io::IO, ::MIME"text/plain", op::KLocalOperatorSum) = print(io, "KLocalOperatorSum: \n\t -sites: $(op.sites)") Base.eltype(::T) where {T<:KLocalOperatorSum} = eltype(T) Base.eltype(T::Type{KLocalOperatorSum{Vec,VOp}}) where {Vec,VOp} = eltype(eltype(VOp)) + +*(a::Number, b::KLocalOperatorSum) = + _op_alpha_prod(b,a) +*(b::KLocalOperatorSum, a::Number) = + _op_alpha_prod(b,a) + +function _op_alpha_prod(ops::KLocalOperatorSum, a::Number) + op_all = [a*op for op=operators(ops)] + return sum(op_all) +end diff --git a/src/Operators/KLocalOperatorTensor.jl b/src/Operators/KLocalOperatorTensor.jl new file mode 100644 index 0000000..4888a69 --- /dev/null +++ b/src/Operators/KLocalOperatorTensor.jl @@ -0,0 +1,189 @@ +""" + KLocalOperatorTensor + +A KLocalOperator representing the sum of several KLocalOperator-s. Internally, +the sum is stored as a vector of local operators acting on some sites. +""" +struct KLocalOperatorTensor{T,O1,O2} <: AbsLinearOperator + sites::T + # list of sites in this sum + op_l::O1 + op_r::O2 +end + +function KLocalOperatorTensor(op_l, op_r) + if isnothing(op_l) + st = (Int[], sites(op_r)) + elseif isnothing(op_r) + st = (sites(op_l), Int[]) + else + st = (sites(op_l), sites(op_r)) + end + KLocalOperatorTensor(st, op_l, op_r) +end + +function KLocalOperatorTensor(op_l::KLocalOperatorSum, op_r::Nothing) + ops = [KLocalOperatorTensor(op, op_r) for op=operators(op_l)] + return sum(ops) +end + +function KLocalOperatorTensor(op_l::Nothing, op_r::KLocalOperatorSum) + ops = [KLocalOperatorTensor(op_l, op) for op=operators(op_r)] + return sum(ops) +end + +KLocalOperatorSum(op::KLocalOperatorTensor) = KLocalOperatorSum([sites(op)], [op]) + +function KLocalOperatorTensor(op_l::KLocalOperatorSum, op_r::KLocalOperatorSum) + throw("error not impl") +end + +sites(op::KLocalOperatorTensor) = op.sites + + +## Accessors +operators(op::KLocalOperatorTensor) = (op,) +conn_type(op::KLocalOperatorTensor{T,O1,O2}) where {T,O1,O2} = conn_type(op.op_l) +conn_type(op::KLocalOperatorTensor{T,Nothing,O2}) where {T,O2} = + conn_type(op.op_r) +conn_type(::Type{KLocalOperatorTensor{T,O1,O2}}) where {T,O1,O2} = + conn_type(O1) +conn_type(::Type{KLocalOperatorTensor{T,Nothing,O2}}) where{T,O2} = + conn_type(O2) + + +duplicate(::Nothing) = nothing + +function duplicate(op::KLocalOperatorTensor) + KLocalOperatorTensor(duplicate(op.op_l), duplicate(op.op_r)) +end + +function row_valdiff!(conn::OpConnection, op::KLocalOperatorTensor, v::DoubleState) + op_r = op.op_r + op_l = op.op_l + + if op_r === nothing + r_r = local_index(row(v), sites(op_l)) + append!(conn, op.op_conns[r]) + elseif op_l === nothing + r_c = local_index(col(v), sites(op_r)) + append!(conn, op.op_conns[r]) + else + r_r = local_index(row(v), sites(op_l)) + r_c = local_index(col(v), sites(op_r)) + #append!(conn, op.op_conns[r]) + throw("Not implemented") + end + return conn +end + + +function map_connections(fun::Function, op::KLocalOperatorTensor, v::DoubleState) + op_r = op.op_r + op_l = op.op_l + + if op_r === nothing + r = local_index(row(v), sites(op_l)) + + for (mel,changes)=op_l.op_conns[r] + #fun(mel, 1.0, changes, nothing, v) + fun(mel, (changes, nothing), v) + end + elseif op_l === nothing + r = local_index(col(v), sites(op_r)) + + for (mel,changes)=op_r.op_conns[r] + #fun(1.0, mel, nothing, changes, v) + fun(mel, (nothing, changes), v) + end + else + r_r = local_index(row(v), sites(op_l)) + r_c = local_index(col(v), sites(op_r)) + + for (mel_r, changes_r)=op_l.op_conns[r_r] + for (mel_c, changes_c)=op_r.op_conns[r_c] + #fun(mel_r, mel_c, changes_r, changes_c, v) + fun(mel_r*mel_c, (changes_r, changes_c), v) + end + end + end + return nothing +end + +function accumulate_connections!(acc::AbstractAccumulator, op::KLocalOperatorTensor, v::DoubleState) + op_l = op.op_l + op_r = op.op_r + + if op_r === nothing + r = local_index(row(v), sites(op_l)) + + for (mel,changes)=op_l.op_conns[r] + #fun(mel, 1.0, changes, nothing, v) + acc(mel, changes, nothing, v) + end + elseif op_l === nothing + r = local_index(col(v), sites(op_r)) + + for (mel,changes)=op_r.op_conns[r] + #fun(1.0, mel, nothing, changes, v) + acc(mel, nothing, changes, v) + end + else + r_r = local_index(row(v), sites(op_l)) + r_c = local_index(col(v), sites(op_r)) + + for (mel_r, changes_r)=op_l.op_conns[r_r] + for (mel_c, changes_c)=op_r.op_conns[r_c] + acc(mel_r*mel_c, changes_r, changes_c, v) + end + end + end + return acc +end + + +_sum_samesite(op_l::KLocalOperatorTensor, op_r::KLocalOperatorTensor) = _sum_samesite!(duplicate(op_l), op_r) + +_sum_samesite!(::Nothing, ::Nothing) = nothing +function _sum_samesite!(op_l::KLocalOperatorTensor, op_r::KLocalOperatorTensor) + @assert op_l.sites == op_r.sites + + _sum_samesite!(op_l.op_l, op_r.op_l) + _sum_samesite!(op_l.op_r, op_r.op_r) + return op_l +end + +function Base.conj!(op::KLocalOperatorTensor) + !isnothing(op.op_l) && conj!(op.op_l) + !isnothing(op.op_r) && conj!(op.op_r) + return op +end + +Base.transpose(op::KLocalOperatorTensor) = + KLocalOperatorTensor(op.op_r, op.op_l) + +Base.conj(op::KLocalOperatorTensor) = conj!(duplicate(op)) + +Base.adjoint(op::KLocalOperatorTensor) = conj(transpose(op)) + + +Base.:+(op_l::KLocalOperatorTensor, op_r::KLocalOperatorTensor) = begin + sites(op_l) == sites(op_r) && return _sum_samesite(op_l, op_r) + return KLocalOperatorSum(op_l) + op_r +end + +Base.:*(a::Number, b::KLocalOperatorTensor) = + _op_alpha_prod(b,a) +Base.:*(b::KLocalOperatorTensor, a::Number) = + _op_alpha_prod(b,a) + +function _op_alpha_prod(op::KLocalOperatorTensor, a::Number) + if isnothing(op.op_l) + return KLocalOperatorTensor(nothing, a*op.op_r) + elseif isnothing(op.op_r) + return KLocalOperatorTensor(a*op.op_l, nothing) + else + a = sqrt(a) + return KLocalOperatorTensor(a*op.op_l, a*op.op_r) + end +end diff --git a/src/Problems/Hamiltonian/HamiltonianGSEnergyProblem.jl b/src/Problems/Hamiltonian/HamiltonianGSEnergyProblem.jl index 50c2def..31e5062 100644 --- a/src/Problems/Hamiltonian/HamiltonianGSEnergyProblem.jl +++ b/src/Problems/Hamiltonian/HamiltonianGSEnergyProblem.jl @@ -6,17 +6,17 @@ struct HamiltonianGSEnergyProblem{B, SM} <: HermitianMatrixProblem where {B<:Bas end HamiltonianGSEnergyProblem(args...) = HamiltonianGSEnergyProblem(STD_REAL_PREC, args...) -HamiltonianGSEnergyProblem(T::Type{<:Number}, gl::GraphOperator; operators=true) = begin +HamiltonianGSEnergyProblem(T::Type{<:Real}, gl::GraphOperator; operators=true) = begin if operators - return HamiltonianGSEnergyProblem(basis(gl), to_linear_operator(gl), 0.0) + return HamiltonianGSEnergyProblem(basis(gl), to_linear_operator(gl, T), 0.0) else return HamiltonianGSEnergyProblem(T, SparseOperator(gl)) end end HamiltonianGSEnergyProblem(T::Type{<:Number}, Ham::SparseOperator) = - HamiltonianGSEnergyProblem(Ham.basis_l, data(Ham), 0.0) + HamiltonianGSEnergyProblem(Ham.basis_l, Complex{T}.(data(Ham)), 0.0) -basis(prob::HamiltonianGSEnergyProblem) = prob.HilbSpace +QuantumOpticsBase.basis(prob::HamiltonianGSEnergyProblem) = prob.HilbSpace function compute_Cloc(prob::HamiltonianGSEnergyProblem{B,SM}, net::KetNet, σ::State, lnψ=net(σ), σp=deepcopy(σ)) where {B,SM<:SparseMatrixCSC} @@ -53,9 +53,7 @@ function compute_Cloc(prob::HamiltonianGSEnergyProblem{B,SM}, net::KetNet, σ::S r = local_index(σ, sites(op)) for (mel, changes)=op.op_conns[r] set_index!(σp, index(σ)) - for (site,val)=changes - setat!(σp, site, val) - end + apply!(σp, changes) log_ratio = logψ(net, σp) - lnψ C_loc += mel * exp(log_ratio) diff --git a/src/Problems/Hamiltonian/HamiltonianGSVarianceProblem.jl b/src/Problems/Hamiltonian/HamiltonianGSVarianceProblem.jl new file mode 100644 index 0000000..e092c6f --- /dev/null +++ b/src/Problems/Hamiltonian/HamiltonianGSVarianceProblem.jl @@ -0,0 +1,87 @@ +struct HamiltonianGSVarianceProblem{B, SM} <: LRhoSquaredProblem where {B<:Basis, + SM} + HilbSpace::B # 0 + H::SM + ρss +end + +HamiltonianGSVarianceProblem(args...) = HamiltonianGSVarianceProblem(STD_REAL_PREC, args...) +HamiltonianGSVarianceProblem(T::Type{<:Number}, gl::GraphOperator; operators=true) = begin + if operators + return HamiltonianGSVarianceProblem(basis(gl), to_linear_operator(gl), 0.0) + else + return HamiltonianGSVarianceProblem(T, SparseOperator(gl)) + end +end +HamiltonianGSVarianceProblem(T::Type{<:Number}, Ham::SparseOperator) = + HamiltonianGSVarianceProblem(Ham.basis_l, data(Ham), 0.0) + +QuantumOpticsBase.basis(prob::HamiltonianGSVarianceProblem) = prob.HilbSpace + +function compute_Cloc!(LLO_i, ∇lnψ, prob::HamiltonianGSVarianceProblem{B,SM}, net::KetNet, σ::State, + lnψ=net(σ), σp=deepcopy(σ)) where {B,SM<:SparseMatrixCSC} + H = prob.H + + for el=LLO_i + el .= 0.0 + end + + #### Now compute E(S) = Σₛ⟨s|Hψ⟩/⟨s|ψ⟩ + C_loc = zero(Complex{real(out_type(net))}) + # Iterate through all elements in row i_σ of the matrix computing + # ⟨i_σ|H|ψ⟩ = Σ_{i_σp} ⟨i_σ|H|i_σp⟩⟨i_σp|ψ⟩ + # NOTE: H is CSC, but I would like a CSR matrix. Since it is hermitian I + # can simply take the conjugate of the elements in the columns + i_σ = index(σ) + for row_id = H.colptr[i_σ]:(H.colptr[i_σ+1]-1) + # Find nonzero elements s by doing + i_σp = H.rowval[row_id] + # BackConvert to int + set_index!(σp, i_σp) + # Compute the log(ψ(σ)/ψ(σ')), by only computing differences. + lnψ_i, ∇lnψ_i = logψ_and_∇logψ!(∇lnψ, net, σp) + C_loc_i = H.nzval[row_id] * exp(lnψ_i - lnψ) #TODO check + + C_loc += C_loc_i + for (LLOave, _∇lnψ)= zip(LLO_i, ∇lnψ_i.tuple_all_weights) + LLOave .+= C_loc_i .* (_∇lnψ) + end + end + + return C_loc +end + +function compute_Cloc!(∇𝒞σ, ∇lnψ, prob::HamiltonianGSVarianceProblem{B,SM}, net::KetNet, σ::State, + lnψσ=net(σ), σp=deepcopy(σ)) where {B,SM<:AbsLinearOperator} + H = prob.H + + for ∇𝒞ᵢ=∇𝒞σ + ∇𝒞ᵢ .= 0.0 + end + + #### Now compute E(S) = Σₛ⟨s|H|ψ⟩/⟨s|ψ⟩ + Cσ = zero(Complex{real(out_type(net))}) + for op=operators(H) + r = local_index(σ, sites(op)) + for (mel, changes)=op.op_conns[r] + set_index!(σp, index(σ)) + apply!(σp, changes) + + lnψ_σ̃ , ∇lnψ_σ̃ = logψ_and_∇logψ!(∇lnψ, net, σp) + 𝒞σ̃ = mel * exp(lnψ_σ̃ - lnψσ) + for (∇𝒞σᵢ, ∇lnψ_σ̃ᵢ)= zip(∇𝒞σ, ∇lnψ_σ̃.tuple_all_weights) + ∇𝒞σᵢ .+= 𝒞σ̃ .* ∇lnψ_σ̃ᵢ + end + Cσ += 𝒞σ̃ + end + end + + return Cσ +end + +# pretty printing +Base.show(io::IO, p::HamiltonianGSVarianceProblem) = print(io, + """ + HamiltonianGSVarianceProblem: target minimum ground state energy + - space : $(basis(p)) + - using operators : $(p.H isa AbsLinearOperator)""") diff --git a/src/Problems/ObservablesProblem.jl b/src/Problems/ObservablesProblem.jl index b54bf4b..e257ccd 100644 --- a/src/Problems/ObservablesProblem.jl +++ b/src/Problems/ObservablesProblem.jl @@ -41,7 +41,7 @@ function ObservablesProblem(T::Type{<:Number}, obs::Any...; operator=true) op = el end push!(names, Symbol(name)) - push!(matrices_trans, to_linear_operator(op)) + push!(matrices_trans, to_linear_operator(op, T)) b = basis(op) end @@ -81,7 +81,7 @@ function ObservablesProblem(T::Type{<:Number}, obs::Any...; operator=true) end -basis(prob::ObservablesProblem) = prob.HilbSpace +QuantumOpticsBase.basis(prob::ObservablesProblem) = prob.HilbSpace state(T::Type{<:Number}, prob::ObservablesProblem, net::MatrixNet) = DiagonalStateWrapper(state(T, basis(prob), net)) diff --git a/src/Problems/SteadyStateLindblad/LRhoKLocalOpProblem.jl b/src/Problems/SteadyStateLindblad/LRhoKLocalOpProblem.jl index 8181a25..1314a26 100644 --- a/src/Problems/SteadyStateLindblad/LRhoKLocalOpProblem.jl +++ b/src/Problems/SteadyStateLindblad/LRhoKLocalOpProblem.jl @@ -20,7 +20,7 @@ function LRhoKLocalOpProblem(T, gl::GraphLindbladian) return LRhoKLocalOpProblem(basis(gl), HnH, c_ops, 0.0) end -basis(prob::LRhoKLocalOpProblem) = prob.HilbSpace +QuantumOpticsBase.basis(prob::LRhoKLocalOpProblem) = prob.HilbSpace # Standard method dispatched when the state is generic (non lut). # will work only if 𝝝 and 𝝝p are the same type (and non lut!) @@ -28,6 +28,8 @@ function compute_Cloc!(LLO_i, ∇lnψ, prob::LRhoKLocalOpProblem, net::MatrixNet, 𝝝::S, lnψ=net(𝝝), 𝝝p::S=deepcopy(𝝝)) where {S} # hey + T = real(out_type(net)) + CT = Complex{T} HnH = prob.HnH L_ops = prob.L_ops @@ -39,7 +41,7 @@ function compute_Cloc!(LLO_i, ∇lnψ, prob::LRhoKLocalOpProblem, el .= 0.0 end - C_loc = zero(Complex{real(out_type(net))}) + C_loc = zero(CT) # ⟨σ|Hρ|σt⟩ (using hermitianity of HdH) # diffs_hnh = row_valdiff(HnH, row(𝝝)) @@ -51,7 +53,7 @@ function compute_Cloc!(LLO_i, ∇lnψ, prob::LRhoKLocalOpProblem, apply!(𝝝p_row, changes) lnψ_i, ∇lnψ_i = logψ_and_∇logψ!(∇lnψ, net, 𝝝p) - C_loc_i = -1.0im * mel * exp(lnψ_i - lnψ) + C_loc_i = -T(1.0)im * mel * exp(lnψ_i - lnψ) for (LLOave, _∇lnψ)= zip(LLO_i, ∇lnψ_i.tuple_all_weights) LLOave .+= C_loc_i .* _∇lnψ end @@ -69,7 +71,7 @@ function compute_Cloc!(LLO_i, ∇lnψ, prob::LRhoKLocalOpProblem, apply!(𝝝p_col, changes) lnψ_i, ∇lnψ_i = logψ_and_∇logψ!(∇lnψ, net, 𝝝p) - C_loc_i = 1.0im * conj(mel) * exp(lnψ_i - lnψ) + C_loc_i = T(1.0)im * conj(mel) * exp(lnψ_i - lnψ) for (LLOave, _∇lnψ)= zip(LLO_i, ∇lnψ_i.tuple_all_weights) LLOave .+= C_loc_i .* _∇lnψ end diff --git a/src/Problems/SteadyStateLindblad/LRhoKLocalSOpProblem.jl b/src/Problems/SteadyStateLindblad/LRhoKLocalSOpProblem.jl new file mode 100644 index 0000000..ef9e7e3 --- /dev/null +++ b/src/Problems/SteadyStateLindblad/LRhoKLocalSOpProblem.jl @@ -0,0 +1,25 @@ +""" + LRhoKLocalOpProblem <: AbstractProblem + +Problem or finding the steady state of a ℒdagℒ matrix by computing +𝒞 = ∑|ρ(σ)|²|⟨⟨σ|ℒ |ρ⟩⟩|² using the sparse Liouvillian matrix. + +DO NOT USE WITH COMPLEX-WEIGHT NETWORKS, AS IT DOES NOT WORK +""" +struct LRhoKLocalSOpProblem{B, LL} <: LRhoSquaredProblem where {B<:Basis} + HilbSpace::B # 0 + L::LL +end + +LRhoKLocalSOpProblem(gl::GraphLindbladian) = LRhoKLocalSOpProblem(STD_REAL_PREC, gl) +function LRhoKLocalSOpProblem(T, gl::GraphLindbladian) + HnH, c_ops, c_ops_t = to_linear_operator(gl, Complex{real(T)}) + Liouv = KLocalLiouvillian(HnH, c_ops) + return LRhoKLocalSOpProblem(basis(gl), Liouv) +end + +QuantumOpticsBase.basis(prob::LRhoKLocalSOpProblem) = prob.HilbSpace + +# pretty printing +Base.show(io::IO, p::LRhoKLocalSOpProblem) = print(io, + "LRhoKLocalSOpProblem on space $(basis(p)) computing the variance of Lrho using the sparse liouvillian") diff --git a/src/Problems/SteadyStateLindblad/LRhoSparseOpProblem.jl b/src/Problems/SteadyStateLindblad/LRhoSparseOpProblem.jl index 96aaaf8..540118d 100644 --- a/src/Problems/SteadyStateLindblad/LRhoSparseOpProblem.jl +++ b/src/Problems/SteadyStateLindblad/LRhoSparseOpProblem.jl @@ -52,7 +52,7 @@ function LRhoSparseOpProblem(T::Type{<:Number}, Hilb::Basis, Ham::DataOperator, c_ops[i] = c_ops_q[i].data c_ops_h[i] = c_ops[i]' c_ops_trans[i] = transpose(c_ops[i]) - H_eff -= 0.5im * (c_ops[i]'*c_ops[i]) + H_eff -= T(0.5im) * (c_ops[i]'*c_ops[i]) end LRhoSparseOpProblem{typeof(Hilb), ST}(Hilb, # 0 @@ -64,7 +64,7 @@ function LRhoSparseOpProblem(T::Type{<:Number}, Hilb::Basis, Ham::DataOperator, 0.0) end -basis(prob::LRhoSparseOpProblem) = prob.HilbSpace +QuantumOpticsBase.basis(prob::LRhoSparseOpProblem) = prob.HilbSpace function compute_Cloc!(LLO_i, ∇lnψ, prob::LRhoSparseOpProblem, net::MatrixNet, 𝝝, lnψ=net(𝝝), 𝝝p=deepcopy(𝝝)) @@ -74,6 +74,9 @@ function compute_Cloc!(LLO_i, ∇lnψ, prob::LRhoSparseOpProblem, net::MatrixNet c_ops_h = prob.L_ops_h c_ops_trans = prob.L_ops_t + T = real(out_type(net)) + CT = Complex{T} + σ = row(𝝝) σt = col(𝝝) set_index!(𝝝p, index(𝝝)) @@ -95,7 +98,7 @@ function compute_Cloc!(LLO_i, ∇lnψ, prob::LRhoSparseOpProblem, net::MatrixNet i_σ_p = HnH_t.rowval[row_id] set_index!(𝝝p_row, i_σ_p) lnψ_i, ∇lnψ_i = logψ_and_∇logψ!(∇lnψ, net, 𝝝p) - C_loc_i = -1.0im * HnH_t.nzval[row_id] * exp(lnψ_i - lnψ) + C_loc_i = -T(1.0)im * HnH_t.nzval[row_id] * exp(lnψ_i - lnψ) for (LLOave, _∇lnψ)= zip(LLO_i, ∇lnψ_i.tuple_all_weights) LLOave .+= C_loc_i .* _∇lnψ @@ -109,7 +112,7 @@ function compute_Cloc!(LLO_i, ∇lnψ, prob::LRhoSparseOpProblem, net::MatrixNet i_σ_p = HnH.rowval[row_id] set_index!(𝝝p_col, i_σ_p) lnψ_i, ∇lnψ_i = logψ_and_∇logψ!(∇lnψ, net, 𝝝p) - C_loc_i = 1.0im * conj(HnH_t.nzval[row_id]) * exp(lnψ_i - lnψ) + C_loc_i = T(1.0)im * conj(HnH_t.nzval[row_id]) * exp(lnψ_i - lnψ) for (LLOave, _∇lnψ)= zip(LLO_i, ∇lnψ_i.tuple_all_weights) LLOave .+= C_loc_i .* _∇lnψ diff --git a/src/Problems/SteadyStateLindblad/LRhoSparseSuperopProblem.jl b/src/Problems/SteadyStateLindblad/LRhoSparseSuperopProblem.jl index 870cd4e..449df96 100644 --- a/src/Problems/SteadyStateLindblad/LRhoSparseSuperopProblem.jl +++ b/src/Problems/SteadyStateLindblad/LRhoSparseSuperopProblem.jl @@ -33,7 +33,7 @@ LRhoSparseSuperopProblem(T::Type{<:Number}, Liouv::SparseSuperOperator) = LRhoSparseSuperopProblem(T::Type{<:Number}, Ham::DataOperator, cops::Vector) = LRhoSparseSuperopProblem(T, liouvillian(Ham, cops)) -basis(prob::LRhoSparseSuperopProblem) = prob.HilbSpace +QuantumOpticsBase.basis(prob::LRhoSparseSuperopProblem) = prob.HilbSpace function compute_Cloc!(LLO_i, ∇lnψ, prob::LRhoSparseSuperopProblem, net::MatrixNet, σ, lnψ=net(σ), σp=deepcopy(σ)) diff --git a/src/Problems/SteadyStateLindblad/LdagLSparseOpProblem.jl b/src/Problems/SteadyStateLindblad/LdagLSparseOpProblem.jl index d22829b..18bc52c 100644 --- a/src/Problems/SteadyStateLindblad/LdagLSparseOpProblem.jl +++ b/src/Problems/SteadyStateLindblad/LdagLSparseOpProblem.jl @@ -18,7 +18,7 @@ struct LdagLSparseOpProblem{B, SM} <: HermitianMatrixProblem where {B<:Basis, ρss end -basis(prob::LdagLSparseOpProblem) = prob.HilbSpace +QuantumOpticsBase.basis(prob::LdagLSparseOpProblem) = prob.HilbSpace """ @@ -53,7 +53,7 @@ function LdagLSparseOpProblem(T::Type{<:Number}, Hilb::Basis, Ham::DataOperator, for i=1:length(c_ops) c_ops[i] = c_ops_q[i].data c_ops_trans[i] = transpose(c_ops[i]) - H_eff -= 0.5im * (c_ops[i]'*c_ops[i]) + H_eff -= T(0.5im) * (c_ops[i]'*c_ops[i]) end LdH_ops = Vector{ST}(undef, length(c_ops)) @@ -94,6 +94,9 @@ function compute_Cloc(prob::LdagLSparseOpProblem, net::MatrixNet, 𝝝, lnψ=net HdL_ops = prob.HdL_ops LdL_ops_t = prob.LdL_ops_t + T = real(out_type(net)) + CT = Complex{T} + σ = row(𝝝) σt = col(𝝝) set_index!(𝝝p, index(𝝝)) @@ -182,7 +185,7 @@ function compute_Cloc(prob::LdagLSparseOpProblem, net::MatrixNet, 𝝝, lnψ=net log_ratio = logψ(net, 𝝝p) - lnψ #@assert (mat[i_σ_p, i_σt_p] - log_ratio) == 0 - ΔE = -1.0im * val_σ_p * LdH.nzval[int_row_id] * exp(log_ratio) # mat[i_σ_p, i_σt_p] + ΔE = -T(1.0)im * val_σ_p * LdH.nzval[int_row_id] * exp(log_ratio) # mat[i_σ_p, i_σt_p] C_loc += ΔE # 2.0*real(ΔE) end end @@ -205,7 +208,7 @@ function compute_Cloc(prob::LdagLSparseOpProblem, net::MatrixNet, 𝝝, lnψ=net log_ratio = logψ(net, 𝝝p) - lnψ #@assert (mat[i_σt_p, i_σ_p] - log_ratio) == 0 - ΔE = 1.0im * conj(val_σ_p) * conj(LdH.nzval[int_row_id]) * exp(log_ratio) # mat[i_σt_p, i_σ_p] + ΔE = T(1.0)im * conj(val_σ_p) * conj(LdH.nzval[int_row_id]) * exp(log_ratio) # mat[i_σt_p, i_σ_p] C_loc += ΔE # 2.0*real(ΔE) end end @@ -230,7 +233,7 @@ function compute_Cloc(prob::LdagLSparseOpProblem, net::MatrixNet, 𝝝, lnψ=net log_ratio = logψ(net, 𝝝p) - lnψ #@assert (mat[i_σ_p, i_σt_p] - log_ratio) == 0 - ΔE = 1.0im * conj(val_σ_p) * HdL.nzval[int_row_id] * exp(log_ratio) # mat[i_σ_p, i_σt_p] + ΔE = T(1.0)im * conj(val_σ_p) * HdL.nzval[int_row_id] * exp(log_ratio) # mat[i_σ_p, i_σt_p] C_loc += ΔE # 2.0*real(ΔE) end end @@ -253,7 +256,7 @@ function compute_Cloc(prob::LdagLSparseOpProblem, net::MatrixNet, 𝝝, lnψ=net log_ratio = logψ(net, 𝝝p) - lnψ #@assert (mat[i_σt_p, i_σ_p] - log_ratio) == 0 - ΔE = -1.0im * val_σ_p * conj(HdL.nzval[int_row_id]) * exp(log_ratio) # mat[i_σt_p, i_σ_p] + ΔE = -T(1.0)im * val_σ_p * conj(HdL.nzval[int_row_id]) * exp(log_ratio) # mat[i_σt_p, i_σ_p] C_loc += ΔE # 2.0*real(ΔE) end end diff --git a/src/Problems/SteadyStateLindblad/LdagLSparseSuperopProblem.jl b/src/Problems/SteadyStateLindblad/LdagLSparseSuperopProblem.jl index 1f56b06..3653623 100644 --- a/src/Problems/SteadyStateLindblad/LdagLSparseSuperopProblem.jl +++ b/src/Problems/SteadyStateLindblad/LdagLSparseSuperopProblem.jl @@ -32,7 +32,7 @@ LdagLSparseSuperopProblem(T::Type{<:Number}, Ham::DataOperator, cops::Vector) = LdagLSparseSuperopProblem(T::Type{<:Number}, Liouv::SparseSuperOperator) = LdagLSparseSuperopProblem(first(Liouv.basis_l), Liouv.data'*Liouv.data, 0.0) -basis(prob::LdagLSparseSuperopProblem) = prob.HilbSpace +QuantumOpticsBase.basis(prob::LdagLSparseSuperopProblem) = prob.HilbSpace function compute_Cloc(prob::LdagLSparseSuperopProblem, net::MatrixNet, σ, lnψ=net(σ), σp=deepcopy(σ)) ℒdagℒ = prob.LdagL diff --git a/src/Problems/SteadyStateLindblad/build_SteadyStateProblem.jl b/src/Problems/SteadyStateLindblad/build_SteadyStateProblem.jl index a239568..091872b 100644 --- a/src/Problems/SteadyStateLindblad/build_SteadyStateProblem.jl +++ b/src/Problems/SteadyStateLindblad/build_SteadyStateProblem.jl @@ -1,7 +1,7 @@ export SteadyStateProblem """ - SteadyStateProblem([T=STD_REAL_PREC], ℒ, operators=true, variance=true) + SteadyStateProblem([T=STD_REAL_PREC], ℒ, operators=true, variance=true, superop=false) SteadyStateProblem([T=STD_REAL_PREC], H, J, operators=true, variance=true) Returns the problem targeting the steady-state of the Liouvillian `ℒ` through @@ -12,6 +12,8 @@ better convergence properies. The sampling is performed on ℒ'ℒ ρ otherwise. See appendix of https://arxiv.org/abs/1902.10104 for more info. +If `superop=false` a Klocal operator of the lindbladian is used (unstable). + If `operators=true` a memory efficient representation of the hamiltonian is used, resulting in less memory consuption but higher CPU usage. This is needed for lattices bigger than a certain threshold. @@ -34,10 +36,18 @@ function SteadyStateProblem(T::Type{<:Number}, H::DataOperator, J::AbstractVecto end # Dispatched when called with the object for the whole liouvillian -function SteadyStateProblem(T::Type{<:Number}, ℒ; operators=true, variance=true, kwargs...) +function SteadyStateProblem(T::Type{<:Number}, ℒ; operators=true, variance=true, superop=false, kwargs...) base = basis(ℒ) - if operators + if superop + if !variance + throw("Can't use operators=true and variance=false. Operators are not + compatible with non-variance minimization.") + end + + return LRhoKLocalSOpProblem(T, ℒ) + + elseif operators if !variance throw("Can't use operators=true and variance=false. Operators are not compatible with non-variance minimization.") diff --git a/src/Problems/base_problems.jl b/src/Problems/base_problems.jl new file mode 100644 index 0000000..712ddca --- /dev/null +++ b/src/Problems/base_problems.jl @@ -0,0 +1,5 @@ +abstract type AbstractSteadyStateProblem <: AbstractProblem end +abstract type HermitianMatrixProblem <: AbstractSteadyStateProblem end +abstract type LRhoSquaredProblem <: AbstractSteadyStateProblem end +abstract type OpenTimeEvolutionProblem <: AbstractSteadyStateProblem end +abstract type OperatorEstimationProblem <: AbstractProblem end diff --git a/src/Problems/time_evo/time_evo_L.jl b/src/Problems/time_evo/time_evo_L.jl index 0423ab0..d9d8546 100644 --- a/src/Problems/time_evo/time_evo_L.jl +++ b/src/Problems/time_evo/time_evo_L.jl @@ -19,7 +19,7 @@ time_evo_L(T::Type{<:Number}, Liouv::SparseSuperOperator) = time_evo_L(T::Type{<:Number}, Ham::DataOperator, cops::Vector) = time_evo_L(T, liouvillian(Ham, cops)) -basis(prob::time_evo_L) = prob.HilbSpace +QuantumOpticsBase.basis(prob::time_evo_L) = prob.HilbSpace function compute_Cloc(prob::time_evo_L, net::MatrixNet, σ, lnψ=net(σ), σp=deepcopy(σ)) ℒ = prob.L diff --git a/src/Samplers/Exact.jl b/src/Samplers/Exact.jl index b6db383..ea99e72 100644 --- a/src/Samplers/Exact.jl +++ b/src/Samplers/Exact.jl @@ -35,6 +35,8 @@ function init_sampler!(sampler::ExactSampler, net::Union{MatrixNet,KetNet}, σ, return c end +chain_length(s::ExactSampler, c::ExactSamplerCache) = s.samples_length + done(s::ExactSampler, σ, c) = c.steps_done >= s.samples_length function samplenext!(σ, s::ExactSampler, net::Union{MatrixNet,KetNet}, c) diff --git a/src/Samplers/FullSum.jl b/src/Samplers/FullSum.jl index 354bd7a..107f4bc 100644 --- a/src/Samplers/FullSum.jl +++ b/src/Samplers/FullSum.jl @@ -18,6 +18,8 @@ function init_sampler!(sampler::FullSumSampler, net, σ::FiniteBasisState, c::Fu return c end +chain_length(s::FullSumSampler, c::FullSumSamplerCache) = length(c.interval) + done(s::FullSumSampler, σ, c) = c.last_position >= length(c.interval) function samplenext!(σ, s::FullSumSampler, net, c) diff --git a/src/Samplers/MCMCSampler.jl b/src/Samplers/MCMCSampler.jl index 541fa27..889499e 100644 --- a/src/Samplers/MCMCSampler.jl +++ b/src/Samplers/MCMCSampler.jl @@ -36,7 +36,6 @@ function init_sampler!(s::MCMCSampler, net, σ, c::MCMCSamplerCache) c.steps_done = 0 c.steps_accepted = 0 set_index!(σ, rand(c.rng, 1:spacedimension(σ))) - init_lut!(σ, net) while c.steps_done < s.burn_length markov_chain_step!(σ, s, net, c) @@ -49,7 +48,9 @@ end init_sampler_rule_cache!(rc, s::MCMCSampler, net, σ, c::MCMCSamplerCache) = nothing -done(s::MCMCSampler, σ, c) = c.steps_done >= s.chain_length +chain_length(s::MCMCSampler, c::MCMCSamplerCache) = s.chain_length + +done(s::MCMCSampler, σ, c) = c.steps_done >= s.chain_length-1 function samplenext!(σ, s::MCMCSampler, net, c) # Check termination condition, and return if verified diff --git a/src/Samplers/base_samplers.jl b/src/Samplers/base_samplers.jl index d5255e2..5290f42 100644 --- a/src/Samplers/base_samplers.jl +++ b/src/Samplers/base_samplers.jl @@ -30,3 +30,10 @@ provided, one will be initialized and returned. The state σ is the first in the list of sampled states. """ init_sampler!(s::Sampler, net, σ) = init_sampler!(s, net, σ, cache(s, σ, net)) + +""" + chain_length(sampler, sampler_cache) -> Int + +Returns the estimated length of the chain. +""" +function chain_length end diff --git a/src/States/DiagonalStateWrapper.jl b/src/States/DiagonalStateWrapper.jl index 3b3ab35..cd1ac2e 100644 --- a/src/States/DiagonalStateWrapper.jl +++ b/src/States/DiagonalStateWrapper.jl @@ -13,7 +13,7 @@ index(s::DiagonalStateWrapper) = index(row(s.parent)) index_to_int(s::DiagonalStateWrapper, id) = index_to_int(row(s.parent), id) flipped(a::DiagonalStateWrapper, b::DiagonalStateWrapper) = flipped(a.parent, b.parent) -@inline eltype(state::DiagonalStateWrapper) = eltype(state.parent) +@inline Base.eltype(state::DiagonalStateWrapper) = eltype(state.parent) @inline config(state::DiagonalStateWrapper) = config(state.parent) zero!(s::DiagonalStateWrapper) = zero!(s.parent) @@ -45,7 +45,7 @@ function add!(s::DiagonalStateWrapper, i) return s end -function rand!(rng::AbstractRNG, s::DiagonalStateWrapper) +function Random.rand!(rng::AbstractRNG, s::DiagonalStateWrapper) rand!(rng, row(s.parent)) set!(col(s.parent), toint(row(s.parent))) return s diff --git a/src/States/DoubleState.jl b/src/States/DoubleState.jl index d3321e5..c829b20 100644 --- a/src/States/DoubleState.jl +++ b/src/States/DoubleState.jl @@ -27,7 +27,7 @@ DoubleState{ST}(n, i_σ=0) where ST = set!(DoubleState(ST(n, 0), ST(n, 0)), i_σ @inline spacedimension(state::DoubleState) = state.space_dim @inline nsites(state::DoubleState) = 2*nsites(state.σ_row) @inline local_dimension(state::DoubleState{ST}) where {ST} = local_dimension(ST) -@inline eltype(state::DoubleState) = eltype(row(state)) +@inline Base.eltype(state::DoubleState) = eltype(row(state)) toint(state::DoubleState) = _toint(col(state), row(state)) #toint(state.σ_row, state.σ_col) index(state::DoubleState) = toint(state) + 1 # was before @@ -56,27 +56,36 @@ function setat!(v::DoubleState, i::Int, val) i > v.n ? setat!(v.σ_row, i-v.n, val) : setat!(v.σ_col, i, val) end +""" + apply!(state::State, changes) + +Applies the changes `changes` to the `state`. +If `state isa DoubleState` then single-value changes +are applied to the columns of the state (in order to +compute matrix-operator products). Otherwise it should +be a tuple with changes of row and columns +""" function apply!(state::DoubleState, changes::StateChanges) - for (site, val) = row(changes) + for (site, val) = changes + #setat!(col(state), site, val) + # The code below automatically applies it only + # to columns. setat!(state, site, val) end end -function apply!(state::DoubleState, changes::Tuple{StateChanges}) - changes_r, changes_c = changes - for (id, val) = row(changes_r) - setat!(row(state), id, val) - end - for (id, val) = col(changes_c) - setat!(col(state), id, val) - end +@inline apply!(state::DoubleState, (changes_r, changes_c)::Tuple) = + apply!(state, changes_r, changes_c) +function apply!(state::DoubleState, changes_r, changes_c) + apply!(row(state), changes_r) + apply!(col(state), changes_c) return state end set_index!(v::DoubleState, i::Integer) = set!(v, index_to_int(v, i)) function set!(v::DoubleState, i::Integer) row = div(i, spacedimension(v.σ_row)) #row = i>>(nsites(state.σ_row)) - col = i - row*spacedimension(v.σ_row)#col = i - (row<< nsites(state.σ_row)) + col = i - row*spacedimension(v.σ_row) #col = i - (row<< nsites(state.σ_row)) set!(v.σ_col, row) #i set!(v.σ_row, col) #j @@ -84,7 +93,7 @@ function set!(v::DoubleState, i::Integer) end set!(v::DoubleState, i_row, i_col) = (set!(row(v), i_row); set!(col(v), i_col); v) -function rand!(rng::AbstractRNG, state::DoubleState) +function Random.rand!(rng::AbstractRNG, state::DoubleState) rand!(rng, state.σ_row) rand!(rng, state.σ_col) state diff --git a/src/States/ModifiedState.jl b/src/States/ModifiedState.jl index 243ce35..9f29b10 100644 --- a/src/States/ModifiedState.jl +++ b/src/States/ModifiedState.jl @@ -51,7 +51,7 @@ raw_config(s::DoubleState{<:ModifiedState}) = (raw_config(s.σ_row), raw_config( @inline nsites(s::ModifiedState) = nsites(raw_state(s)) @inline local_dimension(s::ModifiedState{S,C}) where {S,C} = local_dimension(S) @inline local_dimension(s::Type{ModifiedState{S,C}}) where {S,C} = local_dimension(S) -@inline eltype(s::ModifiedState) = eltype(raw_state(s)) +@inline Base.eltype(s::ModifiedState) = eltype(raw_state(s)) toint(s::ModifiedState) = toint(apply_warn_raw!(s)) index(s::ModifiedState) = index(apply_warn_raw!(s)) @@ -111,7 +111,7 @@ end set_index!(s::ModifiedState, val) = (zero!(changes(s)); set_index!(raw_state(s), val)) set!(s::ModifiedState, val) = (zero!(changes(s)); set!(raw_state(s), val); return s) add!(s::ModifiedState, val) = add!(apply_warn_raw!(state), val) -rand!(rng, s::ModifiedState) = (zero!(changes(s)); +Random.rand!(rng, s::ModifiedState) = (zero!(changes(s)); rand!(rng, raw_state(s))) Base.show(io::IO, ::MIME"text/plain", s::ModifiedState) = diff --git a/src/States/NAryState.jl b/src/States/NAryState.jl index b6b7804..3f46532 100644 --- a/src/States/NAryState.jl +++ b/src/States/NAryState.jl @@ -31,7 +31,7 @@ export NAryState @inline nsites(state::NAryState) = state.n @inline local_dimension(state::Type{NAryState{T,N}}) where {T,N} = N @inline local_dimension(state::NAryState{T,N}) where {T,N} = local_dimension(typeof(state)) -@inline eltype(state::NAryState{T,N}) where {T,N} = T +@inline Base.eltype(state::NAryState{T,N}) where {T,N} = T @inline toint(state::NAryState) = state.i_σ @inline index(state::NAryState) = toint(state)+1 @@ -94,6 +94,20 @@ function setat!(state::NAryState{T, N}, i::Int, val::T) where {T, N} return old_val end +function setat!(state::NAryState{T1, N}, i::Int, val::T2) where {T1,T2, N} + #=throw(""" + Error: cannot use setat! on a state with precision $T1 with + a value with precision $T2. + This error often occurs when you initialize your problem with a + different precision than your network or state. + """)=# + old_val = state.σ[i] + + state.i_σ += (Int(val)-Int(old_val))*N^(i-1) + state.σ[i] = val + return old_val +end + set_index!(state::NAryState, val::Integer) = set!(state, index_to_int(state, val)) function set!(state::NAryState{T, N}, val::Integer) where {T, N} state.i_σ = val @@ -108,7 +122,7 @@ function add!(state::NAryState, val::Integer) state end -function rand!(rng::AbstractRNG, state::NAryState) +function Random.rand!(rng::AbstractRNG, state::NAryState) val = rand(rng, 0:(spacedimension(state)-1)) set!(state, val) end @@ -161,7 +175,8 @@ function String(bv::Vector, toInt=true) str end -Base.show(io::IO, ::MIME"text/plain", bs::NAryState{T,2}) where T = print(io, "NAryState(",bs.n,") : ", StringToSpin(bs.σ,false), +Base.show(io::IO, ::MIME"text/plain", bs::NAryState{T,2}) where T = print(io, + "NAryState{$T}($(bs.n)) : ", StringToSpin(bs.σ,false), " = ", bs.i_σ) Base.show(io::IO, bs::NAryState{T,2}) where T = print(io, bs.i_σ, StringToSpin(bs.σ,false)) function StringToSpin(bv::Vector, toInt=true) diff --git a/src/States/PurifiedState.jl b/src/States/PurifiedState.jl index 6f447d7..8d5f1f3 100644 --- a/src/States/PurifiedState.jl +++ b/src/States/PurifiedState.jl @@ -32,7 +32,7 @@ PurifiedState{ST}(n, add, i_σ=0) where ST = @inline index_to_int(state::PurifiedState, id, add::FiniteBasisState) = index_to_int(state, id, toint(add)) @inline index_to_int(state::PurifiedState, id, add) = add*spacedimension(sys(state)) + index_to_int(sys(state), id) -@inline eltype(state::PurifiedState) = eltype(sys(state)) +@inline Base.eltype(state::PurifiedState) = eltype(sys(state)) # custom accessor sys(v::PurifiedState) = v.σ_sys @@ -61,7 +61,7 @@ function set!(v::PurifiedState, i::Integer) end set!(v::PurifiedState, sys, add) = set!(sys(v), row) && set!(add(v), col) && v -function rand!(rng::AbstractRNG, state::PurifiedState) +function Random.rand!(rng::AbstractRNG, state::PurifiedState) rand!(rng, state.σ_sys) rand!(rng, state.σ_add) end diff --git a/src/base_batched_networks.jl b/src/base_batched_networks.jl index fc5657e..4536e76 100644 --- a/src/base_batched_networks.jl +++ b/src/base_batched_networks.jl @@ -1,3 +1,20 @@ +export logψ! + +# Definitions for batched evaluation of networks +# When the networrks are not cached and therefore allocate +# the result structure +@inline logψ!(out::AbstractArray, net::NeuralNetwork, σ::State) where N = + out .= logψ(net, config(σ)) +@inline logψ!(out::AbstractArray, net::NeuralNetwork, σ::NTuple{N,<:AbstractArray}) where N = + out .= logψ(net, σ) +@inline logψ!(out::AbstractArray, net::NeuralNetwork, σ::AbstractArray) = + out .= logψ(net, σ) +#@inline function logψ_and_∇logψ!(der, out, n::NeuralNetwork, σ...) +# lnψ, der = logψ_and_∇logψ!(der, n, σ...) +# out .= lnψ +#¶ return (out, der) +#end + """ NNCache{N} @@ -12,16 +29,135 @@ cached(net::NeuralNetwork, batch_sz::Int) = cached(net::CachedNet, batch_sz::Int) = CachedNet(net.net, cache(net.net, batch_sz)) +batch_size(cache::NNBatchedCache) = throw("Not Implemented") + +# Definition for inplace evaluation of batched cached networks +@inline logψ!(out::AbstractArray, net::CachedNet, σ::NTuple{N,<:AbstractArray}) where N = + logψ!(out, net.net, net.cache, σ...) +@inline logψ!(out::AbstractArray, cnet::CachedNet, σ::AbstractArray) = + logψ!(out, cnet.net, cnet.cache, σ) +@inline logψ!(out::AbstractArray, cnet::CachedNet, σ::State) = + logψ!(out, cnet, config(σ)) + +# Definition for allocating evaluation of batched cached networks +# Shadowing things at ~80 of base_cached_networks.jl +@inline logψ(net::CachedNet{NN,NC}, σ::NTuple{N,<:AbstractArray}) where {N,NN,NC<:NNBatchedCache} = begin + b_sz = last(size(first(σ))) + out = similar(trainable_first(net), out_type(net), 1, b_sz) + logψ!(out, net.net, net.cache, σ...) +end +@inline logψ(net::CachedNet{NN,NC}, σ::Vararg{N,T}) where {N,T,NN,NC<:NNBatchedCache} = begin + b_sz = last(size(first(σ))) + out = similar(trainable_first(net), out_type(net), 1, b_sz) + logψ!(out, net.net, net.cache, σ...) +end +@inline logψ(cnet::CachedNet{NN,NC}, σ::AbstractArray) where {NN,NC<:NNBatchedCache} = begin + b_sz = last(size(σ)) + out = similar(trainable_first(cnet), out_type(cnet), 1, b_sz) + logψ!(out, cnet.net, cnet.cache, σ) +end + +# Declare the two functions, even if config(blabla)=blabla, because of a shitty +# Julia's performance bug #32761 +# see https://github.com/JuliaLang/julia/issues/32761 +@inline logψ_and_∇logψ!(der, n::CachedNet{NN,NC}, σ::AbstractArray) where {NN,NC<:NNBatchedCache} = begin + b_sz = last(size(σ)) + out = similar(trainable_first(n), out_type(n), 1, b_sz) + logψ_and_∇logψ!(der, out, n.net, n.cache, σ) + return out, der +end +@inline logψ_and_∇logψ!(der, n::CachedNet{NN,NC}, σ::NTuple{N,<:AbstractArray}) where {N,NN,NC<:NNBatchedCache} = begin + b_sz = last(size(first(σ))) + out = similar(trainable_first(n), out_type(n), 1, b_sz) + logψ_and_∇logψ!(der, out, n.net, n.cache, σ...) + return out, der +end +@inline logψ_and_∇logψ!(der, n::CachedNet{NN,NC}, σ::Vararg{<:AbstractArray,N}) where {N,T,NN,NC<:NNBatchedCache} = begin + b_sz = last(size(first(σ))) + out = similar(trainable_first(n), out_type(n), 1, b_sz) + logψ_and_∇logψ!(der, out, n.net, n.cache, σ...) + return out, der +end + +@inline logψ_and_∇logψ!(der, out, n::CachedNet, σ::State) = + logψ_and_∇logψ!(der, out, n, config(σ)) +@inline function logψ_and_∇logψ!(der, out, n::CachedNet, σ::NTuple{N,<:AbstractArray}) where N + logψ_and_∇logψ!(der, out, n.net, n.cache, σ...) + return (out, der) +end +@inline function logψ_and_∇logψ!(der, out, n::CachedNet, σ::Vararg{<:AbstractArray,N}) where N + logψ_and_∇logψ!(der, out, n.net, n.cache, σ...) + return (out, der) +end +@inline function logψ_and_∇logψ!(der, out, n::CachedNet, σ::AbstractArray) where N + logψ_and_∇logψ!(der, out, n.net, n.cache, σ); + return (out, der) +end + + # -grad_cache(net::NeuralNetwork, batch_sz) = begin - is_analytic(net) && return RealDerivative(net, batch_sz) - return WirtingerDerivative(net, batch_sz) +grad_cache(net::NeuralNetwork, batch_sz) = + grad_cache(out_type(net), net, batch_sz) +grad_cache(T::Type{<:Number}, net::NeuralNetwork, batch_sz) = begin + is_analytic(net) && return RealDerivative(T, net, batch_sz) + return WirtingerDerivative(T, net, batch_sz) end -function RealDerivative(net::NeuralNetwork, batch_sz::Int) +function RealDerivative(T::Type{<:Number}, net::NeuralNetwork, batch_sz::Int) pars = trainable(net) - vec = similar(trainable_first(pars), out_type(net), _tlen(pars), batch_sz) + vec = similar(trainable_first(pars), T, _tlen(pars), batch_sz) i, fields = batched_weight_tuple(net, vec) return RealDerivative(fields, [vec]) end + + +## Things for batched states +preallocate_state_batch(arrT::Array, + T::Type{<:Real}, + v::NAryState, + batch_sz) = + _std_state_batch(arrT, T, v, batch_sz) + +preallocate_state_batch(arrT::Array, + T::Type{<:Real}, + v::DoubleState, + batch_sz) = + _std_state_batch(arrT, T, v, batch_sz) + + +_std_state_batch(arrT::AbstractArray, + T::Type{<:Number}, + v::NAryState, + batch_sz) = + similar(arrT, T, nsites(v), batch_sz) + +_std_state_batch(arrT::AbstractArray, + T::Type{<:Number}, + v::DoubleState, + batch_sz) = begin + vl = similar(arrT, T, nsites(row(v)), batch_sz) + vr = similar(arrT, T, nsites(col(v)), batch_sz) + return (vl, vr) +end + + +@inline store_state!(cache::Array, + v::AbstractVector, + i::Integer) = begin + #@uviews cache v begin + uview(cache, :, i) .= v + #end + return cache +end + +@inline store_state!(cache::NTuple{2,<:Matrix}, + (vl, vr)::NTuple{2,<:AbstractVector}, + i::Integer) = begin + cache_l, cache_r = cache + #@uviews cache_l cache_r vl vr begin + uview(cache_l, :,i) .= vl + uview(cache_r, :,i) .= vr + #end + return cache +end diff --git a/src/base_cached_networks.jl b/src/base_cached_networks.jl index 472947f..21694b5 100644 --- a/src/base_cached_networks.jl +++ b/src/base_cached_networks.jl @@ -53,7 +53,7 @@ cached(net::CachedNet) = CachedNet(net.net, cache(net)) Copy a cached network, building a shallow copy of the network and a deep-copy of the cache. """ -copy(cnet::CachedNet) = CachedNet(cnet.net, deepcopy(cnet.cache)) +Base.copy(cnet::CachedNet) = CachedNet(cnet.net, deepcopy(cnet.cache)) """ cache(net) @@ -62,6 +62,7 @@ Constructs the `NNCache{typeof(net)}` object that holds the cache for this netwo If it has not been implemented returns nothing. """ cache(net) = nothing +cache(net::CachedNet) = cache(net.net) """ weights(net) @@ -74,7 +75,13 @@ weights(net) = trainable(net) @inline (cnet::CachedNet)(σ...) = logψ(cnet, σ...) # When you call logψ on a cached net use the cache to compute the net -@inline logψ(cnet::CachedNet, σ...) = cnet.net(cnet.cache, config(σ)...) +@inline logψ(cnet::CachedNet, σ::State) =logψ(cnet, config(σ)) +@inline logψ(cnet::CachedNet, σ::NTuple{N,<:AbstractArray}) where N = + cnet.net(cnet.cache, σ...) +@inline logψ(cnet::CachedNet, σ::Vararg{N,V}) where {N,V} = + cnet.net(cnet.cache, σ...) +@inline logψ(cnet::CachedNet, σ::AbstractArray) where N = + cnet.net(cnet.cache, σ) function logψ_and_∇logψ(n::CachedNet, σ::Vararg{N,V}) where {N,V} #@warn "Inefficient calling logψ_and_∇logψ for cachedNet" @@ -88,10 +95,10 @@ end # see https://github.com/JuliaLang/julia/issues/32761 @inline logψ_and_∇logψ!(der, n::CachedNet, σ::State) = logψ_and_∇logψ!(der, n, config(σ)) @inline function logψ_and_∇logψ!(der, n::CachedNet, σ::NTuple{N,AbstractArray}) where N - lψ = logψ_and_∇logψ!(der, n.net, n.cache, config(σ)...) + lψ = logψ_and_∇logψ!(der, n.net, n.cache, σ...) return (lψ, der) end -@inline function logψ_and_∇logψ!(der, n::CachedNet, σ::Vararg{AbstractArray,N}) where N +@inline function logψ_and_∇logψ!(der, n::CachedNet, σ::Vararg{<:AbstractArray,N}) where N lψ = logψ_and_∇logψ!(der, n.net, n.cache, σ...); return (lψ, der) end @@ -102,7 +109,12 @@ end end ## Optimisation of cachednet -update!(opt, cnet::CachedNet, Δ, state=nothing) = (update!(opt, weights(cnet), weights(Δ), state); invalidate_cache!(cnet.cache)) +update!(opt, cnet::CachedNet, Δ, state=nothing) = begin + update!(opt, weights(cnet), weights(Δ), state) + invalidate_cache!(cnet.cache) + return nothing +end + apply!(opt, val1::Union{NeuralNetwork, CachedNet}, val2::Union{NeuralNetwork, CachedNet}, args...) = apply!(weights(val1), weights(val2), args...) @@ -139,7 +151,7 @@ const KetNet = Union{KetNeuralNetwork, CachedNet{<:KetNeuralNetwork}} # This overrides the standard behaviour of net(σ...) because Vector unpacking # should not happen -@inline logψ(cnet::CachedNet{<:KetNeuralNetwork}, σ) = cnet.net(cnet.cache, config(σ)) +#@inline logψ(cnet::CachedNet{<:KetNeuralNetwork}, σ) = cnet.net(cnet.cache, config(σ)) function logψ_and_∇logψ(n::CachedNet{<:KetNeuralNetwork}, σ) ∇lnψ = grad_cache(n) diff --git a/src/base_derivatives.jl b/src/base_derivatives.jl index 7941b0e..dd1c3ba 100644 --- a/src/base_derivatives.jl +++ b/src/base_derivatives.jl @@ -19,6 +19,8 @@ end @inline vec_data(s::RealDerivative) = getfield(s, :vectorised_data) @inline fields(s::RealDerivative) = getfield(s, :fields) +weights(der::RealDerivative) = der + function RealDerivative(net::NeuralNetwork) pars = trainable(net) @@ -46,9 +48,11 @@ Base.imag(s::WirtingerDerivative) = s.c_derivatives end function WirtingerDerivative(net::NeuralNetwork) - vec = similar(trainable_first(net), out_type(net), trainable_length(net)*2) - i, fields_r = weight_tuple(net, fieldnames(typeof(net)), vec) - i, fields_c = weight_tuple(net, fieldnames(typeof(net)), vec, i+1) + pars = trainable(net) + + vec = similar(trainable_first(net), out_type(net), _tlen(pars)*2) + i, fields_r = weight_tuple(net, vec) + i, fields_c = weight_tuple(net, vec, i+1) return WirtingerDerivative(fields_r, fields_c, [vec]) end diff --git a/src/base_networks.jl b/src/base_networks.jl index bba9245..d8995a9 100644 --- a/src/base_networks.jl +++ b/src/base_networks.jl @@ -28,6 +28,7 @@ If `net isa CachedNet` then the computation will be performed efficiently with minimal allocations. """ @inline logψ(net::NeuralNetwork, σ) = net(σ) +@inline logψ(net::NeuralNetwork, σ::NTuple{N,<:AbstractArray}) where N = net(σ...) @inline log_prob_ψ(net, σ...) = 2.0*real(net(σ...)) @inline ∇logψ(args...) = logψ_and_∇logψ(args...)[2] @inline ∇logψ!(args...) = logψ_and_∇logψ!(args...)[2] @@ -123,4 +124,4 @@ num_params(net::NeuralNetwork) = trainable_length(net) # TODO does this even make sense?! # the idea was that a shallow-copy of the weights of the net is not # even a copy.... -copy(net::NeuralNetwork) = net +Base.copy(net::NeuralNetwork) = net diff --git a/src/base_states.jl b/src/base_states.jl index 8244444..d84d0fc 100644 --- a/src/base_states.jl +++ b/src/base_states.jl @@ -4,14 +4,37 @@ add!(v::FiniteBasisState, i) = set!(v, toint(v)+i) zero!(v::FiniteBasisState) = set!(v, 0) @inline config(v) = v -rand!(v::State) = rand!(GLOBAL_RNG, v) +Random.rand!(v::State) = rand!(GLOBAL_RNG, v) flipat!(v::State, i) = flipat!(GLOBAL_RNG, v, i) flipat_fast!(v::State, i) = flipat_fast!(GLOBAL_RNG, v, i) export NAryState, DoubleState, BinaryState export local_dimension, spacedimension export nsites, toint, index, index_to_int, flipped, row, col, config -export add!, zero! +export add!, zero!, apply! export setat!, set!, set_index!, rand! +""" + apply!(state::State, changes) + +Applies the changes `changes` to the `state`. + +If `state isa DoubleState` then single-value changes +are applied to the columns of the state (in order to +compute matrix-operator products). Otherwise it should +be a tuple with changes of row and columns. + +If changes is nothing, does nothing. +""" apply!(σ::State, cngs::Nothing) = σ + +""" + apply(state::State, cngs) + +Applies the changes `cngs` to the state `σ`, by allocating a +copy. + +See also @ref(apply!) +""" +apply(σ::State, cngs) = apply!(deepcopy(σ), cngs) + diff --git a/src/tuple_logic.jl b/src/tuple_logic.jl index dbd71e6..3bd7e6b 100644 --- a/src/tuple_logic.jl +++ b/src/tuple_logic.jl @@ -67,7 +67,11 @@ end function weight_tuple(x::AbstractArray{<:Number}, vec::AbstractVector, start) length(vec) < start+length(x)-1 && resize!(vec, start+length(x)-1) @views data_vec = vec[start:start+length(x)-1] - reshpd_params = reshape(data_vec, size(x)) + if size(x) == size(data_vec) + reshpd_params = data_vec + else + reshpd_params = reshape(data_vec, size(x)) + end reshpd_params .= x return length(x), reshpd_params end @@ -106,11 +110,15 @@ function batched_weight_tuple(x::AbstractArray{<:Number}, vec::AbstractMatrix, s bsz = size(vec, 2) @views data_vec = vec[start:start+length(x)-1, :] - reshpd_params = reshape(data_vec, size(x)..., bsz) - reshpd_params .= x - if reshpd_params isa Base.ReshapedArray - reshpd_params = StridedView(reshpd_params) + if size(data_vec) != (size(x)..., bsz) + reshpd_params = reshape(data_vec, size(x)..., bsz) + if x isa Array + reshpd_params = StridedView(reshpd_params) + end + else + reshpd_params = data_vec end + #reshpd_params .= x return length(x), reshpd_params end # ? stridedView? diff --git a/src/utils/math.jl b/src/utils/math.jl index b5d2fb7..7441fd5 100644 --- a/src/utils/math.jl +++ b/src/utils/math.jl @@ -36,3 +36,44 @@ Internally uses the fact that R is a StridedView end=# return R end + +@inline function _batched_outer_prod!(R::StridedView, α, vb, wb) + #@unsafe_strided R begin + @inbounds @simd for i=1:size(R, 3) + for j=1:size(wb, 1) + for k=1:size(vb, 1) + R[k,j,i] = α * vb[k,i]*conj(wb[j,i]) + end + end + end + #end + + #=@unsafe_strided R vb wb begin + for i=1:size(R, 3) + BLAS.ger!(1.0, vb[:,i], wb[:,i], R[:,:,i]) + end + end=# + return R +end + +@inline function _batched_outer_prod_∑!(R::StridedView, α, vb, wb, vb2, wb2) + @inbounds @simd for i=1:size(R, 3) + for j=1:size(wb, 1) + for k=1:size(vb, 1) + R[k,j,i] = α * (vb[k,i]*conj(wb[j,i]) + vb2[k,i]*conj(wb2[j,i])) + end + end + end + return R +end + +@inline function _batched_outer_prod_Δ!(R::StridedView, α, vb, wb, vb2, wb2) + @inbounds @simd for i=1:size(R, 3) + for j=1:size(wb, 1) + for k=1:size(vb, 1) + R[k,j,i] = α * (vb[k,i]*conj(wb[j,i]) - vb2[k,i]*conj(wb2[j,i])) + end + end + end + return R +end diff --git a/test/Machines/test_batched.jl b/test/Machines/test_batched.jl new file mode 100644 index 0000000..81772ef --- /dev/null +++ b/test/Machines/test_batched.jl @@ -0,0 +1,84 @@ +using NeuralQuantum, Test +using NeuralQuantum: set_index!, trainable_first, preallocate_state_batch +num_types = [Float32, Float64] + +machines = Dict() + +ma = (T, N) -> RBMSplit(T, N, 2) +machines["RBMSplit"] = ma + +ma = (T, N) -> NDM(T, N, 2, 3) +machines["NDM"] = ma + +ma = (T, N) -> RBM(T, N, 2) +machines["RBM"] = ma + +N = 4 +T = Float32 +b_sz = 3 + +@testset "test batched dispatch - values: $name" for name=keys(machines) + net = machines[name](T,N) + + cnet = cached(net, b_sz) + v = state(T, SpinBasis(1//2)^N, net) + vb = preallocate_state_batch(trainable_first(net), T, + v, b_sz) + if vb isa Tuple + rand!.(vb) + else + rand!(vb) + end + + @test net(vb) ≈ cnet(vb) + if (vb isa Tuple) + @test net(vb...) ≈ cnet(vb...) + end + + o = rand(Complex{T}, 1, b_sz) + o2 = similar(o) + + oo = logψ!(o, net, vb) + oo2 = logψ!(o2, cnet, vb) + @test oo === o + @test oo2 === o2 + @test oo ≈ oo2 +end + +@testset "test cached dispatch - inplace gradients: $name" for name=keys(machines) + net = machines[name](T,N) + + cnet = cached(net, b_sz) + v = state(T, SpinBasis(1//2)^N, net) + vb = preallocate_state_batch(trainable_first(net), T, + v, b_sz) + if vb isa Tuple + rand!.(vb) + else + rand!(vb) + end + g1 = grad_cache(net, b_sz) + g2 = grad_cache(net, b_sz) + g3 = grad_cache(net, b_sz) + + v1, gg1 = logψ_and_∇logψ!(g1, cnet, vb) + v3 = similar(v1) + vv3, gg3 = logψ_and_∇logψ!(g3, v3, cnet, vb) + @test v1 ≈ v3 + @test gg1 ≈ gg3 + @test g3 === gg3 + @test v3 === vv3 + + if vb isa Tuple + v2, gg2 = logψ_and_∇logψ!(g2, cnet, vb...) + @test v1 ≈ v2 + @test gg1 ≈ gg2 + @test g1 === gg1 && g2 === gg2 + + vv3, gg3 = logψ_and_∇logψ!(g3, v3, cnet, vb...) + @test v1 ≈ v3 + @test gg1 ≈ gg3 + @test g3 === gg3 + @test v3 === vv3 + end +end diff --git a/test/Machines/test_cached.jl b/test/Machines/test_cached.jl index e69de29..d5afec3 100644 --- a/test/Machines/test_cached.jl +++ b/test/Machines/test_cached.jl @@ -0,0 +1,68 @@ +using NeuralQuantum, Test +using NeuralQuantum: set_index! +num_types = [Float32, Float64] + +machines = Dict() + +ma = (T, N) -> RBMSplit(T, N, 2) +machines["RBMSplit"] = ma + +ma = (T, N) -> RBM(T, N, 2) +machines["RBM"] = ma + +N = 4 +T = Float32 + +@testset "test cached dispatch - values: $name" for name=keys(machines) + net = machines[name](T,N) + + cnet = cached(net) + v = state(T, SpinBasis(1//2)^N, net) + arr_v = config(v) + + @test net(v) ≈ cnet(v) + @test net(arr_v) ≈ cnet(arr_v) + if !(arr_v isa AbstractVector) + @test net(arr_v...) ≈ cnet(arr_v...) + end +end + +@testset "test cached dispatch - allocating gradients: $name" for name=keys(machines) + net = machines[name](T,N) + + cnet = cached(net) + v = state(T, SpinBasis(1//2)^N, net) + arr_v = config(v) + + @test ∇logψ(net, v) ≈ ∇logψ(cnet, v) + @test ∇logψ(net, arr_v) ≈ ∇logψ(cnet, arr_v) + if !(arr_v isa AbstractVector) + #@test ∇logψ(net, arr_v...) ≈ ∇logψ(cnet, arr_v...) + end +end + +@testset "test cached dispatch - inplace gradients: $name" for name=keys(machines) + net = machines[name](T,N) + + cnet = cached(net) + v = state(T, SpinBasis(1//2)^N, net) + arr_v = config(v) + g1 = grad_cache(net) + g2 = grad_cache(net) + + v1, gg1 = logψ_and_∇logψ!(g1, net, v) + v2, gg2 = logψ_and_∇logψ!(g2, cnet, v) + @test v1 ≈ v2 + @test gg1 ≈ gg2 + @test g1 === gg1 && g2 === gg2 + + v1, gg1 = logψ_and_∇logψ!(g1, net, arr_v) + v2, gg2 = logψ_and_∇logψ!(g2, cnet, arr_v) + @test v1 ≈ v2 + @test gg1 ≈ gg2 + @test g1 === gg1 && g2 === gg2 + + if !(arr_v isa AbstractVector) + #@test ∇logψ(net, arr_v...) ≈ ∇logψ(cnet, arr_v...) + end +end diff --git a/test/Machines/test_grad.jl b/test/Machines/test_grad.jl index 21af7d2..a588dde 100644 --- a/test/Machines/test_grad.jl +++ b/test/Machines/test_grad.jl @@ -14,7 +14,7 @@ im_machines["RBMSplit"] = ma ma = (T, N) -> RBM(T, N, 2) im_machines["RBM"] = ma -ma = (T, N) -> PureStateAnsatz(Chain(Dense(N, N*2), Dense(N*2, N*3), WSum(N*3))) +ma = (T, N) -> PureStateAnsatz(Chain(Dense(N, N*2), Dense(N*2, N*3), WSum(N*3)), N) re_machines["ChainKet"] = ma diff --git a/test/Machines/test_ndmcomplex.jl b/test/Machines/test_ndmcomplex.jl index 0f877f9..3e204c5 100644 --- a/test/Machines/test_ndmcomplex.jl +++ b/test/Machines/test_ndmcomplex.jl @@ -11,7 +11,6 @@ N = 4 @testset "Test Properties $name" for name=keys(machines) for T=num_types - T = Complex{T} net = machines[name](T,N) cnet = cached(net) @@ -23,7 +22,7 @@ end @testset "Test Cached Value $name" for name=keys(machines) for T=num_types - net = machines[name](Complex{T},N) + net = machines[name](T,N) cnet = cached(net) v = state(real(T), SpinBasis(1//2)^N, net) diff --git a/test/Problems/observables.jl b/test/Problems/observables.jl index 61faa59..5969a59 100644 --- a/test/Problems/observables.jl +++ b/test/Problems/observables.jl @@ -11,8 +11,8 @@ Sx = QuantumLattices.LocalObservable(lind, sigmax, Nsites) Sy = QuantumLattices.LocalObservable(lind, sigmay, Nsites) Sz = QuantumLattices.LocalObservable(lind, sigmaz, Nsites) H = lind.H -oprob = ObservablesProblem(Sx, Sy, Sz, H, operator=false) -oprob_op = ObservablesProblem(Sx, Sy, Sz, H, operator=true) +oprob = ObservablesProblem(T, Sx, Sy, Sz, H, operator=false) +oprob_op = ObservablesProblem(T, Sx, Sy, Sz, H, operator=true) obs_dense = [DenseOperator(op).data for op=[Sx, Sy, Sz, H]] net = cached(RBMSplit(T, Nsites, 2)) @@ -59,8 +59,8 @@ ham = quantum_ising_ham(lattice, g=1.0, V=2.0) Sx = QuantumLattices.LocalObservable(ham, sigmax, Nsites) Sy = QuantumLattices.LocalObservable(ham, sigmay, Nsites) Sz = QuantumLattices.LocalObservable(ham, sigmaz, Nsites) -oprob = ObservablesProblem(Sx, Sy, Sz, ham, operator=false) -oprob_op = ObservablesProblem(Sx, Sy, Sz, ham, operator=true) +oprob = ObservablesProblem(T, Sx, Sy, Sz, ham, operator=false) +oprob_op = ObservablesProblem(T, Sx, Sy, Sz, ham, operator=true) obs_dense = [DenseOperator(op).data for op=[Sx, Sy, Sz, ham]] net = cached(RBM(Complex{T}, Nsites, 2)) diff --git a/test/runtests.jl b/test/runtests.jl index e8f1306..8d83c16 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,9 @@ using Test @testset "NeuralQuantum" begin println("Testing machines...") @testset "Machines" begin + include("Machines/test_cached.jl") include("Machines/test_grad.jl") + include("Machines/test_batched.jl") include("Machines/test_ndmcomplex.jl") end