Skip to content

Commit f3b50bc

Browse files
committed
issue/843: add description in kernel
1 parent 0e81911 commit f3b50bc

File tree

1 file changed

+52
-4
lines changed
  • src/infiniop/ops/quant/per_channel_quant_int8/cuda

1 file changed

+52
-4
lines changed

src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,51 @@
22
#define __PERCHANNEL_QUANTINT8_KERNEL_CUH__
33

44
#include <cub/block/block_reduce.cuh>
5+
/**
6+
* Rounds a floating-point value to the nearest integer using
7+
* the "half away from zero" tie-breaking rule.
8+
*
9+
* This rounding mode rounds to the nearest whole number, with ties
10+
* (values exactly halfway between integers) rounded away from zero.
11+
* For positive numbers: 1.5 rounds to 2, 2.5 rounds to 3
12+
* For negative numbers: -1.5 rounds to -2, -2.5 rounds to -3
13+
* This differs from standard "round to nearest, ties to even" banking rounding.
14+
*
15+
* @param x The floating-point value to round.
16+
* @return The rounded integer value as an int.
17+
*
18+
* @note This is a CUDA device function designed to execute on GPU hardware.
19+
* @note Uses floorf() and fabsf() from the CUDA math library.
20+
*/
521
__device__ inline int round_half_away_from_zero(float x) {
622
float ax = fabsf(x);
723
float r = floorf(ax + 0.5f);
824
return (x >= 0.0f) ? (int)r : -(int)r;
925
}
10-
26+
/**
27+
* Performs per-channel asymmetric quantization to int8 precision for large matrices.
28+
*
29+
* This kernel quantizes input matrix x (M x K) to int8 using channel-wise (column-wise)
30+
* quantization parameters, optimized for cases where K >= 1024. Each channel (column)
31+
* has independently computed scale and zero point to minimize quantization error.
32+
*
33+
* The quantization follows: x_quantized = round((x - zero) / scale)
34+
* where zero points shift the range and scales normalize to int8 range [-128, 127].
35+
*
36+
* @tparam Tdata Input data type (typically float or half)
37+
* @tparam BLOCK_SIZE CUDA block size for thread cooperation
38+
*
39+
* @param x_packed Output buffer for packed int8 quantized values
40+
* @param x_scale Output buffer for per-channel scale factors
41+
* @param x_zero Output buffer for per-channel zero points
42+
* @param x Input matrix in row-major layout (M rows, K columns)
43+
* @param M Number of rows in input matrix
44+
* @param K Number of columns in input matrix (channels)
45+
*
46+
* @note This is a CUDA device function optimized for GPU execution
47+
* @note Designed for large channel dimensions (K >= 1024) to maximize parallelization
48+
* @note Uses block-level reductions for efficient min/max computation per channel
49+
*/
1150
template <typename Tdata, unsigned int BLOCK_SIZE>
1251
__device__ void blockPerChannelQuantI8Kernel(
1352
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
@@ -72,7 +111,10 @@ __device__ void blockPerChannelQuantI8Kernel(
72111
x_packed[tid + ind] = (int8_t)q;
73112
}
74113
}
75-
114+
/**
115+
* Performs per-channel symmetric quantization to int8 for large matrices (K >= 1024).
116+
* Uses zero-centered scaling only, no zero point, and packs quantized data.
117+
*/
76118
template <typename Tdata, unsigned int BLOCK_SIZE>
77119
__device__ void blockPerChannelQuantI8SymKernel(
78120
int8_t *x_packed, float *x_scale, const Tdata *x,
@@ -145,7 +187,10 @@ __inline__ __device__ T WarpAllReduce(T val) {
145187
}
146188
return val;
147189
}
148-
190+
/**
191+
* Performs per-channel asymmetric quantization to int8 for large matrices (K < 1024).
192+
* Computes scale/zero point per channel (column) and packs quantized data.
193+
*/
149194
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
150195
__device__ void warpPerChannelQuantI8Kernel(
151196
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
@@ -208,7 +253,10 @@ __device__ void warpPerChannelQuantI8Kernel(
208253
}
209254
}
210255
}
211-
256+
/**
257+
* Performs per-channel symmetric quantization to int8 for large matrices (K < 1024).
258+
* Uses zero-centered scaling only, no zero point, and packs quantized data.
259+
*/
212260
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
213261
__device__ void warpPerChannelQuantI8SymKernel(
214262
int8_t *x_packed, float *x_scale, const Tdata *x,

0 commit comments

Comments
 (0)