Skip to content

Commit 4a7a759

Browse files
committed
issue/1052: debug per tensor
1 parent 7444e42 commit 4a7a759

File tree

6 files changed

+55
-63
lines changed

6 files changed

+55
-63
lines changed

src/infiniop/ops/dequant/per_tensor_dequant_int8/cuda/kernel.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ __device__ void perTensorDequantI8SymKernel(
1212
unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
1313
const int grid_size = blockDim.x * gridDim.x;
1414
float x_scale_val = x_scale[0];
15-
for (int tid = gid; tid < num_elements; tid += grid_size) {
15+
for (int ind = gid; ind < num_elements; ind += grid_size) {
16+
int tid = ind;
1617
int w = tid % (int)width;
1718
tid = tid / (int)width;
1819

src/infiniop/ops/quant/per_tensor_quant_int8/cuda/kernel.cuh

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,21 @@
99

1010
#define FULL_MASK 0xffffffff
1111

12-
1312
// warp reduce max
14-
__device__ __forceinline__ float warpReduceMax(float val)
15-
{
16-
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2)
13+
__device__ __forceinline__ float warpReduceMax(float val) {
14+
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
1715
val = fmaxf(val, __shfl_xor_sync(FULL_MASK, val, offset));
16+
}
1817
return val;
1918
}
2019

21-
2220
// float atomic max (safe version)
23-
__device__ __forceinline__ void atomicMaxFloat(float* addr, float val)
24-
{
25-
int* addr_i = (int*)addr;
21+
__device__ __forceinline__ void atomicMaxFloat(float *addr, float val) {
22+
int *addr_i = (int *)addr;
2623
int old = *addr_i;
2724
int assumed;
2825

29-
do
30-
{
26+
do {
3127
assumed = old;
3228
float old_f = __int_as_float(assumed);
3329
float new_f = fmaxf(val, old_f);
@@ -48,15 +44,15 @@ __device__ void perTensorAbsmaxSymKernel(float *x_scale, const Tdata *x,
4844
size_t batch_size, size_t channel, size_t hidden_dim, size_t width,
4945
ptrdiff_t strides_0, ptrdiff_t strides_1, ptrdiff_t strides_2, ptrdiff_t strides_3,
5046
int num_elements) {
51-
int tid = threadIdx.x;
52-
int gid = blockIdx.x * blockDim.x + tid;
47+
int idx = threadIdx.x;
48+
int gid = blockIdx.x * blockDim.x + idx;
5349
int grid_size = blockDim.x * gridDim.x;
5450

5551
float local_max = 0.f;
5652

5753
// grid-stride loop
58-
for (int tid = gid; tid < num_elements; tid += grid_size)
59-
{
54+
for (int ind = gid; ind < num_elements; ind += grid_size) {
55+
int tid = ind;
6056
int w = tid % (int)width;
6157
tid = tid / (int)width;
6258

@@ -78,11 +74,9 @@ __device__ void perTensorAbsmaxSymKernel(float *x_scale, const Tdata *x,
7874
// warp reduction
7975
local_max = warpReduceMax(local_max);
8076
// 每个 warp 只 atomic 一次
81-
if ((tid & (WARP_SIZE - 1)) == 0)
82-
{
77+
if ((idx & (WARP_SIZE - 1)) == 0) {
8378
atomicMaxFloat(x_scale, local_max / 127.0f);
8479
}
85-
8680
}
8781

8882
template <typename Tdata, unsigned int BLOCK_SIZE>
@@ -98,7 +92,8 @@ __device__ void perTensorQuantI8SymKernel(
9892

9993
float scale_val = 1.0f / x_scale[0];
10094

101-
for (int tid = gid; tid < num_elements; tid += grid_size) {
95+
for (int ind = gid; ind < num_elements; ind += grid_size) {
96+
int tid = ind;
10297
int w = tid % (int)width;
10398
tid = tid / (int)width;
10499

test/infiniop/libinfiniop/op_register.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,7 @@ def per_tensor_dequant_int8_(lib):
833833
infiniopOperatorDescriptor_t,
834834
]
835835

836+
836837
@OpRegister.operator
837838
def softplus_(lib):
838839
lib.infiniopCreateSoftplusDescriptor.restype = c_int32
@@ -1127,47 +1128,43 @@ def scaled_mm_int8_(lib):
11271128
]
11281129

11291130

1130-
11311131
@OpRegister.operator
11321132
def kv_caching_(lib):
11331133
lib.infiniopCreateKVCachingDescriptor.restype = c_int32
11341134
lib.infiniopCreateKVCachingDescriptor.argtypes = [
11351135
infiniopHandle_t,
11361136
POINTER(infiniopOperatorDescriptor_t),
1137-
infiniopTensorDescriptor_t,
1138-
infiniopTensorDescriptor_t,
1139-
infiniopTensorDescriptor_t,
1140-
infiniopTensorDescriptor_t,
1141-
infiniopTensorDescriptor_t,
1137+
infiniopTensorDescriptor_t,
1138+
infiniopTensorDescriptor_t,
1139+
infiniopTensorDescriptor_t,
1140+
infiniopTensorDescriptor_t,
1141+
infiniopTensorDescriptor_t,
11421142
]
11431143

1144-
11451144
lib.infiniopGetKVCachingWorkspaceSize.restype = c_int32
11461145
lib.infiniopGetKVCachingWorkspaceSize.argtypes = [
11471146
infiniopOperatorDescriptor_t,
11481147
POINTER(c_size_t),
11491148
]
11501149

1151-
11521150
lib.infiniopKVCaching.restype = c_int32
11531151
lib.infiniopKVCaching.argtypes = [
11541152
infiniopOperatorDescriptor_t,
1155-
c_void_p,
1156-
c_size_t,
1157-
c_void_p,
1158-
c_void_p,
1159-
c_void_p,
1160-
c_void_p,
1161-
c_void_p,
1162-
c_void_p,
1153+
c_void_p,
1154+
c_size_t,
1155+
c_void_p,
1156+
c_void_p,
1157+
c_void_p,
1158+
c_void_p,
1159+
c_void_p,
1160+
c_void_p,
11631161
]
11641162

1165-
11661163
lib.infiniopDestroyKVCachingDescriptor.restype = c_int32
11671164
lib.infiniopDestroyKVCachingDescriptor.argtypes = [
11681165
infiniopOperatorDescriptor_t,
11691166
]
1170-
1167+
11711168

11721169
@OpRegister.operator
11731170
def paged_attention_(lib):

test/infiniop/per_channel_quant_int8.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def per_token_quant_int8_torch(x, symmetric):
7878

7979
return w_packed, w_scale, w_zero
8080

81+
8182
def test(
8283
handle,
8384
device,
@@ -86,12 +87,12 @@ def test(
8687
dtype=InfiniDtype.F16,
8788
sync=None,
8889
):
89-
90+
9091
print(
9192
f"Testing Per Channel Quant Int8 on {InfiniDeviceNames[device]} with x_shape:{x_shape}, symmetric:{symmetric} , dtype:{InfiniDtypeNames[dtype]}"
9293
)
9394
M, K = x_shape
94-
95+
9596
x = TestTensor(x_shape, None, dtype, device)
9697
x_p, x_s, x_z = per_token_quant_int8_torch(x.torch_tensor(), symmetric)
9798
x_packed = TestTensor(x_shape, None, InfiniDtype.I8, device, mode="zeros")
@@ -129,7 +130,7 @@ def test(
129130
)
130131
)
131132
workspace = TestWorkspace(workspace_size.value, x.device)
132-
133+
133134
def lib_per_channel_quant_int8():
134135
check_error(
135136
LIBINFINIOP.infiniopPerChannelQuantI8(
@@ -145,7 +146,7 @@ def lib_per_channel_quant_int8():
145146
)
146147

147148
lib_per_channel_quant_int8()
148-
149+
149150
if sync is not None:
150151
sync()
151152

@@ -157,12 +158,15 @@ def lib_per_channel_quant_int8():
157158
debug(x_zero.actual_tensor(), x_z, atol=atol, rtol=rtol)
158159

159160
if symmetric:
160-
assert (torch.allclose(x_packed.actual_tensor(), x_p, atol=2, rtol=0) and
161-
torch.allclose(x_scale.actual_tensor(), x_s, atol=atol, rtol=rtol))
161+
assert torch.allclose(
162+
x_packed.actual_tensor(), x_p, atol=2, rtol=0
163+
) and torch.allclose(x_scale.actual_tensor(), x_s, atol=atol, rtol=rtol)
162164
else:
163-
assert (torch.allclose(x_packed.actual_tensor(), x_p, atol=2, rtol=0) and
164-
torch.allclose(x_scale.actual_tensor(), x_s, atol=atol, rtol=rtol) and
165-
torch.allclose(x_zero.actual_tensor(), x_z, atol=atol, rtol=rtol))
165+
assert (
166+
torch.allclose(x_packed.actual_tensor(), x_p, atol=2, rtol=0)
167+
and torch.allclose(x_scale.actual_tensor(), x_s, atol=atol, rtol=rtol)
168+
and torch.allclose(x_zero.actual_tensor(), x_z, atol=atol, rtol=rtol)
169+
)
166170

167171
# Profiling workflow
168172
if PROFILE:
@@ -185,5 +189,5 @@ def lib_per_channel_quant_int8():
185189

186190
for device in get_test_devices(args):
187191
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
188-
192+
189193
print("\033[92mTest passed!\033[0m")

test/infiniop/per_tensor_dequant_int8.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@
3131
((16, 5632), (13312, 1), (13312, 1), True),
3232
((4, 4, 5632), None, None, True),
3333
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), True),
34-
((1, 1, 8, 1), None, None, True),
35-
((1, 8, 32, 32), None, None, True),
36-
((8, 16, 64, 128), (8388608, 524288, 8192, 1), None, True),
37-
((1, 2, 2304, 128), (589824, 294912, 128, 1), (589824, 294912, 128, 1), True),
34+
((1, 4, 132, 128), (67584, 16896, 128, 1), (67584, 16896, 128, 1), True),
35+
((1, 4, 132, 128), None, None, True),
3836
]
3937

4038

test/infiniop/per_tensor_quant_int8.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@
3131
((16, 5632), (13312, 1), (13312, 1), True, True),
3232
((4, 4, 5632), None, None, True, False),
3333
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), True, True),
34-
((1, 1, 8, 1), None, None, True, False),
35-
((1, 8, 32, 32), None, None, True, True),
36-
((8, 16, 64, 128), (8388608, 524288, 8192, 1), None, True, False),
37-
((1, 2, 2304, 128), (589824, 294912, 128, 1), (589824, 294912, 128, 1), True, True),
34+
((1, 32, 4, 128), (147456, 4608, 128, 1), (147456, 4608, 128, 1), True, False),
35+
((1, 32, 4, 128), (16384, 512, 128, 1), (16384, 512, 128, 1), True, True),
3836
]
3937

4038

@@ -61,8 +59,8 @@ def per_tensor_quant_int8_torch(x, x_scale, symmetric, is_static):
6159
x = x.float()
6260
if is_static:
6361
x_q = x.mul(1 / x_scale)
64-
x_q = torch.round(x_q).to(torch.int8)
65-
return x_q, x_scale, None
62+
x_packed = torch.clamp(x_q, -127, 127).to(torch.int8)
63+
return x_packed, x_scale, None
6664
else:
6765
absmax = x.flatten().abs().max()
6866
if absmax == 0:
@@ -71,9 +69,8 @@ def per_tensor_quant_int8_torch(x, x_scale, symmetric, is_static):
7169
return q, scale, None
7270
scale = absmax / 127
7371
x_q = x.mul(127 / absmax)
74-
x_q = torch.round(x_q).to(torch.int8)
75-
76-
return x_q, scale, None
72+
x_packed = torch.clamp(x_q, -127, 127).to(torch.int8)
73+
return x_packed, scale, None
7774

7875

7976
def test(
@@ -154,17 +151,17 @@ def lib_per_tensor_quant_int8():
154151
)
155152

156153
lib_per_tensor_quant_int8()
157-
154+
158155
if sync is not None:
159156
sync()
160-
157+
161158
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
162159
if DEBUG:
163160
debug(x_packed.actual_tensor(), x_p, atol=2, rtol=0)
164161
debug(x_scale.actual_tensor(), x_s, atol=atol, rtol=rtol)
165162
if symmetric == False:
166163
debug(x_zero.actual_tensor(), x_z, atol=atol, rtol=rtol)
167-
164+
168165
if symmetric:
169166
assert torch.allclose(
170167
x_packed.actual_tensor(), x_p, atol=2, rtol=0

0 commit comments

Comments
 (0)