Skip to content

Commit

Permalink
Fix inline C kernels
Browse files Browse the repository at this point in the history
This fixes multiple issues in the inline C implementation of reference
approximate logarithmic multipliers
  • Loading branch information
etrommer committed Dec 6, 2023
1 parent b655c53 commit 60c36a1
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions src/torchapprox/operators/kernels/cuda/ta_gemm_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "ta_gemm_cuda.h"

const auto BLOCK_SIZE = 16;
const auto K = 3;
const auto K = 5;

__device__ inline int32_t DRUM(int16_t op1, int16_t op2) {
if (op1 == 0 || op2 == 0)
Expand Down Expand Up @@ -46,7 +46,7 @@ __device__ inline int32_t DRUM(int16_t op1, int16_t op2) {
return y;
}

__device__ inline int32_t mitchell_trunc(int16_t op1, int16_t op2, uint8_t w) {
__device__ inline int32_t mitchell_trunc(int16_t op1, int16_t op2) {
// Same as DRUM, only that the lowest non-truncated Bit position is not
// de-biased by setting it to one.
if (op1 == 0 || op2 == 0)
Expand All @@ -56,25 +56,30 @@ __device__ inline int32_t mitchell_trunc(int16_t op1, int16_t op2, uint8_t w) {
if (op2 == -1)
return -op1;

const bool sgn1 = op1 > 0;
const bool sgn2 = op2 > 0;
// Sign extraction
const bool sgn1 = op1 < 0;
const bool sgn2 = op2 < 0;

uint32_t abs1 = sgn1 ? -(op1)-1 : op1;
uint32_t abs2 = sgn2 ? -(op2)-1 : op2;
uint32_t abs1 = sgn1 ? -op1 : op1;
uint32_t abs2 = sgn2 ? -op2 : op2;

const auto lz1 = 31 - __clz(abs1);
const auto lz2 = 31 - __clz(abs2);
// Find leading one
const auto lead1_1 = 31 - __clz(abs1);
const auto lead1_2 = 31 - __clz(abs2);

const auto mask = (1 << w) - 1;
if (lz1 > w) {
abs1 &= (mask << (lz1 - w));
// Mask with the lowest `k` Bits set, zero otherwise
const auto mask = (1 << K) - 1;
if (lead1_1 > K) {
// Truncate to the most-significant `k` bits
abs1 &= (mask << (lead1_1 - K + 1));
}
if (lz2 > w) {
abs2 &= (mask << (lz2 - w));
if (lead1_2 > K) {
abs2 &= (mask << (lead1_2 - K + 1));
}

auto y0 = abs1 * abs2;
auto y = (sgn1 ^ sgn2) ? -y0 : y0;

return y;
}

Expand Down Expand Up @@ -123,7 +128,7 @@ ta_gemm_kernel(cudaTextureObject_t tex,
auto i1 = a_shared[threadIdx.y][n];
auto i2 = b_shared[threadIdx.x][n];
/* auto val = lut_operator<uint8_t>(tex, i1, i2);*/
auto val = DRUM((int16_t)i1, (int16_t)i2);
auto val = mitchell_trunc((int16_t)i1, (int16_t)i2);
acc += val;
}
__syncthreads();
Expand Down Expand Up @@ -171,7 +176,7 @@ ta_gemm_kernel_batchb(cudaTextureObject_t tex,
auto i1 = a_shared[threadIdx.y][n];
auto i2 = b_shared[threadIdx.x][n];
/* auto val = lut_operator<uint8_t>(tex, i2, i1);*/
auto val = DRUM((int16_t)i2, (int16_t)i1);
auto val = mitchell_trunc((int16_t)i2, (int16_t)i1);
acc += val;
}
__syncthreads();
Expand Down

0 comments on commit 60c36a1

Please sign in to comment.