Skip to content

Commit

Permalink
Fix compilation errors and add experimental batch support (#3)
Browse files Browse the repository at this point in the history
* Add the cases for dispatching of cached networks depending on inputs, and for batched too.
Add new preallocated outputs for batched variantts

* BatchAccumulators

* Add NDMBatched

* fix import orders

* Add KLocal Tensor Operators

* Experimental SuperOperator problem

* chain length

* UnsafeArrays for manipulations

* Batched evaluate

* Plug points for Neural gpu

* Remove cuarrays

* Fix ambiguities in state generation batch

* Improve type stability and remove allocations

* Type stability for operator hamiltonian problem

* Fix optimisers behaviour

* When building batched derivative cache, don't reshape if not needed.

* Improve batched calculator

* v0.1.2

* update manifest

* Update test configuration

* backports

* fix travis
  • Loading branch information
PhilipVinc authored Nov 28, 2019
1 parent f68c135 commit 3fc5392
Show file tree
Hide file tree
Showing 64 changed files with 2,127 additions and 312 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/Julia-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,19 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: [1.2.0, 1.3.0-rc3]
julia-version: [1.2, 1.3]
os: [ubuntu-latest, macOS-latest]

steps:
- uses: actions/[email protected]
- name: "Setup Julia environment ${{ matrix.julia-version }}"
uses: julia-actions/setup-julia@v0.2
uses: julia-actions/setup-julia@v1.0.1
with:
version: ${{ matrix.julia-version }}
- name: "Run Tests"
run: |
julia --color=yes --project=@. -e 'using Pkg;
Pkg.activate();
Pkg.instantiate();
Pkg.add(PackageSpec(url="https://github.com/PhilipVinc/QuantumLattices.jl"));
println("Instantiated");'
julia --color=yes --project=@. -e "using Pkg; Pkg.test(coverage=true)"
11 changes: 4 additions & 7 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,16 @@ julia:
- nightly
matrix:
allow_failures:
- julia: [1.3, nightly]
- os: osx
- julia: [nightly]
notifications:
email: false
script:
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
- travis_wait 20 julia --project=@. --color=yes -e 'using Pkg;
Pkg.activate();
println("Instantiated");
println("Activate");
Pkg.instantiate();
Pkg.add(PackageSpec(url="https://github.com/PhilipVinc/QuantumLattices.jl"));
println("Installed [QuantumLattices, Zygote]");
println("Developed");
println("Instantiate");
Pkg.test(coverage=true)';
after_success:
- julia -e 'using Pkg; cd(Pkg.dir("NeuralQuantum")); Pkg.add("Coverage"); using Coverage; Codecov.submit(Codecov.process_folder())'
Expand All @@ -32,7 +29,7 @@ jobs:
os: linux
script:
- julia --project=docs/ -e 'using Pkg;
Pkg.add([PackageSpec(url="https://github.com/PhilipVinc/QuantumLattices.jl")]);
Pkg.activate();
Pkg.add("Documenter");
Pkg.develop(PackageSpec(path=pwd()));
Pkg.instantiate()'
Expand Down
120 changes: 56 additions & 64 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40"
git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "0.4.1"
version = "0.5.0"

[[Adapt]]
deps = ["LinearAlgebra"]
Expand All @@ -28,16 +28,10 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"

[[BinaryProvider]]
deps = ["Libdl", "Logging", "SHA"]
git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
deps = ["Libdl", "SHA"]
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.6"

[[CSTParser]]
deps = ["Tokenize"]
git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.6.2"
version = "0.5.8"

[[CommonSubexpressions]]
deps = ["Test"]
Expand All @@ -47,27 +41,21 @@ version = "0.2.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
git-tree-sha1 = "ed2c4abadf84c53d9e58510b5fc48912c2336fbb"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.1.0"
version = "2.2.0"

[[Conda]]
deps = ["JSON", "VersionParsing"]
git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032"
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
version = "1.3.0"

[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "517ce30aa57cdfae1ab444a7c0aef8bb86345bc2"
git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.1"
version = "0.17.6"

[[Dates]]
deps = ["Printf"]
Expand All @@ -84,40 +72,50 @@ uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.4"

[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
deps = ["NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "f734b5f6bc9c909027ef99f6d91d5d9e4b111eed"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.10"
version = "0.1.0"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "88bb0edb352b16608036faadcc071adda068582a"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.1"

[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "0be9bf63e854a2408c2ecd3c600d68d4d87d8a73"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.24.2"

[[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport"]
git-tree-sha1 = "4cfd3d43819228b9e73ab46600d0af0aa5cedceb"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.0.1"
version = "1.1.0"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "16974065d5bfa867446d3228bc63f05a440e910b"
git-tree-sha1 = "1a9fe4e1323f38de0ba4da49eafd15b25ec62298"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.7.2"
version = "0.8.2"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "da46ac97b17793eba44ff366dc6cb70f1238a738"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3"
version = "0.10.7"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "5bdac064676287838868e3715d44d1129d8f2d26"
repo-rev = "master"
repo-url = "https://github.com/MikeInnes/IRTools.jl.git"
git-tree-sha1 = "72421971e60917b8cd7737f9577c4f0f87eab306"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.2.3"
version = "0.3.0"

[[Inflate]]
deps = ["Pkg", "Printf", "Random", "Test"]
Expand Down Expand Up @@ -161,10 +159,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"]
git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76"
deps = ["Compat", "DataStructures", "Test"]
git-tree-sha1 = "82921f0e3bde6aebb8e524efc20f4042373c0c06"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.1"
version = "0.5.2"

[[Markdown]]
deps = ["Base64"]
Expand All @@ -180,10 +178,9 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.6.0"

[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"
version = "0.3.3"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
Expand All @@ -193,9 +190,9 @@ version = "1.1.0"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "ef0af6c8601db18c282d092ccbd2f01f3f0cd70b"
git-tree-sha1 = "0139ba59ce9bc680e2925aec5b7db79065d60556"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "0.3.7"
version = "0.3.10"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down Expand Up @@ -275,9 +272,9 @@ version = "0.8.0"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6"
git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.11.0"
version = "0.12.1"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand All @@ -294,21 +291,15 @@ deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
deps = ["Printf"]
git-tree-sha1 = "311765af81bbb48d7bad01fb016d9c328c6ede03"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"

[[Tokenize]]
git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.6"
version = "0.5.3"

[[TupleTools]]
deps = ["Random", "Test"]
git-tree-sha1 = "b006524003142128cc6d36189dce337729aa0050"
git-tree-sha1 = "62a7a6cd5a608ff71cecfdb612e67a0897836069"
uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
version = "1.1.0"
version = "1.2.0"

[[URIParser]]
deps = ["Test", "Unicode"]
Expand All @@ -323,6 +314,11 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[UnsafeArrays]]
git-tree-sha1 = "1de6ef280110c7ad3c5d2f7a31a360b57a1bde21"
uuid = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"
version = "1.0.0"

[[VersionParsing]]
deps = ["Compat"]
git-tree-sha1 = "c9d5aa108588b978bd859554660c8a5c4f2f7669"
Expand All @@ -331,16 +327,12 @@ version = "1.1.3"

[[Zygote]]
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "bc6f84f1ac81f1d1548264dc0c1816252e1c62ef"
repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git"
git-tree-sha1 = "e4245b9c5362346e154b62842a89a18e0210b92b"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.3.4"
version = "0.4.1"

[[ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "c4c29b30b8ff3be13d4244e78be7df2a42bc54d0"
repo-rev = "master"
repo-url = "https://github.com/FluxML/ZygoteRules.jl.git"
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.0"
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
name = "NeuralQuantum"
uuid = "eb923273-1014-53d4-802c-abcb7262255a"
authors = ["Filippo Vicentini <[email protected]>"]
version = "0.1.1"
version = "0.1.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -20,14 +19,16 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Zygote = ">= 0.4"
julia = ">= 1.2"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

[targets]
test = ["Test"]
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,28 @@ pkg"add https://github.com/PhilipVinc/NeuralQuantum.jl"
`QuantumLattices` is a custom package that allows defining new types of operators on a lattice.
It's not needed natively but it is usefull to define hamiltonians on a lattice.

Alternatively you may activate the project included in the manifest that comes with NeuralQuantum.


## Example
The basic idea of the package is the following: you create an hamiltonian/lindbladian with QuantumOptics or QuantumLattices (the latter allows you to go to arbitrarily big lattices, but isn't yet very documented...).
Then, you create a `SteadyStateProblem`, which performs some transformations on those operators to put them in the shaped necessary to optimize them efficiently.
You also will probably need to create an `ObservablesProblem`, by providing it all the observables that you wish to monitor during the optimization.

By default, if you don't provide the precision `Float32` is used.

Then, you will pick a network, a sampler, and create an iterative sampler to sample the network.
You must write the training loop by yourself. Check the documentation and the examples in the folder `examples/` to better understand how to do this.

*IMPORTANT:* If you want to use multithreaded samplers (identified by a `MT` at the beginning of their name), you will launch one markov chain per julia thread. As such, you will get much better performance if you set `JULIA_NUM_THREADS` environment variable to the number of physical cores in your computer before launching julia.

```
# Load dependencies
using NeuralQuantum, QuantumLattices
using Printf, ValueHistoriesLogger, Logging, ValueHistories
using Printf, Logging, ValueHistories
# Select the numerical precision
T = Float64
T = Float32
# Select how many sites you want
Nsites = 6
Expand Down
8 changes: 5 additions & 3 deletions examples/dissipative_spins_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using NeuralQuantum, QuantumLattices
using Logging, Printf, ValueHistories

# Select the numerical precision
T = Float64
T = Float32
# Select how many sites you want
Nsites = 7

Expand All @@ -22,7 +22,7 @@ Sx = QuantumLattices.LocalObservable(lind, sigmax, Nsites)
Sy = QuantumLattices.LocalObservable(lind, sigmay, Nsites)
Sz = QuantumLattices.LocalObservable(lind, sigmaz, Nsites)
# Create the problem object with all the observables to be computed.
oprob = ObservablesProblem(Sx, Sy, Sz)
oprob = ObservablesProblem(T, Sx, Sy, Sz)


# Define the Neural Network. A NDM with N visible spins and αa=2 and αh=1
Expand All @@ -33,7 +33,7 @@ cnet = cached(net)
# Chose a sampler. Options are FullSumSampler() which sums over the whole space
# ExactSampler() which does exact sampling or MCMCSamler which does a markov
# chain.
sampl = MCMCSampler(Metropolis(Nsites), 3000, burn=50)
sampl = MCMCSampler(Metropolis(Nsites), 1000, burn=50)
# Chose a sampler for the observables.
osampl = FullSumSampler()

Expand Down Expand Up @@ -69,6 +69,8 @@ for i=1:110
Optimisers.update!(optimizer, cnet, Δw)
end

using QuantumOptics, Plots

# Optional: compute the exact solution
ρ = last(steadystate.master(lind)[2])
ESx = real(expect(SparseOperator(Sx), ρ))
Expand Down
Loading

0 comments on commit 3fc5392

Please sign in to comment.