Skip to content

Commit

Permalink
gptq_4bit
Browse files Browse the repository at this point in the history
qgemm_4bit

cpu support, slow!

in_features

x bit support

fix windows

minor fix, weight index

rebase conflict
  • Loading branch information
wejoncy committed Oct 11, 2024
1 parent 6ada97c commit 35c9d88
Show file tree
Hide file tree
Showing 11 changed files with 1,224 additions and 0 deletions.
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuantNbitsGemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DequantizeAndUnpackWeight);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
Expand Down Expand Up @@ -302,6 +304,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuantNbitsGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DequantizeAndUnpackWeight)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND)>,
#if !defined(DISABLE_SPARSE_TENSORS)
Expand Down
75 changes: 75 additions & 0 deletions onnxruntime/contrib_ops/cpu/dequant_weight_unpack.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <cstdint>
#include <cstdio>
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensor.h"
#include "core/framework/tensorprotoutils.h"


namespace onnxruntime {
namespace contrib {

class DequantizeAndUnpackWeight final : public OpKernel {
public:
explicit DequantizeAndUnpackWeight(const OpKernelInfo& info) : OpKernel{info} {
ORT_ENFORCE(info.GetAttr<int64_t>("bits", &bits_).IsOK());
ORT_ENFORCE(info.GetAttr<int64_t>("groupsize", &groupsize_).IsOK());
in_features_ = info.GetAttrOrDefault<int64_t>("in_features", -1);

ORT_ENFORCE(bits_ > 1 && bits_ < 9, "bits must be in range [2, 8]");
if (bits_ != 2 && bits_ != 4 && bits_ != 8 && in_features_ == -1) {
ORT_THROW("in_features must be specified for bits other than 2, 4, 8");
}
if (in_features_ == -1) {
const auto& node{Node()};
const auto& input_defs = node.InputDefs();
const NodeArg& X = *input_defs[0];
auto X_shape = utils::GetTensorShapeFromTensorShapeProto(*X.Shape());
in_features_ = X_shape[0] * (32 / bits_);
}
}

Status Compute(OpKernelContext* context) const override;

private:
template <typename T>
struct ComputeImpl;

int64_t bits_;
int64_t groupsize_;
int64_t in_features_;
};

ONNX_OPERATOR_KERNEL_EX(
DequantizeAndUnpackWeight,
kMSDomain,
1,
kCpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<uint32_t, int32_t>()),
DequantizeAndUnpackWeight);

void DequantNbitWeight(OpKernelContext* ctx, const Tensor* input_weight, Tensor* output, const Tensor* input_zeros,
const Tensor* input_scale, const int64_t bits_, const int64_t compress_ratio,
const int64_t groupsize_);

Status DequantizeAndUnpackWeight::Compute(OpKernelContext* ctx) const {
const auto* input_weight = ctx->Input<Tensor>(0);
const auto* input_scale = ctx->Input<Tensor>(1);
const auto* input_zeros = ctx->Input<Tensor>(2);
// const auto* input_gidx = ctx->Input<Tensor>(5);
const auto& qweight_shape = input_weight->Shape();
const int64_t compress_ratio = sizeof(int32_t)*8 / bits_;
TensorShape output_shape = qweight_shape;
output_shape[0] = output_shape[0] * compress_ratio;
auto* output = ctx->Output(0, output_shape);
DequantNbitWeight(ctx, input_weight, output, input_zeros, input_scale, bits_, compress_ratio, groupsize_);

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
42 changes: 42 additions & 0 deletions onnxruntime/contrib_ops/cpu/dequant_weight_unpack_impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "core/framework/float16.h"
#include "core/platform/threadpool.h"
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensor.h"

namespace onnxruntime {
namespace contrib {

void DequantNbitWeight(OpKernelContext* ctx, const Tensor* input_weight, Tensor* output, const Tensor* input_zeros,
const Tensor* input_scale, const int64_t bits_, const int64_t compress_ratio,
const int64_t groupsize_) {
if(ctx)return;
const auto& qweight_shape = input_weight->Shape();
const uint32_t* u32_in = reinterpret_cast<const uint32_t*>(input_weight->Data<int32_t>());
float* f32_out = output->MutableData<float>();
const uint32_t* u32_zeros = reinterpret_cast<const uint32_t*>(input_zeros->Data<int32_t>());
const MLFloat16* f16_scale = input_scale->Data<MLFloat16>();

int64_t task_count = qweight_shape[0];
// for (int64_t mi = 0; mi < qweight_shape[0]; mi++) {
concurrency::ThreadPool::TryBatchParallelFor(
ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
int64_t mi = task_idx;
for (int64_t ki = 0; ki < qweight_shape[1]; ki++) {
uint32_t u32_weight = u32_in[mi * qweight_shape[1] + ki];
uint32_t u32_zero = u32_zeros[mi / groupsize_ * qweight_shape[1] / compress_ratio + ki / compress_ratio];
uint8_t u8_zero = (u32_zero >> (ki / compress_ratio)) & 0xF;
float f32_scale_val = (f16_scale[mi / groupsize_ * qweight_shape[1] + ki]).ToFloat();
float scale_zero = f32_scale_val * (u8_zero);
for (int64_t w_idx = 0; w_idx < compress_ratio; w_idx++) {
f32_out[(mi + w_idx) * qweight_shape[1] + ki] = (u32_weight & 0xF) * (f32_scale_val)-scale_zero;
u32_weight = u32_weight >> bits_;
}
}
},
0);
}

} // namespace contrib
} // namespace onnxruntime
72 changes: 72 additions & 0 deletions onnxruntime/contrib_ops/cpu/quant_nbit_gemm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <cstdio>
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensor.h"

namespace onnxruntime {
namespace contrib {

class QuantNbitsGemm final : public OpKernel {
public:
explicit QuantNbitsGemm(const OpKernelInfo& info) : OpKernel{info} {
//ORT_ENFORCE(info.GetAttr("outfeatures", &outfeatures_).IsOK());
//ORT_ENFORCE(info.GetAttr("infeatures", &in_features_).IsOK());
bits_ = info.GetAttrOrDefault<int64_t>("bits", 3);
groupsize_ = info.GetAttrOrDefault<int64_t>("groupsize", 128);
}

Status Compute(OpKernelContext* context) const override;

private:

template <typename T>
struct ComputeImpl;

int64_t outfeatures_;
int64_t in_features_;
int64_t bits_;
int64_t groupsize_;
};

ONNX_OPERATOR_KERNEL_EX(
QuantNbitsGemm,
kMSDomain,
1,
kCpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<float, MLFloat16>()),
QuantNbitsGemm);


Status QuantNbitsGemm::Compute(OpKernelContext* ctx) const {
const auto* input_x = ctx->Input<Tensor>(0);
const auto* input_weight = ctx->Input<Tensor>(1);
//const auto* input_scale = ctx->Input<Tensor>(2);
const auto* input_zeros = ctx->Input<Tensor>(3);
//const auto* input_bias = ctx->Input<Tensor>(4);
//const auto* input_gidx = ctx->Input<Tensor>(5);
const auto& input_shape = input_x->Shape();
const auto& weight_shape = input_weight->Shape();
TensorShapeVector output_shape = input_shape.AsShapeVector();
output_shape[output_shape.size() - 1] = weight_shape[1];
auto* output = ctx->Output(0, output_shape);
auto batch = input_shape[0] * (input_shape.NumDimensions() > 2 ? input_shape[1] : 1);
//int64_t in_features = input_shape[input_shape.NumDimensions() - 1];
input_x->Data<MLFloat16>();
//auto *outp=output->Data<MLFloat16>();
//input_scale->Data<MLFloat16>();
printf("%zu,%zu\n", batch, output->Shape()[1]);

size_t sz = weight_shape[0] * weight_shape[1]*2;
std::vector<int32_t> buf(sz);
printf("%d...%d,", input_weight->Data<int32_t>()[0], input_zeros->Data<int32_t>()[0]);


return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulNBits);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MatMulBnb4);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulBnb4);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulBnb4);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, QuantNbitsGemm);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, DequantizeAndUnpackWeight);
class CUDA_MS_OP_CLASS_NAME(1, Trilu);
class CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor);
class CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping);
Expand Down Expand Up @@ -348,6 +350,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MatMulBnb4)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulBnb4)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulBnb4)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QuantNbitsGemm)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DequantizeAndUnpackWeight)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasSoftmax)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasDropout)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BitmaskDropout)>,
Expand Down
79 changes: 79 additions & 0 deletions onnxruntime/contrib_ops/cuda/dequant_weight_unpack.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_kernel.h"
#include "core/framework/tensorprotoutils.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

class DequantizeAndUnpackWeight final : public ::onnxruntime::cuda::CudaKernel {
public:
explicit DequantizeAndUnpackWeight(const OpKernelInfo& info) : CudaKernel{info} {
ORT_ENFORCE(info.GetAttr<int64_t>("bits", &bits_).IsOK());
ORT_ENFORCE(info.GetAttr<int64_t>("groupsize", &group_size_).IsOK());
in_features_ = info.GetAttrOrDefault<int64_t>("in_features", -1);

ORT_ENFORCE(bits_ > 1 && bits_ < 9, "bits must be in range [2, 8]");
if (bits_ != 2 && bits_ != 4 && bits_ != 8 && in_features_ == -1) {
ORT_THROW("in_features must be specified for bits other than 2, 4, 8");
}
if (in_features_ == -1) {
const auto& node{Node()};
const auto& input_defs = node.InputDefs();
const NodeArg& X = *input_defs[0];
in_features_ = X.Shape()->dim(0).dim_value() * (32 / bits_);
}
}

Status ComputeInternal(OpKernelContext* context) const override;

private:
using Base = CudaKernel;
int64_t bits_;
int64_t group_size_;
int64_t in_features_;
};

ONNX_OPERATOR_KERNEL_EX(
DequantizeAndUnpackWeight,
kMSDomain,
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<int32_t>())
.TypeConstraint("T1", BuildKernelDefConstraints<MLFloat16>()),
DequantizeAndUnpackWeight);

void DequantWeightNbit(
cudaStream_t stream,
const int32_t* qweight_i32,
const void* scales_data,
const int32_t* zeros_data,
void* weight_out,
uint32_t MATRIX_K,
uint32_t MATRIX_N,
uint32_t bits,
uint32_t groupsize);

Status DequantizeAndUnpackWeight::ComputeInternal(OpKernelContext* ctx) const {
const auto* qweight = ctx->Input<Tensor>(0);
const auto* input_scale = ctx->Input<Tensor>(1);
const auto* input_zeros = ctx->Input<Tensor>(2);

auto output_shape = qweight->Shape();
output_shape[0] = in_features_;

auto* output = ctx->Output(0, output_shape);
DequantWeightNbit(Stream(ctx), qweight->Data<int32_t>(),
input_scale->Data<MLFloat16>(),
input_zeros->Data<int32_t>(),
output->MutableData<MLFloat16>(),
in_features_, output_shape[1], bits_, group_size_);
return Status::OK();
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit 35c9d88

Please sign in to comment.