diff --git a/.github/workflows/build_release.yml b/.github/workflows/build_release.yml index 6f3737ab..e2935be5 100644 --- a/.github/workflows/build_release.yml +++ b/.github/workflows/build_release.yml @@ -45,9 +45,9 @@ jobs: - name: Build wheels run: python -m cibuildwheel --output-dir wheelhouse env: - CIBW_BEFORE_BUILD_WINDOWS: pip3 install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed - CIBW_BEFORE_BUILD_MACOS: pip3 install torch==1.9.0 -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed - CIBW_BEFORE_BUILD_LINUX: pip3 install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed + CIBW_BEFORE_BUILD_WINDOWS: pip3 install torch==1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed + CIBW_BEFORE_BUILD_MACOS: pip3 install torch==1.10.0 -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed + CIBW_BEFORE_BUILD_LINUX: pip3 install torch==1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --ignore-installed CIBW_REPAIR_WHEEL_COMMAND: "" CIBW_BUILD: cp37-* cp38-* cp39-* CIBW_SKIP: "*-manylinux_i686 *-win32" diff --git a/CHANGELOG.md b/CHANGELOG.md index 26c67deb..2bcb95cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,18 @@ ## Added - -1. Partial support for the `groups` argument for convolutional layers. +1. The `random_crossbar_init` argument to memtorch.bh.Crossbar. If true, this is used to initialize crossbars to random device conductances in between 1/Ron and 1/Roff. +2. `CUDA_device_idx` to `setup.py` to allow users to specify the `CUDA` device to use when installing `MemTorch` from source. +3. Implementations of CUDA accelerated passive crossbar programming routines for the 2021 Data-Driven model. +4. A BiBTeX entry, which can be used to cite the corresponding OSP paper. ## Fixed - -1. Patching procedure in `memtorch.mn.module.patch_model` and `memtorch.bh.nonideality.apply_nonidealities` to fix semantic error in `Tutorial.ipynb`. -2. Import statement in `Exemplar_Simulations.ipynb`. +1. In the getting started tutorial, Section 4.1 was a code cell. This has since been converted to a markdown cell. +2. OOM errors encountered when modeling passive inference routines of crossbars. ## Enhanced -1. Further modularized patching logic in `memtorch.bh.nonideality.NonIdeality` and `memtorch.mn.Module`. -2. Modified default number of worker in `memtorch.utils` from 2 to 1. +1. Templated quantize bindings and fixed semantic error in `memtorch.bh.nonideality.FiniteConductanceStates`. +2. The memory consumption when modeling passive inference routines. +3. The sparse factorization method used to solve sparse linear matrix systems. +4. The `naive_program` routine for crossbar programming. The maximum number of crossbar programming iterations is now configurable. +5. Updated ReadTheDocs documentation for `memtorch.bh.Crossbar`. +6. Updated the version of `PyTorch` used to build Python wheels from `1.9.0` to `1.10.0`. \ No newline at end of file diff --git a/README.md b/README.md index 458cd8c7..b8f83afe 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,9 @@ [![codecov](https://codecov.io/gh/coreylammie/MemTorch/branch/master/graph/badge.svg)](https://codecov.io/gh/coreylammie/MemTorch) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -MemTorch is a _Simulation Framework for Memristive Deep Learning Systems_, which integrates directly with the well-known PyTorch Machine Learning (ML) library. MemTorch is formally described in _MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems_, which is openly accessible [here](https://arxiv.org/abs/2004.10971). +MemTorch is a _Simulation Framework for Memristive Deep Learning Systems_, which integrates directly with the well-known PyTorch Machine Learning (ML) library. MemTorch is formally described in _MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems_, which is openly accessible [here](https://arxiv.org/abs/2004.10971). + +_MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems_ has been published as an Original Software Publication (OSP) in the *Neurocomputing* journal [here](https://doi.org/10.1016/j.neucom.2022.02.043). We kindly ask that the following [BibTeX entry](https://github.com/coreylammie/MemTorch/blob/master/citation.bib?raw=True) is used to cite MemTorch, if you use it in your work. ![Overview](https://github.com/coreylammie/MemTorch/blob/master/overview.svg) @@ -81,14 +83,15 @@ _Be sure to merge the latest from 'upstream' before making a pull request_. This To cite _MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems_, use the following BibTex entry: ``` -@misc{lammie2020memtorch, - title={{MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems}}, - author={Corey Lammie and Wei Xiang and Bernab\'e Linares-Barranco and Mostafa Rahimi Azghadi}, - month=Apr., - year={2020}, - eprint={2004.10971}, - archivePrefix={arXiv}, - primaryClass={cs.ET} +@Article{Lammie2022, + author = {Corey Lammie and Wei Xiang and Bernabé Linares-Barranco and Mostafa Rahimi Azghadi}, + title = {{MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems}}, + journal = {Neurocomputing}, + year = {2022}, + issn = {0925-2312}, + doi = {https://doi.org/10.1016/j.neucom.2022.02.043}, + keywords = {Memristors, RRAM, Non-Ideal Device Characteristics, Deep Learning, Simulation Framework}, + url = {https://www.sciencedirect.com/science/article/pii/S0925231222002053}, } ``` diff --git a/citation.bib b/citation.bib new file mode 100644 index 00000000..dc5b8fe6 --- /dev/null +++ b/citation.bib @@ -0,0 +1,10 @@ +@Article{Lammie2022, + author = {Corey Lammie and Wei Xiang and Bernabé Linares-Barranco and Mostafa Rahimi Azghadi}, + title = {{MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems}}, + journal = {Neurocomputing}, + year = {2022}, + issn = {0925-2312}, + doi = {https://doi.org/10.1016/j.neucom.2022.02.043}, + keywords = {Memristors, RRAM, Non-Ideal Device Characteristics, Deep Learning, Simulation Framework}, + url = {https://www.sciencedirect.com/science/article/pii/S0925231222002053}, +} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index e25b6467..06d0b0ac 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,11 +17,11 @@ # -- Project information ----------------------------------------------------- project = "MemTorch" -copyright = "2021, Corey Lammie" +copyright = "2022, Corey Lammie" author = "Corey Lammie" # The full version, including alpha/beta/rc tags -release = "1.1.5" +release = "1.1.6" autodoc_inherit_docstrings = False # -- General configuration --------------------------------------------------- diff --git a/docs/memtorch.bh.rst b/docs/memtorch.bh.rst index 444c3b3a..3c9f0421 100644 --- a/docs/memtorch.bh.rst +++ b/docs/memtorch.bh.rst @@ -30,6 +30,14 @@ Class used to model memristor crossbars and to manage modular crossbar tiles. .. note:: **use_bindings** is enabled by default, to accelerate operation using C++/CUDA (if supported) bindings. +.. warning:: + As of version 1.1.6, the **write_conductance_matrix** method exhibits different behavior when **self.use_bindings** is True, **CUDA** operation is enabled, and the **Data_Driven2021** memristor model is used. + + When **self.use_bindings** is True, **CUDA** operation is enabled, and the **Data_Driven2021** memristor model is used, the programming voltage is force adjusted by **force_adjustment_voltage** when a device becomes stuck. + For all others models, or when **CUDA** operation is not enabled or **self.use_bindings** is false, the conductance state of the device being modelled is adjusted using **force_adjustment** when it becomes stuck. + + This behavior will made consistent across Python, C++, and CUDA bindings, in a future release. + .. automodule:: memtorch.bh.crossbar.Crossbar :members: :undoc-members: diff --git a/memtorch/bh/crossbar/Crossbar.py b/memtorch/bh/crossbar/Crossbar.py index 049c6a0f..7ccacac7 100644 --- a/memtorch/bh/crossbar/Crossbar.py +++ b/memtorch/bh/crossbar/Crossbar.py @@ -210,6 +210,7 @@ def write_conductance_matrix( ) else: raise Exception("Unsupported crossbar shape.") + if self.tile_shape is not None: conductance_matrix, tiles_map = gen_tiles( conductance_matrix, @@ -232,7 +233,7 @@ def write_conductance_matrix( conductance_matrix = torch.max( torch.min(conductance_matrix.to(self.device), max), min ) - if transistor or programming_routine is None: + if transistor: self.conductance_matrix = conductance_matrix self.max_abs_conductance = ( torch.abs(self.conductance_matrix).flatten().max() @@ -265,6 +266,9 @@ def write_conductance_matrix( ) self.update(from_devices=False) else: + assert ( + programming_routine is not None + ), "If memtorch_cuda_bindings.simulate_passive is not used, a programming routine must be provided." if self.tile_shape is not None: for i in range(0, self.devices.shape[0]): for j in range(0, self.devices.shape[1]): diff --git a/memtorch/bh/crossbar/Program.py b/memtorch/bh/crossbar/Program.py index 0d47f0b2..7c94e4fc 100644 --- a/memtorch/bh/crossbar/Program.py +++ b/memtorch/bh/crossbar/Program.py @@ -22,6 +22,7 @@ def naive_program( force_adjustment_rel_tol=1e-1, force_adjustment_pos_voltage_threshold=0, force_adjustment_neg_voltage_threshold=0, + failure_iteration_threshold=1000, simulate_neighbours=True, ): """Method to program (alter) the conductance of a given device within a crossbar. @@ -54,6 +55,8 @@ def naive_program( Positive voltage level threshold (V) to enable force adjustment. force_adjustment_neg_voltage_threshold : float Negative voltage level threshold (V) to enable force adjustment. + failure_iteration_threshold : int + Failure iteration threshold. simulate_neighbours : bool Simulate neighbours (True). @@ -142,7 +145,7 @@ def naive_program( ) iterations += 1 - if iterations % 100 == 0 and time.time() > timeout: + if iterations >= failure_iteration_threshold and time.time() > timeout: warnings.warn("Failed to program device to rel_tol (%f)." % rel_tol) break diff --git a/memtorch/cu/simulate_passive.cpp b/memtorch/cu/simulate_passive.cpp index 777aa7bd..fd21c9a8 100644 --- a/memtorch/cu/simulate_passive.cpp +++ b/memtorch/cu/simulate_passive.cpp @@ -10,100 +10,190 @@ #include "simulate_passive_kernels.cuh" -//Default values -std::vector r_p{2699.2336, -672.930205}; -std::vector r_n{649.413746, -1474.32358}; +// Default values for r_p and r_n parameters of the Data_Driven2021 model. +std::vector r_p{2699.2336, -672.930205}; // r_p voltage-dependent resistive boundary function coefficients. +std::vector r_n{649.413746, -1474.32358}; // r_n voltage-dependent resistive boundary function coefficients. void simulate_passive_bindings(py::module_ &m) { - - //Data_Driven2021 model + // Data_Driven2021 model m.def( "simulate_passive", - [&](at::Tensor conductance_matrix, at::Tensor device_matrix,int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float A_p, float A_n, float t_p, float t_n, - float k_p, float k_n, std::vector r_p, std::vector r_n, float a_p, float a_n, float b_p, float b_n, bool sim_neighbors) { - return simulate_passive_dd(conductance_matrix, device_matrix,cuda_malloc_heap_size, rel_tol, - pulse_duration, refactory_period, pos_voltage_level, neg_voltage_level, - timeout, force_adjustment, force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold, - force_adjustment_neg_voltage_threshold,time_series_resolution,r_off,r_on,A_p,A_n,t_p,t_n,k_p,k_n,r_p,r_n,a_p,a_n,b_p,b_n, sim_neighbors); + [&](at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, + float neg_voltage_level, float timeout, float force_adjustment, + float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, float r_off, float r_on, float A_p, + float A_n, float t_p, float t_n, float k_p, float k_n, + std::vector r_p, std::vector r_n, float a_p, float a_n, + float b_p, float b_n, bool sim_neighbors) { + return simulate_passive_dd( + conductance_matrix, device_matrix, cuda_malloc_heap_size, rel_tol, + pulse_duration, refactory_period, pos_voltage_level, + neg_voltage_level, timeout, force_adjustment, + force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold, + force_adjustment_neg_voltage_threshold, + force_adjustment_voltage, failure_iteration_threshold, + time_series_resolution, + r_off, r_on, A_p, A_n, t_p, t_n, k_p, k_n, r_p, r_n, a_p, a_n, b_p, + b_n, sim_neighbors); }, - py::arg("conductance_matrix"), py::arg("device_matrix"),py::arg("cuda_malloc_heap_size")=50, py::arg("rel_tol")=0.1, - py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, py::arg("pos_voltage_level") = 1.0, - py::arg("neg_voltage_level") = -1.0, py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3, - py::arg("force_adjustment_rel_tol") = 1e-1, py::arg("force_adjustment_pos_voltage_threshold") = 0, - py::arg("force_adjustment_neg_voltage_threshold") = 0, py::arg("time_series_resolution") = 1e-10, py::arg("r_off") = 10000, py::arg("r_on") = 1000, py::arg("A_p") = 600.10075, - py::arg("A_n")=-34.5988399, py::arg("t_p") = -0.0212028, py::arg("t_n") = -0.05343997, py::arg("k_p") = 5.11e-4, py::arg("k_n") = 1.17e-3, - py::arg("r_p") = r_p, py::arg("r_n") = r_n, py::arg("a_p")=0.32046175, - py::arg("a_n")=0.32046175, py::arg("b_p")=2.71689828, py::arg("b_n")=2.71689828, py::arg("simulate_neighbours") = true); //Maybe change order of simulate_neighbours to before memristor args + py::arg("conductance_matrix"), py::arg("device_matrix"), + py::arg("cuda_malloc_heap_size") = 50, py::arg("rel_tol") = 0.1, + py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, + py::arg("pos_voltage_level") = 1.0, py::arg("neg_voltage_level") = -1.0, + py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3, + py::arg("force_adjustment_rel_tol") = 1e-1, + py::arg("force_adjustment_pos_voltage_threshold") = 0, + py::arg("force_adjustment_neg_voltage_threshold") = 0, + py::arg("force_adjustment_voltage") = 0.2, + py::arg("failure_iteration_threshold") = 1000, + py::arg("time_series_resolution") = 1e-10, py::arg("r_off") = 10000, + py::arg("r_on") = 1000, py::arg("A_p") = 600.10075, + py::arg("A_n") = -34.5988399, py::arg("t_p") = -0.0212028, + py::arg("t_n") = -0.05343997, py::arg("k_p") = 5.11e-4, + py::arg("k_n") = 1.17e-3, py::arg("r_p") = r_p, py::arg("r_n") = r_n, + py::arg("a_p") = 0.32046175, py::arg("a_n") = 0.32046175, + py::arg("b_p") = 2.71689828, py::arg("b_n") = 2.71689828, + py::arg("simulate_neighbours") = + true); - //Linear Ion Drift + // Linear Ion Drift m.def( "simulate_passive", - [&](at::Tensor conductance_matrix, at::Tensor device_matrix,int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float u_v, - float d,float pos_write_threshold, float neg_write_threshold, float p, bool sim_neighbors) { - return simulate_passive_linearIonDrift(conductance_matrix, device_matrix,cuda_malloc_heap_size, rel_tol, - pulse_duration, refactory_period, pos_voltage_level, neg_voltage_level, - timeout, force_adjustment, force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold, - force_adjustment_neg_voltage_threshold,time_series_resolution,r_off,r_on, u_v, - d, pos_write_threshold, neg_write_threshold, p,sim_neighbors); + [&](at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, + float neg_voltage_level, float timeout, float force_adjustment, + float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, float r_off, float r_on, float u_v, + float d, float pos_write_threshold, float neg_write_threshold, + float p, bool sim_neighbors) { + return simulate_passive_linearIonDrift( + conductance_matrix, device_matrix, cuda_malloc_heap_size, rel_tol, + pulse_duration, refactory_period, pos_voltage_level, + neg_voltage_level, timeout, force_adjustment, + force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold, + force_adjustment_neg_voltage_threshold, + force_adjustment_voltage, failure_iteration_threshold, + time_series_resolution, + r_off, r_on, u_v, d, pos_write_threshold, neg_write_threshold, p, + sim_neighbors); }, - py::arg("conductance_matrix"), py::arg("device_matrix"),py::arg("cuda_malloc_heap_size")=50, py::arg("rel_tol")=0.1, - py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, py::arg("pos_voltage_level") = 1.0, - py::arg("neg_voltage_level") = -1.0, py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3, - py::arg("force_adjustment_rel_tol") = 1e-1, py::arg("force_adjustment_pos_voltage_threshold") = 0, - py::arg("force_adjustment_neg_voltage_threshold") = 0, py::arg("time_series_resolution") = 1e-4, py::arg("r_off") = 10000, py::arg("r_on") = 1000, py::arg("u_v") = 1e-14, - py::arg("d") = 10e-9, py::arg("pos_write_threshold") = 0.55, py::arg("neg_write_threshold") = -0.55, py::arg("p") = 1, py::arg("simulate_neighbours") = true); + py::arg("conductance_matrix"), py::arg("device_matrix"), + py::arg("cuda_malloc_heap_size") = 50, py::arg("rel_tol") = 0.1, + py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, + py::arg("pos_voltage_level") = 1.0, py::arg("neg_voltage_level") = -1.0, + py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3, + py::arg("force_adjustment_rel_tol") = 1e-1, + py::arg("force_adjustment_pos_voltage_threshold") = 0, + py::arg("force_adjustment_neg_voltage_threshold") = 0, + py::arg("force_adjustment_voltage") = 0.2, + py::arg("failure_iteration_threshold") = 1000, + py::arg("time_series_resolution") = 1e-4, py::arg("r_off") = 10000, + py::arg("r_on") = 1000, py::arg("u_v") = 1e-14, py::arg("d") = 10e-9, + py::arg("pos_write_threshold") = 0.55, + py::arg("neg_write_threshold") = -0.55, py::arg("p") = 1, + py::arg("simulate_neighbours") = true); - //VTEAM - m.def( + // VTEAM + m.def( "simulate_passive", - [&](at::Tensor conductance_matrix, at::Tensor device_matrix,int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float d, - float k_on, float k_off, float alpha_on, float alpha_off, float v_on, float v_off, float x_on, float x_off, bool sim_neighbors) { - return simulate_passive_VTEAM(conductance_matrix, device_matrix,cuda_malloc_heap_size, rel_tol, - pulse_duration, refactory_period, pos_voltage_level, neg_voltage_level, - timeout, force_adjustment, force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold, - force_adjustment_neg_voltage_threshold,time_series_resolution,r_off,r_on,d, - k_on, k_off, alpha_on, alpha_off, v_on, v_off, x_on, x_off, sim_neighbors); + [&](at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, + float neg_voltage_level, float timeout, float force_adjustment, + float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, float r_off, float r_on, float d, + float k_on, float k_off, float alpha_on, float alpha_off, float v_on, + float v_off, float x_on, float x_off, bool sim_neighbors) { + return simulate_passive_VTEAM( + conductance_matrix, device_matrix, cuda_malloc_heap_size, rel_tol, + pulse_duration, refactory_period, pos_voltage_level, + neg_voltage_level, timeout, force_adjustment, + force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold, + force_adjustment_neg_voltage_threshold, + force_adjustment_voltage, failure_iteration_threshold, + time_series_resolution, + r_off, r_on, d, k_on, k_off, alpha_on, alpha_off, v_on, v_off, x_on, + x_off, sim_neighbors); }, - py::arg("conductance_matrix"), py::arg("device_matrix"),py::arg("cuda_malloc_heap_size")=50, py::arg("rel_tol")=0.1, - py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, py::arg("pos_voltage_level") = 1.0, - py::arg("neg_voltage_level") = -1.0, py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3, - py::arg("force_adjustment_rel_tol") = 1e-1, py::arg("force_adjustment_pos_voltage_threshold") = 0, - py::arg("force_adjustment_neg_voltage_threshold") = 0, py::arg("time_series_resolution") = 1e-10, py::arg("r_off") = 10000, py::arg("r_on") = 1000, py::arg("d") = 3e-9, - py::arg("k_on") =-10, py::arg("k_off") = 5e-4, py::arg("alpha_on") =3, py::arg("alpha_off") = 1, py::arg("v_on") = 0.2, py::arg("v_off") = 0.02, py::arg("x_on") = 0, - py::arg("x_off") = 3e-9, py::arg("simulate_neighbours") = true); + py::arg("conductance_matrix"), py::arg("device_matrix"), + py::arg("cuda_malloc_heap_size") = 50, py::arg("rel_tol") = 0.1, + py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, + py::arg("pos_voltage_level") = 1.0, py::arg("neg_voltage_level") = -1.0, + py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3, + py::arg("force_adjustment_rel_tol") = 1e-1, + py::arg("force_adjustment_pos_voltage_threshold") = 0, + py::arg("force_adjustment_neg_voltage_threshold") = 0, + py::arg("force_adjustment_voltage") = 0.2, + py::arg("failure_iteration_threshold") = 1000, + py::arg("time_series_resolution") = 1e-10, py::arg("r_off") = 10000, + py::arg("r_on") = 1000, py::arg("d") = 3e-9, py::arg("k_on") = -10, + py::arg("k_off") = 5e-4, py::arg("alpha_on") = 3, + py::arg("alpha_off") = 1, py::arg("v_on") = 0.2, py::arg("v_off") = 0.02, + py::arg("x_on") = 0, py::arg("x_off") = 3e-9, + py::arg("simulate_neighbours") = true); - //Stanford_PKU - m.def( + // Stanford_PKU + m.def( "simulate_passive", - [&](at::Tensor conductance_matrix, at::Tensor device_matrix,int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float gap_init, - float g_0, float V_0, float I_0, float read_voltage, float T_init, float R_th, float gamma_init, - float beta, float t_ox, float F_min, float vel_0, float E_a, float a_0, float delta_g_init, + [&](at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, + float neg_voltage_level, float timeout, float force_adjustment, + float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, float r_off, float r_on, float gap_init, + float g_0, float V_0, float I_0, float read_voltage, float T_init, + float R_th, float gamma_init, float beta, float t_ox, float F_min, + float vel_0, float E_a, float a_0, float delta_g_init, float model_switch, float T_crit, float T_smth, bool sim_neighbors) { - return simulate_passive_Stanford_PKU(conductance_matrix, device_matrix,cuda_malloc_heap_size, rel_tol, - pulse_duration, refactory_period, pos_voltage_level, neg_voltage_level, - timeout, force_adjustment, force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold, - force_adjustment_neg_voltage_threshold,time_series_resolution,r_off,r_on, gap_init, - g_0, V_0, I_0, read_voltage, T_init, R_th, gamma_init, beta, t_ox, F_min, vel_0, E_a, a_0, - delta_g_init, model_switch, T_crit, T_smth, sim_neighbors); + return simulate_passive_Stanford_PKU( + conductance_matrix, device_matrix, cuda_malloc_heap_size, rel_tol, + pulse_duration, refactory_period, pos_voltage_level, + neg_voltage_level, timeout, force_adjustment, + force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold, + force_adjustment_neg_voltage_threshold, + force_adjustment_voltage, failure_iteration_threshold, + time_series_resolution, + r_off, r_on, gap_init, g_0, V_0, I_0, read_voltage, T_init, R_th, + gamma_init, beta, t_ox, F_min, vel_0, E_a, a_0, delta_g_init, + model_switch, T_crit, T_smth, sim_neighbors); }, - py::arg("conductance_matrix"), py::arg("device_matrix"),py::arg("cuda_malloc_heap_size")=50, py::arg("rel_tol")=0.1, - py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, py::arg("pos_voltage_level") = 1.0, - py::arg("neg_voltage_level") = -1.0, py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3, - py::arg("force_adjustment_rel_tol") = 1e-1, py::arg("force_adjustment_pos_voltage_threshold") = 0, - py::arg("force_adjustment_neg_voltage_threshold") = 0, py::arg("time_series_resolution") = 1e-10, py::arg("r_off") = 10000, py::arg("r_on") = 1000, py::arg("gap_init") = 2e-10, - py::arg("g_0") = 0.25e-9, py::arg("V_0") = 0.25, py::arg("I_0") = 1000e-6, py::arg("read_voltage") = 0.1, py::arg("T_init") = 298, py::arg("R_th") = 2.1e3, - py::arg("gamma_init") = 16, py::arg("beta") = 0.8, py::arg("t_ox") = 12e-9,py::arg("F_min") = 1.4e9, py::arg("vel_0") = 10, py::arg("E_a") = 0.6, py::arg("a_0") = 0.25e-9, - py::arg("delta_g_init") = 0.02, py::arg("model_switch") = 0, py::arg("T_crit") = 450, py::arg("T_smth") = 500, py::arg("simulate_neighbours") = true); + py::arg("conductance_matrix"), py::arg("device_matrix"), + py::arg("cuda_malloc_heap_size") = 50, py::arg("rel_tol") = 0.1, + py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, + py::arg("pos_voltage_level") = 1.0, py::arg("neg_voltage_level") = -1.0, + py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3, + py::arg("force_adjustment_rel_tol") = 1e-1, + py::arg("force_adjustment_pos_voltage_threshold") = 0, + py::arg("force_adjustment_neg_voltage_threshold") = 0, + py::arg("force_adjustment_voltage") = 0.2, + py::arg("failure_iteration_threshold") = 1000, + py::arg("time_series_resolution") = 1e-10, py::arg("r_off") = 10000, + py::arg("r_on") = 1000, py::arg("gap_init") = 2e-10, + py::arg("g_0") = 0.25e-9, py::arg("V_0") = 0.25, py::arg("I_0") = 1000e-6, + py::arg("read_voltage") = 0.1, py::arg("T_init") = 298, + py::arg("R_th") = 2.1e3, py::arg("gamma_init") = 16, + py::arg("beta") = 0.8, py::arg("t_ox") = 12e-9, py::arg("F_min") = 1.4e9, + py::arg("vel_0") = 10, py::arg("E_a") = 0.6, py::arg("a_0") = 0.25e-9, + py::arg("delta_g_init") = 0.02, py::arg("model_switch") = 0, + py::arg("T_crit") = 450, py::arg("T_smth") = 500, + py::arg("simulate_neighbours") = true); } \ No newline at end of file diff --git a/memtorch/cu/simulate_passive.cuh b/memtorch/cu/simulate_passive.cuh deleted file mode 100644 index 65ec0cc3..00000000 --- a/memtorch/cu/simulate_passive.cuh +++ /dev/null @@ -1 +0,0 @@ -//unused so far diff --git a/memtorch/cu/simulate_passive_kernels.cu b/memtorch/cu/simulate_passive_kernels.cu index d1ff7f71..efaa40ea 100644 --- a/memtorch/cu/simulate_passive_kernels.cu +++ b/memtorch/cu/simulate_passive_kernels.cu @@ -1,13 +1,13 @@ #include "cuda_runtime.h" #include #include +#include #include #include #include +#include #include #include -#include -#include #include #include @@ -20,28 +20,24 @@ __constant__ float positive_voltage; __constant__ float negative_voltage; - -__constant__ int NX; //number of rows -__constant__ int NY; //number of columns -__constant__ int NZ; //number of tiles - -__constant__ float tsr_global; //time series resolution +__constant__ int NX; // number of rows +__constant__ int NY; // number of columns +__constant__ int NZ; // number of tiles +__constant__ float tsr_global; // time series resolution __constant__ float r_off_global; __constant__ float r_on_global; -__constant__ float pulse_dur_global; //pulse duration -__constant__ float writing_tol_global; //rel_tol -__constant__ float res_adjustment; //adjustment made if the resistance is stuck +__constant__ float pulse_dur_global; // pulse duration +__constant__ float writing_tol_global; // rel_tol +__constant__ float res_adjustment; // adjustment made if the resistance is stuck __constant__ float res_adjustment_rel_tol; -//Data_Driven Global variables +// Data_Driven2021 model global variables __constant__ float t_n_global; __constant__ float t_p_global; - __constant__ float s_n_global; __constant__ float s_p_global; __constant__ float r_p_global; __constant__ float r_n_global; - __constant__ float r_p_0; __constant__ float r_p_1; __constant__ float r_n_0; @@ -49,15 +45,14 @@ __constant__ float r_n_1; __constant__ float A_p_global; __constant__ float A_n_global; - - /** - * Cuda kernel to simulate the devices with neighbors + * CUDA kernel to simulate the devices with neighbors * * @param device_matrix the device matrix of conductances * @param current_i the position in i for the current device to simulate * @param current_j the position in j for the current device to simulate - * @param instruction_array array of integers corresponding to the current instruction for the tile being programmed + * @param instruction_array array of integers corresponding to the current + * instruction for the tile being programmed * @param r_n_arr array of r_n values shared across devices * @param s_n_arr array of s_n values shared across devices * @param r_p_arr array of r_n values shared across devices @@ -68,64 +63,69 @@ __constant__ float A_n_global; * @param s_p_half_arr array of s_n/2 values shared across devices */ -__global__ void simulate_device_dd(float *device_matrix, int current_i, int current_j, int *instruction_array, float *r_n_arr, float *s_n_arr, float *r_n_half_arr, float *s_n_half_arr, float *r_p_arr, float *s_p_arr, float *r_p_half_arr, float *s_p_half_arr) -{ +__global__ void simulate_device_dd(float *device_matrix, int current_i, + int current_j, int *instruction_array, + float *r_n_arr, float *s_n_arr, + float *r_n_half_arr, float *s_n_half_arr, + float *r_p_arr, float *s_p_arr, + float *r_p_half_arr, float *s_p_half_arr) { int i = threadIdx.x + blockIdx.x * blockDim.x; // for (int i = 0; i < NX; i++) int j = threadIdx.y + blockIdx.y * blockDim.y; // for (int j = 0; j < NY; j++) int k = threadIdx.z + blockIdx.z * blockDim.z; // for (int k = 0; k < NZ; k++) - if (i < NX && j < NY && k < NZ) - { + if (i < NX && j < NY && k < NZ) { float resistance_; int index = (k * NX * NY) + (j * NX) + i; - if (i == current_i && j == current_j) //if it is the device to program + if (i == current_i && j == current_j) // if it is the device to program { float R0 = 1 / device_matrix[index]; - if (instruction_array[k] == 1) - { - resistance_ = (R0 + (s_p_arr[k] * r_p_arr[k] * (r_p_arr[k] - R0)) * pulse_dur_global) / (1 + s_p_arr[k] * (r_p_arr[k] - R0) * pulse_dur_global); - if(resistance_ < r_p_arr[k]){ - resistance_ = R0; + if (instruction_array[k] == 1) { + resistance_ = (R0 + (s_p_arr[k] * r_p_arr[k] * (r_p_arr[k] - R0)) * + pulse_dur_global) / + (1 + s_p_arr[k] * (r_p_arr[k] - R0) * pulse_dur_global); + if (resistance_ < r_p_arr[k]) { + resistance_ = R0; } - } - else if (instruction_array[k] == 2) - { - resistance_ = (R0 + (s_n_arr[k] * r_n_arr[k] * (r_n_arr[k] - R0)) * pulse_dur_global) / (1 + s_n_arr[k] * (r_n_arr[k] - R0) * pulse_dur_global); - if(resistance_ > r_n_arr[k]){ - resistance_ = R0; + } else if (instruction_array[k] == 2) { + resistance_ = (R0 + (s_n_arr[k] * r_n_arr[k] * (r_n_arr[k] - R0)) * + pulse_dur_global) / + (1 + s_n_arr[k] * (r_n_arr[k] - R0) * pulse_dur_global); + if (resistance_ > r_n_arr[k]) { + resistance_ = R0; } } - if (instruction_array[k] != 0) - { - if(resistance_ >= R0 - res_adjustment_rel_tol*R0 && resistance_ <= R0 + res_adjustment_rel_tol*R0){ - if(instruction_array[k] == 2 && resistance_ < r_off_global) - resistance_ += res_adjustment*resistance_; - else if(instruction_array[k] == 1 && resistance_ > r_on_global) - resistance_ -= res_adjustment*resistance_; + if (instruction_array[k] != 0) { + if (resistance_ >= R0 - res_adjustment_rel_tol * R0 && + resistance_ <= R0 + res_adjustment_rel_tol * R0) { + if (instruction_array[k] == 2 && resistance_ < r_off_global) + resistance_ += res_adjustment * resistance_; + else if (instruction_array[k] == 1 && resistance_ > r_on_global) + resistance_ -= res_adjustment * resistance_; } } - } - else if (i == current_i || j == current_j) //if the device is in the same row or column + } else if (i == current_i || + j == current_j) // if the device is in the same row or column { float R0 = 1 / device_matrix[index]; - if (instruction_array[k] == 1) - { - resistance_ = (R0 + (s_p_half_arr[k] * r_p_half_arr[k] * (r_p_half_arr[k] - R0)) * pulse_dur_global) / (1 + s_p_half_arr[k] * (r_p_half_arr[k] - R0) * pulse_dur_global); - } - else if (instruction_array[k] == 2) - { - resistance_ = (R0 + (s_n_half_arr[k] * r_n_half_arr[k] * (r_n_half_arr[k] - R0)) * pulse_dur_global) / (1 + s_n_half_arr[k] * (r_n_half_arr[k] - R0) * pulse_dur_global); + if (instruction_array[k] == 1) { + resistance_ = + (R0 + (s_p_half_arr[k] * r_p_half_arr[k] * (r_p_half_arr[k] - R0)) * + pulse_dur_global) / + (1 + s_p_half_arr[k] * (r_p_half_arr[k] - R0) * pulse_dur_global); + } else if (instruction_array[k] == 2) { + resistance_ = + (R0 + (s_n_half_arr[k] * r_n_half_arr[k] * (r_n_half_arr[k] - R0)) * + pulse_dur_global) / + (1 + s_n_half_arr[k] * (r_n_half_arr[k] - R0) * pulse_dur_global); } } - if((i == current_i && j == current_j) || (i == current_i || j == current_j)){ - if (instruction_array[k] != 0) - { - //Check to ensure that the resistance remains within possible range - if (resistance_ > r_off_global) - { + if ((i == current_i && j == current_j) || + (i == current_i || j == current_j)) { + if (instruction_array[k] != 0) { + // Check to ensure that the resistance remains within possible range + if (resistance_ > r_off_global) { resistance_ = r_off_global; } - if (resistance_ < r_on_global) - { + if (resistance_ < r_on_global) { resistance_ = r_on_global; } device_matrix[index] = 1 / resistance_; @@ -135,105 +135,127 @@ __global__ void simulate_device_dd(float *device_matrix, int current_i, int curr } /** - * Cuda kernel to simulate the devices without neighbors + * CUDA kernel to simulate the devices without neighbors * * @param device_matrix the device matrix of conductances * @param conductance_matrix the matrix of target conductances * @param force_adjustment_pos_voltage_threshold maximum positive voltage * @param force_adjustment_neg_voltage_threshold minimum negative voltage - + * @param force_adjustment_voltage the amplitude of the force voltage adjustment + * @param failure_iteration_threshold the number of iterations to attempt to program a device for */ -__global__ void simulate_device_dd_no_neighbours(float *device_matrix,float *conductance_matrix,float force_adjustment_pos_voltage_threshold,float force_adjustment_neg_voltage_threshold) -{ +__global__ void +simulate_device_dd_no_neighbours(float *device_matrix, + float *conductance_matrix, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold + ) { int i = threadIdx.x + blockIdx.x * blockDim.x; // for (int i = 0; i < i; i++) int j = threadIdx.y + blockIdx.y * blockDim.y; // for (int j = 0; j < j; j++) int k = threadIdx.z + blockIdx.z * blockDim.z; // for (int k = 0; k < z; k++) - if (i < NX && j < NY && k < NZ) - { - int index = (k * NX * NY) + (j * NX) + i; - float R0 = 1 / device_matrix[index]; - float target_R = 1 / conductance_matrix[index]; - float resistance_; - float s_n; - float r_n; - float s_p; - float r_p; - float neg_voltage_level = negative_voltage; - float pos_voltage_level = positive_voltage; - int iterations = 0; - while ((R0 < target_R - writing_tol_global*target_R || R0 > target_R + writing_tol_global*target_R) && iterations < 1000) - { - iterations += 1; - if(R0 < target_R - writing_tol_global*target_R) - { - s_n = A_n_global * (exp(abs(neg_voltage_level) / t_n_global) - 1); - r_n = r_n_0 + r_n_1 * neg_voltage_level; - resistance_ = (R0 + (s_n * r_n * (r_n - R0)) * pulse_dur_global) / (1 + s_n * (r_n - R0) * pulse_dur_global); - if(resistance_ > r_n) - { - resistance_ = R0; - } - pos_voltage_level = positive_voltage; - if(neg_voltage_level > force_adjustment_neg_voltage_threshold){ - neg_voltage_level -= 0.02; - } - if(resistance_ >= R0 - res_adjustment_rel_tol*R0 && resistance_ <= R0 + res_adjustment_rel_tol*R0){ - resistance_ += res_adjustment*resistance_; - } - R0 = resistance_; - } - else if (R0 > target_R + writing_tol_global*target_R) - { - s_p = A_p_global * (exp(abs(pos_voltage_level) / t_p_global) - 1); - r_p = r_p_0 + r_p_1 * pos_voltage_level; - resistance_ = (R0 + (s_p * r_p * (r_p - R0)) * pulse_dur_global) / (1 + s_p * (r_p - R0) * pulse_dur_global); - neg_voltage_level = negative_voltage; - if (resistance_ < r_p) - { - resistance_ = R0; // Artificially confine the resistance between r_on and r_off - } - if(resistance_ >= R0 - res_adjustment_rel_tol*R0 && resistance_ <= R0 + res_adjustment_rel_tol*R0){ - resistance_ -= res_adjustment*resistance_; - } - if(pos_voltage_level < force_adjustment_pos_voltage_threshold){ - pos_voltage_level += 0.02; - } - R0 = resistance_; - } - } - //Check to ensure that the resistance remains within possible range - if (R0 > r_off_global) - { - R0 = r_off_global; - } - if (R0 < r_on_global) - { - R0 = r_on_global; + if (i < NX && j < NY && k < NZ) { + int index = (k * NX * NY) + (j * NX) + i; + float R0 = 1 / device_matrix[index]; + float target_R = 1 / conductance_matrix[index]; + float resistance_; + float s_n; + float r_n; + float s_p; + float r_p; + float neg_voltage_level = negative_voltage; + float pos_voltage_level = positive_voltage; + int iterations = 0; + while ((R0 < target_R - writing_tol_global * target_R || + R0 > target_R + writing_tol_global * target_R) && + iterations < failure_iteration_threshold) { + iterations += 1; + if (R0 < target_R - writing_tol_global * target_R) { + s_n = A_n_global * (exp(abs(neg_voltage_level) / t_n_global) - 1); + r_n = r_n_0 + r_n_1 * neg_voltage_level; + resistance_ = (R0 + (s_n * r_n * (r_n - R0)) * pulse_dur_global) / + (1 + s_n * (r_n - R0) * pulse_dur_global); + if (resistance_ > r_n) { + resistance_ = R0; } - device_matrix[index] = 1/R0; -} + pos_voltage_level = positive_voltage; + if (neg_voltage_level > force_adjustment_neg_voltage_threshold) { + neg_voltage_level -= force_adjustment_voltage; + } + if (resistance_ >= R0 - res_adjustment_rel_tol * R0 && + resistance_ <= R0 + res_adjustment_rel_tol * R0) { + resistance_ += res_adjustment * resistance_; + } + R0 = resistance_; + } else if (R0 > target_R + writing_tol_global * target_R) { + s_p = A_p_global * (exp(abs(pos_voltage_level) / t_p_global) - 1); + r_p = r_p_0 + r_p_1 * pos_voltage_level; + resistance_ = (R0 + (s_p * r_p * (r_p - R0)) * pulse_dur_global) / + (1 + s_p * (r_p - R0) * pulse_dur_global); + neg_voltage_level = negative_voltage; + if (resistance_ < r_p) { + resistance_ = + R0; // Artificially confine the resistance between r_on and r_off + } + if (resistance_ >= R0 - res_adjustment_rel_tol * R0 && + resistance_ <= R0 + res_adjustment_rel_tol * R0) { + resistance_ -= res_adjustment * resistance_; + } + if (pos_voltage_level < force_adjustment_pos_voltage_threshold) { + pos_voltage_level += force_adjustment_voltage; + } + R0 = resistance_; + } + } + // Check to ensure that the resistance remains within possible range + if (R0 > r_off_global) { + R0 = r_off_global; + } + if (R0 < r_on_global) { + R0 = r_on_global; + } + device_matrix[index] = 1 / R0; + } } /** - * Simulates passive crossbar programming using CUDA. The amplitudes of the voltages are increased or decreased by 0.02V (TODO: make this a parameter) every time the desired resistance is not achieved - * The voltages are reset to their initial values if the desired resistance is passed. - * + * Simulates passive crossbar programming using CUDA. The amplitudes of the + * voltages are increased or decreased by force_adjustment_voltage + * every time the desired resistance is not achieved The voltages are reset to + * their initial values if the desired resistance is passed. * * @param conductance_matrix * @param device_matrix * @param cuda_malloc_heap_size * @param rel_tol acceptable tolerance on the achieved tolerance value * @param pulse_duration duration of the pulses - * @param cuda_malloc_heap_size maximum heap size for Cuda - * @param refactory_period refactory period in between every pulse sent //Not currently used// + * @param cuda_malloc_heap_size maximum heap size for CUDA + * @param refactory_period refactory period in between every pulse sent + * //Not currently used// * @param pos_voltage_level initial positive voltage level * @param neg_voltage_level initial negative voltage level - * @param timeout timeout in between every pulse sent //Not currently used// - * @param force_adjustment percentage of the current resistance value to artificially force towards target resistance if the programming has not changed the resistance significantly (+/- force_adjustment*current_resistance) - * @param force_adjustment_rel_tol percentage of the previous resistance used in determining if the resistance has not changed enough in one cycle to warrant forced adjustment using parameter force_adjustment (if current_resistance previous_resistance +/- force_adjustment_rel_tol*previous_resistance where the signs depend on polarity ) - * @param force_adjustment_pos_voltage_threshold Maximum voltage that the incrementation of 0.02V can lead to (voltage will always be lower than this even with incrementation) - * @param force_adjustment_neg_voltage_threshold Minimum voltage that the decrementation of 0.02V can lead to (voltage will always be higher than this even with decrementation) - * @param time_series_resolution time series resolution used in the simulation //Not currently used// + * @param timeout timeout in between every pulse sent + * //Not currently used// + * @param force_adjustment percentage of the current resistance value to + * artificially force towards target resistance if the programming has not + * changed the resistance significantly (+/- + * force_adjustment * current_resistance) + * @param force_adjustment_rel_tol percentage of the previous resistance used in + * determining if the resistance has not changed enough in one cycle to warrant + * forced adjustment using parameter force_adjustment (if current_resistance + * previous_resistance +/- force_adjustment_rel_tol * previous_resistance where + * the signs depend on polarity) + * @param force_adjustment_pos_voltage_threshold Maximum voltage that the + * incrementation of force_adjustment_voltage can lead to + (the voltage will always be lower than this, even with incrementation) + * @param force_adjustment_neg_voltage_threshold Minimum voltage that the + * decrementation of force_adjustment_voltage can lead to + * (the voltage will always be higher than this, even with decrementation) + * @param force_adjustment_voltage the amplitude of the force voltage adjustment + * @param failure_iteration_threshold the number of iterations to attempt to program a device for + * @param time_series_resolution time series resolution used in the simulation + * //Not currently used// * @param r_off Resistance at HRS * @param r_on Resistance at LRS * @param A_p A_p parameter of Data Driven model @@ -248,277 +270,333 @@ __global__ void simulate_device_dd_no_neighbours(float *device_matrix,float *con * @param a_n r_n parameter of Data Driven model * @param b_p r_p parameter of Data Driven model * @param b_n r_n parameter of Data Driven model - * @param simulate_neighbors boolean to determine if neighbor simulation is necessary + * @param simulate_neighbors boolean to determine if neighbor simulation is + * necessary * @return Tensor of new devices */ -at::Tensor simulate_passive_dd(at::Tensor conductance_matrix, at::Tensor device_matrix,int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution, float r_off, float r_on, float A_p, float A_n, float t_p, float t_n, - float k_p, float k_n, std::vector r_p, std::vector r_n, float a_p, float a_n, float b_p, float b_n, bool sim_neighbors) -{ - - assert(at::cuda::is_available()); - //Assign global variables their value - float original_pos_voltage = pos_voltage_level; - float original_neg_voltage = neg_voltage_level; - const size_t sz = sizeof(float); - const size_t si = sizeof(int); - cudaDeviceSetLimit(cudaLimitMallocHeapSize, - size_t(1024) * size_t(1024) * - size_t(cuda_malloc_heap_size)); - cudaSafeCall(cudaMemcpyToSymbol(res_adjustment_rel_tol, &force_adjustment_rel_tol, sz, size_t(0), cudaMemcpyHostToDevice)); - cudaSafeCall(cudaMemcpyToSymbol(res_adjustment, &force_adjustment, sz, size_t(0), cudaMemcpyHostToDevice)); - cudaSafeCall(cudaMemcpyToSymbol(writing_tol_global, &rel_tol, sz, size_t(0), cudaMemcpyHostToDevice)); - cudaSafeCall(cudaMemcpyToSymbol(tsr_global, &time_series_resolution, sz, size_t(0), cudaMemcpyHostToDevice)); - cudaSafeCall(cudaMemcpyToSymbol(r_off_global, &r_off, sz, size_t(0), cudaMemcpyHostToDevice)); - cudaSafeCall(cudaMemcpyToSymbol(r_on_global, &r_on, sz, size_t(0), cudaMemcpyHostToDevice)); - cudaSafeCall(cudaMemcpyToSymbol(t_p_global, &t_p, sz, size_t(0), cudaMemcpyHostToDevice)); - cudaSafeCall(cudaMemcpyToSymbol(t_n_global, &t_n, sz, size_t(0), cudaMemcpyHostToDevice)); - cudaSafeCall(cudaMemcpyToSymbol(A_p_global, &A_p, sz, size_t(0), cudaMemcpyHostToDevice)); - cudaSafeCall(cudaMemcpyToSymbol(A_n_global, &A_n, sz, size_t(0), cudaMemcpyHostToDevice)); - cudaSafeCall(cudaMemcpyToSymbol(pulse_dur_global, &pulse_duration, sz, size_t(0), cudaMemcpyHostToDevice)); - float *device_matrix_accessor = device_matrix.data_ptr(); - float *conductance_matrix_accessor = conductance_matrix.data_ptr(); - float *device_matrix_device; - float *conductance_matrix_device; - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - const int nz = conductance_matrix.sizes()[0]; //n_tiles - const int ny = conductance_matrix.sizes()[2]; //n_columns - const int nx = conductance_matrix.sizes()[1]; //n_rows - int max_threads = prop.maxThreadsDim[0]; - dim3 grid; - dim3 block; - at::Tensor new_device_matrix; - if (nx * ny * nz > max_threads) - { - int n_grid = ceil_int_div(nx * ny * nz, max_threads); - grid = dim3(n_grid, n_grid, n_grid); - block = dim3(ceil_int_div(nx, n_grid), ceil_int_div(ny, n_grid), ceil_int_div(nz, n_grid)); - } - else - { - grid = dim3(1, 1, 1); - block = dim3(nx, ny, nz); - } - cudaMemcpyToSymbol(NX, &nx, si, size_t(0), cudaMemcpyHostToDevice); - cudaMemcpyToSymbol(NY, &ny, si, size_t(0), cudaMemcpyHostToDevice); - cudaMemcpyToSymbol(NZ, &nz, si, size_t(0), cudaMemcpyHostToDevice); - //boolean set to true when all tiles are programmed and false otherwise - bool all_tiles_programmed = false; - if (!sim_neighbors) - { - cudaMalloc(&conductance_matrix_device, sizeof(float) * nz * nx * ny); - cudaMemcpy(conductance_matrix_device, conductance_matrix_accessor, nz * ny * nx * sizeof(float), cudaMemcpyHostToDevice); - cudaMalloc(&device_matrix_device, sizeof(float) * nz * nx * ny); - cudaMemcpy(device_matrix_device, device_matrix_accessor, nz * ny * nx * sizeof(float), cudaMemcpyHostToDevice); - simulate_device_dd_no_neighbours<<>>(device_matrix_device,conductance_matrix_device,force_adjustment_pos_voltage_threshold,force_adjustment_neg_voltage_threshold); - cudaSafeCall(cudaDeviceSynchronize()); //Erreur ici - cudaMemcpy(device_matrix_accessor, device_matrix_device, nz * ny * nx * sizeof(float), cudaMemcpyDeviceToHost); - new_device_matrix = torch::from_blob(device_matrix_accessor, {nz,nx,ny},at::kFloat); - cudaSafeCall(cudaFree(device_matrix_device)); - cudaSafeCall(cudaFree(conductance_matrix_device)); + at::Tensor simulate_passive_dd( + at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, float neg_voltage_level, + float timeout, float force_adjustment, float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, + float r_off, float r_on, float A_p, float A_n, float t_p, float t_n, + float k_p, float k_n, std::vector r_p, std::vector r_n, + float a_p, float a_n, float b_p, float b_n, bool sim_neighbors) { +assert(at::cuda::is_available()); +// Assign global variables their value +float original_pos_voltage = pos_voltage_level; +float original_neg_voltage = neg_voltage_level; +const size_t sz = sizeof(float); +const size_t si = sizeof(int); +cudaDeviceSetLimit(cudaLimitMallocHeapSize, + size_t(1024) * size_t(1024) * + size_t(cuda_malloc_heap_size)); +cudaSafeCall(cudaMemcpyToSymbol(res_adjustment_rel_tol, + &force_adjustment_rel_tol, sz, size_t(0), + cudaMemcpyHostToDevice)); +cudaSafeCall(cudaMemcpyToSymbol(res_adjustment, &force_adjustment, sz, + size_t(0), cudaMemcpyHostToDevice)); +cudaSafeCall(cudaMemcpyToSymbol(writing_tol_global, &rel_tol, sz, size_t(0), + cudaMemcpyHostToDevice)); +cudaSafeCall(cudaMemcpyToSymbol(tsr_global, &time_series_resolution, sz, + size_t(0), cudaMemcpyHostToDevice)); +cudaSafeCall(cudaMemcpyToSymbol(r_off_global, &r_off, sz, size_t(0), + cudaMemcpyHostToDevice)); +cudaSafeCall(cudaMemcpyToSymbol(r_on_global, &r_on, sz, size_t(0), + cudaMemcpyHostToDevice)); +cudaSafeCall(cudaMemcpyToSymbol(t_p_global, &t_p, sz, size_t(0), + cudaMemcpyHostToDevice)); +cudaSafeCall(cudaMemcpyToSymbol(t_n_global, &t_n, sz, size_t(0), + cudaMemcpyHostToDevice)); +cudaSafeCall(cudaMemcpyToSymbol(A_p_global, &A_p, sz, size_t(0), + cudaMemcpyHostToDevice)); +cudaSafeCall(cudaMemcpyToSymbol(A_n_global, &A_n, sz, size_t(0), + cudaMemcpyHostToDevice)); +cudaSafeCall(cudaMemcpyToSymbol(pulse_dur_global, &pulse_duration, sz, + size_t(0), cudaMemcpyHostToDevice)); +float *device_matrix_accessor = device_matrix.data_ptr(); +float *conductance_matrix_accessor = conductance_matrix.data_ptr(); +float *device_matrix_device; +float *conductance_matrix_device; +cudaDeviceProp prop; +cudaGetDeviceProperties(&prop, 0); +const int NZ = conductance_matrix.sizes()[0]; // n_tiles +const int NY = conductance_matrix.sizes()[2]; // n_columns +const int NX = conductance_matrix.sizes()[1]; // n_rows +int max_threads = prop.maxThreadsDim[0]; +dim3 grid; +dim3 block; +at::Tensor new_device_matrix; +if (NX * NY * NZ > max_threads) { + int n_grid = ceil_int_div(NX * NY * NZ, max_threads); + grid = dim3(n_grid, n_grid, n_grid); + block = dim3(ceil_int_div(NX, n_grid), ceil_int_div(NY, n_grid), + ceil_int_div(NZ, n_grid)); +} else { + grid = dim3(1, 1, 1); + block = dim3(NX, NY, NZ); +} +cudaMemcpyToSymbol(NX, &NX, si, size_t(0), cudaMemcpyHostToDevice); +cudaMemcpyToSymbol(NY, &NY, si, size_t(0), cudaMemcpyHostToDevice); +cudaMemcpyToSymbol(NZ, &NZ, si, size_t(0), cudaMemcpyHostToDevice); +bool all_tiles_programmed = false; +if (!sim_neighbors) { + cudaMalloc(&conductance_matrix_device, sizeof(float) * NZ * NX * NY); + cudaMemcpy(conductance_matrix_device, conductance_matrix_accessor, + NZ * NY * NX * sizeof(float), cudaMemcpyHostToDevice); + cudaMalloc(&device_matrix_device, sizeof(float) * NZ * NX * NY); + cudaMemcpy(device_matrix_device, device_matrix_accessor, + NZ * NY * NX * sizeof(float), cudaMemcpyHostToDevice); + simulate_device_dd_no_neighbours<<>>( + device_matrix_device, conductance_matrix_device, + force_adjustment_pos_voltage_threshold, + force_adjustment_neg_voltage_threshold, + force_adjustment_voltage, + failure_iteration_threshold + ); + cudaSafeCall(cudaDeviceSynchronize()); + cudaMemcpy(device_matrix_accessor, device_matrix_device, + NZ * NY * NX * sizeof(float), cudaMemcpyDeviceToHost); + new_device_matrix = + torch::from_blob(device_matrix_accessor, {NZ, NX, NY}, at::kFloat); + cudaSafeCall(cudaFree(device_matrix_device)); + cudaSafeCall(cudaFree(conductance_matrix_device)); +} else { + int iterations = 0; + float *neg_voltage_levels; + float *pos_voltage_levels; + // Vector passed to the threads to determine the amplitude of the negative + // voltages + neg_voltage_levels = new float[NZ]; + // Vector passed to the threads to determine the amplitude of the positive + // voltages + pos_voltage_levels = new float[NZ]; + // Set all the voltages to be passed to the threads to their initial value + for (int k = 0; k < NZ; k++) { + neg_voltage_levels[k] = original_neg_voltage; + pos_voltage_levels[k] = original_pos_voltage; } - else - { - int iterations = 0; - float *neg_voltage_levels; - float *pos_voltage_levels; - //vector passed to the threads to determine the amplitude of the negative voltages - neg_voltage_levels = new float[nz]; - //vector passed to the threads to determine the amplitude of the positive voltages - pos_voltage_levels = new float[nz]; - //set all the voltages to be passed to the threads to their initial value - for (int k = 0; k < nz; k++) - { - neg_voltage_levels[k] = original_neg_voltage; - pos_voltage_levels[k] = original_pos_voltage; - } - int *instruction_array; - //vector passed to the threads to determine the polarity of the voltage on each tile: 0 -> no voltage, 1 -> positive_voltage, 2 -> negative_voltage - //the size of this vector corresponds to the number of tiles - instruction_array = (int *)malloc(nz * si); - float *r_n_array; - float *r_p_array; - float *s_n_array; - float *s_p_array; - //vector of data driven parameter for each tile as they will be the same for all devices - r_n_array = (float *)malloc(nz * sz); - r_p_array = (float *)malloc(nz * sz); - s_n_array = (float *)malloc(nz * sz); - s_p_array = (float *)malloc(nz * sz); - float *r_n_half_array; - float *r_p_half_array; - float *s_n_half_array; - float *s_p_half_array; - r_n_half_array = (float *)malloc(nz * sz); - r_p_half_array = (float *)malloc(nz * sz); - s_n_half_array = (float *)malloc(nz * sz); - s_p_half_array = (float *)malloc(nz * sz); - int *i_a; - //n_programmed corresponds to the number of tiles programmed - int n_programmed; - float *r_n_arr; - float *r_p_arr; - float *s_n_arr; - float *s_p_arr; - float *r_n_half_arr; - float *r_p_half_arr; - float *s_n_half_arr; - float *s_p_half_arr; - //Allocate the memory for all necessary arrays on the GPU - cudaMalloc(&i_a, sizeof(int) * nz); - cudaMalloc(&r_n_arr, sizeof(float) * nz); - cudaMalloc(&r_p_arr, sizeof(float) * nz); - cudaMalloc(&s_n_arr, sizeof(float) * nz); - cudaMalloc(&s_p_arr, sizeof(float) * nz); - cudaMalloc(&r_n_half_arr, sizeof(float) * nz); - cudaMalloc(&r_p_half_arr, sizeof(float) * nz); - cudaMalloc(&s_n_half_arr, sizeof(float) * nz); - cudaMalloc(&s_p_half_arr, sizeof(float) * nz); - cudaMalloc(&device_matrix_device, sizeof(float) * nz * nx * ny); - - for (int i = 0; i < nx; i++) - { //for all i rows - for (int j = 0; j < ny; j++) - { //for all j columns - all_tiles_programmed = false; - iterations = 0; - while (!all_tiles_programmed) - { - if (iterations == 100) - { //Safety to ensure we do not get stuck with devices TODO: make this a variable - printf("unable to program device(s) at row %d and column %d\n",i,j); - all_tiles_programmed = true; - iterations = 0; - } - iterations++; - n_programmed = 0; - for (int k = 0; k < nz; k++) - { //should not be a very big array (corresponds to the number of tiles) - int index = i + j * nx + k * nx * ny; - if (1/conductance_matrix[k][i][j].item() - rel_tol * 1/conductance_matrix[k][i][j].item() > 1/device_matrix_accessor[index]) - { - instruction_array[k] = 2; - s_n_array[k] = A_n * (exp(abs(neg_voltage_levels[k]) / t_n) - 1); - r_n_array[k] = r_n[0] + r_n[1] * neg_voltage_levels[k]; - s_n_half_array[k] = A_n * (exp(abs(neg_voltage_levels[k] / 2) / t_n) - 1); - r_n_half_array[k] = r_n[0] + r_n[1] * neg_voltage_levels[k] / 2; - s_p_array[k] = 0; - r_p_array[k] = 0; - s_p_half_array[k] = 0; - r_p_half_array[k] = 0; - pos_voltage_levels[k] = original_pos_voltage; - //0.02 hard coded so far - if(neg_voltage_levels[k] > force_adjustment_neg_voltage_threshold){ - neg_voltage_levels[k] -= 0.02; - } - } - else if (1/conductance_matrix[k][i][j].item() + rel_tol * 1/conductance_matrix[k][i][j].item() < 1/(device_matrix_accessor[index])) - { - instruction_array[k] = 1; - s_p_array[k] = A_p * (exp(abs(pos_voltage_levels[k]) / t_p) - 1); - r_p_array[k] = r_p[0] + r_p[1] * pos_voltage_levels[k]; - s_p_half_array[k] = A_p * (exp(abs(pos_voltage_levels[k] / 2) / t_p) - 1); - r_p_half_array[k] = r_p[0] + r_p[1] * pos_voltage_levels[k] / 2; - s_n_array[k] = 0; - r_n_array[k] = 0; - s_n_half_array[k] = 0; - r_n_half_array[k] = 0; - neg_voltage_levels[k] = original_neg_voltage; - if(pos_voltage_levels[k] < force_adjustment_pos_voltage_threshold){ - pos_voltage_levels[k] += 0.02; - } + int *instruction_array; + // Vector passed to the threads to determine the polarity of the voltage on + // each tile: 0 -> no voltage, 1 -> positive_voltage, 2 -> negative_voltage + // the size of this vector corresponds to the number of tiles + instruction_array = (int *)malloc(NZ * si); + float *r_n_array; + float *r_p_array; + float *s_n_array; + float *s_p_array; + // Vector of data driven parameter for each tile as they will be the same + // for all devices + r_n_array = (float *)malloc(NZ * sz); + r_p_array = (float *)malloc(NZ * sz); + s_n_array = (float *)malloc(NZ * sz); + s_p_array = (float *)malloc(NZ * sz); + float *r_n_half_array; + float *r_p_half_array; + float *s_n_half_array; + float *s_p_half_array; + r_n_half_array = (float *)malloc(NZ * sz); + r_p_half_array = (float *)malloc(NZ * sz); + s_n_half_array = (float *)malloc(NZ * sz); + s_p_half_array = (float *)malloc(NZ * sz); + int *i_a; + // n_programmed corresponds to the number of tiles programmed + int n_programmed; + float *r_n_arr; + float *r_p_arr; + float *s_n_arr; + float *s_p_arr; + float *r_n_half_arr; + float *r_p_half_arr; + float *s_n_half_arr; + float *s_p_half_arr; + // Allocate the memory for all necessary arrays on the GPU + cudaMalloc(&i_a, sizeof(int) * NZ); + cudaMalloc(&r_n_arr, sizeof(float) * NZ); + cudaMalloc(&r_p_arr, sizeof(float) * NZ); + cudaMalloc(&s_n_arr, sizeof(float) * NZ); + cudaMalloc(&s_p_arr, sizeof(float) * NZ); + cudaMalloc(&r_n_half_arr, sizeof(float) * NZ); + cudaMalloc(&r_p_half_arr, sizeof(float) * NZ); + cudaMalloc(&s_n_half_arr, sizeof(float) * NZ); + cudaMalloc(&s_p_half_arr, sizeof(float) * NZ); + cudaMalloc(&device_matrix_device, sizeof(float) * NZ * NX * NY); + for (int i = 0; i < NX; i++) { // for all i rows + for (int j = 0; j < NY; j++) { // for all j columns + all_tiles_programmed = false; + iterations = 0; + while (!all_tiles_programmed) { + if (iterations == failure_iteration_threshold) { + printf("unable to program device(s) at row %d and column %d\n", i, + j); + all_tiles_programmed = true; + iterations = 0; + } + iterations++; + n_programmed = 0; + for (int k = 0; k < NZ; k++) { + int index = i + j * NX + k * NX * NY; + if (1 / conductance_matrix[k][i][j].item() - + rel_tol * 1 / conductance_matrix[k][i][j].item() > + 1 / device_matrix_accessor[index]) { + instruction_array[k] = 2; + s_n_array[k] = A_n * (exp(abs(neg_voltage_levels[k]) / t_n) - 1); + r_n_array[k] = r_n[0] + r_n[1] * neg_voltage_levels[k]; + s_n_half_array[k] = + A_n * (exp(abs(neg_voltage_levels[k] / 2) / t_n) - 1); + r_n_half_array[k] = r_n[0] + r_n[1] * neg_voltage_levels[k] / 2; + s_p_array[k] = 0; + r_p_array[k] = 0; + s_p_half_array[k] = 0; + r_p_half_array[k] = 0; + pos_voltage_levels[k] = original_pos_voltage; + if (neg_voltage_levels[k] > + force_adjustment_neg_voltage_threshold) { + neg_voltage_levels[k] -= force_adjustment_voltage; } - else - { - instruction_array[k] = 0; - n_programmed++; + } else if (1 / conductance_matrix[k][i][j].item() + + rel_tol * 1 / + conductance_matrix[k][i][j].item() < + 1 / (device_matrix_accessor[index])) { + instruction_array[k] = 1; + s_p_array[k] = A_p * (exp(abs(pos_voltage_levels[k]) / t_p) - 1); + r_p_array[k] = r_p[0] + r_p[1] * pos_voltage_levels[k]; + s_p_half_array[k] = + A_p * (exp(abs(pos_voltage_levels[k] / 2) / t_p) - 1); + r_p_half_array[k] = r_p[0] + r_p[1] * pos_voltage_levels[k] / 2; + s_n_array[k] = 0; + r_n_array[k] = 0; + s_n_half_array[k] = 0; + r_n_half_array[k] = 0; + neg_voltage_levels[k] = original_neg_voltage; + if (pos_voltage_levels[k] < + force_adjustment_pos_voltage_threshold) { + pos_voltage_levels[k] += force_adjustment_voltage; } + } else { + instruction_array[k] = 0; + n_programmed++; } - if (n_programmed == nz && nz != 0) - { - all_tiles_programmed = true; - iterations = 0; - n_programmed = 0; - } - //The gain in rapidity is probably limited by these transfers of data to the GPU but is still significant. - cudaMemcpy(i_a, instruction_array, nz * sizeof(int), cudaMemcpyHostToDevice); - cudaMemcpy(r_n_arr, r_n_array, nz * sizeof(float), cudaMemcpyHostToDevice); - cudaMemcpy(s_n_arr, s_n_array, nz * sizeof(float), cudaMemcpyHostToDevice); - cudaMemcpy(r_n_half_arr, r_n_half_array, nz * sizeof(float), cudaMemcpyHostToDevice); - cudaMemcpy(s_n_half_arr, s_n_half_array, nz * sizeof(float), cudaMemcpyHostToDevice); - cudaMemcpy(r_p_arr, r_p_array, nz * sizeof(float), cudaMemcpyHostToDevice); - cudaMemcpy(s_p_arr, s_p_array, nz * sizeof(float), cudaMemcpyHostToDevice); - cudaMemcpy(r_p_half_arr, r_p_half_array, nz * sizeof(float), cudaMemcpyHostToDevice); - cudaMemcpy(s_p_half_arr, s_p_half_array, nz * sizeof(float), cudaMemcpyHostToDevice); - cudaMemcpy(device_matrix_device, device_matrix_accessor, nz * ny * nx * sizeof(float), cudaMemcpyHostToDevice); - simulate_device_dd<<>>(device_matrix_device, i, j, i_a, r_n_arr, s_n_arr, r_n_half_arr, s_n_half_arr, r_p_arr, s_p_arr, r_p_half_arr, s_p_half_arr); - cudaSafeCall(cudaDeviceSynchronize()); - cudaMemcpy(device_matrix_accessor, device_matrix_device, nz * ny * nx * sizeof(float), cudaMemcpyDeviceToHost); } + if (n_programmed == NZ && NZ != 0) { + all_tiles_programmed = true; + iterations = 0; + n_programmed = 0; + } + cudaMemcpy(i_a, instruction_array, NZ * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(r_n_arr, r_n_array, NZ * sizeof(float), + cudaMemcpyHostToDevice); + cudaMemcpy(s_n_arr, s_n_array, NZ * sizeof(float), + cudaMemcpyHostToDevice); + cudaMemcpy(r_n_half_arr, r_n_half_array, NZ * sizeof(float), + cudaMemcpyHostToDevice); + cudaMemcpy(s_n_half_arr, s_n_half_array, NZ * sizeof(float), + cudaMemcpyHostToDevice); + cudaMemcpy(r_p_arr, r_p_array, NZ * sizeof(float), + cudaMemcpyHostToDevice); + cudaMemcpy(s_p_arr, s_p_array, NZ * sizeof(float), + cudaMemcpyHostToDevice); + cudaMemcpy(r_p_half_arr, r_p_half_array, NZ * sizeof(float), + cudaMemcpyHostToDevice); + cudaMemcpy(s_p_half_arr, s_p_half_array, NZ * sizeof(float), + cudaMemcpyHostToDevice); + cudaMemcpy(device_matrix_device, device_matrix_accessor, + NZ * NY * NX * sizeof(float), cudaMemcpyHostToDevice); + simulate_device_dd<<>>( + device_matrix_device, i, j, i_a, r_n_arr, s_n_arr, r_n_half_arr, + s_n_half_arr, r_p_arr, s_p_arr, r_p_half_arr, s_p_half_arr); + cudaSafeCall(cudaDeviceSynchronize()); + cudaMemcpy(device_matrix_accessor, device_matrix_device, + NZ * NY * NX * sizeof(float), cudaMemcpyDeviceToHost); } } - new_device_matrix = torch::from_blob(device_matrix_accessor, {nz,nx,ny},at::kFloat); - cudaSafeCall(cudaFree(i_a)); - cudaSafeCall(cudaFree(r_n_arr)); - cudaSafeCall(cudaFree(s_n_arr)); - cudaSafeCall(cudaFree(r_n_half_arr)); - cudaSafeCall(cudaFree(s_n_half_arr)); - cudaSafeCall(cudaFree(r_p_arr)); - cudaSafeCall(cudaFree(s_p_arr)); - cudaSafeCall(cudaFree(r_p_half_arr)); - cudaSafeCall(cudaFree(s_p_half_arr)); - cudaSafeCall(cudaFree(device_matrix_device)); } - cudaStreamSynchronize(at::cuda::getCurrentCUDAStream()); - return new_device_matrix; + new_device_matrix = + torch::from_blob(device_matrix_accessor, {NZ, NX, NY}, at::kFloat); + cudaSafeCall(cudaFree(i_a)); + cudaSafeCall(cudaFree(r_n_arr)); + cudaSafeCall(cudaFree(s_n_arr)); + cudaSafeCall(cudaFree(r_n_half_arr)); + cudaSafeCall(cudaFree(s_n_half_arr)); + cudaSafeCall(cudaFree(r_p_arr)); + cudaSafeCall(cudaFree(s_p_arr)); + cudaSafeCall(cudaFree(r_p_half_arr)); + cudaSafeCall(cudaFree(s_p_half_arr)); + cudaSafeCall(cudaFree(device_matrix_device)); +} +cudaStreamSynchronize(at::cuda::getCurrentCUDAStream()); +return new_device_matrix; } -at::Tensor simulate_passive_linearIonDrift(at::Tensor conductance_matrix, at::Tensor device_matrix, int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution, float r_off, float r_on, float u_v, - float d, float pos_write_threshold, float neg_write_threshold, float p, bool sim_neighbors) -{ - +at::Tensor simulate_passive_linearIonDrift( + at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, float neg_voltage_level, + float timeout, float force_adjustment, float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, + float r_off, float r_on, float u_v, float d, float pos_write_threshold, + float neg_write_threshold, float p, bool sim_neighbors) { assert(at::cuda::is_available()); cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); int max_threads = prop.maxThreadsDim[0]; - throw std::runtime_error(std::string("Failed: simulate_passive_linearIonDrift has yet to be implemented")); + throw std::runtime_error(std::string( + "Failed: simulate_passive_linearIonDrift has yet to be implemented")); return conductance_matrix; } -at::Tensor simulate_passive_Stanford_PKU(at::Tensor conductance_matrix, at::Tensor device_matrix,int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution, float r_off, float r_on, float gap_init, - float g_0, float V_0, float I_0, float read_voltage, float T_init, float R_th, float gamma_init, - float beta, float t_ox, float F_min, float vel_0, float E_a, float a_0, float delta_g_init, - float model_switch, float T_crit, float T_smth, bool sim_neighbors) -{ - +at::Tensor simulate_passive_Stanford_PKU( + at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, float neg_voltage_level, + float timeout, float force_adjustment, float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, + float r_off, float r_on, float gap_init, float g_0, float V_0, float I_0, + float read_voltage, float T_init, float R_th, float gamma_init, float beta, + float t_ox, float F_min, float vel_0, float E_a, float a_0, + float delta_g_init, float model_switch, float T_crit, float T_smth, + bool sim_neighbors) { assert(at::cuda::is_available()); cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); int max_threads = prop.maxThreadsDim[0]; - throw std::runtime_error(std::string("Failed: simulate_passive_Stanford_PKU has yet to be implemented")); + throw std::runtime_error(std::string( + "Failed: simulate_passive_Stanford_PKU has yet to be implemented")); return conductance_matrix; } -at::Tensor simulate_passive_VTEAM(at::Tensor conductance_matrix, at::Tensor device_matrix, int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution, float r_off, float r_on, float d, - float k_on, float k_off, float alpha_on, float alpha_off, float v_on, float v_off, float x_on, float x_off, bool sim_neighbors) -{ - +at::Tensor simulate_passive_VTEAM( + at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, float neg_voltage_level, + float timeout, float force_adjustment, float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, + float r_off, float r_on, float d, float k_on, float k_off, float alpha_on, + float alpha_off, float v_on, float v_off, float x_on, float x_off, + bool sim_neighbors) { assert(at::cuda::is_available()); cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); int max_threads = prop.maxThreadsDim[0]; - throw std::runtime_error(std::string("Failed: simulate_passive_VTEAM has yet to be implemented")); + throw std::runtime_error( + std::string("Failed: simulate_passive_VTEAM has yet to be implemented")); return conductance_matrix; -} +} \ No newline at end of file diff --git a/memtorch/cu/simulate_passive_kernels.cuh b/memtorch/cu/simulate_passive_kernels.cuh index 670ba7a3..1e2883e0 100644 --- a/memtorch/cu/simulate_passive_kernels.cuh +++ b/memtorch/cu/simulate_passive_kernels.cuh @@ -1,28 +1,58 @@ -at::Tensor simulate_passive_dd(at::Tensor conductance_matrix, at::Tensor device_matrix, int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float A_p, float A_n, float t_p, float t_n, - float k_p, float k_n, std::vector r_p, std::vector r_n, float a_p, float a_n, float b_p, float b_n, bool sim_neighbors); +at::Tensor simulate_passive_dd( + at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, float neg_voltage_level, + float timeout, float force_adjustment, float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, + float r_off, float r_on, float A_p, float A_n, float t_p, float t_n, + float k_p, float k_n, std::vector r_p, std::vector r_n, + float a_p, float a_n, float b_p, float b_n, bool sim_neighbors); -at::Tensor simulate_passive_linearIonDrift(at::Tensor conductance_matrix, at::Tensor device_matrix, int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float u_v, - float d,float pos_write_threshold, float neg_write_threshold, float p, bool sim_neighbors); +at::Tensor simulate_passive_linearIonDrift( + at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, float neg_voltage_level, + float timeout, float force_adjustment, float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, + float r_off, float r_on, float u_v, float d, float pos_write_threshold, + float neg_write_threshold, float p, bool sim_neighbors); -at::Tensor simulate_passive_Stanford_PKU(at::Tensor conductance_matrix, at::Tensor device_matrix, int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float gap_init, - float g_0, float V_0, float I_0, float read_voltage, float T_init, float R_th, float gamma_init, - float beta, float t_ox, float F_min, float vel_0, float E_a, float a_0, float delta_g_init, - float model_switch, float T_crit, float T_smth, bool sim_neighbors); - -at::Tensor simulate_passive_VTEAM(at::Tensor conductance_matrix, at::Tensor device_matrix, int cuda_malloc_heap_size, float rel_tol, - float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level, - float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold, - float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float d, - float k_on, float k_off, float alpha_on, float alpha_off, float v_on, float v_off, float x_on, float x_off, bool sim_neighbors); +at::Tensor simulate_passive_Stanford_PKU( + at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, float neg_voltage_level, + float timeout, float force_adjustment, float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, + float r_off, float r_on, float gap_init, float g_0, float V_0, float I_0, + float read_voltage, float T_init, float R_th, float gamma_init, float beta, + float t_ox, float F_min, float vel_0, float E_a, float a_0, + float delta_g_init, float model_switch, float T_crit, float T_smth, + bool sim_neighbors); +at::Tensor simulate_passive_VTEAM( + at::Tensor conductance_matrix, at::Tensor device_matrix, + int cuda_malloc_heap_size, float rel_tol, float pulse_duration, + float refactory_period, float pos_voltage_level, float neg_voltage_level, + float timeout, float force_adjustment, float force_adjustment_rel_tol, + float force_adjustment_pos_voltage_threshold, + float force_adjustment_neg_voltage_threshold, + float force_adjustment_voltage, + int failure_iteration_threshold, + float time_series_resolution, + float r_off, float r_on, float d, float k_on, float k_off, float alpha_on, + float alpha_off, float v_on, float v_off, float x_on, float x_off, + bool sim_neighbors); int countOccurrences(int arr[], int n, int x); \ No newline at end of file diff --git a/memtorch/mn/Conv1d.py b/memtorch/mn/Conv1d.py index 2587d18f..8eba793a 100644 --- a/memtorch/mn/Conv1d.py +++ b/memtorch/mn/Conv1d.py @@ -167,7 +167,7 @@ def __init__( scheme=scheme, tile_shape=tile_shape, use_bindings=use_bindings, - cuda_malloc_heap_size = self.cuda_malloc_heap_size, + cuda_malloc_heap_size=self.cuda_malloc_heap_size, random_crossbar_init=random_crossbar_init, ) self.transform_output = lambda x: x diff --git a/memtorch/version.py b/memtorch/version.py index 21533d04..fb088e84 100644 --- a/memtorch/version.py +++ b/memtorch/version.py @@ -1 +1 @@ -__version__ = "1.1.5-cpu" \ No newline at end of file +__version__ = "1.1.6-cpu" diff --git a/setup.py b/setup.py index 0f9f29da..d28c8fa2 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import find_packages, setup from torch.utils.cpp_extension import include_paths, library_paths -version = "1.1.5" +version = "1.1.6" CUDA = False CUDA_device_idx = 0 # Optional, ignored if CUDA is False