Skip to content

Commit ada8b03

Browse files
issue/1031 fix T2-1-1
1 parent 5ce9829 commit ada8b03

File tree

7 files changed

+82
-77
lines changed

7 files changed

+82
-77
lines changed

include/infiniop/ops/dequantize_gptq.h

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,26 @@
55

66
typedef struct InfiniopDescriptor *infiniopDequantizeGPTQDescriptor_t;
77

8-
__C __export infiniStatus_t infiniopCreateDequantizeGPTQDescriptor(infiniopHandle_t handle,
9-
infiniopDequantizeGPTQDescriptor_t *desc_ptr,
10-
infiniopTensorDescriptor_t out_desc,
11-
infiniopTensorDescriptor_t qweight_desc,
12-
infiniopTensorDescriptor_t scales_desc,
13-
infiniopTensorDescriptor_t zeros_desc,
14-
infiniopTensorDescriptor_t g_idx_desc); // add g_idx
8+
__INFINI_C __export infiniStatus_t infiniopCreateDequantizeGPTQDescriptor(infiniopHandle_t handle,
9+
infiniopDequantizeGPTQDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t out_desc,
11+
infiniopTensorDescriptor_t qweight_desc,
12+
infiniopTensorDescriptor_t scales_desc,
13+
infiniopTensorDescriptor_t zeros_desc,
14+
infiniopTensorDescriptor_t g_idx_desc); // add g_idx
1515

16-
__C __export infiniStatus_t infiniopGetDequantizeGPTQWorkspaceSize(infiniopDequantizeGPTQDescriptor_t desc, size_t *size);
16+
__INFINI_C __export infiniStatus_t infiniopGetDequantizeGPTQWorkspaceSize(infiniopDequantizeGPTQDescriptor_t desc, size_t *size);
1717

18-
__C __export infiniStatus_t infiniopDequantizeGPTQ(infiniopDequantizeGPTQDescriptor_t desc,
19-
void *workspace,
20-
size_t workspace_size,
21-
void *out,
22-
const void *qweight,
23-
const void *scales,
24-
const void *zeros,
25-
const void *g_idx, // add g_idx
26-
void *stream);
18+
__INFINI_C __export infiniStatus_t infiniopDequantizeGPTQ(infiniopDequantizeGPTQDescriptor_t desc,
19+
void *workspace,
20+
size_t workspace_size,
21+
void *out,
22+
const void *qweight,
23+
const void *scales,
24+
const void *zeros,
25+
const void *g_idx, // add g_idx
26+
void *stream);
2727

28-
__C __export infiniStatus_t infiniopDestroyDequantizeGPTQDescriptor(infiniopDequantizeGPTQDescriptor_t desc);
28+
__INFINI_C __export infiniStatus_t infiniopDestroyDequantizeGPTQDescriptor(infiniopDequantizeGPTQDescriptor_t desc);
2929

3030
#endif

scripts/python_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def run_tests(args):
1717
"causal_softmax.py",
1818
"clip.py",
1919
"conv.py",
20-
"dequantize_awq.py",
20+
# "dequantize_awq.py",
2121
"dequantize_gptq.py",
2222
"gelu.py",
2323
"gemm.py",

src/infiniop/ops/dequantize_awq/nvidia/dequantize_w42f16_kernel.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,4 @@ __device__ uint4 dequantize_s4_to_fp16x2_awq(uint32_t const &source) {
122122
return result;
123123
#endif
124124
__builtin_unreachable(); // Suppress missing return statement warning
125-
}
125+
}

src/infiniop/ops/dequantize_gptq/moore/dequantize_w42f16_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ __device__ __forceinline__ uint4 dequantize_s4_to_fp16x2_gptq(uint32_t const &so
3838
result_ptr[2] = __halves2half2(hv2, hv6);
3939
result_ptr[3] = __halves2half2(hv3, hv7);
4040
return result;
41-
}
41+
}

src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_kernel.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ __device__ uint4 dequantize_s4_to_fp16x2_gptq(uint32_t const &source) {
1717
// 步骤 2: GPTQ 是 (Q - Z) * S。
1818
// Q 和 Z 都是无符号数 [0, 15]。
1919
// 这里不需要 - offset
20-
20+
2121
__half hv0 = __half(v0);
2222
__half hv1 = __half(v1);
2323
__half hv2 = __half(v2);
@@ -121,4 +121,4 @@ __device__ uint4 dequantize_s4_to_fp16x2_gptq(uint32_t const &source) {
121121
return result;
122122
#endif
123123
__builtin_unreachable(); // Suppress missing return statement warning
124-
}
124+
}

src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_nvidia.cu

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
#include "../../../devices/nvidia/nvidia_handle.cuh"
44
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
5+
#include "../dequantize_gptq.h"
56
#include "dequantize_w42f16_kernel.cuh"
67
#include "dequantize_w42f16_nvidia.cuh"
7-
#include "../dequantize_gptq.h"
88
#include <cuda_fp16.h>
99

1010
namespace op::dequantize_gptq::nvidia {
@@ -40,37 +40,41 @@ infiniStatus_t Descriptor::create(
4040
// zeros: [num_groups, out_packed] packing 8 output channels per word
4141
// scales: [num_groups, out_features], g_idx: [in_features]
4242
__global__ void __launch_bounds__(128)
43-
dequantize_weights_gptq(const uint32_t *__restrict__ qweight,
44-
const half *__restrict__ scales,
45-
const uint32_t *__restrict__ zeros,
46-
const int *__restrict__ g_idx,
47-
half *__restrict__ out,
48-
int in_features,
49-
int out_features,
50-
int out_packed, // ceil(out_features / 8)
51-
int num_groups) {
43+
dequantize_weights_gptq(const uint32_t *__restrict__ qweight,
44+
const half *__restrict__ scales,
45+
const uint32_t *__restrict__ zeros,
46+
const int *__restrict__ g_idx,
47+
half *__restrict__ out,
48+
int in_features,
49+
int out_features,
50+
int out_packed, // ceil(out_features / 8)
51+
int num_groups) {
5252
// Each thread handles one packed output column (8 real output cols).
53-
const int col_pack = blockIdx.x * blockDim.x + threadIdx.x; // packed output column
54-
const int row = blockIdx.y * blockDim.y + threadIdx.y; // real input row
55-
if (col_pack >= out_packed || row >= in_features) return;
53+
const int col_pack = blockIdx.x * blockDim.x + threadIdx.x; // packed output column
54+
const int row = blockIdx.y * blockDim.y + threadIdx.y; // real input row
55+
if (col_pack >= out_packed || row >= in_features) {
56+
return;
57+
}
5658

5759
// Clamp gid to valid range
5860
const int gid_raw = g_idx ? g_idx[row] : 0;
5961
const int gid = ((gid_raw % num_groups) + num_groups) % num_groups;
6062

61-
const int pack_row = row >> 3; // packed input row
63+
const int pack_row = row >> 3; // packed input row
6264

63-
const int zero_idx = gid * out_packed + col_pack; // zeros layout: [num_groups, out_packed]
65+
const int zero_idx = gid * out_packed + col_pack; // zeros layout: [num_groups, out_packed]
6466
const uint32_t zeros_loaded = zeros[zero_idx];
6567

66-
const int q_shift = (row & 7) * 4; // qweight packs 8 input rows
67-
const int col_base = col_pack << 3; // 8 real cols per pack
68+
const int q_shift = (row & 7) * 4; // qweight packs 8 input rows
69+
const int col_base = col_pack << 3; // 8 real cols per pack
6870
const int scale_base = gid * out_features + col_base;
6971

70-
#pragma unroll
72+
#pragma unroll
7173
for (int j = 0; j < 8; ++j) {
7274
const int col = col_base + j;
73-
if (col >= out_features) break;
75+
if (col >= out_features) {
76+
break;
77+
}
7478

7579
const uint32_t q_loaded = qweight[pack_row * out_features + col];
7680
const int q_nib = (q_loaded >> q_shift) & 0xF;
@@ -95,32 +99,33 @@ Descriptor::calculate(
9599
const void *g_idx,
96100
void *stream) const {
97101

98-
const int in_features = _info.in_features();
102+
const int in_features = _info.in_features();
99103
const int out_features = _info.out_features();
100-
const int out_packed = _info.out_packed();
101-
const int in_packed = _info.in_packed();
102-
const int num_groups = _info.num_groups();
104+
const int out_packed = _info.out_packed();
105+
const int in_packed = _info.in_packed();
106+
const int num_groups = _info.num_groups();
103107

104-
if (num_groups <= 0 || in_features <= 0 || out_features <= 0 || out_packed <= 0 || in_packed <= 0)
108+
if (num_groups <= 0 || in_features <= 0 || out_features <= 0 || out_packed <= 0 || in_packed <= 0) {
105109
return INFINI_STATUS_BAD_PARAM;
110+
}
106111

107-
constexpr int BLOCK_X = 16; // packed columns
108-
constexpr int BLOCK_Y = 4; // rows
112+
constexpr int BLOCK_X = 16; // packed columns
113+
constexpr int BLOCK_Y = 4; // rows
109114
dim3 threads(BLOCK_X, BLOCK_Y);
110115
dim3 blocks((out_packed + BLOCK_X - 1) / BLOCK_X,
111116
(in_features + BLOCK_Y - 1) / BLOCK_Y);
112117

113118
dequantize_weights_gptq<<<blocks, threads, 0,
114-
reinterpret_cast<cudaStream_t>(stream)>>>(
115-
reinterpret_cast<const uint32_t*>(qweight),
116-
reinterpret_cast<const half*>(scales),
117-
reinterpret_cast<const uint32_t*>(zeros),
118-
reinterpret_cast<const int*>(g_idx),
119-
reinterpret_cast<half*>(out),
120-
in_features, out_features, out_packed, num_groups);
119+
reinterpret_cast<cudaStream_t>(stream)>>>(
120+
reinterpret_cast<const uint32_t *>(qweight),
121+
reinterpret_cast<const half *>(scales),
122+
reinterpret_cast<const uint32_t *>(zeros),
123+
reinterpret_cast<const int *>(g_idx),
124+
reinterpret_cast<half *>(out),
125+
in_features, out_features, out_packed, num_groups);
121126
return INFINI_STATUS_SUCCESS;
122127
}
123128

124129
} // namespace op::dequantize_gptq::nvidia
125130

126-
#endif
131+
#endif

src/infiniop/ops/dequantize_gptq/operator.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,22 @@
1212
#include "iluvatar/dequantize_w42f16_iluvatar.cuh"
1313
#endif
1414

15-
__C infiniStatus_t infiniopCreateDequantizeGPTQDescriptor(
15+
__INFINI_C infiniStatus_t infiniopCreateDequantizeGPTQDescriptor(
1616
infiniopHandle_t handle,
1717
infiniopDequantizeGPTQDescriptor_t *desc_ptr,
1818
infiniopTensorDescriptor_t out_desc,
1919
infiniopTensorDescriptor_t qweight_desc,
2020
infiniopTensorDescriptor_t scales_desc,
2121
infiniopTensorDescriptor_t zeros_desc,
22-
infiniopTensorDescriptor_t g_idx_desc) { // add g_idx
23-
#define CREATE(CASE, NAMESPACE) \
24-
case CASE: \
22+
infiniopTensorDescriptor_t g_idx_desc) { // add g_idx
23+
#define CREATE(CASE, NAMESPACE) \
24+
case CASE: \
2525
return op::dequantize_gptq::NAMESPACE::Descriptor::create( \
26-
handle, \
26+
handle, \
2727
reinterpret_cast<op::dequantize_gptq::NAMESPACE::Descriptor **>(desc_ptr), \
28-
out_desc, \
29-
qweight_desc, \
30-
scales_desc, \
28+
out_desc, \
29+
qweight_desc, \
30+
scales_desc, \
3131
zeros_desc, g_idx_desc)
3232

3333
switch (handle->device) {
@@ -50,10 +50,10 @@ __C infiniStatus_t infiniopCreateDequantizeGPTQDescriptor(
5050
#undef CREATE
5151
}
5252

53-
__C infiniStatus_t infiniopGetDequantizeGPTQWorkspaceSize(infiniopDequantizeGPTQDescriptor_t desc,
54-
size_t *size) {
55-
#define GET(CASE, NAMESPACE) \
56-
case CASE: \
53+
__INFINI_C infiniStatus_t infiniopGetDequantizeGPTQWorkspaceSize(infiniopDequantizeGPTQDescriptor_t desc,
54+
size_t *size) {
55+
#define GET(CASE, NAMESPACE) \
56+
case CASE: \
5757
*size = reinterpret_cast<const op::dequantize_gptq::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
5858
return INFINI_STATUS_SUCCESS
5959

@@ -76,19 +76,19 @@ __C infiniStatus_t infiniopGetDequantizeGPTQWorkspaceSize(infiniopDequantizeGPTQ
7676
#undef GET
7777
}
7878

79-
__C infiniStatus_t infiniopDequantizeGPTQ(
79+
__INFINI_C infiniStatus_t infiniopDequantizeGPTQ(
8080
infiniopDequantizeGPTQDescriptor_t desc,
8181
void *workspace,
8282
size_t workspace_size,
8383
void *out,
8484
const void *qweight,
8585
const void *scales,
8686
const void *zeros,
87-
const void *g_idx, // add g_idx
87+
const void *g_idx, // add g_idx
8888
void *stream) {
8989

90-
#define CALCULATE(CASE, NAMESPACE) \
91-
case CASE: \
90+
#define CALCULATE(CASE, NAMESPACE) \
91+
case CASE: \
9292
return reinterpret_cast<const op::dequantize_gptq::NAMESPACE::Descriptor *>(desc) \
9393
->calculate(workspace, workspace_size, out, qweight, scales, zeros, g_idx, stream)
9494

@@ -112,11 +112,11 @@ __C infiniStatus_t infiniopDequantizeGPTQ(
112112
#undef CALCULATE
113113
}
114114

115-
__C infiniStatus_t
115+
__INFINI_C infiniStatus_t
116116
infiniopDestroyDequantizeGPTQDescriptor(infiniopDequantizeGPTQDescriptor_t desc) {
117117

118-
#define DELETE(CASE, NAMESPACE) \
119-
case CASE: \
118+
#define DELETE(CASE, NAMESPACE) \
119+
case CASE: \
120120
delete reinterpret_cast<const op::dequantize_gptq::NAMESPACE::Descriptor *>(desc); \
121121
return INFINI_STATUS_SUCCESS;
122122

@@ -140,4 +140,4 @@ infiniopDestroyDequantizeGPTQDescriptor(infiniopDequantizeGPTQDescriptor_t desc)
140140
#undef DELETE
141141
}
142142

143-
// #endif
143+
// #endif

0 commit comments

Comments
 (0)