Skip to content

Commit

Permalink
fuse_min_max_observer_and_matmul_quant
Browse files Browse the repository at this point in the history
  • Loading branch information
clackhan committed Sep 22, 2023
1 parent c1c859b commit 4ba3cc5
Show file tree
Hide file tree
Showing 15 changed files with 488 additions and 62 deletions.
10 changes: 9 additions & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,11 @@
]
bind_python: True

- name: "redistribute"
signature:
'Tensor (Tensor in) => Redistribute'
bind_python: True

- name: "matmul_quant"
signature:
[
Expand All @@ -1089,7 +1094,10 @@
'TensorTuple (TensorTuple as, TensorTuple bs, TensorTuple in_zero_points, TensorTuple in_scales, TensorTuple weight_scales,
TensorTuple weight_accs, TensorTuple biases,
Bool transpose_a=False, Bool transpose_b=False,
Double alpha=1.0, DataType output_dtype=None) => GroupedMatmulQuant'
Double alpha=1.0, DataType output_dtype=None) => GroupedMatmulQuant',
'TensorTuple (TensorTuple as, TensorTuple bs, TensorTuple in_zero_points, TensorTuple in_scales, TensorTuple weight_scales,
TensorTuple weight_accs, TensorTuple biases, TensorTuple add_to_outputs,
Bool transpose_a=False, Bool transpose_b=False, Double alpha=1.0, DataType output_dtype=None) => GroupedMatmulQuant'
]
bind_python: True

Expand Down
73 changes: 72 additions & 1 deletion oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,23 @@ class BatchMatMulFunctor {
std::shared_ptr<OpExpr> batch_matmul_op_;
};

class RedistributeFunctor {
public:
RedistributeFunctor() {
redistribute_op_ = CHECK_JUST(one::OpBuilder("redistribute").Input("in").Output("out").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in) const {
const auto& in_size = in->shape();
const int n = in_size->At(0);
const int k = in_size->At(1);
CHECK_EQ_OR_RETURN(k % 16, 0);
return OpInterpUtil::Dispatch<Tensor>(*redistribute_op_, {in});
}

private:
std::shared_ptr<OpExpr> redistribute_op_;
};

class MatMulQuantFunctor {
public:
MatMulQuantFunctor() {
Expand Down Expand Up @@ -808,6 +825,58 @@ class GroupedMatMulBiasQuantWithFilterScaleFunctor {
std::vector<std::shared_ptr<OpExpr>> grouped_matmul_bias_quant_with_filter_bias_op_;
};

class GroupedMatMulBiasQuantWithFilterScaleResidualFunctor {
public:
GroupedMatMulBiasQuantWithFilterScaleResidualFunctor() {
grouped_matmul_bias_quant_with_filter_bias_op_.resize(kMaxInputCount);
for (int n = 1; n < kMaxInputCount; ++n) {
grouped_matmul_bias_quant_with_filter_bias_op_[n] =
CHECK_JUST(one::OpBuilder("grouped_matmul_quant")
.Input("as", n)
.Input("bs", n)
.Input("in_zero_points", n)
.Input("in_scales", n)
.Input("weight_scales", n)
.Input("weight_accs", n)
.Input("biases", n)
.Input("_add_to_outputs", n)
.Output("outputs", n)
.Build());
}
}
Maybe<TensorTuple> operator()(const TensorTuple& as, const TensorTuple& bs,
const TensorTuple& in_zero_points, const TensorTuple& in_scales,
const TensorTuple& weight_scales, const TensorTuple& weight_accs,
const TensorTuple& biases, const TensorTuple& add_to_outputs,
const bool& transpose_a, const bool& transpose_b,
const double& alpha,
const Optional<Symbol<DType>>& output_dtype) const {
CHECK_OR_RETURN(!transpose_a)
<< "the first input should not be transposed for quantized matmul.";
CHECK_OR_RETURN(transpose_b) << "the second input should be transposed for quantized matmul.";
CHECK_EQ_OR_RETURN(alpha, 1) << "alpha should be 1 for quantized matmul.";
auto& attrs =
THREAD_CACHED_MUTABLE_ATTR_MAP("transpose_a", "transpose_b", "alpha", "out_dtype");
attrs.SetAllAttrs(transpose_a, transpose_b, alpha,
output_dtype.value_or(DType::Float())->data_type());
int input_size = as.size();
TensorTuple input(8 * input_size);
std::copy(as.begin(), as.end(), input.begin() + 0 * input_size);
std::copy(bs.begin(), bs.end(), input.begin() + 1 * input_size);
std::copy(in_zero_points.begin(), in_zero_points.end(), input.begin() + 2 * input_size);
std::copy(in_scales.begin(), in_scales.end(), input.begin() + 3 * input_size);
std::copy(weight_scales.begin(), weight_scales.end(), input.begin() + 4 * input_size);
std::copy(weight_accs.begin(), weight_accs.end(), input.begin() + 5 * input_size);
std::copy(biases.begin(), biases.end(), input.begin() + 6 * input_size);
std::copy(add_to_outputs.begin(), add_to_outputs.end(), input.begin() + 7 * input_size);
return OpInterpUtil::Dispatch<TensorTuple>(
*grouped_matmul_bias_quant_with_filter_bias_op_[input_size], input, attrs);
}

private:
std::vector<std::shared_ptr<OpExpr>> grouped_matmul_bias_quant_with_filter_bias_op_;
};

class VectorMatrixProductFunctor {
public:
VectorMatrixProductFunctor() {
Expand Down Expand Up @@ -5805,7 +5874,9 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::MatMulFunctor>("MatMul");
m.add_functor<impl::MatMulQuantFunctor, impl::MatMulQuantWithFilterScaleFunctor>("MatmulQuant");
m.add_functor<impl::GroupedMatMulQuantFunctor, impl::GroupedMatMulQuantWithFilterScaleFunctor,
impl::GroupedMatMulBiasQuantWithFilterScaleFunctor>("GroupedMatmulQuant");
impl::GroupedMatMulBiasQuantWithFilterScaleFunctor,
impl::GroupedMatMulBiasQuantWithFilterScaleResidualFunctor>("GroupedMatmulQuant");
m.add_functor<impl::RedistributeFunctor>("Redistribute");
m.add_functor<impl::MatMulNoBroadCastFunctor>("MatMulNoBroadCast");
m.add_functor<impl::BatchMatMulFunctor>("BatchMatMul");
m.add_functor<impl::MatrixVectorProductFunctor>("MatrixVectorProduct");
Expand Down
15 changes: 14 additions & 1 deletion oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5375,6 +5375,19 @@ def OneFlow_ErfcGradOp : OneFlow_BaseOp<"erfc_grad", [NoMemoryEffect, DeclareOpI
let has_data_type_infer_fn = 1;
}

def OneFlow_RedistributeOp : OneFlow_BaseOp<"redistribute", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in
);
let output = (outs
OneFlow_Tensor:$out
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_MatmulQuantOp : OneFlow_BaseOp<"matmul_quant", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$a,
Expand Down Expand Up @@ -8310,7 +8323,7 @@ def OneFlow_FakeQuantizationOp : OneFlow_BaseOp<"fake_quantization", [NoMemoryEf
let has_input_arg_modify_fn = 1;
}

def OneFlow_MinMaxObserverOp : OneFlow_BaseOp<"min_max_observer", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
def OneFlow_MinMaxObserverOp : OneFlow_BaseOp<"min_max_observer", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in
);
Expand Down
31 changes: 31 additions & 0 deletions oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,34 @@ Pattern {
replace dynamic_quantization with (quantization.0, fused_layer_norm_min_max_observer.1, fused_layer_norm_min_max_observer.2);
};
}

Pattern {
let center: Attr;
let scale: Attr;
let begin_norm_axis: Attr;
let begin_params_axis: Attr;
let epsilon: Attr;
let quantization_formula: Attr;
let quantization_bit: Attr;
let quantization_scheme: Attr;
let per_layer_quantization: Attr;

let layer_norm = op<oneflow.layer_norm>(x: Value, beta: Value, gamma: Value)
{center = center, scale = scale, begin_norm_axis = begin_norm_axis, begin_params_axis = begin_params_axis, epsilon = epsilon}
-> (y: Type, mean: Type, inv_variance: Type);
let min_max_observer = op<oneflow.min_max_observer>(layer_norm.0)
{quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme,
per_layer_quantization = per_layer_quantization} -> (in_scale: Type, in_zero_point: Type);

rewrite min_max_observer with {
let fused_layer_norm_min_max_observer = op<oneflow.fused_layer_norm_min_max_observer>(x, beta, gamma)
{center = center, scale = scale, begin_norm_axis = begin_norm_axis, begin_params_axis = begin_params_axis, epsilon = epsilon,
quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme,
per_layer_quantization = per_layer_quantization,
operand_segment_sizes = attr<"array<i32: 1, 1, 1>">} -> (y, in_scale, in_zero_point);

CopyUserOpAttrs(layer_norm, fused_layer_norm_min_max_observer);

replace min_max_observer with (fused_layer_norm_min_max_observer.1, fused_layer_norm_min_max_observer.2);
};
}
21 changes: 21 additions & 0 deletions oneflow/ir/lib/OneFlow/Transform/AutoNHWCOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,27 @@ llvm::SmallVector<Value, 4> DynamicQuantizationOp::NchwToNhwc(llvm::SmallVector<
return results;
}

bool MinMaxObserverOp::IsNCHW() { return false; }

llvm::DenseSet<Value> MinMaxObserverOp::OperandsToTranspose() { return {this->getIn()}; }

llvm::DenseSet<Value> MinMaxObserverOp::ResultsToTranspose() { return {}; }

llvm::SmallVector<Value, 4> MinMaxObserverOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,
PatternRewriter& rewriter) {
auto min_max_observer_op = *this;
SmallVector<Value, 4> operands{value[0]};
auto res = rewriter
.create<oneflow::MinMaxObserverOp>(min_max_observer_op.getLoc(),
getNHWCResultTypes(min_max_observer_op),
operands, min_max_observer_op->getAttrs())
->getResults();
llvm::SmallVector<Value, 4> results;
results.push_back(res[0]);
results.push_back(res[1]);
return results;
}

} // namespace oneflow

} // namespace mlir
27 changes: 16 additions & 11 deletions oneflow/user/kernels/conv_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ class Conv2dQuantKernel final : public user_op::OpKernel, public user_op::CudaGr
cutlass::library::NumericTypeID::kS8, cutlass::library::LayoutTypeID::kTensorNHWC,
cutlass::library::NumericTypeID::kS32, cutlass::library::LayoutTypeID::kTensorNHWC,
cutlass::library::NumericTypeID::kS32, cutlass::library::NumericTypeID::kS32);
if (in->data_type() == DataType::kFloat16) {
key.element_A = cutlass::library::NumericTypeID::kF16;
}
if (out->data_type() == DataType::kFloat) {
key.element_C = cutlass::library::NumericTypeID::kF32;
key.element_compute = cutlass::library::NumericTypeID::kF32;
Expand All @@ -173,17 +176,19 @@ class Conv2dQuantKernel final : public user_op::OpKernel, public user_op::CudaGr
}
};

REGISTER_USER_KERNEL("conv2d_quant")
.SetCreateFn<Conv2dQuantKernel>()
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)
&& (user_op::HobAttr<std::string>("data_format") == "channels_last")
&& (user_op::HobAttr<int32_t>("groups") == 1)
&& (user_op::HobDataType("in", 0) == DataType::kInt8))
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {
// use static workspace size
return 128 * 1024 * 1024;
})
.SetPriority(user_op::kKernelPriorityOptimized);
#define REGISTER_CONV_2D_QUANT_KERNEL(data_type) \
REGISTER_USER_KERNEL("conv2d_quant") \
.SetCreateFn<Conv2dQuantKernel>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobAttr<std::string>("data_format") == "channels_last") \
&& (user_op::HobAttr<int32_t>("groups") == 1) \
&& (user_op::HobDataType("in", 0) == data_type) \
&& (user_op::HobDataType("weight", 0) == DataType::kInt8)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { return 128 * 1024 * 1024; }) \
.SetPriority(user_op::kKernelPriorityOptimized);

REGISTER_CONV_2D_QUANT_KERNEL(DataType::kInt8)
REGISTER_CONV_2D_QUANT_KERNEL(DataType::kFloat16)

} // namespace

Expand Down
6 changes: 6 additions & 0 deletions oneflow/user/kernels/fused_glu_quant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ class GpuFusedGluQuantKernel final : public user_op::OpKernel, public user_op::C
cutlass::library::NumericTypeID::kS32, // element_D
cutlass::library::LayoutTypeID::kRowMajor // layout_D
);
if (input_x->data_type() == DataType::kFloat16) {
key.element_A = cutlass::library::NumericTypeID::kF16;
}
if (data_type == DataType::kFloat) {
key.element_scalar = cutlass::library::NumericTypeID::kF32;
key.element_C = cutlass::library::NumericTypeID::kF32;
Expand Down Expand Up @@ -354,6 +357,9 @@ class GpuFusedGluQuantKernel final : public user_op::OpKernel, public user_op::C
REGISTER_GPU_FUSED_GLU_QUANT_KERNEL(int8_t, float);
REGISTER_GPU_FUSED_GLU_QUANT_KERNEL(int8_t, half);
REGISTER_GPU_FUSED_GLU_QUANT_KERNEL(half, float);
REGISTER_GPU_FUSED_GLU_QUANT_KERNEL(half, half);
} // namespace oneflow
#endif // CUDA_VERSION >= 11020
Expand Down
29 changes: 18 additions & 11 deletions oneflow/user/kernels/grouped_matmul_quant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ class GroupedMatmulQuantKernel final : public user_op::OpKernel, public user_op:
cutlass::library::NumericTypeID::kS32, // element_D
cutlass::library::LayoutTypeID::kRowMajor // layout_D
);
const user_op::Tensor* a = ctx->Tensor4ArgNameAndIndex("as", 0);
if (a->data_type() == DataType::kFloat16) {
key.element_A = cutlass::library::NumericTypeID::kF16;
}

if (GetDataType<OutType>::value == DataType::kFloat) {
key.element_scalar = cutlass::library::NumericTypeID::kF32;
Expand Down Expand Up @@ -279,19 +283,22 @@ class GroupedMatmulQuantKernel final : public user_op::OpKernel, public user_op:
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(out_cpp_type, out_data_type) \
REGISTER_USER_KERNEL("grouped_matmul_quant") \
.SetCreateFn<GroupedMatmulQuantKernel<out_cpp_type>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("as", 0) == DataType::kInt8) \
&& (user_op::HobDataType("bs", 0) == DataType::kInt8) \
&& (user_op::HobDataType("outputs", 0) == out_data_type)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \
return kMaxProblemBatch * 10 * sizeof(void*) + 3 * 1024 * 1024; \
#define REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(a_data_type, out_cpp_type, out_data_type) \
REGISTER_USER_KERNEL("grouped_matmul_quant") \
.SetCreateFn<GroupedMatmulQuantKernel<out_cpp_type>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("as", 0) == a_data_type) \
&& (user_op::HobDataType("bs", 0) == DataType::kInt8) \
&& (user_op::HobDataType("outputs", 0) == out_data_type)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \
return kMaxProblemBatch * 10 * sizeof(void*) + 3 * 1024 * 1024; \
});

REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(half, DataType::kFloat16)
REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(float, DataType::kFloat)
REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(DataType::kInt8, half, DataType::kFloat16)
REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(DataType::kInt8, float, DataType::kFloat)

REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(DataType::kFloat16, half, DataType::kFloat16)
REGISTER_GROUPED_MATMUL_BIAS_KERNEL_GPU(DataType::kFloat16, float, DataType::kFloat)

} // namespace

Expand Down
25 changes: 15 additions & 10 deletions oneflow/user/kernels/matmul_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ class MatmulQuantKernel final : public user_op::OpKernel {
cutlass::library::LayoutTypeID::kRowMajor // layout_D
);

if (a->data_type() == DataType::kFloat16) {
key.element_A = cutlass::library::NumericTypeID::kF16;
}

if (out->data_type() == DataType::kFloat) {
key.element_scalar = cutlass::library::NumericTypeID::kF32;
key.element_C = cutlass::library::NumericTypeID::kF32;
Expand All @@ -170,16 +174,17 @@ class MatmulQuantKernel final : public user_op::OpKernel {
}
};

REGISTER_USER_KERNEL("matmul_quant")
.SetCreateFn<MatmulQuantKernel>()
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA)
&& (user_op::HobDataType("a", 0) == DataType::kInt8)
&& (user_op::HobDataType("b", 0) == DataType::kInt8))
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t {
// use static workspace size
return 128 * 1024 * 1024;
})
.SetPriority(user_op::kKernelPriorityOptimized);
#define REGISTER_MATMUL_QUANT_KERNEL(data_type) \
REGISTER_USER_KERNEL("matmul_quant") \
.SetCreateFn<MatmulQuantKernel>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("a", 0) == data_type) \
&& (user_op::HobDataType("b", 0) == DataType::kInt8)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { return 128 * 1024 * 1024; }) \
.SetPriority(user_op::kKernelPriorityOptimized);

REGISTER_MATMUL_QUANT_KERNEL(DataType::kInt8)
REGISTER_MATMUL_QUANT_KERNEL(DataType::kFloat16)

} // namespace oneflow

Expand Down
Loading

0 comments on commit 4ba3cc5

Please sign in to comment.