Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated CUDA Bindings for Passive Crossbar Inference Routines #123

Merged
merged 9 commits into from
Feb 1, 2022
2 changes: 1 addition & 1 deletion memtorch/bh/memristor/LinearIonDrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def dxdt(self, current):
"""
return (
self.u_v
* (self.r_on / (self.d ** 2))
* (self.r_on / (self.d**2))
* current
* memtorch.bh.memristor.window.Jogelkar(self.x, self.p)
)
Expand Down
2 changes: 1 addition & 1 deletion memtorch/cpp/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ T det_integral(at::Tensor tensor, T overflow_rate, T min, T max) {
return ceil(
log2(data_ptr[std::min<int>((int)round(overflow_rate * tensor_numel),
tensor_numel - 1)] +
1e-12));
1e-12f));
}
}

Expand Down
28 changes: 22 additions & 6 deletions memtorch/cu/quantize.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
__device__ float det_integral(float *tensor, int tensor_numel,
#ifndef _QUANTIZE_
#define _QUANTIZE_

__device__ inline float det_integral(float *tensor, int tensor_numel,
float overflow_rate, float min, float max) {
if ((min != NULL) || (max != NULL)) {
float max_bound;
Expand All @@ -13,18 +16,17 @@ __device__ float det_integral(float *tensor, int tensor_numel,
tensor[0] = max_bound;
}
}
return ceilf(
log2f(tensor[(int)round(overflow_rate * tensor_numel)] + 1e-12f));
return ceilf(log2f(tensor[min_((int)round(overflow_rate * tensor_numel), tensor_numel - 1)] + 1e-12f));
}

__device__ float det_sf(float *tensor, int tensor_numel, int bits,
__device__ inline float det_sf(float *tensor, int tensor_numel, int bits,
float overflow_rate, float min, float max) {
assert(overflow_rate <= 1.0);
sort_<float>(tensor, tensor_numel);
return 1 - bits + det_integral(tensor, tensor_numel, overflow_rate, min, max);
}

__device__ Eigen::VectorXf linear_quantize(Eigen::VectorXf tensor, float sf,
__device__ inline Eigen::VectorXf linear_quantize(Eigen::VectorXf tensor, float sf,
int bits, float overflow_rate) {
float delta = powf(2.0f, sf);
float bound = powf(2.0f, bits - 1);
Expand All @@ -36,9 +38,21 @@ __device__ Eigen::VectorXf linear_quantize(Eigen::VectorXf tensor, float sf,
return x_;
}
});
} // To remove overflow rate TODO!

__device__ inline void
linear_quantize(float *tensor, int i, float *sf, int bits) {
float delta = powf(2.0f, sf[0]);
float bound = powf(2.0f, bits - 1);
float x_ = clamp_<float>(floorf((tensor[i] / delta) + 0.5f), -bound, bound - 1) * delta;
if (isnan(x_)) {
tensor[i] = 0.0f;
} else {
tensor[i] = x_;
}
}

__device__ Eigen::VectorXf quantize(Eigen::VectorXf tensor, int bits,
__device__ inline Eigen::VectorXf quantize(Eigen::VectorXf tensor, int bits,
float overflow_rate, int quant_method) {
if (quant_method == 0) {
// linear
Expand Down Expand Up @@ -75,3 +89,5 @@ __device__ Eigen::VectorXf quantize(Eigen::VectorXf tensor, int bits,
return tensor;
}
}

#endif
23 changes: 20 additions & 3 deletions memtorch/cu/solve_passive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,28 @@ void solve_passive_bindings(py::module_ &m) {
m.def(
"solve_passive",
[&](at::Tensor conductance_matrix, at::Tensor V_WL, at::Tensor V_BL,
int ADC_resolution, float overflow_rate, int quant_method,
float R_source, float R_line, bool det_readout_currents) {
return solve_passive(conductance_matrix, V_WL, V_BL, R_source, R_line,
return solve_passive(conductance_matrix, V_WL, V_BL, ADC_resolution,
overflow_rate, quant_method, R_source, R_line,
det_readout_currents);
},
py::arg("conductance_matrix"), py::arg("V_WL"), py::arg("V_BL"),
py::arg("R_source"), py::arg("R_line"),
py::arg("det_readout_currents") = true);
py::arg("ADC_resolution") = -1, py::arg("overflow_rate") = -1,
py::arg("quant_method") = -1, py::arg("R_source"),
py::arg("R_line"), py::arg("det_readout_currents") = true);

m.def(
"solve_passive",
[&](at::Tensor conductance_matrix, at::Tensor V_WL, at::Tensor V_BL,
int ADC_resolution, float overflow_rate, int quant_method,
float R_source, float R_line, bool det_readout_currents) {
return solve_passive(conductance_matrix, V_WL, V_BL, ADC_resolution,
overflow_rate, quant_method, R_source, R_line,
det_readout_currents);
},
py::arg("conductance_matrix"), py::arg("V_WL"), py::arg("V_BL"),
py::arg("ADC_resolution"), py::arg("overflow_rate"),
py::arg("quant_method"), py::arg("R_source"),
py::arg("R_line"), py::arg("det_readout_currents") = true);
}
156 changes: 103 additions & 53 deletions memtorch/cu/solve_passive_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
#include "solve_passive.h"
#include "solve_sparse_linear.h"
#include "utils.cuh"
#include "quantize.cuh"

class Triplet {
public:
__host__ __device__ Triplet() : m_row(0), m_col(0), m_value(0) {}

__host__ __device__ Triplet(int i, int j, float v)
: m_row(i), m_col(j), m_value(v) {}

__host__ __device__ const int &row() { return m_row; }
__host__ __device__ const int &col() { return m_col; }
__host__ __device__ const float &value() { return m_value; }
Expand Down Expand Up @@ -64,17 +63,20 @@ __global__ void gen_ABE_kernel(
conductance_matrix_accessor[i][0] + 1.0f / R_source +
1.0f / R_line);
}
} else if (j == (n - 1)) {
if (R_line != 0) {
sparse_element(i * n, i * n,
conductance_matrix_accessor[i][0] + 1.0f / R_line);
}
} else {
ABCD_matrix[index] = sparse_element(0, 0, 0.0f);
}
index++;
if (R_line == 0) {
ABCD_matrix[index] = sparse_element(i * n + j, i * n + j,
conductance_matrix_accessor[i][j]);
} else {
ABCD_matrix[index] =
sparse_element(i * n + j, i * n + j,
conductance_matrix_accessor[i][j] + 2.0f / R_line);
if (R_line == 0) {
ABCD_matrix[index] = sparse_element(i * n + j, i * n + j,
conductance_matrix_accessor[i][j]);
} else {
ABCD_matrix[index] =
sparse_element(i * n + j, i * n + j,
conductance_matrix_accessor[i][j] + 2.0f / R_line);
}
}
index++;
if (j < n - 1) {
Expand Down Expand Up @@ -112,12 +114,12 @@ __global__ void gen_CDE_kernel(
torch::PackedTensorAccessor32<float, 2> conductance_matrix_accessor,
float *V_WL_accessor, float *V_BL_accessor, int m, int n, float R_source,
float R_line, sparse_element *ABCD_matrix, float *E_matrix) {
int j = threadIdx.x + blockIdx.x * blockDim.x; // for (int j = 0; j < n; j++)
int i = threadIdx.y + blockIdx.y * blockDim.y; // for (int i = 0; i < m; i++)
if (j < n && i < m) {
int index = (5 * m * n) + ((j * m + i) * 4);
int j = threadIdx.x + blockIdx.x * blockDim.x; // for (int j = 0; j < m; j++)
int i = threadIdx.y + blockIdx.y * blockDim.y; // for (int i = 0; i < n; i++)
if (j < m && i < n) {
int index = (5 * m * n) + 4 * (j * n + i);
// D matrix
if (i == 0) {
if (j == 0) {
// E matrix (partial)
if (E_matrix != NULL) {
if (R_source == 0) {
Expand All @@ -127,77 +129,106 @@ __global__ void gen_CDE_kernel(
}
}
if (R_line == 0) {
ABCD_matrix[index] = sparse_element(m * n + (j * m), m * n + j,
-conductance_matrix_accessor[0][j]);
ABCD_matrix[index] = sparse_element(m * n + (i * m), m * n + i,
-conductance_matrix_accessor[i][0]);
index++;
ABCD_matrix[index] = sparse_element(m * n + (j * m), m * n + j + n, 0);
ABCD_matrix[index] = sparse_element(m * n + (i * m), m * n + i + n, 0);
} else {
ABCD_matrix[index] =
sparse_element(m * n + (j * m), m * n + j,
-1.0f / R_line - conductance_matrix_accessor[0][j]);
sparse_element(m * n + (i * m), m * n + i,
-1.0f / R_line - conductance_matrix_accessor[i][0]);
index++;
ABCD_matrix[index] =
sparse_element(m * n + (j * m), m * n + j + n, 1.0f / R_line);
sparse_element(m * n + (i * m), m * n + i + n, 1.0f / R_line);
}
index++;
ABCD_matrix[index] = sparse_element(0, 0, 0.0f);
} else if (i < m - 1) {
} else if (j < (m - 1)) {
if (R_line == 0) {
ABCD_matrix[index] =
sparse_element(m * n + (j * m) + i, m * n + (n * (i - 1)) + j, 0);
ABCD_matrix[index] = sparse_element(
m * n + (i * m) + j, m * n + (n * (j - 1)) + i, 0);
index++;
ABCD_matrix[index] =
sparse_element(m * n + (j * m) + i, m * n + (n * (i + 1)) + j, 0);
ABCD_matrix[index] = sparse_element(
m * n + (i * m) + j, m * n + (n * j) + i, -conductance_matrix_accessor[i][j]);
index++;
ABCD_matrix[index] =
sparse_element(m * n + (j * m) + i, m * n + (n * i) + j,
-conductance_matrix_accessor[i][j]);
ABCD_matrix[index] = sparse_element(
m * n + (i * m) + j, m * n + (n * (j + 1)) + i, 0);
} else {
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + i, m * n + (n * (i - 1)) + j, 1.0f / R_line);
m * n + (i * m) + j, m * n + (n * (j - 1)) + i, 1.0f / R_line);
index++;
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + i, m * n + (n * (i + 1)) + j, 1.0f / R_line);
m * n + (i * m) + j, m * n + (n * j) + i, -conductance_matrix_accessor[i][j] - 2.0f / R_line);
index++;
ABCD_matrix[index] =
sparse_element(m * n + (j * m) + i, m * n + (n * i) + j,
-conductance_matrix_accessor[i][j] - 2.0f / R_line);
ABCD_matrix[index] = sparse_element(
m * n + (i * m) + j, m * n + (n * (j + 1)) + i, 1.0f / R_line);
}
} else {
if (R_line == 0) {
ABCD_matrix[index] = sparse_element(m * n + (j * m) + m - 1,
m * n + (n * (m - 2)) + j, 0);
} else {
if (R_line != 0) {
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + m - 1, m * n + (n * (m - 2)) + j, 1 / R_line);
m * n + (i * m) + m - 1, m * n + (n * (j - 1)) + i, 1 / R_line);
} else {
ABCD_matrix[index] = sparse_element(m * n + (i * m) + m - 1, m * n + (n * (j - 1)) + i, 0.0f);
}
index++;
if (R_source == 0) {
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j,
-conductance_matrix_accessor[m - 1][j] - 1.0f / R_line);
m * n + (i * m) + m - 1, m * n + (n * j) + i,
-conductance_matrix_accessor[i][m - 1] - 1.0f / R_line);
} else if (R_line == 0) {
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j,
-1.0f / R_source - conductance_matrix_accessor[m - 1][j]);
m * n + (i * m) + m - 1, m * n + (n * j) + i,
-1.0f / R_source - conductance_matrix_accessor[i][m - 1]);
} else {
ABCD_matrix[index] = sparse_element(
m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j,
-1.0f / R_source - conductance_matrix_accessor[m - 1][j] -
m * n + (i * m) + m - 1, m * n + (n * j) + i,
-1.0f / R_source - conductance_matrix_accessor[i][m - 1] -
1.0f / R_line);
}
index++;
ABCD_matrix[index] = sparse_element(0, 0, 0.0f);
}
index++;
// C matrix
ABCD_matrix[index] = sparse_element(j * m + i + (m * n), n * i + j,
ABCD_matrix[index] = sparse_element((m * n) + (i * m) + j, n * j + i,
conductance_matrix_accessor[i][j]);
}
}

__global__ void
construct_V_applied(torch::PackedTensorAccessor32<float, 2> V_applied_accessor,
det_sf_kernel(float *tensor, int numel, int bits, float overflow_rate, float* sf) {
float *tensor_copy;
tensor_copy = (float *)malloc(numel * sizeof(float));
#pragma unroll 4
for (int i = 0; i < numel; i++) {
tensor_copy[i] = tensor[i];
}
sf[0] = det_sf(tensor_copy, numel, bits, overflow_rate, NULL, NULL);
free(tensor_copy);
}

__global__ void
quantize_kernel(float *tensor, int numel, int bits, float* sf, int quant_method) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
if (quant_method == 0) {
// linear
linear_quantize(tensor, i, sf, bits);
} else if (quant_method == 1) {
// log
bool s = tensor[i] >= 0.0f;
tensor[i] = max_<float>(logf(abs_<float>(tensor[i])), 1e-20f);
linear_quantize(tensor, i, sf, bits);
if (s) {
tensor[i] = expf(tensor[i]);
} else {
tensor[i] = -expf(tensor[i]);
}
}
}
}

__global__ void
construct_V_applied_kernel(torch::PackedTensorAccessor32<float, 2> V_applied_accessor,
float *V_accessor, int m, int n) {
int i = threadIdx.x + blockIdx.x * blockDim.x; // for (int i = 0; i < m; i++)
int j = threadIdx.y + blockIdx.y * blockDim.y; // for (int j = 0; j < n; j++)
Expand All @@ -208,7 +239,9 @@ construct_V_applied(torch::PackedTensorAccessor32<float, 2> V_applied_accessor,
}

at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL,
at::Tensor V_BL, float R_source, float R_line,
at::Tensor V_BL, int ADC_resolution,
float overflow_rate, int quant_method,
float R_source, float R_line,
bool det_readout_currents) {
assert(at::cuda::is_available());
conductance_matrix = conductance_matrix.to(torch::Device("cuda:0"));
Expand Down Expand Up @@ -252,15 +285,16 @@ at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL,
V_BL_accessor, m, n, R_source, R_line,
ABCD_matrix, E_matrix);
cudaSafeCall(cudaDeviceSynchronize());
Eigen::SparseMatrix<float> ABCD(2 * m * n, 2 * m * n);
cudaMemcpy(ABCD_matrix_host, ABCD_matrix,
sizeof(sparse_element) * non_zero_elements,
cudaMemcpyDeviceToHost);
Eigen::SparseMatrix<float> ABCD(2 * m * n, 2 * m * n);
ABCD.setFromTriplets(&ABCD_matrix_host[0],
&ABCD_matrix_host[non_zero_elements]);
ABCD.makeCompressed();
cudaMemcpy(E_matrix_host, E_matrix, sizeof(float) * 2 * m * n,
cudaMemcpyDeviceToHost);
solve_sparse_linear(ABCD, E_matrix_host, 2 * m * n);
Eigen::Map<Eigen::VectorXf> V(E_matrix_host, 2 * m * n);
at::Tensor V_applied_tensor =
at::zeros({m, n}, torch::TensorOptions().device(torch::kCUDA, 0));
Expand All @@ -270,7 +304,7 @@ at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL,
cudaMalloc(&V_accessor, sizeof(float) * V.size());
cudaMemcpy(V_accessor, V.data(), sizeof(float) * V.size(),
cudaMemcpyHostToDevice);
construct_V_applied<<<grid, block>>>(V_applied_accessor, V_accessor, m, n);
construct_V_applied_kernel<<<grid, block>>>(V_applied_accessor, V_accessor, m, n);
cudaSafeCall(cudaDeviceSynchronize());
cudaSafeCall(cudaFree(ABCD_matrix));
cudaSafeCall(cudaFree(E_matrix));
Expand All @@ -279,6 +313,22 @@ at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL,
if (!det_readout_currents) {
return V_applied_tensor;
} else {
return at::sum(at::mul(V_applied_tensor, conductance_matrix), 0);
at::Tensor I_tensor = at::sum(at::mul(V_applied_tensor, conductance_matrix), 0);
V_applied_tensor.resize_(at::IntArrayRef{0});
if (ADC_resolution != -1) {
float *I_tensor_accessor = I_tensor.data_ptr<float>();
float *sf;
cudaMalloc(&sf, sizeof(float));
det_sf_kernel<<<dim3(1, 1, 1), dim3(1, 1, 1)>>>(I_tensor_accessor, n, ADC_resolution, overflow_rate, sf);
cudaSafeCall(cudaDeviceSynchronize());
int numSMs;
cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, 0);
int numBlocks = ceil_int_div(n, numSMs);
int numThreads = numSMs;
quantize_kernel<<<numBlocks, numThreads>>>(I_tensor_accessor, n, ADC_resolution, sf, quant_method);
cudaSafeCall(cudaDeviceSynchronize());
cudaSafeCall(cudaFree(sf));
}
return I_tensor;
}
}
6 changes: 4 additions & 2 deletions memtorch/cu/solve_passive_kernels.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL,
at::Tensor V_BL, float R_source, float R_line,
bool det_readout_currents);
at::Tensor V_BL, int ADC_resolution,
float overflow_rate, int quant_method,
float R_source, float R_line,
bool det_readout_currents);
Loading