Skip to content

Commit

Permalink
Updated CUDA Bindings for Passive Crossbar Inference Routines (#123)
Browse files Browse the repository at this point in the history
Updated CUDA bindings for passive crossbar inference routines to avoid OOM errors (#120).
  • Loading branch information
coreylammie committed Feb 1, 2022
1 parent 913fce1 commit aec1e25
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 333 deletions.
2 changes: 1 addition & 1 deletion memtorch/bh/memristor/LinearIonDrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def dxdt(self, current):
"""
return (
self.u_v
* (self.r_on / (self.d ** 2))
* (self.r_on / (self.d**2))
* current
* memtorch.bh.memristor.window.Jogelkar(self.x, self.p)
)
Expand Down
2 changes: 1 addition & 1 deletion memtorch/cpp/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ T det_integral(at::Tensor tensor, T overflow_rate, T min, T max) {
return ceil(
log2(data_ptr[std::min<int>((int)round(overflow_rate * tensor_numel),
tensor_numel - 1)] +
1e-12));
1e-12f));
}
}

Expand Down
28 changes: 22 additions & 6 deletions memtorch/cu/quantize.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
__device__ float det_integral(float *tensor, int tensor_numel,
#ifndef _QUANTIZE_
#define _QUANTIZE_

__device__ inline float det_integral(float *tensor, int tensor_numel,
float overflow_rate, float min, float max) {
if ((min != NULL) || (max != NULL)) {
float max_bound;
Expand All @@ -13,18 +16,17 @@ __device__ float det_integral(float *tensor, int tensor_numel,
tensor[0] = max_bound;
}
}
return ceilf(
log2f(tensor[(int)round(overflow_rate * tensor_numel)] + 1e-12f));
return ceilf(log2f(tensor[min_((int)round(overflow_rate * tensor_numel), tensor_numel - 1)] + 1e-12f));
}

__device__ float det_sf(float *tensor, int tensor_numel, int bits,
__device__ inline float det_sf(float *tensor, int tensor_numel, int bits,
float overflow_rate, float min, float max) {
assert(overflow_rate <= 1.0);
sort_<float>(tensor, tensor_numel);
return 1 - bits + det_integral(tensor, tensor_numel, overflow_rate, min, max);
}

__device__ Eigen::VectorXf linear_quantize(Eigen::VectorXf tensor, float sf,
__device__ inline Eigen::VectorXf linear_quantize(Eigen::VectorXf tensor, float sf,
int bits, float overflow_rate) {
float delta = powf(2.0f, sf);
float bound = powf(2.0f, bits - 1);
Expand All @@ -36,9 +38,21 @@ __device__ Eigen::VectorXf linear_quantize(Eigen::VectorXf tensor, float sf,
return x_;
}
});
} // To remove overflow rate TODO!

__device__ inline void
linear_quantize(float *tensor, int i, float *sf, int bits) {
float delta = powf(2.0f, sf[0]);
float bound = powf(2.0f, bits - 1);
float x_ = clamp_<float>(floorf((tensor[i] / delta) + 0.5f), -bound, bound - 1) * delta;
if (isnan(x_)) {
tensor[i] = 0.0f;
} else {
tensor[i] = x_;
}
}

__device__ Eigen::VectorXf quantize(Eigen::VectorXf tensor, int bits,
__device__ inline Eigen::VectorXf quantize(Eigen::VectorXf tensor, int bits,
float overflow_rate, int quant_method) {
if (quant_method == 0) {
// linear
Expand Down Expand Up @@ -75,3 +89,5 @@ __device__ Eigen::VectorXf quantize(Eigen::VectorXf tensor, int bits,
return tensor;
}
}

#endif
23 changes: 20 additions & 3 deletions memtorch/cu/solve_passive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,28 @@ void solve_passive_bindings(py::module_ &m) {
m.def(
"solve_passive",
[&](at::Tensor conductance_matrix, at::Tensor V_WL, at::Tensor V_BL,
int ADC_resolution, float overflow_rate, int quant_method,
float R_source, float R_line, bool det_readout_currents) {
return solve_passive(conductance_matrix, V_WL, V_BL, R_source, R_line,
return solve_passive(conductance_matrix, V_WL, V_BL, ADC_resolution,
overflow_rate, quant_method, R_source, R_line,
det_readout_currents);
},
py::arg("conductance_matrix"), py::arg("V_WL"), py::arg("V_BL"),
py::arg("R_source"), py::arg("R_line"),
py::arg("det_readout_currents") = true);
py::arg("ADC_resolution") = -1, py::arg("overflow_rate") = -1,
py::arg("quant_method") = -1, py::arg("R_source"),
py::arg("R_line"), py::arg("det_readout_currents") = true);

m.def(
"solve_passive",
[&](at::Tensor conductance_matrix, at::Tensor V_WL, at::Tensor V_BL,
int ADC_resolution, float overflow_rate, int quant_method,
float R_source, float R_line, bool det_readout_currents) {
return solve_passive(conductance_matrix, V_WL, V_BL, ADC_resolution,
overflow_rate, quant_method, R_source, R_line,
det_readout_currents);
},
py::arg("conductance_matrix"), py::arg("V_WL"), py::arg("V_BL"),
py::arg("ADC_resolution"), py::arg("overflow_rate"),
py::arg("quant_method"), py::arg("R_source"),
py::arg("R_line"), py::arg("det_readout_currents") = true);
}
156 changes: 103 additions & 53 deletions memtorch/cu/solve_passive_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
#include "solve_passive.h"
#include "solve_sparse_linear.h"
#include "utils.cuh"
#include "quantize.cuh"

class Triplet {
public:
__host__ __device__ Triplet() : m_row(0), m_col(0), m_value(0) {}

__host__ __device__ Triplet(int i, int j, float v)
: m_row(i), m_col(j), m_value(v) {}

__host__ __device__ const int &row() { return m_row; }
__host__ __device__ const int &col() { return m_col; }
__host__ __device__ const float &value() { return m_value; }
Expand Down Expand Up @@ -64,17 +63,20 @@ __global__ void gen_ABE_kernel(
conductance_matrix_accessor[i][0] + 1.0f / R_source +
1.0f / R_line);
}
} else if (j == (n - 1)) {
if (R_line != 0) {
sparse_element(i * n, i * n,
conductance_matrix_accessor[i][0] + 1.0f / R_line);
}
} else {
ABCD_matrix[index] = sparse_element(0, 0, 0.0f);
}
index++;
if (R_line == 0) {
ABCD_matrix[index] = sparse_element(i * n + j, i * n + j,
conductance_matrix_accessor[i][j]);
} else {
ABCD_matrix[index] =
sparse_element(i * n + j, i * n + j,
conductance_matrix_accessor[i][j] + 2.0f / R_line);
if (R_line == 0) {
ABCD_matrix[index] = sparse_element(i * n + j, i * n + j,
conductance_matrix_accessor[i][j]);
} else {
ABCD_matrix[index] =
sparse_element(i * n + j, i * n + j,
conductance_matrix_accessor[i][j] + 2.0f / R_line);
}
}
index++;
if (j < n - 1) {
Expand Down Expand Up @@ -112,12 +114,12 @@ __global__ void gen_CDE_kernel(
torch::PackedTensorAccessor32<float, 2> conductance_matrix_accessor,
float *V_WL_accessor, float *V_BL_accessor, int m, int n, float R_source,
float R_line, sparse_element *ABCD_matrix, float *E_matrix) {
int j = threadIdx.x + blockIdx.x * blockDim.x; // for (int j = 0; j < n; j++)
int i = threadIdx.y + blockIdx.y * blockDim.y; // for (int i = 0; i < m; i++)
if (j < n && i < m) {
int index = (5 * m * n) + ((j * m + i) * 4);
int j = threadIdx.x + blockIdx.x * blockDim.x; // for (int j = 0; j < m; j++)
int i = threadIdx.y + blockIdx.y * blockDim.y; // for (int i = 0; i < n; i++)
if (j < m && i < n) {
int index = (5 * m * n) + 4 * (j * n + i);
// D matrix
if (i == 0) {
if (j == 0) {
// E matrix (partial)
if (E_matrix != NULL) {
if (R_source == 0) {
Expand All @@ -127,77 +129,106 @@ __global__ void gen_CDE_kernel(
}
}
if (R_line == 0) {
ABCD_matrix[index] = sparse_element(m * n + (j * m), m * n + j,
-conductance_matrix_accessor[0][j]);
ABCD_matrix[index] = sparse_element(m * n + (i * m), m * n + i,
-conductance_matrix_accessor[i][0]);
index++;
ABCD_matrix[index] = sparse_element(m * n + (j * m), m * n + j + n, 0);
ABCD_matrix[index] = sparse_element(m * n + (i * m), m * n + i + n, 0);
} else {
ABCD_matrix[index] =
sparse_element(m * n + (j * m), m * n + j,
-1.0f / R_line - conductance_matrix_accessor[0][j]);
sparse_element(m * n + (i * m), m * n + i,
-1.0f / R_line - conductance_matrix_accessor[i][0]);
index++;
ABCD_matrix[index] =
sparse_element(m * n + (j * m), m * n + j + n, 1.0f / R_line);
sparse_element(m * n + (i * m), m * n + i + n, 1.0f / R_line);
}
index++;
ABCD_matrix[index] = sparse_element(0, 0, 0.0f);
} else if (i < m - 1) {
} else if (j < (m - 1)) {
if (R_line == 0) {
ABCD_matrix[index] =
sparse_element(m * n + (j * m) + i, m * n + (n * (i - 1)) + j, 0);
ABCD_matrix[index] = sparse_element(
m * n + (i * m) + j, m * n + (n * (j - 1)) + i, 0);
index++;
ABCD_matrix[index] =
sparse_element(m * n + (j * m) + i, m * n + (n * (i + 1)) + j, 0);
ABCD_matrix[index] = sparse_element(
m * n + (i * m) + j, m * n + (n * j) + i, -conductance_matrix_accessor[i][j]);
index++;
ABCD_matrix[index] =
sparse_element(m * n + (j * m) + i, m * n + (n * i) + j,
-conductance_matrix_accessor[i][j]);
ABCD_matrix[index] = sparse_element(
m * n + (i * m) + j, m * n + (n * (j + 1)) + i, 0);
} else {
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + i, m * n + (n * (i - 1)) + j, 1.0f / R_line);
m * n + (i * m) + j, m * n + (n * (j - 1)) + i, 1.0f / R_line);
index++;
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + i, m * n + (n * (i + 1)) + j, 1.0f / R_line);
m * n + (i * m) + j, m * n + (n * j) + i, -conductance_matrix_accessor[i][j] - 2.0f / R_line);
index++;
ABCD_matrix[index] =
sparse_element(m * n + (j * m) + i, m * n + (n * i) + j,
-conductance_matrix_accessor[i][j] - 2.0f / R_line);
ABCD_matrix[index] = sparse_element(
m * n + (i * m) + j, m * n + (n * (j + 1)) + i, 1.0f / R_line);
}
} else {
if (R_line == 0) {
ABCD_matrix[index] = sparse_element(m * n + (j * m) + m - 1,
m * n + (n * (m - 2)) + j, 0);
} else {
if (R_line != 0) {
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + m - 1, m * n + (n * (m - 2)) + j, 1 / R_line);
m * n + (i * m) + m - 1, m * n + (n * (j - 1)) + i, 1 / R_line);
} else {
ABCD_matrix[index] = sparse_element(m * n + (i * m) + m - 1, m * n + (n * (j - 1)) + i, 0.0f);
}
index++;
if (R_source == 0) {
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j,
-conductance_matrix_accessor[m - 1][j] - 1.0f / R_line);
m * n + (i * m) + m - 1, m * n + (n * j) + i,
-conductance_matrix_accessor[i][m - 1] - 1.0f / R_line);
} else if (R_line == 0) {
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j,
-1.0f / R_source - conductance_matrix_accessor[m - 1][j]);
m * n + (i * m) + m - 1, m * n + (n * j) + i,
-1.0f / R_source - conductance_matrix_accessor[i][m - 1]);
} else {
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j,
-1.0f / R_source - conductance_matrix_accessor[m - 1][j] -
m * n + (i * m) + m - 1, m * n + (n * j) + i,
-1.0f / R_source - conductance_matrix_accessor[i][m - 1] -
1.0f / R_line);
}
index++;
ABCD_matrix[index] = sparse_element(0, 0, 0.0f);
}
index++;
// C matrix
ABCD_matrix[index] = sparse_element(j * m + i + (m * n), n * i + j,
ABCD_matrix[index] = sparse_element((m * n) + (i * m) + j, n * j + i,
conductance_matrix_accessor[i][j]);
}
}

__global__ void
construct_V_applied(torch::PackedTensorAccessor32<float, 2> V_applied_accessor,
det_sf_kernel(float *tensor, int numel, int bits, float overflow_rate, float* sf) {
float *tensor_copy;
tensor_copy = (float *)malloc(numel * sizeof(float));
#pragma unroll 4
for (int i = 0; i < numel; i++) {
tensor_copy[i] = tensor[i];
}
sf[0] = det_sf(tensor_copy, numel, bits, overflow_rate, NULL, NULL);
free(tensor_copy);
}

__global__ void
quantize_kernel(float *tensor, int numel, int bits, float* sf, int quant_method) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
if (quant_method == 0) {
// linear
linear_quantize(tensor, i, sf, bits);
} else if (quant_method == 1) {
// log
bool s = tensor[i] >= 0.0f;
tensor[i] = max_<float>(logf(abs_<float>(tensor[i])), 1e-20f);
linear_quantize(tensor, i, sf, bits);
if (s) {
tensor[i] = expf(tensor[i]);
} else {
tensor[i] = -expf(tensor[i]);
}
}
}
}

__global__ void
construct_V_applied_kernel(torch::PackedTensorAccessor32<float, 2> V_applied_accessor,
float *V_accessor, int m, int n) {
int i = threadIdx.x + blockIdx.x * blockDim.x; // for (int i = 0; i < m; i++)
int j = threadIdx.y + blockIdx.y * blockDim.y; // for (int j = 0; j < n; j++)
Expand All @@ -208,7 +239,9 @@ construct_V_applied(torch::PackedTensorAccessor32<float, 2> V_applied_accessor,
}

at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL,
at::Tensor V_BL, float R_source, float R_line,
at::Tensor V_BL, int ADC_resolution,
float overflow_rate, int quant_method,
float R_source, float R_line,
bool det_readout_currents) {
assert(at::cuda::is_available());
conductance_matrix = conductance_matrix.to(torch::Device("cuda:0"));
Expand Down Expand Up @@ -252,15 +285,16 @@ at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL,
V_BL_accessor, m, n, R_source, R_line,
ABCD_matrix, E_matrix);
cudaSafeCall(cudaDeviceSynchronize());
Eigen::SparseMatrix<float> ABCD(2 * m * n, 2 * m * n);
cudaMemcpy(ABCD_matrix_host, ABCD_matrix,
sizeof(sparse_element) * non_zero_elements,
cudaMemcpyDeviceToHost);
Eigen::SparseMatrix<float> ABCD(2 * m * n, 2 * m * n);
ABCD.setFromTriplets(&ABCD_matrix_host[0],
&ABCD_matrix_host[non_zero_elements]);
ABCD.makeCompressed();
cudaMemcpy(E_matrix_host, E_matrix, sizeof(float) * 2 * m * n,
cudaMemcpyDeviceToHost);
solve_sparse_linear(ABCD, E_matrix_host, 2 * m * n);
Eigen::Map<Eigen::VectorXf> V(E_matrix_host, 2 * m * n);
at::Tensor V_applied_tensor =
at::zeros({m, n}, torch::TensorOptions().device(torch::kCUDA, 0));
Expand All @@ -270,7 +304,7 @@ at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL,
cudaMalloc(&V_accessor, sizeof(float) * V.size());
cudaMemcpy(V_accessor, V.data(), sizeof(float) * V.size(),
cudaMemcpyHostToDevice);
construct_V_applied<<<grid, block>>>(V_applied_accessor, V_accessor, m, n);
construct_V_applied_kernel<<<grid, block>>>(V_applied_accessor, V_accessor, m, n);
cudaSafeCall(cudaDeviceSynchronize());
cudaSafeCall(cudaFree(ABCD_matrix));
cudaSafeCall(cudaFree(E_matrix));
Expand All @@ -279,6 +313,22 @@ at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL,
if (!det_readout_currents) {
return V_applied_tensor;
} else {
return at::sum(at::mul(V_applied_tensor, conductance_matrix), 0);
at::Tensor I_tensor = at::sum(at::mul(V_applied_tensor, conductance_matrix), 0);
V_applied_tensor.resize_(at::IntArrayRef{0});
if (ADC_resolution != -1) {
float *I_tensor_accessor = I_tensor.data_ptr<float>();
float *sf;
cudaMalloc(&sf, sizeof(float));
det_sf_kernel<<<dim3(1, 1, 1), dim3(1, 1, 1)>>>(I_tensor_accessor, n, ADC_resolution, overflow_rate, sf);
cudaSafeCall(cudaDeviceSynchronize());
int numSMs;
cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, 0);
int numBlocks = ceil_int_div(n, numSMs);
int numThreads = numSMs;
quantize_kernel<<<numBlocks, numThreads>>>(I_tensor_accessor, n, ADC_resolution, sf, quant_method);
cudaSafeCall(cudaDeviceSynchronize());
cudaSafeCall(cudaFree(sf));
}
return I_tensor;
}
}
6 changes: 4 additions & 2 deletions memtorch/cu/solve_passive_kernels.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL,
at::Tensor V_BL, float R_source, float R_line,
bool det_readout_currents);
at::Tensor V_BL, int ADC_resolution,
float overflow_rate, int quant_method,
float R_source, float R_line,
bool det_readout_currents);
Loading

0 comments on commit aec1e25

Please sign in to comment.