Skip to content

Commit da487c2

Browse files
committed
issue/1083: gptq_marlin_gemm
1 parent f7b2511 commit da487c2

23 files changed

+8418
-1
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#ifndef __INFINIOP_GPTQ_MARLIN_GEMM_API_H__
2+
#define __INFINIOP_GPTQ_MARLIN_GEMM_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
#include <cstdint>
6+
7+
typedef struct InfiniopDescriptor *infiniopGptqMarlinGemmDescriptor_t;
8+
9+
__INFINI_C __export infiniStatus_t infiniopCreateGptqMarlinGemmDescriptor(infiniopHandle_t handle,
10+
infiniopGptqMarlinGemmDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t out_desc,
12+
infiniopTensorDescriptor_t a_desc,
13+
infiniopTensorDescriptor_t b_desc,
14+
infiniopTensorDescriptor_t b_scales_desc,
15+
infiniopTensorDescriptor_t global_scale_desc,
16+
infiniopTensorDescriptor_t b_zeros_desc,
17+
infiniopTensorDescriptor_t g_idx_desc,
18+
infiniopTensorDescriptor_t perm_desc);
19+
20+
__INFINI_C __export infiniStatus_t infiniopGetGptqMarlinGemmWorkspaceSize(infiniopGptqMarlinGemmDescriptor_t desc, size_t *size);
21+
22+
__INFINI_C __export infiniStatus_t infiniopGptqMarlinGemm(infiniopGptqMarlinGemmDescriptor_t desc,
23+
void *workspace,
24+
size_t workspace_size,
25+
void *out,
26+
const void *a,
27+
const void *b,
28+
void *b_scales,
29+
void *global_scale,
30+
void *b_zeros,
31+
void *g_idx,
32+
void *perm,
33+
int64_t b_q_type_id,
34+
bool is_k_full,
35+
bool use_atomic_add,
36+
bool use_fp32_reduce,
37+
bool is_zp_float,
38+
void *stream);
39+
40+
__INFINI_C __export infiniStatus_t infiniopDestroyGptqMarlinGemmDescriptor(infiniopGptqMarlinGemmDescriptor_t desc);
41+
42+
#endif
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef __GPTQ_MARLIN_GEMM_H__
2+
#define __GPTQ_MARLIN_GEMM_H__
3+
4+
#include "../../../utils.h"
5+
#include "../../operator.h"
6+
#include "../../tensor.h"
7+
#include "info.h"
8+
9+
#define DESCRIPTOR(NAMESPACE) \
10+
\
11+
namespace op::gptq_marlin_gemm::NAMESPACE { \
12+
class Descriptor final : public InfiniopDescriptor { \
13+
struct Opaque; \
14+
Opaque *_opaque; \
15+
GptqMarlinGemmInfo _info; \
16+
size_t _workspace_size; \
17+
\
18+
Descriptor( \
19+
size_t workspace_size_, \
20+
Opaque *opaque, \
21+
GptqMarlinGemmInfo info, \
22+
infiniDevice_t device_type, \
23+
int device_id) \
24+
: InfiniopDescriptor{device_type, device_id}, \
25+
_opaque(opaque), \
26+
_info(info), \
27+
_workspace_size(workspace_size_) {} \
28+
\
29+
public: \
30+
~Descriptor(); \
31+
\
32+
size_t workspaceSize() const { return _workspace_size; } \
33+
\
34+
static infiniStatus_t create( \
35+
infiniopHandle_t handle, \
36+
Descriptor **desc_ptr, \
37+
infiniopTensorDescriptor_t out_desc, \
38+
infiniopTensorDescriptor_t a_desc, \
39+
infiniopTensorDescriptor_t b_desc, \
40+
infiniopTensorDescriptor_t b_scales_desc, \
41+
infiniopTensorDescriptor_t global_scale_desc, \
42+
infiniopTensorDescriptor_t b_zeros_desc, \
43+
infiniopTensorDescriptor_t g_idx_desc, \
44+
infiniopTensorDescriptor_t perm_desc); \
45+
\
46+
infiniStatus_t calculate( \
47+
void *workspace, \
48+
size_t workspace_size, \
49+
void *out, \
50+
const void *a, \
51+
const void *b, \
52+
void *b_scales, \
53+
void *global_scale, \
54+
void *b_zeros, \
55+
void *g_idx, \
56+
void *perm, \
57+
int64_t b_q_type_id, \
58+
bool is_k_full, \
59+
bool use_atomic_add, \
60+
bool use_fp32_reduce, \
61+
bool is_zp_float, \
62+
void *stream) const; \
63+
}; \
64+
}
65+
66+
#endif //__GPTQ_MARLIN_GEMM_H__
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#ifndef __GPTQ_MARLIN_GEMM_INFO_H__
2+
#define __GPTQ_MARLIN_GEMM_INFO_H__
3+
4+
#include "../../../utils.h"
5+
#include "../../tensor.h"
6+
#include <vector>
7+
8+
#include <cassert>
9+
10+
namespace op::gptq_marlin_gemm {
11+
12+
class GptqMarlinGemmInfo {
13+
GptqMarlinGemmInfo() = default;
14+
15+
public:
16+
infiniDtype_t dtype;
17+
size_t M, K, N, b_q_size_1;
18+
int num_groups;
19+
ptrdiff_t a_stride_0;
20+
21+
static utils::Result<GptqMarlinGemmInfo> create(
22+
infiniopTensorDescriptor_t out_desc,
23+
infiniopTensorDescriptor_t a_desc,
24+
infiniopTensorDescriptor_t b_desc,
25+
infiniopTensorDescriptor_t b_scales_desc,
26+
infiniopTensorDescriptor_t global_scale_desc,
27+
infiniopTensorDescriptor_t b_zeros_desc,
28+
infiniopTensorDescriptor_t g_idx_desc,
29+
infiniopTensorDescriptor_t perm_desc) {
30+
CHECK_OR_RETURN(
31+
out_desc != nullptr && a_desc != nullptr && b_desc != nullptr && b_scales_desc != nullptr,
32+
INFINI_STATUS_NULL_POINTER);
33+
const infiniDtype_t dtype = a_desc->dtype();
34+
size_t M = out_desc->dim(0);
35+
size_t N = out_desc->dim(1);
36+
size_t K = a_desc->dim(1);
37+
size_t b_q_size_1 = b_desc->dim(1);
38+
int num_groups = static_cast<int>(b_scales_desc->dim(0));
39+
ptrdiff_t a_stride_0 = a_desc->strides()[0];
40+
41+
auto ndim = out_desc->ndim();
42+
CHECK_OR_RETURN(ndim == 2
43+
&& a_desc->ndim() == ndim
44+
&& b_desc->ndim() == ndim
45+
&& b_scales_desc->ndim() == ndim,
46+
INFINI_STATUS_BAD_TENSOR_SHAPE);
47+
48+
CHECK_OR_RETURN(b_scales_desc->shape()[1] == N
49+
&& a_stride_0 % 8 == 0,
50+
INFINI_STATUS_BAD_TENSOR_SHAPE);
51+
52+
return utils::Result<GptqMarlinGemmInfo>(
53+
GptqMarlinGemmInfo{dtype, M, K, N, b_q_size_1, num_groups, a_stride_0});
54+
}
55+
};
56+
57+
} // namespace op::gptq_marlin_gemm
58+
59+
#endif // __GPTQ_MARLIN_GEMM_INFO_H__

0 commit comments

Comments
 (0)