-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
qgemm_4bit cpu support, slow! in_features x bit support fix windows minor fix, weight index rebase conflict
- Loading branch information
Showing
11 changed files
with
1,224 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include <cstdint> | ||
#include <cstdio> | ||
#include "core/common/common.h" | ||
#include "core/framework/op_kernel.h" | ||
#include "core/framework/tensor.h" | ||
#include "core/framework/tensorprotoutils.h" | ||
|
||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
class DequantizeAndUnpackWeight final : public OpKernel { | ||
public: | ||
explicit DequantizeAndUnpackWeight(const OpKernelInfo& info) : OpKernel{info} { | ||
ORT_ENFORCE(info.GetAttr<int64_t>("bits", &bits_).IsOK()); | ||
ORT_ENFORCE(info.GetAttr<int64_t>("groupsize", &groupsize_).IsOK()); | ||
in_features_ = info.GetAttrOrDefault<int64_t>("in_features", -1); | ||
|
||
ORT_ENFORCE(bits_ > 1 && bits_ < 9, "bits must be in range [2, 8]"); | ||
if (bits_ != 2 && bits_ != 4 && bits_ != 8 && in_features_ == -1) { | ||
ORT_THROW("in_features must be specified for bits other than 2, 4, 8"); | ||
} | ||
if (in_features_ == -1) { | ||
const auto& node{Node()}; | ||
const auto& input_defs = node.InputDefs(); | ||
const NodeArg& X = *input_defs[0]; | ||
auto X_shape = utils::GetTensorShapeFromTensorShapeProto(*X.Shape()); | ||
in_features_ = X_shape[0] * (32 / bits_); | ||
} | ||
} | ||
|
||
Status Compute(OpKernelContext* context) const override; | ||
|
||
private: | ||
template <typename T> | ||
struct ComputeImpl; | ||
|
||
int64_t bits_; | ||
int64_t groupsize_; | ||
int64_t in_features_; | ||
}; | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
DequantizeAndUnpackWeight, | ||
kMSDomain, | ||
1, | ||
kCpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", BuildKernelDefConstraints<uint32_t, int32_t>()), | ||
DequantizeAndUnpackWeight); | ||
|
||
void DequantNbitWeight(OpKernelContext* ctx, const Tensor* input_weight, Tensor* output, const Tensor* input_zeros, | ||
const Tensor* input_scale, const int64_t bits_, const int64_t compress_ratio, | ||
const int64_t groupsize_); | ||
|
||
Status DequantizeAndUnpackWeight::Compute(OpKernelContext* ctx) const { | ||
const auto* input_weight = ctx->Input<Tensor>(0); | ||
const auto* input_scale = ctx->Input<Tensor>(1); | ||
const auto* input_zeros = ctx->Input<Tensor>(2); | ||
// const auto* input_gidx = ctx->Input<Tensor>(5); | ||
const auto& qweight_shape = input_weight->Shape(); | ||
const int64_t compress_ratio = sizeof(int32_t)*8 / bits_; | ||
TensorShape output_shape = qweight_shape; | ||
output_shape[0] = output_shape[0] * compress_ratio; | ||
auto* output = ctx->Output(0, output_shape); | ||
DequantNbitWeight(ctx, input_weight, output, input_zeros, input_scale, bits_, compress_ratio, groupsize_); | ||
|
||
return Status::OK(); | ||
} | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#include "core/framework/float16.h" | ||
#include "core/platform/threadpool.h" | ||
#include "core/common/common.h" | ||
#include "core/framework/op_kernel.h" | ||
#include "core/framework/tensor.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
void DequantNbitWeight(OpKernelContext* ctx, const Tensor* input_weight, Tensor* output, const Tensor* input_zeros, | ||
const Tensor* input_scale, const int64_t bits_, const int64_t compress_ratio, | ||
const int64_t groupsize_) { | ||
if(ctx)return; | ||
const auto& qweight_shape = input_weight->Shape(); | ||
const uint32_t* u32_in = reinterpret_cast<const uint32_t*>(input_weight->Data<int32_t>()); | ||
float* f32_out = output->MutableData<float>(); | ||
const uint32_t* u32_zeros = reinterpret_cast<const uint32_t*>(input_zeros->Data<int32_t>()); | ||
const MLFloat16* f16_scale = input_scale->Data<MLFloat16>(); | ||
|
||
int64_t task_count = qweight_shape[0]; | ||
// for (int64_t mi = 0; mi < qweight_shape[0]; mi++) { | ||
concurrency::ThreadPool::TryBatchParallelFor( | ||
ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count), | ||
[&](ptrdiff_t task_idx) { | ||
int64_t mi = task_idx; | ||
for (int64_t ki = 0; ki < qweight_shape[1]; ki++) { | ||
uint32_t u32_weight = u32_in[mi * qweight_shape[1] + ki]; | ||
uint32_t u32_zero = u32_zeros[mi / groupsize_ * qweight_shape[1] / compress_ratio + ki / compress_ratio]; | ||
uint8_t u8_zero = (u32_zero >> (ki / compress_ratio)) & 0xF; | ||
float f32_scale_val = (f16_scale[mi / groupsize_ * qweight_shape[1] + ki]).ToFloat(); | ||
float scale_zero = f32_scale_val * (u8_zero); | ||
for (int64_t w_idx = 0; w_idx < compress_ratio; w_idx++) { | ||
f32_out[(mi + w_idx) * qweight_shape[1] + ki] = (u32_weight & 0xF) * (f32_scale_val)-scale_zero; | ||
u32_weight = u32_weight >> bits_; | ||
} | ||
} | ||
}, | ||
0); | ||
} | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include <cstdio> | ||
#include "core/common/common.h" | ||
#include "core/framework/op_kernel.h" | ||
#include "core/framework/tensor.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
class QuantNbitsGemm final : public OpKernel { | ||
public: | ||
explicit QuantNbitsGemm(const OpKernelInfo& info) : OpKernel{info} { | ||
//ORT_ENFORCE(info.GetAttr("outfeatures", &outfeatures_).IsOK()); | ||
//ORT_ENFORCE(info.GetAttr("infeatures", &in_features_).IsOK()); | ||
bits_ = info.GetAttrOrDefault<int64_t>("bits", 3); | ||
groupsize_ = info.GetAttrOrDefault<int64_t>("groupsize", 128); | ||
} | ||
|
||
Status Compute(OpKernelContext* context) const override; | ||
|
||
private: | ||
|
||
template <typename T> | ||
struct ComputeImpl; | ||
|
||
int64_t outfeatures_; | ||
int64_t in_features_; | ||
int64_t bits_; | ||
int64_t groupsize_; | ||
}; | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
QuantNbitsGemm, | ||
kMSDomain, | ||
1, | ||
kCpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", BuildKernelDefConstraints<float, MLFloat16>()), | ||
QuantNbitsGemm); | ||
|
||
|
||
Status QuantNbitsGemm::Compute(OpKernelContext* ctx) const { | ||
const auto* input_x = ctx->Input<Tensor>(0); | ||
const auto* input_weight = ctx->Input<Tensor>(1); | ||
//const auto* input_scale = ctx->Input<Tensor>(2); | ||
const auto* input_zeros = ctx->Input<Tensor>(3); | ||
//const auto* input_bias = ctx->Input<Tensor>(4); | ||
//const auto* input_gidx = ctx->Input<Tensor>(5); | ||
const auto& input_shape = input_x->Shape(); | ||
const auto& weight_shape = input_weight->Shape(); | ||
TensorShapeVector output_shape = input_shape.AsShapeVector(); | ||
output_shape[output_shape.size() - 1] = weight_shape[1]; | ||
auto* output = ctx->Output(0, output_shape); | ||
auto batch = input_shape[0] * (input_shape.NumDimensions() > 2 ? input_shape[1] : 1); | ||
//int64_t in_features = input_shape[input_shape.NumDimensions() - 1]; | ||
input_x->Data<MLFloat16>(); | ||
//auto *outp=output->Data<MLFloat16>(); | ||
//input_scale->Data<MLFloat16>(); | ||
printf("%zu,%zu\n", batch, output->Shape()[1]); | ||
|
||
size_t sz = weight_shape[0] * weight_shape[1]*2; | ||
std::vector<int32_t> buf(sz); | ||
printf("%d...%d,", input_weight->Data<int32_t>()[0], input_zeros->Data<int32_t>()[0]); | ||
|
||
|
||
return Status::OK(); | ||
} | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/cuda/cuda_kernel.h" | ||
#include "core/framework/tensorprotoutils.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace cuda { | ||
|
||
class DequantizeAndUnpackWeight final : public ::onnxruntime::cuda::CudaKernel { | ||
public: | ||
explicit DequantizeAndUnpackWeight(const OpKernelInfo& info) : CudaKernel{info} { | ||
ORT_ENFORCE(info.GetAttr<int64_t>("bits", &bits_).IsOK()); | ||
ORT_ENFORCE(info.GetAttr<int64_t>("groupsize", &group_size_).IsOK()); | ||
in_features_ = info.GetAttrOrDefault<int64_t>("in_features", -1); | ||
|
||
ORT_ENFORCE(bits_ > 1 && bits_ < 9, "bits must be in range [2, 8]"); | ||
if (bits_ != 2 && bits_ != 4 && bits_ != 8 && in_features_ == -1) { | ||
ORT_THROW("in_features must be specified for bits other than 2, 4, 8"); | ||
} | ||
if (in_features_ == -1) { | ||
const auto& node{Node()}; | ||
const auto& input_defs = node.InputDefs(); | ||
const NodeArg& X = *input_defs[0]; | ||
in_features_ = X.Shape()->dim(0).dim_value() * (32 / bits_); | ||
} | ||
} | ||
|
||
Status ComputeInternal(OpKernelContext* context) const override; | ||
|
||
private: | ||
using Base = CudaKernel; | ||
int64_t bits_; | ||
int64_t group_size_; | ||
int64_t in_features_; | ||
}; | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
DequantizeAndUnpackWeight, | ||
kMSDomain, | ||
1, | ||
kCudaExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", BuildKernelDefConstraints<int32_t>()) | ||
.TypeConstraint("T1", BuildKernelDefConstraints<MLFloat16>()), | ||
DequantizeAndUnpackWeight); | ||
|
||
void DequantWeightNbit( | ||
cudaStream_t stream, | ||
const int32_t* qweight_i32, | ||
const void* scales_data, | ||
const int32_t* zeros_data, | ||
void* weight_out, | ||
uint32_t MATRIX_K, | ||
uint32_t MATRIX_N, | ||
uint32_t bits, | ||
uint32_t groupsize); | ||
|
||
Status DequantizeAndUnpackWeight::ComputeInternal(OpKernelContext* ctx) const { | ||
const auto* qweight = ctx->Input<Tensor>(0); | ||
const auto* input_scale = ctx->Input<Tensor>(1); | ||
const auto* input_zeros = ctx->Input<Tensor>(2); | ||
|
||
auto output_shape = qweight->Shape(); | ||
output_shape[0] = in_features_; | ||
|
||
auto* output = ctx->Output(0, output_shape); | ||
DequantWeightNbit(Stream(ctx), qweight->Data<int32_t>(), | ||
input_scale->Data<MLFloat16>(), | ||
input_zeros->Data<int32_t>(), | ||
output->MutableData<MLFloat16>(), | ||
in_features_, output_shape[1], bits_, group_size_); | ||
return Status::OK(); | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
Oops, something went wrong.