From aec1e257457db0c7b01645736703b86d3166e831 Mon Sep 17 00:00:00 2001 From: Corey Lammie Date: Tue, 1 Feb 2022 14:12:48 +1000 Subject: [PATCH] Updated CUDA Bindings for Passive Crossbar Inference Routines (#123) Updated CUDA bindings for passive crossbar inference routines to avoid OOM errors (#120). --- memtorch/bh/memristor/LinearIonDrift.py | 2 +- memtorch/cpp/quantize.h | 2 +- memtorch/cu/quantize.cuh | 28 ++- memtorch/cu/solve_passive.cpp | 23 +- memtorch/cu/solve_passive_kernels.cu | 156 +++++++++----- memtorch/cu/solve_passive_kernels.cuh | 6 +- memtorch/cu/solve_sparse_linear.cpp | 30 +-- memtorch/cu/tile_matmul_kernels.cu | 269 ++++-------------------- memtorch/cu/utils.cuh | 9 +- memtorch/mn/Conv1d.py | 15 +- memtorch/mn/Conv2d.py | 21 +- memtorch/mn/Conv3d.py | 27 +-- setup.py | 4 +- 13 files changed, 259 insertions(+), 333 deletions(-) diff --git a/memtorch/bh/memristor/LinearIonDrift.py b/memtorch/bh/memristor/LinearIonDrift.py index 067db658..f528a4a7 100644 --- a/memtorch/bh/memristor/LinearIonDrift.py +++ b/memtorch/bh/memristor/LinearIonDrift.py @@ -131,7 +131,7 @@ def dxdt(self, current): """ return ( self.u_v - * (self.r_on / (self.d ** 2)) + * (self.r_on / (self.d**2)) * current * memtorch.bh.memristor.window.Jogelkar(self.x, self.p) ) diff --git a/memtorch/cpp/quantize.h b/memtorch/cpp/quantize.h index 49e46529..772d0bbd 100644 --- a/memtorch/cpp/quantize.h +++ b/memtorch/cpp/quantize.h @@ -49,7 +49,7 @@ T det_integral(at::Tensor tensor, T overflow_rate, T min, T max) { return ceil( log2(data_ptr[std::min((int)round(overflow_rate * tensor_numel), tensor_numel - 1)] + - 1e-12)); + 1e-12f)); } } diff --git a/memtorch/cu/quantize.cuh b/memtorch/cu/quantize.cuh index e5217fb9..1aad034c 100644 --- a/memtorch/cu/quantize.cuh +++ b/memtorch/cu/quantize.cuh @@ -1,4 +1,7 @@ -__device__ float det_integral(float *tensor, int tensor_numel, +#ifndef _QUANTIZE_ +#define _QUANTIZE_ + +__device__ inline float det_integral(float *tensor, int tensor_numel, float overflow_rate, float min, float max) { if ((min != NULL) || (max != NULL)) { float max_bound; @@ -13,18 +16,17 @@ __device__ float det_integral(float *tensor, int tensor_numel, tensor[0] = max_bound; } } - return ceilf( - log2f(tensor[(int)round(overflow_rate * tensor_numel)] + 1e-12f)); + return ceilf(log2f(tensor[min_((int)round(overflow_rate * tensor_numel), tensor_numel - 1)] + 1e-12f)); } -__device__ float det_sf(float *tensor, int tensor_numel, int bits, +__device__ inline float det_sf(float *tensor, int tensor_numel, int bits, float overflow_rate, float min, float max) { assert(overflow_rate <= 1.0); sort_(tensor, tensor_numel); return 1 - bits + det_integral(tensor, tensor_numel, overflow_rate, min, max); } -__device__ Eigen::VectorXf linear_quantize(Eigen::VectorXf tensor, float sf, +__device__ inline Eigen::VectorXf linear_quantize(Eigen::VectorXf tensor, float sf, int bits, float overflow_rate) { float delta = powf(2.0f, sf); float bound = powf(2.0f, bits - 1); @@ -36,9 +38,21 @@ __device__ Eigen::VectorXf linear_quantize(Eigen::VectorXf tensor, float sf, return x_; } }); +} // To remove overflow rate TODO! + +__device__ inline void +linear_quantize(float *tensor, int i, float *sf, int bits) { + float delta = powf(2.0f, sf[0]); + float bound = powf(2.0f, bits - 1); + float x_ = clamp_(floorf((tensor[i] / delta) + 0.5f), -bound, bound - 1) * delta; + if (isnan(x_)) { + tensor[i] = 0.0f; + } else { + tensor[i] = x_; + } } -__device__ Eigen::VectorXf quantize(Eigen::VectorXf tensor, int bits, +__device__ inline Eigen::VectorXf quantize(Eigen::VectorXf tensor, int bits, float overflow_rate, int quant_method) { if (quant_method == 0) { // linear @@ -75,3 +89,5 @@ __device__ Eigen::VectorXf quantize(Eigen::VectorXf tensor, int bits, return tensor; } } + +#endif \ No newline at end of file diff --git a/memtorch/cu/solve_passive.cpp b/memtorch/cu/solve_passive.cpp index e7deb505..248b74ba 100644 --- a/memtorch/cu/solve_passive.cpp +++ b/memtorch/cu/solve_passive.cpp @@ -14,11 +14,28 @@ void solve_passive_bindings(py::module_ &m) { m.def( "solve_passive", [&](at::Tensor conductance_matrix, at::Tensor V_WL, at::Tensor V_BL, + int ADC_resolution, float overflow_rate, int quant_method, float R_source, float R_line, bool det_readout_currents) { - return solve_passive(conductance_matrix, V_WL, V_BL, R_source, R_line, + return solve_passive(conductance_matrix, V_WL, V_BL, ADC_resolution, + overflow_rate, quant_method, R_source, R_line, det_readout_currents); }, py::arg("conductance_matrix"), py::arg("V_WL"), py::arg("V_BL"), - py::arg("R_source"), py::arg("R_line"), - py::arg("det_readout_currents") = true); + py::arg("ADC_resolution") = -1, py::arg("overflow_rate") = -1, + py::arg("quant_method") = -1, py::arg("R_source"), + py::arg("R_line"), py::arg("det_readout_currents") = true); + + m.def( + "solve_passive", + [&](at::Tensor conductance_matrix, at::Tensor V_WL, at::Tensor V_BL, + int ADC_resolution, float overflow_rate, int quant_method, + float R_source, float R_line, bool det_readout_currents) { + return solve_passive(conductance_matrix, V_WL, V_BL, ADC_resolution, + overflow_rate, quant_method, R_source, R_line, + det_readout_currents); + }, + py::arg("conductance_matrix"), py::arg("V_WL"), py::arg("V_BL"), + py::arg("ADC_resolution"), py::arg("overflow_rate"), + py::arg("quant_method"), py::arg("R_source"), + py::arg("R_line"), py::arg("det_readout_currents") = true); } \ No newline at end of file diff --git a/memtorch/cu/solve_passive_kernels.cu b/memtorch/cu/solve_passive_kernels.cu index 56177564..69109391 100644 --- a/memtorch/cu/solve_passive_kernels.cu +++ b/memtorch/cu/solve_passive_kernels.cu @@ -15,14 +15,13 @@ #include "solve_passive.h" #include "solve_sparse_linear.h" #include "utils.cuh" +#include "quantize.cuh" class Triplet { public: __host__ __device__ Triplet() : m_row(0), m_col(0), m_value(0) {} - __host__ __device__ Triplet(int i, int j, float v) : m_row(i), m_col(j), m_value(v) {} - __host__ __device__ const int &row() { return m_row; } __host__ __device__ const int &col() { return m_col; } __host__ __device__ const float &value() { return m_value; } @@ -64,17 +63,20 @@ __global__ void gen_ABE_kernel( conductance_matrix_accessor[i][0] + 1.0f / R_source + 1.0f / R_line); } + } else if (j == (n - 1)) { + if (R_line != 0) { + sparse_element(i * n, i * n, + conductance_matrix_accessor[i][0] + 1.0f / R_line); + } } else { - ABCD_matrix[index] = sparse_element(0, 0, 0.0f); - } - index++; - if (R_line == 0) { - ABCD_matrix[index] = sparse_element(i * n + j, i * n + j, - conductance_matrix_accessor[i][j]); - } else { - ABCD_matrix[index] = - sparse_element(i * n + j, i * n + j, - conductance_matrix_accessor[i][j] + 2.0f / R_line); + if (R_line == 0) { + ABCD_matrix[index] = sparse_element(i * n + j, i * n + j, + conductance_matrix_accessor[i][j]); + } else { + ABCD_matrix[index] = + sparse_element(i * n + j, i * n + j, + conductance_matrix_accessor[i][j] + 2.0f / R_line); + } } index++; if (j < n - 1) { @@ -112,12 +114,12 @@ __global__ void gen_CDE_kernel( torch::PackedTensorAccessor32 conductance_matrix_accessor, float *V_WL_accessor, float *V_BL_accessor, int m, int n, float R_source, float R_line, sparse_element *ABCD_matrix, float *E_matrix) { - int j = threadIdx.x + blockIdx.x * blockDim.x; // for (int j = 0; j < n; j++) - int i = threadIdx.y + blockIdx.y * blockDim.y; // for (int i = 0; i < m; i++) - if (j < n && i < m) { - int index = (5 * m * n) + ((j * m + i) * 4); + int j = threadIdx.x + blockIdx.x * blockDim.x; // for (int j = 0; j < m; j++) + int i = threadIdx.y + blockIdx.y * blockDim.y; // for (int i = 0; i < n; i++) + if (j < m && i < n) { + int index = (5 * m * n) + 4 * (j * n + i); // D matrix - if (i == 0) { + if (j == 0) { // E matrix (partial) if (E_matrix != NULL) { if (R_source == 0) { @@ -127,63 +129,60 @@ __global__ void gen_CDE_kernel( } } if (R_line == 0) { - ABCD_matrix[index] = sparse_element(m * n + (j * m), m * n + j, - -conductance_matrix_accessor[0][j]); + ABCD_matrix[index] = sparse_element(m * n + (i * m), m * n + i, + -conductance_matrix_accessor[i][0]); index++; - ABCD_matrix[index] = sparse_element(m * n + (j * m), m * n + j + n, 0); + ABCD_matrix[index] = sparse_element(m * n + (i * m), m * n + i + n, 0); } else { ABCD_matrix[index] = - sparse_element(m * n + (j * m), m * n + j, - -1.0f / R_line - conductance_matrix_accessor[0][j]); + sparse_element(m * n + (i * m), m * n + i, + -1.0f / R_line - conductance_matrix_accessor[i][0]); index++; ABCD_matrix[index] = - sparse_element(m * n + (j * m), m * n + j + n, 1.0f / R_line); + sparse_element(m * n + (i * m), m * n + i + n, 1.0f / R_line); } index++; ABCD_matrix[index] = sparse_element(0, 0, 0.0f); - } else if (i < m - 1) { + } else if (j < (m - 1)) { if (R_line == 0) { - ABCD_matrix[index] = - sparse_element(m * n + (j * m) + i, m * n + (n * (i - 1)) + j, 0); + ABCD_matrix[index] = sparse_element( + m * n + (i * m) + j, m * n + (n * (j - 1)) + i, 0); index++; - ABCD_matrix[index] = - sparse_element(m * n + (j * m) + i, m * n + (n * (i + 1)) + j, 0); + ABCD_matrix[index] = sparse_element( + m * n + (i * m) + j, m * n + (n * j) + i, -conductance_matrix_accessor[i][j]); index++; - ABCD_matrix[index] = - sparse_element(m * n + (j * m) + i, m * n + (n * i) + j, - -conductance_matrix_accessor[i][j]); + ABCD_matrix[index] = sparse_element( + m * n + (i * m) + j, m * n + (n * (j + 1)) + i, 0); } else { ABCD_matrix[index] = sparse_element( - m * n + (j * m) + i, m * n + (n * (i - 1)) + j, 1.0f / R_line); + m * n + (i * m) + j, m * n + (n * (j - 1)) + i, 1.0f / R_line); index++; ABCD_matrix[index] = sparse_element( - m * n + (j * m) + i, m * n + (n * (i + 1)) + j, 1.0f / R_line); + m * n + (i * m) + j, m * n + (n * j) + i, -conductance_matrix_accessor[i][j] - 2.0f / R_line); index++; - ABCD_matrix[index] = - sparse_element(m * n + (j * m) + i, m * n + (n * i) + j, - -conductance_matrix_accessor[i][j] - 2.0f / R_line); + ABCD_matrix[index] = sparse_element( + m * n + (i * m) + j, m * n + (n * (j + 1)) + i, 1.0f / R_line); } } else { - if (R_line == 0) { - ABCD_matrix[index] = sparse_element(m * n + (j * m) + m - 1, - m * n + (n * (m - 2)) + j, 0); - } else { + if (R_line != 0) { ABCD_matrix[index] = sparse_element( - m * n + (j * m) + m - 1, m * n + (n * (m - 2)) + j, 1 / R_line); + m * n + (i * m) + m - 1, m * n + (n * (j - 1)) + i, 1 / R_line); + } else { + ABCD_matrix[index] = sparse_element(m * n + (i * m) + m - 1, m * n + (n * (j - 1)) + i, 0.0f); } index++; if (R_source == 0) { ABCD_matrix[index] = sparse_element( - m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j, - -conductance_matrix_accessor[m - 1][j] - 1.0f / R_line); + m * n + (i * m) + m - 1, m * n + (n * j) + i, + -conductance_matrix_accessor[i][m - 1] - 1.0f / R_line); } else if (R_line == 0) { ABCD_matrix[index] = sparse_element( - m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j, - -1.0f / R_source - conductance_matrix_accessor[m - 1][j]); + m * n + (i * m) + m - 1, m * n + (n * j) + i, + -1.0f / R_source - conductance_matrix_accessor[i][m - 1]); } else { ABCD_matrix[index] = sparse_element( - m * n + (j * m) + m - 1, m * n + (n * (m - 1)) + j, - -1.0f / R_source - conductance_matrix_accessor[m - 1][j] - + m * n + (i * m) + m - 1, m * n + (n * j) + i, + -1.0f / R_source - conductance_matrix_accessor[i][m - 1] - 1.0f / R_line); } index++; @@ -191,13 +190,45 @@ __global__ void gen_CDE_kernel( } index++; // C matrix - ABCD_matrix[index] = sparse_element(j * m + i + (m * n), n * i + j, + ABCD_matrix[index] = sparse_element((m * n) + (i * m) + j, n * j + i, conductance_matrix_accessor[i][j]); } } __global__ void -construct_V_applied(torch::PackedTensorAccessor32 V_applied_accessor, +det_sf_kernel(float *tensor, int numel, int bits, float overflow_rate, float* sf) { + float *tensor_copy; + tensor_copy = (float *)malloc(numel * sizeof(float)); + #pragma unroll 4 + for (int i = 0; i < numel; i++) { + tensor_copy[i] = tensor[i]; + } + sf[0] = det_sf(tensor_copy, numel, bits, overflow_rate, NULL, NULL); + free(tensor_copy); +} + +__global__ void +quantize_kernel(float *tensor, int numel, int bits, float* sf, int quant_method) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + if (quant_method == 0) { + // linear + linear_quantize(tensor, i, sf, bits); + } else if (quant_method == 1) { + // log + bool s = tensor[i] >= 0.0f; + tensor[i] = max_(logf(abs_(tensor[i])), 1e-20f); + linear_quantize(tensor, i, sf, bits); + if (s) { + tensor[i] = expf(tensor[i]); + } else { + tensor[i] = -expf(tensor[i]); + } + } + } +} + +__global__ void +construct_V_applied_kernel(torch::PackedTensorAccessor32 V_applied_accessor, float *V_accessor, int m, int n) { int i = threadIdx.x + blockIdx.x * blockDim.x; // for (int i = 0; i < m; i++) int j = threadIdx.y + blockIdx.y * blockDim.y; // for (int j = 0; j < n; j++) @@ -208,7 +239,9 @@ construct_V_applied(torch::PackedTensorAccessor32 V_applied_accessor, } at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, - at::Tensor V_BL, float R_source, float R_line, + at::Tensor V_BL, int ADC_resolution, + float overflow_rate, int quant_method, + float R_source, float R_line, bool det_readout_currents) { assert(at::cuda::is_available()); conductance_matrix = conductance_matrix.to(torch::Device("cuda:0")); @@ -252,15 +285,16 @@ at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, V_BL_accessor, m, n, R_source, R_line, ABCD_matrix, E_matrix); cudaSafeCall(cudaDeviceSynchronize()); - Eigen::SparseMatrix ABCD(2 * m * n, 2 * m * n); cudaMemcpy(ABCD_matrix_host, ABCD_matrix, sizeof(sparse_element) * non_zero_elements, cudaMemcpyDeviceToHost); + Eigen::SparseMatrix ABCD(2 * m * n, 2 * m * n); ABCD.setFromTriplets(&ABCD_matrix_host[0], &ABCD_matrix_host[non_zero_elements]); ABCD.makeCompressed(); cudaMemcpy(E_matrix_host, E_matrix, sizeof(float) * 2 * m * n, cudaMemcpyDeviceToHost); + solve_sparse_linear(ABCD, E_matrix_host, 2 * m * n); Eigen::Map V(E_matrix_host, 2 * m * n); at::Tensor V_applied_tensor = at::zeros({m, n}, torch::TensorOptions().device(torch::kCUDA, 0)); @@ -270,7 +304,7 @@ at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, cudaMalloc(&V_accessor, sizeof(float) * V.size()); cudaMemcpy(V_accessor, V.data(), sizeof(float) * V.size(), cudaMemcpyHostToDevice); - construct_V_applied<<>>(V_applied_accessor, V_accessor, m, n); + construct_V_applied_kernel<<>>(V_applied_accessor, V_accessor, m, n); cudaSafeCall(cudaDeviceSynchronize()); cudaSafeCall(cudaFree(ABCD_matrix)); cudaSafeCall(cudaFree(E_matrix)); @@ -279,6 +313,22 @@ at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, if (!det_readout_currents) { return V_applied_tensor; } else { - return at::sum(at::mul(V_applied_tensor, conductance_matrix), 0); + at::Tensor I_tensor = at::sum(at::mul(V_applied_tensor, conductance_matrix), 0); + V_applied_tensor.resize_(at::IntArrayRef{0}); + if (ADC_resolution != -1) { + float *I_tensor_accessor = I_tensor.data_ptr(); + float *sf; + cudaMalloc(&sf, sizeof(float)); + det_sf_kernel<<>>(I_tensor_accessor, n, ADC_resolution, overflow_rate, sf); + cudaSafeCall(cudaDeviceSynchronize()); + int numSMs; + cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, 0); + int numBlocks = ceil_int_div(n, numSMs); + int numThreads = numSMs; + quantize_kernel<<>>(I_tensor_accessor, n, ADC_resolution, sf, quant_method); + cudaSafeCall(cudaDeviceSynchronize()); + cudaSafeCall(cudaFree(sf)); + } + return I_tensor; } } \ No newline at end of file diff --git a/memtorch/cu/solve_passive_kernels.cuh b/memtorch/cu/solve_passive_kernels.cuh index eb829a25..ca8e4915 100644 --- a/memtorch/cu/solve_passive_kernels.cuh +++ b/memtorch/cu/solve_passive_kernels.cuh @@ -1,3 +1,5 @@ at::Tensor solve_passive(at::Tensor conductance_matrix, at::Tensor V_WL, - at::Tensor V_BL, float R_source, float R_line, - bool det_readout_currents); \ No newline at end of file + at::Tensor V_BL, int ADC_resolution, + float overflow_rate, int quant_method, + float R_source, float R_line, + bool det_readout_currents); \ No newline at end of file diff --git a/memtorch/cu/solve_sparse_linear.cpp b/memtorch/cu/solve_sparse_linear.cpp index cd421963..7364ed4e 100644 --- a/memtorch/cu/solve_sparse_linear.cpp +++ b/memtorch/cu/solve_sparse_linear.cpp @@ -1,24 +1,30 @@ #include #include -#include +#include #include + void solve_sparse_linear(Eigen::SparseMatrix A, double *B_values, int n) { - Eigen::SparseQR, Eigen::COLAMDOrdering> QR( - A); - QR.analyzePattern(A); - QR.factorize(A); + Eigen::SparseLU, Eigen::COLAMDOrdering> LU(A); + LU.analyzePattern(A); + LU.factorize(A); Eigen::Map B(B_values, n); - Eigen::VectorXd X = QR.solve(B); - memcpy(B_values, X.data(), sizeof(double) * n); + Eigen::VectorXd X = LU.solve(B); + #pragma omp parallel for + for (int i = 0; i < n; i++) { + B_values[i] = X[i]; + } } void solve_sparse_linear(Eigen::SparseMatrix A, float *B_values, int n) { - Eigen::SparseQR, Eigen::COLAMDOrdering> QR(A); - QR.analyzePattern(A); - QR.factorize(A); + Eigen::SparseLU, Eigen::COLAMDOrdering> LU(A); + LU.analyzePattern(A); + LU.factorize(A); Eigen::Map B(B_values, n); - Eigen::VectorXf X = QR.solve(B); - memcpy(B_values, X.data(), sizeof(float) * n); + Eigen::VectorXf X = LU.solve(B); + #pragma omp parallel for + for (int i = 0; i < n; i++) { + B_values[i] = X[i]; + } } \ No newline at end of file diff --git a/memtorch/cu/tile_matmul_kernels.cu b/memtorch/cu/tile_matmul_kernels.cu index a85e23d9..fa799537 100644 --- a/memtorch/cu/tile_matmul_kernels.cu +++ b/memtorch/cu/tile_matmul_kernels.cu @@ -11,9 +11,11 @@ #include #include +#include "utils.cuh" #include "quantize.cuh" -#include "solve_passive.cuh" -#include "solve_sparse_linear.h" +#include "solve_passive_kernels.cuh" + +using namespace torch::indexing; __global__ void tile_matmul_kernel( float *mat_a_tiles_accessor, @@ -47,85 +49,6 @@ __global__ void tile_matmul_kernel( } } -__global__ void tile_matmul_kernel_A( - float *mat_a_tiles_accessor, - torch::PackedTensorAccessor32 mat_a_tiles_map_accessor, - int64_t *mat_a_tiles_shape, float *mat_b_tiles_accessor, - torch::PackedTensorAccessor32 mat_b_tiles_map_accessor, - int64_t *mat_b_tiles_shape, int mat_b_shape_back, - int *ABCD_matrix_indices_x, int *ABCD_matrix_indices_y, - double *ABCD_matrix_values, int *ABCD_matrix_compressed_rows, - int *ABCD_matrix_compressed_columns, double *ABCD_matrix_compressed_values, - double *E_matrix, float source_resistance, float line_resistance, - int limit_i, int limit_j, int limit_k) { - int i = threadIdx.x + blockIdx.x * blockDim.x; - int j = threadIdx.y + blockIdx.y * blockDim.y; - int k = threadIdx.z + blockIdx.z * blockDim.z; - if (i < limit_i && j < limit_j && k < limit_k) { - Eigen::Map tile_a( - &mat_a_tiles_accessor[transform_3d_index(mat_a_tiles_map_accessor[k], i, - 0, mat_a_tiles_shape[1], - mat_a_tiles_shape[2])], - mat_a_tiles_shape[1]); - Eigen::Map> - tile_b(&mat_b_tiles_accessor[transform_3d_index( - mat_b_tiles_map_accessor[k][j], 0, 0, mat_b_tiles_shape[1], - mat_b_tiles_shape[2])], - mat_b_tiles_shape[1], mat_b_tiles_shape[2], - Eigen::Stride<1, Eigen::Dynamic>(1, mat_b_tiles_shape[2])); - int m = (int)mat_b_tiles_shape[1]; - int n = (int)mat_b_tiles_shape[2]; - int nonzero_elements = 8 * m * n - 2 * m - 2 * n; - int kernel_index = transform_3d_index(i, j, k, limit_j, limit_k); - construct_ABCD_E( - tile_b, tile_a, Eigen::VectorXf::Zero(n), source_resistance, - line_resistance, - &ABCD_matrix_indices_x[kernel_index * nonzero_elements], - &ABCD_matrix_indices_y[kernel_index * nonzero_elements], - &ABCD_matrix_values[kernel_index * nonzero_elements], - &ABCD_matrix_compressed_rows[kernel_index * nonzero_elements], - &ABCD_matrix_compressed_columns[kernel_index * (2 * m * n)], - &ABCD_matrix_compressed_values[kernel_index * nonzero_elements], - &E_matrix[kernel_index * (2 * m * n)]); - } -} - -__global__ void tile_matmul_kernel_B( - double *E_matrix, float *mat_b_tiles_accessor, - torch::PackedTensorAccessor32 mat_b_tiles_map_accessor, - int64_t *mat_b_tiles_shape, int mat_b_shape_back, int m, int n, int limit_i, - int limit_j, int limit_k, float *result) { - int i = threadIdx.x + blockIdx.x * blockDim.x; - int j = threadIdx.y + blockIdx.y * blockDim.y; - int k = threadIdx.z + blockIdx.z * blockDim.z; - if (i < limit_i && j < limit_j && k < limit_k) { - int kernel_index = transform_3d_index(i, j, k, limit_j, limit_k); - Eigen::Map> - tile_b(&mat_b_tiles_accessor[transform_3d_index( - mat_b_tiles_map_accessor[k][j], 0, 0, mat_b_tiles_shape[1], - mat_b_tiles_shape[2])], - mat_b_tiles_shape[1], mat_b_tiles_shape[2], - Eigen::Stride<1, Eigen::Dynamic>(1, mat_b_tiles_shape[2])); - Eigen::MatrixXf I_applied_tensor = Eigen::MatrixXf::Zero(m, n); - for (int ii = 0; ii < m; ii++) { - for (int jj = 0; jj < n; jj++) { - I_applied_tensor(ii, jj) = - ((float)E_matrix[kernel_index * (2 * m * n) + n * ii + jj] - - (float) - E_matrix[kernel_index * (2 * m * n) + m * n + n * ii + jj]) * - tile_b(ii, jj); - } - } - Eigen::VectorXf I_tensor = I_applied_tensor.colwise().sum(); - for (int ii = 0; ii < n; ii++) { - result[transform_2d_index(i, j * mat_b_tiles_shape[2] + ii, - mat_b_shape_back)] += I_tensor[ii]; - } - } -} - __global__ void tile_matmul_kernel( float *mat_a_tiles_accessor, torch::PackedTensorAccessor32 mat_a_tiles_map_accessor, @@ -143,7 +66,6 @@ __global__ void tile_matmul_kernel( 0, mat_a_tiles_shape[1], mat_a_tiles_shape[2])], 1, mat_a_tiles_shape[2]); - Eigen::Map> tile_b(&mat_b_tiles_accessor[transform_3d_index( @@ -163,43 +85,6 @@ __global__ void tile_matmul_kernel( } } -__global__ void tile_matmul_kernel_B( - double *E_matrix, float *mat_b_tiles_accessor, - torch::PackedTensorAccessor32 mat_b_tiles_map_accessor, - int64_t *mat_b_tiles_shape, int mat_b_shape_back, int ADC_resolution, - float overflow_rate, int quant_method, int m, int n, int limit_i, - int limit_j, int limit_k, float *result) { - int i = threadIdx.x + blockIdx.x * blockDim.x; - int j = threadIdx.y + blockIdx.y * blockDim.y; - int k = threadIdx.z + blockIdx.z * blockDim.z; - if (i < limit_i && j < limit_j && k < limit_k) { - int kernel_index = transform_3d_index(i, j, k, limit_j, limit_k); - Eigen::Map> - tile_b(&mat_b_tiles_accessor[transform_3d_index( - mat_b_tiles_map_accessor[k][j], 0, 0, mat_b_tiles_shape[1], - mat_b_tiles_shape[2])], - mat_b_tiles_shape[1], mat_b_tiles_shape[2], - Eigen::Stride<1, Eigen::Dynamic>(1, mat_b_tiles_shape[2])); - Eigen::MatrixXf I_applied_tensor = Eigen::MatrixXf::Zero(m, n); - for (int ii = 0; ii < m; ii++) { - for (int jj = 0; jj < n; jj++) { - I_applied_tensor(ii, jj) = - ((float)E_matrix[kernel_index * (2 * m * n) + n * ii + jj] - - (float) - E_matrix[kernel_index * (2 * m * n) + m * n + n * ii + jj]) * - tile_b(ii, jj); - } - } - Eigen::VectorXf I_tensor = I_applied_tensor.colwise().sum(); - I_tensor = quantize(I_tensor, ADC_resolution, overflow_rate, quant_method); - for (int ii = 0; ii < n; ii++) { - result[transform_2d_index(i, j * mat_b_tiles_shape[2] + ii, - mat_b_shape_back)] += I_tensor[ii]; - } - } -} - at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, int mat_a_shape[2], at::Tensor mat_b_tiles, at::Tensor mat_b_tiles_map, int mat_b_shape[2], @@ -234,31 +119,28 @@ at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, mat_a_tiles_map.packed_accessor32(); torch::PackedTensorAccessor32 mat_b_tiles_map_accessor = mat_b_tiles_map.packed_accessor32(); - int limit_i = mat_a_tiles.sizes().end()[-2]; - int limit_j = mat_b_tiles_map.sizes()[1]; - int limit_k = mat_b_tiles_map.sizes()[0]; at::Tensor result = at::zeros({mat_a_shape[0], mat_b_shape[1]}, torch::device(torch::kCUDA)); - cudaDeviceSetLimit(cudaLimitMallocHeapSize, - size_t(1024) * size_t(1024) * - size_t(cuda_malloc_heap_size)); - dim3 grid; - dim3 block; - if (max_threads_dim[0] >= limit_i && max_threads_dim[1] >= limit_j && - max_threads_dim[2] >= limit_k) { - // If multiple blocks are not required - grid = {(unsigned int)limit_i, (unsigned int)limit_j, - (unsigned int)limit_k}; - block = {1, 1, 1}; - } else { - // If multiple blocks are required - grid = {(unsigned int)max_threads_dim[0], (unsigned int)max_threads_dim[1], - (unsigned int)max_threads_dim[2]}; - block = {(unsigned int)ceil_int_div(limit_i, max_threads_dim[0]), - (unsigned int)ceil_int_div(limit_j, max_threads_dim[1]), - (unsigned int)ceil_int_div(limit_k, max_threads_dim[2])}; - } if (line_resistance == -1) { + int limit_i = mat_a_tiles.sizes().end()[-2]; + int limit_j = mat_b_tiles_map.sizes()[1]; + int limit_k = mat_b_tiles_map.sizes()[0]; + dim3 grid; + dim3 block; + if (max_threads_dim[0] >= limit_i && max_threads_dim[1] >= limit_j && + max_threads_dim[2] >= limit_k) { + // If multiple blocks are not required + grid = {(unsigned int)limit_i, (unsigned int)limit_j, + (unsigned int)limit_k}; + block = {1, 1, 1}; + } else { + // If multiple blocks are required + grid = {(unsigned int)max_threads_dim[0], (unsigned int)max_threads_dim[1], + (unsigned int)max_threads_dim[2]}; + block = {(unsigned int)ceil_int_div(limit_i, max_threads_dim[0]), + (unsigned int)ceil_int_div(limit_j, max_threads_dim[1]), + (unsigned int)ceil_int_div(limit_k, max_threads_dim[2])}; + } if (ADC_resolution == -1) { tile_matmul_kernel<<>>( mat_a_tiles_accessor, mat_a_tiles_map_accessor, mat_a_tiles_shape, @@ -275,94 +157,31 @@ at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, int m = mat_b_tiles_shape_host[1]; int n = mat_b_tiles_shape_host[2]; int non_zero_elements = 8 * m * n - 2 * m - 2 * n; - int n_kernels = grid.x * block.x * grid.y * block.y * grid.z * block.z; - int *ABCD_matrix_indices_x; - int *ABCD_matrix_indices_y; - double *ABCD_matrix_values; - int *ABCD_matrix_compressed_columns; - int *ABCD_matrix_compressed_rows; - double *ABCD_matrix_compressed_values; - double *E_matrix; - cudaSafeCall(cudaMalloc(&ABCD_matrix_indices_x, - sizeof(int) * non_zero_elements * n_kernels)); - cudaSafeCall(cudaMalloc(&ABCD_matrix_indices_y, - sizeof(int) * non_zero_elements * n_kernels)); - cudaSafeCall(cudaMalloc(&ABCD_matrix_values, - sizeof(double) * non_zero_elements * n_kernels)); - cudaSafeCall(cudaMalloc(&ABCD_matrix_compressed_columns, - sizeof(int) * (2 * n * m) * n_kernels)); - cudaSafeCall(cudaMalloc(&ABCD_matrix_compressed_rows, - sizeof(int) * non_zero_elements * n_kernels)); - cudaSafeCall(cudaMalloc(&ABCD_matrix_compressed_values, - sizeof(double) * non_zero_elements * n_kernels)); - cudaSafeCall( - cudaMalloc(&E_matrix, sizeof(double) * (2 * m * n) * n_kernels)); - tile_matmul_kernel_A<<>>( - mat_a_tiles_accessor, mat_a_tiles_map_accessor, mat_a_tiles_shape, - mat_b_tiles_accessor, mat_b_tiles_map_accessor, mat_b_tiles_shape, - mat_b_shape[1], ABCD_matrix_indices_x, ABCD_matrix_indices_y, - ABCD_matrix_values, ABCD_matrix_compressed_rows, - ABCD_matrix_compressed_columns, ABCD_matrix_compressed_values, E_matrix, - source_resistance, line_resistance, limit_i, limit_j, limit_k); - cudaSafeCall(cudaDeviceSynchronize()); - cudaSafeCall(cudaFree(ABCD_matrix_indices_x)); - cudaSafeCall(cudaFree(ABCD_matrix_indices_y)); - cudaSafeCall(cudaFree(ABCD_matrix_values)); - int *ABCD_matrix_compressed_rows_host = - (int *)malloc(sizeof(int) * non_zero_elements); - int *ABCD_matrix_compressed_columns_host = - (int *)malloc(sizeof(int) * (2 * m * n)); - double *ABCD_matirx_compressed_values_host = - (double *)malloc(sizeof(double) * non_zero_elements); - double *E_matrix_host = - (double *)malloc(sizeof(double) * (2 * m * n) * n_kernels); - cudaSafeCall(cudaMemcpy(E_matrix_host, E_matrix, - sizeof(double) * (2 * m * n) * n_kernels, - cudaMemcpyDeviceToHost)); -#pragma omp parallel for - for (int i = 0; i < n_kernels; i++) { - cudaSafeCall( - cudaMemcpy(ABCD_matrix_compressed_rows_host, - &ABCD_matrix_compressed_rows[i * non_zero_elements], - sizeof(int) * non_zero_elements, cudaMemcpyDeviceToHost)); - cudaSafeCall(cudaMemcpy(ABCD_matrix_compressed_columns_host, - &ABCD_matrix_compressed_columns[i * (2 * n * m)], - sizeof(int) * (2 * m * n), - cudaMemcpyDeviceToHost)); - cudaSafeCall(cudaMemcpy( - ABCD_matirx_compressed_values_host, - &ABCD_matrix_compressed_values[i * non_zero_elements], - sizeof(double) * non_zero_elements, cudaMemcpyDeviceToHost)); - Eigen::Map> A( - (2 * m * n), (2 * m * n), non_zero_elements, - ABCD_matrix_compressed_columns_host, ABCD_matrix_compressed_rows_host, - ABCD_matirx_compressed_values_host); - solve_sparse_linear(A, &E_matrix_host[i * (2 * n * m)], 2 * m * n); - } - free(ABCD_matrix_compressed_rows_host); - free(ABCD_matrix_compressed_columns_host); - free(ABCD_matirx_compressed_values_host); - cudaSafeCall(cudaMemcpy(E_matrix, E_matrix_host, - sizeof(double) * (2 * n * m) * n_kernels, - cudaMemcpyHostToDevice)); - free(E_matrix_host); - if (ADC_resolution == -1) { - tile_matmul_kernel_B<<>>( - E_matrix, mat_b_tiles_accessor, mat_b_tiles_map_accessor, - mat_b_tiles_shape, mat_b_shape[1], m, n, limit_i, limit_j, limit_k, - result.data_ptr()); - } else { - tile_matmul_kernel_B<<>>( - E_matrix, mat_b_tiles_accessor, mat_b_tiles_map_accessor, - mat_b_tiles_shape, mat_b_shape[1], ADC_resolution, overflow_rate, - quant_method, m, n, limit_i, limit_j, limit_k, - result.data_ptr()); + int mat_a_rows = mat_a_tiles.sizes().end()[-2]; + at::Tensor partial_sum = + at::zeros({mat_b_tiles_map.sizes()[1], mat_b_tiles_shape_host[2]}, torch::device(torch::kCUDA)); + at::Tensor V_BL = at::zeros(n, torch::device(torch::kCUDA)); + for (int i = 0; i < mat_a_rows; i++) { + at::Tensor mat_a_row_tiles = mat_a_tiles.index({Slice(), i, Slice()}); + for (int j = 0; j < mat_b_tiles_map.sizes()[0]; j++) { + at::Tensor tile_a = mat_a_row_tiles[mat_a_tiles_map[j].item()]; + for (int k = 0; k < mat_b_tiles_map.sizes()[1]; k++) { + at::Tensor tile_b = mat_b_tiles[mat_b_tiles_map[j][k].item()]; + partial_sum[k] += + solve_passive(tile_b, tile_a, V_BL, + ADC_resolution, overflow_rate, quant_method, + source_resistance, line_resistance, true) + .squeeze(); + } + result.index_put_({i, Slice()}, result.index({i, Slice()}) + + partial_sum.flatten().index( + {Slice(0, mat_b_shape[1])})); + partial_sum = partial_sum.zero_(); + } } - cudaSafeCall(cudaFree(E_matrix)); } cudaSafeCall(cudaDeviceSynchronize()); cudaSafeCall(cudaFree(mat_a_tiles_shape)); cudaSafeCall(cudaFree(mat_b_tiles_shape)); - cudaStreamSynchronize(at::cuda::getCurrentCUDAStream()); return result; } \ No newline at end of file diff --git a/memtorch/cu/utils.cuh b/memtorch/cu/utils.cuh index 82b6d17b..9b57dafc 100644 --- a/memtorch/cu/utils.cuh +++ b/memtorch/cu/utils.cuh @@ -1,3 +1,6 @@ +#ifndef _UTILS_ +#define _UTILS_ + #define cudaSafeCall(call) \ do { \ cudaError_t err = call; \ @@ -39,7 +42,7 @@ template __host__ __device__ T abs_(T x) { return x; } -template __device__ void sort_(T *tensor, int tensor_numel) { +template __device__ __host__ void sort_(T *tensor, int tensor_numel) { T temp; #pragma omp parallel for for (int i = 0; i < tensor_numel; i++) { @@ -62,4 +65,6 @@ inline __device__ int transform_3d_index(int x, int y, int z, int len_y, return x * len_y * len_z + y * len_z + z; } -inline int ceil_int_div(int a, int b) { return (a + b - 1) / b; } \ No newline at end of file +inline int ceil_int_div(int a, int b) { return (a + b - 1) / b; } + +#endif \ No newline at end of file diff --git a/memtorch/mn/Conv1d.py b/memtorch/mn/Conv1d.py index cd15d98c..777fcabd 100644 --- a/memtorch/mn/Conv1d.py +++ b/memtorch/mn/Conv1d.py @@ -311,10 +311,13 @@ def tune(self, input_batch_size=8, input_shape=32): ) def __str__(self): - return "bh.Conv1d(in_channels=%d, out_channels=%d, kernel_size=%d, stride=%d, padding=%d)" % ( - self.in_channels, - self.out_channels, - self.kernel_size[0], - self.stride[0], - self.padding[0], + return ( + "bh.Conv1d(in_channels=%d, out_channels=%d, kernel_size=%d, stride=%d, padding=%d)" + % ( + self.in_channels, + self.out_channels, + self.kernel_size[0], + self.stride[0], + self.padding[0], + ) ) diff --git a/memtorch/mn/Conv2d.py b/memtorch/mn/Conv2d.py index c448c4a0..3ff1cb46 100644 --- a/memtorch/mn/Conv2d.py +++ b/memtorch/mn/Conv2d.py @@ -347,13 +347,16 @@ def tune(self, input_batch_size=8, input_shape=32): ) def __str__(self): - return "bh.Conv2d(in_channels=%d, out_channels=%d, kernel_size=(%d, %d), stride=(%d, %d), padding=(%d, %d))" % ( - self.in_channels, - self.out_channels, - self.kernel_size[0], - self.kernel_size[1], - self.stride[0], - self.stride[1], - self.padding[0], - self.padding[1], + return ( + "bh.Conv2d(in_channels=%d, out_channels=%d, kernel_size=(%d, %d), stride=(%d, %d), padding=(%d, %d))" + % ( + self.in_channels, + self.out_channels, + self.kernel_size[0], + self.kernel_size[1], + self.stride[0], + self.stride[1], + self.padding[0], + self.padding[1], + ) ) diff --git a/memtorch/mn/Conv3d.py b/memtorch/mn/Conv3d.py index 1099296a..c95148aa 100644 --- a/memtorch/mn/Conv3d.py +++ b/memtorch/mn/Conv3d.py @@ -357,16 +357,19 @@ def tune(self, input_batch_size=4, input_shape=32): ) def __str__(self): - return "bh.Conv3d(in_channels=%d, out_channels=%d, kernel_size=(%d, %d, %d), stride=(%d, %d, %d), padding=(%d, %d, %d))" % ( - self.in_channels, - self.out_channels, - self.kernel_size[0], - self.kernel_size[1], - self.kernel_size[2], - self.stride[0], - self.stride[1], - self.stride[2], - self.padding[0], - self.padding[1], - self.padding[2], + return ( + "bh.Conv3d(in_channels=%d, out_channels=%d, kernel_size=(%d, %d, %d), stride=(%d, %d, %d), padding=(%d, %d, %d))" + % ( + self.in_channels, + self.out_channels, + self.kernel_size[0], + self.kernel_size[1], + self.kernel_size[2], + self.stride[0], + self.stride[1], + self.stride[2], + self.padding[0], + self.padding[1], + self.padding[2], + ) ) diff --git a/setup.py b/setup.py index 6170734f..0f9f29da 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ def create_version_py(version, CUDA): "memtorch/submodules/eigen/", ] ], - extra_compile_args=["-lineinfo"], + extra_compile_args=["-lineinfo", "-use_fast_math"], ), CppExtension( name="memtorch_bindings", @@ -55,6 +55,7 @@ def create_version_py(version, CUDA): "memtorch/submodules/eigen/", ] ], + extra_compile_args=["-O3"], ), ] name = "memtorch" @@ -72,6 +73,7 @@ def create_version_py(version, CUDA): "memtorch/submodules/eigen/", ] ], + extra_compile_args=["-O3"], ) ] name = "memtorch-cpu"