Skip to content

Commit

Permalink
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflow
Browse files Browse the repository at this point in the history
…into dev_int8_conv
  • Loading branch information
clackhan committed Sep 8, 2023
2 parents 54b3dd5 + f262013 commit 5c64169
Show file tree
Hide file tree
Showing 9 changed files with 1,284 additions and 92 deletions.
796 changes: 796 additions & 0 deletions oneflow/core/cuda/layer_norm_min_max_observer.cuh

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7059,6 +7059,38 @@ def OneFlow_LayerNormOp : OneFlow_BaseOp<"layer_norm", [NoMemoryEffect, AttrSize
let has_data_type_infer_fn = 1;
}

def OneFlow_FusedLayerNormMinMaxObserverOp : OneFlow_BaseOp<"fused_layer_norm_min_max_observer", [NoMemoryEffect, NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
Optional<OneFlow_Tensor>:$beta,
Optional<OneFlow_Tensor>:$gamma
);
let output = (outs
OneFlow_Tensor:$y,
OneFlow_Tensor:$y_scale,
OneFlow_Tensor:$y_zero_point
);
let attrs = (ins
DefaultValuedAttr<BoolAttr, "false">:$center,
DefaultValuedAttr<BoolAttr, "false">:$scale,
DefaultValuedAttr<SI64Attr, "0">:$begin_norm_axis,
DefaultValuedAttr<SI64Attr, "0">:$begin_params_axis,
DefaultValuedAttr<F64Attr, "0.">:$epsilon,
DefaultValuedAttr<StrAttr, "\"google\"">:$quantization_formula,
DefaultValuedAttr<SI32Attr, "8">:$quantization_bit,
DefaultValuedAttr<StrAttr, "\"symmetric\"">:$quantization_scheme,
DefaultValuedAttr<BoolAttr, "true">:$per_layer_quantization
);
let trait_attrs = (ins
DenseI32ArrayAttr:$operand_segment_sizes
);
let has_check_fn = 1;
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_SkipLayerNormOp : OneFlow_BaseOp<"skip_layer_norm", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
Expand Down
38 changes: 38 additions & 0 deletions oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,41 @@ Pattern {

replace root with input;
}

Pattern {
let center: Attr;
let scale: Attr;
let begin_norm_axis: Attr;
let begin_params_axis: Attr;
let epsilon: Attr;
let quantization_formula: Attr;
let quantization_bit: Attr;
let quantization_scheme: Attr;
let per_layer_quantization: Attr;

let layer_norm = op<oneflow.layer_norm>(x: Value, beta: Value, gamma: Value)
{center = center, scale = scale, begin_norm_axis = begin_norm_axis, begin_params_axis = begin_params_axis, epsilon = epsilon}
-> (y: Type, mean: Type, inv_variance: Type);
let dynamic_quantization = op<oneflow.dynamic_quantization>(layer_norm.0)
{quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme,
per_layer_quantization = per_layer_quantization} -> (out: Type, in_scale: Type, in_zero_point: Type);

rewrite dynamic_quantization with {
let fused_layer_norm_min_max_observer = op<oneflow.fused_layer_norm_min_max_observer>(x, beta, gamma)
{center = center, scale = scale, begin_norm_axis = begin_norm_axis, begin_params_axis = begin_params_axis, epsilon = epsilon,
quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme,
per_layer_quantization = per_layer_quantization,
operand_segment_sizes = attr<"array<i32: 1, 1, 1>">} -> (y, in_scale, in_zero_point);

CopyUserOpAttrs(layer_norm, fused_layer_norm_min_max_observer);

let quantization = op<oneflow.quantization>(fused_layer_norm_min_max_observer.0,
fused_layer_norm_min_max_observer.1,
fused_layer_norm_min_max_observer.2) {
quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme} -> (out);

CopyUserOpAttrs(dynamic_quantization, quantization);

replace dynamic_quantization with (quantization.0, fused_layer_norm_min_max_observer.1, fused_layer_norm_min_max_observer.2);
};
}
3 changes: 0 additions & 3 deletions oneflow/ir/lib/OneFlow/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,6 @@ LogicalResult PruneReduntantQuantization(OpType op, PatternRewriter& rewriter) {
for (auto q : it.second) {
if (q != q0) {
q->replaceAllUsesWith(q0->getResults());
q->erase();
pruned = true;
}
}
Expand All @@ -532,7 +531,6 @@ LogicalResult PruneReduntantQuantization(OpType op, PatternRewriter& rewriter) {
for (auto q : it.second) {
if (q != q0) {
q->replaceAllUsesWith(q0->getResults());
q->erase();
pruned = true;
}
}
Expand Down Expand Up @@ -626,7 +624,6 @@ struct AutoNhwcPattern : public OpInterfaceRewritePattern<NCHWCompatible> {
}
num_transposed_result += 1;
}
op->erase();
}
return success();
}
Expand Down
96 changes: 96 additions & 0 deletions oneflow/user/kernels/dynamic_quantization_gpu_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/cuda/elementwise.cuh"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/kernel_util.cuh"
#include "oneflow/user/kernels/quantization_utils.cuh"

namespace oneflow {

template<typename T>
class GpuDynamicQuantizationKernel final : public user_op::OpKernel {
public:
GpuDynamicQuantizationKernel() = default;
~GpuDynamicQuantizationKernel() = default;

private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);

user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
user_op::Tensor* scale = ctx->Tensor4ArgNameAndIndex("scale", 0);
user_op::Tensor* zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);

const std::string quantization_scheme = ctx->Attr<std::string>("quantization_scheme");
const int32_t quantization_bit = ctx->Attr<int32_t>("quantization_bit");
const bool per_layer_quantization = ctx->Attr<bool>("per_layer_quantization");
const std::string quantization_formula = ctx->Attr<std::string>("quantization_formula");

CHECK(quantization_scheme == "affine");

const int64_t elements = in->shape_view().elem_cnt();

constexpr int pack_size = cuda::elementwise::PackSize<T>();
int64_t pack_num = (elements + pack_size - 1) / pack_size;
int grid_size = 0;
cuda::elementwise::GetNumBlocks(pack_num, &grid_size);
grid_size = grid_size > 2048 ? 2048 : grid_size;

size_t element_bytes = GetSizeOfDataType(GetDataType<T>::value);
CHECK_GE(tmp_buffer->shape_view().elem_cnt(), grid_size * element_bytes * 2);

T* min_max = reinterpret_cast<T*>(tmp_buffer->mut_dptr());
auto stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();
if (per_layer_quantization) {
quantization::ReduceMinMaxPerTensor<pack_size, T>
<<<grid_size, cuda::elementwise::kBlockSize,
cuda::elementwise::kBlockSize * element_bytes * 2, stream>>>(elements, in->dptr<T>(),
min_max);
} else {
UNIMPLEMENTED() << "dynamic_quantization does not support per-channel quantization";
}

if (quantization_formula == "oneflow") {
if (quantization_bit == 8) {
quantization::ApplyDynamicQuantization<T, int8_t>(
stream, grid_size, min_max, elements, in->dptr<T>(), quantization_bit,
out->mut_dptr<int8_t>(), scale->mut_dptr<float>(), zero_point->mut_dptr<int8_t>());
} else {
UNIMPLEMENTED();
}
} else {
UNIMPLEMENTED() << "dynamic_quantization only support oneflow quantization formula";
}
}

bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_DYNAMIC_QUANTIZATION_KERNEL(dtype) \
REGISTER_USER_KERNEL("dynamic_quantization") \
.SetCreateFn<GpuDynamicQuantizationKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("in", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { return 128 * 1024 * 1024; })

REGISTER_DYNAMIC_QUANTIZATION_KERNEL(float);
REGISTER_DYNAMIC_QUANTIZATION_KERNEL(double);
REGISTER_DYNAMIC_QUANTIZATION_KERNEL(half);

} // namespace oneflow
182 changes: 182 additions & 0 deletions oneflow/user/kernels/fused_layer_norm_min_max_observer_gpu_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/cuda/elementwise.cuh"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/kernel_util.cuh"
#include "oneflow/core/cuda/layer_norm.cuh"
#include "oneflow/core/cuda/layer_norm_min_max_observer.cuh"
#include "oneflow/core/ndarray/binary_func.h"
#include "oneflow/core/kernel/util/numeric_limits.cuh"
#include "oneflow/user/kernels/quantization_utils.cuh"

namespace oneflow {

namespace {

template<typename SRC, typename DST, bool do_scale, bool do_center>
struct AffineStore {
AffineStore(DST* y, int64_t row_size, const DST* gamma, const DST* beta)
: y(y), row_size(row_size), gamma(gamma), beta(beta) {}
template<int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
cuda::layer_norm::Pack<DST, N> y_pack;
cuda::layer_norm::Pack<DST, N> gamma_pack;
cuda::layer_norm::Pack<DST, N> beta_pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t gamma_offset = col / N;
if (do_scale) {
gamma_pack.storage =
*(reinterpret_cast<const cuda::layer_norm::PackType<DST, N>*>(gamma) + gamma_offset);
} else {
#pragma unroll
for (int i = 0; i < N; ++i) { gamma_pack.elem[i] = static_cast<DST>(1.f); }
}
if (do_center) {
beta_pack.storage =
*(reinterpret_cast<const cuda::layer_norm::PackType<DST, N>*>(beta) + gamma_offset);
} else {
#pragma unroll
for (int i = 0; i < N; ++i) { beta_pack.elem[i] = static_cast<DST>(0.f); }
}
#pragma unroll
for (int i = 0; i < N; ++i) {
DST normalized_i = static_cast<DST>(src[i]);
if (do_scale || do_center) {
y_pack.elem[i] = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i];
} else {
y_pack.elem[i] = normalized_i;
}
}
*(reinterpret_cast<cuda::layer_norm::PackType<DST, N>*>(y) + offset) = y_pack.storage;
}
DST* y;
int64_t row_size;
const DST* gamma;
const DST* beta;
};

template<typename T, bool do_scale, bool do_center>
void LayerNormMinMaxObserverGpu(ep::Stream* stream, const int64_t num_instances,
const int64_t norm_size, const double epsilon, const T* x_ptr,
const T* gamma_ptr, const T* beta_ptr, T* y_ptr, T* min_max_ptr) {
using ComputeType = typename cuda::layer_norm::DefaultComputeType<T>::type;
cuda::layer_norm::DirectLoad<T, T> load(x_ptr, norm_size);
AffineStore<ComputeType, T, do_scale, do_center> store(y_ptr, norm_size, gamma_ptr, beta_ptr);
cuda::layer_norm::DispatchLayerNormMinMaxObserver<decltype(load), decltype(store), T,
ComputeType>(
stream->As<ep::CudaStream>()->cuda_stream(), load, store, num_instances, norm_size, epsilon,
min_max_ptr);
}

template<typename T>
void DispatchFusedLayerNormMinMaxObserverGpu(ep::Stream* stream, const int64_t num_instances,
const int64_t norm_size, const double epsilon,
const T* x_ptr, const T* gamma_ptr, const T* beta_ptr,
T* y_ptr, T* min_max_ptr) {
if (gamma_ptr != nullptr && beta_ptr != nullptr) {
LayerNormMinMaxObserverGpu<T, true, true>(stream, num_instances, norm_size, epsilon, x_ptr,
gamma_ptr, beta_ptr, y_ptr, min_max_ptr);
} else if (gamma_ptr != nullptr && beta_ptr == nullptr) {
LayerNormMinMaxObserverGpu<T, true, false>(stream, num_instances, norm_size, epsilon, x_ptr,
gamma_ptr, beta_ptr, y_ptr, min_max_ptr);
} else if (gamma_ptr == nullptr && beta_ptr != nullptr) {
LayerNormMinMaxObserverGpu<T, false, true>(stream, num_instances, norm_size, epsilon, x_ptr,
gamma_ptr, beta_ptr, y_ptr, min_max_ptr);
} else {
LayerNormMinMaxObserverGpu<T, false, false>(stream, num_instances, norm_size, epsilon, x_ptr,
gamma_ptr, beta_ptr, y_ptr, min_max_ptr);
}
}

template<typename T>
class GpuFusedLayerNormMinMaxObserverKernel final : public user_op::OpKernel {
public:
GpuFusedLayerNormMinMaxObserverKernel() = default;
~GpuFusedLayerNormMinMaxObserverKernel() = default;

private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
const double epsilon = ctx->Attr<double>("epsilon");
CHECK_GE(epsilon, CUDNN_BN_MIN_EPSILON);

int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
if (begin_norm_axis < 0) { begin_norm_axis += x->shape_view().NumAxes(); }
const int64_t num_instances = x->shape_view().Count(0, begin_norm_axis);
const int64_t norm_size = x->shape_view().elem_cnt() / num_instances;
const T* gamma_ptr = nullptr;
const T* beta_ptr = nullptr;
if (ctx->has_input("gamma", 0)) {
const user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0);
gamma_ptr = gamma->dptr<T>();
CHECK_EQ(gamma->shape_view().elem_cnt(), norm_size);
}
if (ctx->has_input("beta", 0)) { beta_ptr = ctx->Tensor4ArgNameAndIndex("beta", 0)->dptr<T>(); }

size_t element_bytes = GetSizeOfDataType(GetDataType<T>::value);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
CHECK_GE(tmp_buffer->shape_view().elem_cnt(), num_instances * 2 * element_bytes);
T* min_max = reinterpret_cast<T*>(tmp_buffer->mut_dptr());

DispatchFusedLayerNormMinMaxObserverGpu<T>(ctx->stream(), num_instances, norm_size, epsilon,
x->dptr<T>(), gamma_ptr, beta_ptr, y->mut_dptr<T>(),
min_max);

const std::string quantization_scheme = ctx->Attr<std::string>("quantization_scheme");
const int32_t quantization_bit = ctx->Attr<int32_t>("quantization_bit");
const std::string quantization_formula = ctx->Attr<std::string>("quantization_formula");
CHECK(quantization_scheme == "affine");

user_op::Tensor* y_scale = ctx->Tensor4ArgNameAndIndex("y_scale", 0);
user_op::Tensor* y_zero_point = ctx->Tensor4ArgNameAndIndex("y_zero_point", 0);

auto stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();
if (quantization_formula == "oneflow") {
if (quantization_bit == 8) {
int8_t upper_bound = (1 << (quantization_bit - 1)) - 1;
int8_t lower_bound = -upper_bound - 1;
quantization::ComputeScaleAndZeroPointBlock<T, int8_t>
<<<1, cuda::elementwise::kBlockSize, cuda::elementwise::kBlockSize * element_bytes * 2,
stream>>>(num_instances, min_max, upper_bound, lower_bound,
y_scale->mut_dptr<float>(), y_zero_point->mut_dptr<int8_t>());
} else {
UNIMPLEMENTED();
}
} else {
UNIMPLEMENTED() << "only support oneflow quantization formula";
}
}

bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_FUSED_LAYER_NORM_MIN_MAX_OBSERVER_KERNEL(dtype) \
REGISTER_USER_KERNEL("fused_layer_norm_min_max_observer") \
.SetCreateFn<GpuFusedLayerNormMinMaxObserverKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { return 128 * 1024 * 1024; })

REGISTER_FUSED_LAYER_NORM_MIN_MAX_OBSERVER_KERNEL(double);
REGISTER_FUSED_LAYER_NORM_MIN_MAX_OBSERVER_KERNEL(float);
REGISTER_FUSED_LAYER_NORM_MIN_MAX_OBSERVER_KERNEL(half);

} // namespace

} // namespace oneflow
Loading

0 comments on commit 5c64169

Please sign in to comment.