Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fused group norm minmax observer #10332

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3541,3 +3541,7 @@
- name: "fused_clip_grad"
signature: "Tensor (TensorTuple model_diff, Float max_norm, Float norm_type) => FusedClipGrad"
bind_python: True

- name: "fused_group_norm_min_max_observer"
signature: 'TensorTuple[y, y_scale, y_zero_point] (Tensor x, Tensor gamma=None, Tensor beta=None, Bool affine, Int32 num_groups, Double epsilon=1e-5, String data_format="channels_first", String activation="none", String quantization_scheme="affine", Int32 quantization_bit=8, String quantization_formula="oneflow") => FusedGroupNormMinMaxObserver'
bind_python: True
48 changes: 46 additions & 2 deletions oneflow/core/functional/impl/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/common/optional.h"
#include "oneflow/core/functional/impl/binary_functor.h"

#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/mutable_attr_map.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
Expand Down Expand Up @@ -404,6 +403,50 @@ class FusedLinearWithGroupwiseQuantizedWeightFunctor {
std::shared_ptr<OpExpr> asymmetric_without_bias_op_;
};

class FusedGroupNormMinMaxObserverFunctor {
public:
FusedGroupNormMinMaxObserverFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fused_group_norm_min_max_observer")
.Input("x")
.Output("y")
.Output("y_scale")
.Output("y_zero_point")
.Attr("affine", false)
.Build());
affine_op_ = CHECK_JUST(one::OpBuilder("fused_group_norm_min_max_observer")
.Input("x")
.Input("gamma")
.Input("beta")
.Output("y")
.Output("y_scale")
.Output("y_zero_point")
.Attr("affine", true)
.Build());
}

Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& x, const Optional<Tensor>& gamma,
const Optional<Tensor>& beta, bool affine, int32_t num_groups,
double epsilon, const std::string& data_format,
const std::string& activation,
const std::string& quantization_scheme, int32_t quantization_bit,
const std::string& quantization_formula) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP(
"affine", "num_groups", "epsilon", "data_format", "activation", "quantization_formula",
"quantization_bit", "quantization_scheme", "per_layer_quantization");
attrs.SetAllAttrs(affine, num_groups, epsilon, data_format, activation, quantization_formula,
quantization_bit, quantization_scheme, true);
if (affine) {
return OpInterpUtil::Dispatch<TensorTuple>(*affine_op_, {x, JUST(gamma), JUST(beta)}, attrs);
} else {
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {x}, attrs);
}
}

private:
std::shared_ptr<OpExpr> op_;
std::shared_ptr<OpExpr> affine_op_;
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::FakeQuantizationFunctor>("FakeQuantization"); };
Expand All @@ -417,6 +460,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::FusedLinearWithGroupwiseQuantizedWeightFunctor>(
"FusedLinearWithGroupwiseQuantizedWeight");
m.add_functor<impl::DynamicQuantizationFunctor>("DynamicQuantization");
m.add_functor<impl::FusedGroupNormMinMaxObserverFunctor>("FusedGroupNormMinMaxObserver");
};

} // namespace functional
Expand Down
32 changes: 32 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7091,6 +7091,38 @@ def OneFlow_FusedLayerNormMinMaxObserverOp : OneFlow_BaseOp<"fused_layer_norm_mi
let has_data_type_infer_fn = 1;
}

def OneFlow_FusedGroupNormMinMaxObserverOp : OneFlow_BaseOp<"fused_group_norm_min_max_observer", [NoMemoryEffect, NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
Optional<OneFlow_Tensor>:$beta,
Optional<OneFlow_Tensor>:$gamma
);
let output = (outs
OneFlow_Tensor:$y,
OneFlow_Tensor:$y_scale,
OneFlow_Tensor:$y_zero_point
);
let attrs = (ins
DefaultValuedAttr<BoolAttr, "false">:$affine,
DefaultValuedAttr<SI32Attr, "0">:$num_groups,
DefaultValuedAttr<F64Attr, "0.">:$epsilon,
DefaultValuedAttr<StrAttr, "\"channels_first\"">:$data_format,
DefaultValuedAttr<StrAttr, "\"none\"">:$activation,
DefaultValuedAttr<StrAttr, "\"google\"">:$quantization_formula,
DefaultValuedAttr<SI32Attr, "8">:$quantization_bit,
DefaultValuedAttr<StrAttr, "\"symmetric\"">:$quantization_scheme,
DefaultValuedAttr<BoolAttr, "true">:$per_layer_quantization
);
let trait_attrs = (ins
DenseI32ArrayAttr:$operand_segment_sizes
);
let has_check_fn = 1;
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_SkipLayerNormOp : OneFlow_BaseOp<"skip_layer_norm", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
Expand Down
76 changes: 76 additions & 0 deletions oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,79 @@ Pattern {
replace dynamic_quantization with (quantization.0, fused_layer_norm_min_max_observer.1, fused_layer_norm_min_max_observer.2);
};
}

Pattern {
let affine: Attr;
let num_groups: Attr;
let epsilon: Attr;
let data_format: Attr;
let activation: Attr;
let quantization_formula: Attr;
let quantization_bit: Attr;
let quantization_scheme: Attr;
let per_layer_quantization: Attr;

let group_norm = op<oneflow.group_norm>(x: Value, beta: Value, gamma: Value)
{affine = affine, num_groups = num_groups, epsilon = epsilon, data_format = data_format, activation = activation}
-> (y: Type, mean: Type, inv_variance: Type);
let dynamic_quantization = op<oneflow.dynamic_quantization>(group_norm.0)
{quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme,
per_layer_quantization = per_layer_quantization} -> (out: Type, in_scale: Type, in_zero_point: Type);

rewrite dynamic_quantization with {
let fused_group_norm_min_max_observer = op<oneflow.fused_group_norm_min_max_observer>(x, beta, gamma)
{affine = affine, num_groups = num_groups, epsilon = epsilon, data_format = data_format, activation = activation,
quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme,
per_layer_quantization = per_layer_quantization,
operand_segment_sizes = attr<"array<i32: 1, 1, 1>">} -> (y, in_scale, in_zero_point);

CopyUserOpAttrs(group_norm, fused_group_norm_min_max_observer);

let quantization = op<oneflow.quantization>(fused_group_norm_min_max_observer.0,
fused_group_norm_min_max_observer.1,
fused_group_norm_min_max_observer.2) {
quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme} -> (out);

CopyUserOpAttrs(dynamic_quantization, quantization);

replace dynamic_quantization with (quantization.0, fused_group_norm_min_max_observer.1, fused_group_norm_min_max_observer.2);
};
}

Pattern {
let affine: Attr;
let num_groups: Attr;
let epsilon: Attr;
let data_format: Attr;
let activation: Attr;
let quantization_formula: Attr;
let quantization_bit: Attr;
let quantization_scheme: Attr;
let per_layer_quantization: Attr;

let group_norm = op<oneflow.group_norm>(x: Value)
{affine = affine, num_groups = num_groups, epsilon = epsilon, data_format = data_format, activation = activation}
-> (y: Type, mean: Type, inv_variance: Type);
let dynamic_quantization = op<oneflow.dynamic_quantization>(group_norm.0)
{quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme,
per_layer_quantization = per_layer_quantization} -> (out: Type, in_scale: Type, in_zero_point: Type);

rewrite dynamic_quantization with {
let fused_group_norm_min_max_observer = op<oneflow.fused_group_norm_min_max_observer>(x)
{affine = affine, num_groups = num_groups, epsilon = epsilon, data_format = data_format, activation = activation,
quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme,
per_layer_quantization = per_layer_quantization,
operand_segment_sizes = attr<"array<i32: 1, 0, 0>">} -> (y, in_scale, in_zero_point);

CopyUserOpAttrs(group_norm, fused_group_norm_min_max_observer);

let quantization = op<oneflow.quantization>(fused_group_norm_min_max_observer.0,
fused_group_norm_min_max_observer.1,
fused_group_norm_min_max_observer.2) {
quantization_formula = quantization_formula, quantization_bit = quantization_bit, quantization_scheme = quantization_scheme} -> (out);

CopyUserOpAttrs(dynamic_quantization, quantization);

replace dynamic_quantization with (quantization.0, fused_group_norm_min_max_observer.1, fused_group_norm_min_max_observer.2);
};
}
95 changes: 95 additions & 0 deletions oneflow/ir/test/OneFlow/fuse/test_group_norm_quant_fuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# RUN: python3 -m oneflow.test_utils.throttle --with-cuda=%with_cuda python3 %s | FileCheck %s
# CHECK-NOT: oneflow.cast

import os
import unittest
import numpy as np
import random


import oneflow as flow
import oneflow.unittest


def _cast_fuse_gn_dynamic_quant_pass(test_case, affine):
num_channels = 8
inp = flow.randn(4, num_channels, 32, 32).cuda()
# channels_last = bool(random.randint(0, 1))
# if channels_last:
# inp = flow.randn(4, 32, 32, num_channels).cuda()
# else:
# inp = flow.randn(4, num_channels, 32, 32).cuda()
gn = flow.nn.GroupNorm(2, num_channels, affine=affine).cuda()
kwargs = {
"quantization_formula": "oneflow",
"quantization_bit": 8,
"quantization_scheme": "affine",
}

dynamic_quantization = flow._oneflow_internal._C.dynamic_quantization

def fused_gn_dynamic_quant(inp, gamma, beta, affine, num_groups):
(
y,
y_scale,
y_zero_point,
) = flow._oneflow_internal._C.fused_group_norm_min_max_observer(
inp, gamma, beta, affine, num_groups, **kwargs
)
return (
((y / y_scale).round() + y_zero_point).to(flow.int8),
y_scale,
y_zero_point,
)

ref_result = fused_gn_dynamic_quant(
inp, gn.weight, gn.bias, gn.affine, gn.num_groups
)

class FusedGnDynamicQuantPass(flow.nn.Graph):
def __init__(self):
super().__init__()
self.gn = gn
self.dynamic_quant = dynamic_quantization

def build(self, x):
return self.dynamic_quant(self.gn(x), **kwargs)

lazy_b = FusedGnDynamicQuantPass()(inp)
test_case.assertTrue(np.allclose(ref_result[0].numpy(), lazy_b[0].numpy()))
test_case.assertTrue(np.allclose(ref_result[1].numpy(), lazy_b[1].numpy()))
test_case.assertTrue(np.allclose(ref_result[2].numpy(), lazy_b[2].numpy()))


@flow.unittest.skip_unless_1n1d()
class TestFusedGnDynamicQuantPass(flow.unittest.MLIRTestCase):
def setUp(self):
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1"
os.environ["ONEFLOW_MLIR_STDOUT"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_TIMING"] = "1"
os.environ["ONEFLOW_MLIR_PRINT_STATS"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_IR_PRINTING"] = "1"

def test_cast_fuse_gn_dynamic_quant_pass(test_case):
for affine in [False, True]:
_cast_fuse_gn_dynamic_quant_pass(test_case, affine)


if __name__ == "__main__":
unittest.main()
Loading