Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of CUDA accelerated passive crossbar programming simulation for the 2021 Data Driven model #125

Merged
merged 38 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
83f4e95
random_crossbar_init implemented
Philippe-Drolet Nov 4, 2021
ddb86ce
Forgot this initialization
Philippe-Drolet Nov 4, 2021
6f4e50a
Merge remote-tracking branch 'origin/master'
victor-yon Nov 11, 2021
6f9da4d
Simulate_passive_optimizationV0.1
Philippe-Drolet Nov 26, 2021
fa6479f
partially working
Philippe-Drolet Dec 7, 2021
5e43d4d
Functional Kernels DD
Philippe-Drolet Dec 21, 2021
816a480
Data_Driven simulate no neighbours implemented implemented
Philippe-Drolet Jan 3, 2022
8dec0e0
Update gitignore.
Philippe-Drolet Jan 3, 2022
1e20254
Initial Testing + Documentation
Philippe-Drolet Jan 14, 2022
2b9bd1d
Adapted Crossbar.py to the current implementation of passive simulation
Philippe-Drolet Feb 6, 2022
e76c91b
Updated pre-commit configs
Philippe-Drolet Feb 6, 2022
cf4aa8f
Reformatted according to the docs
Philippe-Drolet Feb 6, 2022
a3e8218
random_crossbar_init implemented
Philippe-Drolet Nov 4, 2021
cf4fed0
Simulate_passive_optimizationV0.1
Philippe-Drolet Nov 26, 2021
c6460c8
partially working
Philippe-Drolet Dec 7, 2021
833877a
Functional Kernels DD
Philippe-Drolet Dec 21, 2021
9a46e56
Data_Driven simulate no neighbours implemented implemented
Philippe-Drolet Jan 3, 2022
493fbe1
Update gitignore.
Philippe-Drolet Jan 3, 2022
3b5f552
Initial Testing + Documentation
Philippe-Drolet Jan 14, 2022
432f147
Adapted Crossbar.py to the current implementation of passive simulation
Philippe-Drolet Feb 6, 2022
10ad7d3
Updated pre-commit configs
Philippe-Drolet Feb 6, 2022
e1f8629
Reformatted according to the docs
Philippe-Drolet Feb 6, 2022
beebe98
Merge branch 'master' of https://github.com/3it-nano/MemTorch
Philippe-Drolet Feb 6, 2022
55dbdb4
Removed test files
Philippe-Drolet Feb 6, 2022
34a7862
Fixed updated files that should not have been
Philippe-Drolet Feb 6, 2022
de8dc2d
fixed small bug in Crossbar.py
Philippe-Drolet Feb 7, 2022
0db6b0c
Update memtorch/cu/simulate_passive_kernels.cu
Philippe-Drolet Feb 8, 2022
917f0ab
Update setup.py
Philippe-Drolet Feb 8, 2022
76c218d
Update memtorch/bh/crossbar/Crossbar.py
Philippe-Drolet Feb 8, 2022
bb1b76a
Update memtorch/bh/crossbar/Crossbar.py
Philippe-Drolet Feb 8, 2022
426e551
Update memtorch/version.py
Philippe-Drolet Feb 8, 2022
2233da0
Update memtorch/cu/simulate_passive_kernels.cu
Philippe-Drolet Feb 8, 2022
9b8c226
Update memtorch/cu/simulate_passive_kernels.cu
Philippe-Drolet Feb 8, 2022
054e5dc
Update memtorch/cu/simulate_passive_kernels.cu
Philippe-Drolet Feb 8, 2022
cd7cd87
Update based on recommendations
Philippe-Drolet Feb 8, 2022
aab336b
remove accidental commit of utils
Philippe-Drolet Feb 8, 2022
f08b1e1
Delete train.py
Philippe-Drolet Feb 8, 2022
e742b7b
Delete tmp_data
Philippe-Drolet Feb 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions memtorch/bh/crossbar/Crossbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@
import torch.nn as nn

import memtorch
from memtorch.bh.memristor import Data_Driven2021

if "cpu" not in memtorch.__version__:
import memtorch_cuda_bindings

from .Tile import gen_tiles

CUDA_supported_memristor_models = [Data_Driven2021]


@unique
class Scheme(Enum):
Expand Down Expand Up @@ -48,14 +54,17 @@ def __init__(
shape,
tile_shape=None,
use_bindings=True,
cuda_malloc_heap_size=50,
random_crossbar_init=False,
):
self.memristor_model_params = memristor_model_params
self.time_series_resolution = memristor_model_params.get(
"time_series_resolution"
)
self.device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda")
self.tile_shape = tile_shape
self.use_bindings = use_bindings
self.cuda_malloc_heap_size = cuda_malloc_heap_size
if hasattr(memristor_model_params, "r_off"):
self.r_off_mean = memristor_model_params["r_off"]
if callable(self.r_off_mean):
Expand Down Expand Up @@ -201,7 +210,6 @@ def write_conductance_matrix(
)
else:
raise Exception("Unsupported crossbar shape.")

if self.tile_shape is not None:
conductance_matrix, tiles_map = gen_tiles(
conductance_matrix,
Expand Down Expand Up @@ -231,27 +239,53 @@ def write_conductance_matrix(
)
self.update(from_devices=False)
else:
if self.tile_shape is not None:
for i in range(0, self.devices.shape[0]):
for j in range(0, self.devices.shape[1]):
for k in range(0, self.devices.shape[2]):
if (
self.use_bindings
and type(self.devices.any()) in CUDA_supported_memristor_models
and "cpu" not in memtorch.__version__
):
device_matrix = torch.FloatTensor(self.g_np(self.devices))
device_matrix_aug = device_matrix
conductance_matrix_aug = conductance_matrix
if (
len(device_matrix.shape) == 2
): # To ensure compatibility with CUDA code
device_matrix_aug = device_matrix[:, :, None]
conductance_matrix_aug = conductance_matrix[:, :, None]

self.conductance_matrix = memtorch_cuda_bindings.simulate_passive(
conductance_matrix_aug,
device_matrix_aug,
self.cuda_malloc_heap_size,
**programming_routine_params,
**self.memristor_model_params
)
self.max_abs_conductance = (
torch.abs(self.conductance_matrix).flatten().max()
)
self.update(from_devices=False)
else:
if self.tile_shape is not None:
for i in range(0, self.devices.shape[0]):
for j in range(0, self.devices.shape[1]):
for k in range(0, self.devices.shape[2]):
self.devices = programming_routine(
self,
(i, j, k),
conductance_matrix[i][j][k],
**programming_routine_params
)
else:
for i in range(0, self.rows):
for j in range(0, self.columns):
self.devices = programming_routine(
self,
(i, j, k),
conductance_matrix[i][j][k],
(i, j),
conductance_matrix[i][j],
**programming_routine_params
)
else:
for i in range(0, self.rows):
for j in range(0, self.columns):
self.devices = programming_routine(
self,
(i, j),
conductance_matrix[i][j],
**programming_routine_params
)

self.update(from_devices=True)
self.update(from_devices=True)


def init_crossbar(
Expand All @@ -266,6 +300,7 @@ def init_crossbar(
scheme=Scheme.DoubleColumn,
tile_shape=(128, 128),
use_bindings=True,
cuda_malloc_heap_size=50,
random_crossbar_init=False,
):
"""Method to initialise and construct memristive crossbars.
Expand Down Expand Up @@ -319,6 +354,7 @@ def init_crossbar(
channel_weights.shape,
tile_shape,
use_bindings=use_bindings,
cuda_malloc_heap_size=cuda_malloc_heap_size,
random_crossbar_init=random_crossbar_init,
)
)
Expand All @@ -329,6 +365,7 @@ def init_crossbar(
channel_weights.shape,
tile_shape,
use_bindings=use_bindings,
cuda_malloc_heap_size=cuda_malloc_heap_size,
random_crossbar_init=random_crossbar_init,
)
)
Expand Down Expand Up @@ -413,6 +450,7 @@ def out(crossbars, operation, idx=(0, 1), **kwargs):
channel_weights.shape,
tile_shape,
use_bindings=use_bindings,
random_crossbar_init=random_crossbar_init,
)
)
conductance_matrix = mapping_routine(
Expand All @@ -437,6 +475,7 @@ def out(crossbars, operation, idx=(0, 1), **kwargs):
weights.shape,
tile_shape,
use_bindings=use_bindings,
random_crossbar_init=random_crossbar_init,
)
)
conductance_matrix = mapping_routine(
Expand Down
1 change: 1 addition & 0 deletions memtorch/bh/memristor/Data_Driven2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import memtorch
from memtorch.utils import clip

from .Memristor import Memristor as Memristor


Expand Down
2 changes: 2 additions & 0 deletions memtorch/cu/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
#include "inference.h"
#include "solve_passive.h"
#include "tile_matmul.h"
#include "simulate_passive.h"


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
gen_tiles_bindings_gpu(m);
tile_matmul_bindings(m);
inference_bindings(m);
simulate_passive_bindings(m);
solve_passive_bindings(m);
}
109 changes: 109 additions & 0 deletions memtorch/cu/simulate_passive.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include <ATen/ATen.h>
#include <cmath>
#include <torch/extension.h>

#include <Eigen/Core>

#include <Eigen/SparseCore>

#include <Eigen/SparseLU>

#include "simulate_passive_kernels.cuh"

//Default values
std::vector<float> r_p{2699.2336, -672.930205};
std::vector<float> r_n{649.413746, -1474.32358};

void simulate_passive_bindings(py::module_ &m) {

//Data_Driven2021 model
m.def(
"simulate_passive",
[&](at::Tensor conductance_matrix, at::Tensor device_matrix,int cuda_malloc_heap_size, float rel_tol,
float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level,
float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold,
float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float A_p, float A_n, float t_p, float t_n,
float k_p, float k_n, std::vector<float> r_p, std::vector<float> r_n, float a_p, float a_n, float b_p, float b_n, bool sim_neighbors) {
return simulate_passive_dd(conductance_matrix, device_matrix,cuda_malloc_heap_size, rel_tol,
pulse_duration, refactory_period, pos_voltage_level, neg_voltage_level,
timeout, force_adjustment, force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold,
force_adjustment_neg_voltage_threshold,time_series_resolution,r_off,r_on,A_p,A_n,t_p,t_n,k_p,k_n,r_p,r_n,a_p,a_n,b_p,b_n, sim_neighbors);
},
py::arg("conductance_matrix"), py::arg("device_matrix"),py::arg("cuda_malloc_heap_size")=50, py::arg("rel_tol")=0.1,
py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, py::arg("pos_voltage_level") = 1.0,
py::arg("neg_voltage_level") = -1.0, py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3,
py::arg("force_adjustment_rel_tol") = 1e-1, py::arg("force_adjustment_pos_voltage_threshold") = 0,
py::arg("force_adjustment_neg_voltage_threshold") = 0, py::arg("time_series_resolution") = 1e-10, py::arg("r_off") = 10000, py::arg("r_on") = 1000, py::arg("A_p") = 600.10075,
py::arg("A_n")=-34.5988399, py::arg("t_p") = -0.0212028, py::arg("t_n") = -0.05343997, py::arg("k_p") = 5.11e-4, py::arg("k_n") = 1.17e-3,
py::arg("r_p") = r_p, py::arg("r_n") = r_n, py::arg("a_p")=0.32046175,
py::arg("a_n")=0.32046175, py::arg("b_p")=2.71689828, py::arg("b_n")=2.71689828, py::arg("simulate_neighbours") = true); //Maybe change order of simulate_neighbours to before memristor args
Philippe-Drolet marked this conversation as resolved.
Show resolved Hide resolved

//Linear Ion Drift
m.def(
"simulate_passive",
[&](at::Tensor conductance_matrix, at::Tensor device_matrix,int cuda_malloc_heap_size, float rel_tol,
float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level,
float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold,
float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float u_v,
float d,float pos_write_threshold, float neg_write_threshold, float p, bool sim_neighbors) {
return simulate_passive_linearIonDrift(conductance_matrix, device_matrix,cuda_malloc_heap_size, rel_tol,
pulse_duration, refactory_period, pos_voltage_level, neg_voltage_level,
timeout, force_adjustment, force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold,
force_adjustment_neg_voltage_threshold,time_series_resolution,r_off,r_on, u_v,
d, pos_write_threshold, neg_write_threshold, p,sim_neighbors);
},
py::arg("conductance_matrix"), py::arg("device_matrix"),py::arg("cuda_malloc_heap_size")=50, py::arg("rel_tol")=0.1,
py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, py::arg("pos_voltage_level") = 1.0,
py::arg("neg_voltage_level") = -1.0, py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3,
py::arg("force_adjustment_rel_tol") = 1e-1, py::arg("force_adjustment_pos_voltage_threshold") = 0,
py::arg("force_adjustment_neg_voltage_threshold") = 0, py::arg("time_series_resolution") = 1e-4, py::arg("r_off") = 10000, py::arg("r_on") = 1000, py::arg("u_v") = 1e-14,
py::arg("d") = 10e-9, py::arg("pos_write_threshold") = 0.55, py::arg("neg_write_threshold") = -0.55, py::arg("p") = 1, py::arg("simulate_neighbours") = true);

//VTEAM
m.def(
"simulate_passive",
[&](at::Tensor conductance_matrix, at::Tensor device_matrix,int cuda_malloc_heap_size, float rel_tol,
float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level,
float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold,
float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float d,
float k_on, float k_off, float alpha_on, float alpha_off, float v_on, float v_off, float x_on, float x_off, bool sim_neighbors) {
return simulate_passive_VTEAM(conductance_matrix, device_matrix,cuda_malloc_heap_size, rel_tol,
pulse_duration, refactory_period, pos_voltage_level, neg_voltage_level,
timeout, force_adjustment, force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold,
force_adjustment_neg_voltage_threshold,time_series_resolution,r_off,r_on,d,
k_on, k_off, alpha_on, alpha_off, v_on, v_off, x_on, x_off, sim_neighbors);
},
py::arg("conductance_matrix"), py::arg("device_matrix"),py::arg("cuda_malloc_heap_size")=50, py::arg("rel_tol")=0.1,
py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, py::arg("pos_voltage_level") = 1.0,
py::arg("neg_voltage_level") = -1.0, py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3,
py::arg("force_adjustment_rel_tol") = 1e-1, py::arg("force_adjustment_pos_voltage_threshold") = 0,
py::arg("force_adjustment_neg_voltage_threshold") = 0, py::arg("time_series_resolution") = 1e-10, py::arg("r_off") = 10000, py::arg("r_on") = 1000, py::arg("d") = 3e-9,
py::arg("k_on") =-10, py::arg("k_off") = 5e-4, py::arg("alpha_on") =3, py::arg("alpha_off") = 1, py::arg("v_on") = 0.2, py::arg("v_off") = 0.02, py::arg("x_on") = 0,
py::arg("x_off") = 3e-9, py::arg("simulate_neighbours") = true);

//Stanford_PKU
m.def(
"simulate_passive",
[&](at::Tensor conductance_matrix, at::Tensor device_matrix,int cuda_malloc_heap_size, float rel_tol,
float pulse_duration, float refactory_period, float pos_voltage_level, float neg_voltage_level,
float timeout, float force_adjustment, float force_adjustment_rel_tol, float force_adjustment_pos_voltage_threshold,
float force_adjustment_neg_voltage_threshold, float time_series_resolution , float r_off, float r_on, float gap_init,
float g_0, float V_0, float I_0, float read_voltage, float T_init, float R_th, float gamma_init,
float beta, float t_ox, float F_min, float vel_0, float E_a, float a_0, float delta_g_init,
float model_switch, float T_crit, float T_smth, bool sim_neighbors) {
return simulate_passive_Stanford_PKU(conductance_matrix, device_matrix,cuda_malloc_heap_size, rel_tol,
pulse_duration, refactory_period, pos_voltage_level, neg_voltage_level,
timeout, force_adjustment, force_adjustment_rel_tol, force_adjustment_pos_voltage_threshold,
force_adjustment_neg_voltage_threshold,time_series_resolution,r_off,r_on, gap_init,
g_0, V_0, I_0, read_voltage, T_init, R_th, gamma_init, beta, t_ox, F_min, vel_0, E_a, a_0,
delta_g_init, model_switch, T_crit, T_smth, sim_neighbors);
},
py::arg("conductance_matrix"), py::arg("device_matrix"),py::arg("cuda_malloc_heap_size")=50, py::arg("rel_tol")=0.1,
py::arg("pulse_duration") = 1e-3, py::arg("refactory_period") = 0, py::arg("pos_voltage_level") = 1.0,
py::arg("neg_voltage_level") = -1.0, py::arg("timeout") = 5, py::arg("force_adjustment") = 1e-3,
py::arg("force_adjustment_rel_tol") = 1e-1, py::arg("force_adjustment_pos_voltage_threshold") = 0,
py::arg("force_adjustment_neg_voltage_threshold") = 0, py::arg("time_series_resolution") = 1e-10, py::arg("r_off") = 10000, py::arg("r_on") = 1000, py::arg("gap_init") = 2e-10,
py::arg("g_0") = 0.25e-9, py::arg("V_0") = 0.25, py::arg("I_0") = 1000e-6, py::arg("read_voltage") = 0.1, py::arg("T_init") = 298, py::arg("R_th") = 2.1e3,
py::arg("gamma_init") = 16, py::arg("beta") = 0.8, py::arg("t_ox") = 12e-9,py::arg("F_min") = 1.4e9, py::arg("vel_0") = 10, py::arg("E_a") = 0.6, py::arg("a_0") = 0.25e-9,
py::arg("delta_g_init") = 0.02, py::arg("model_switch") = 0, py::arg("T_crit") = 450, py::arg("T_smth") = 500, py::arg("simulate_neighbours") = true);
}
1 change: 1 addition & 0 deletions memtorch/cu/simulate_passive.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
//unused so far
Philippe-Drolet marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions memtorch/cu/simulate_passive.h
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
void simulate_passive_bindings(py::module_ &m);
Loading