Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add quantization conv and matmul #10327

Closed
wants to merge 84 commits into from
Closed
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
2cfde9d
change for sd torch graph compile
strint Aug 18, 2023
f04cc42
auto format by CI
oneflow-ci-bot Aug 18, 2023
c30278e
Update python/oneflow/nn/graph/graph.py
strint Aug 18, 2023
b0d72a3
auto format by CI
oneflow-ci-bot Aug 18, 2023
ed27149
restore md
strint Aug 18, 2023
df4a31d
update cutlass and support int8 conv
hjchen2 Aug 21, 2023
cc82f66
tuning int8 conv kernels
hjchen2 Aug 22, 2023
1e3c555
fix
hjchen2 Aug 22, 2023
19770b1
rm incompatible int8 conv implementation
hjchen2 Aug 22, 2023
7bfd55a
Merge branch 'master' into support_sd_torch_compile
hjchen2 Aug 22, 2023
c8110e8
half reduce_min and reduce_max
hjchen2 Aug 22, 2023
dc5db78
Merge remote-tracking branch 'origin/support_sd_torch_compile' into d…
hjchen2 Aug 22, 2023
30435f9
fake dynamo
strint Aug 23, 2023
6cfa893
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Aug 23, 2023
69bbc0e
auto format by CI
oneflow-ci-bot Aug 23, 2023
6158db3
warmup int8 conv algo
hjchen2 Aug 23, 2023
998c501
general argstree
strint Aug 23, 2023
eccb5f1
Merge branch 'fix_dynamo_mock_error' of https://github.com/Oneflow-In…
strint Aug 23, 2023
44a3710
auto format by CI
oneflow-ci-bot Aug 23, 2023
3619406
refine repr of argstree
strint Aug 24, 2023
93d0fe3
Merge branch 'fix_dynamo_mock_error' of https://github.com/Oneflow-In…
strint Aug 24, 2023
ccac35e
auto format by CI
oneflow-ci-bot Aug 24, 2023
025d4c5
graph allow any input
strint Aug 24, 2023
679e0d5
Merge branch 'fix_dynamo_mock_error' of https://github.com/Oneflow-In…
strint Aug 24, 2023
354cf21
save support any input
strint Aug 24, 2023
99b3d60
fix
jackalcooper Aug 25, 2023
8347ed8
fix gsize>64
jackalcooper Aug 25, 2023
6705366
Merge branch 'add-kMaxInputCount-in-GroupedMatmulFunctor' of https://…
strint Aug 25, 2023
dde053c
compile with cutlass extension
hjchen2 Aug 26, 2023
cb84faf
fix
hjchen2 Aug 26, 2023
5558ed3
add conv2d quant op
hjchen2 Aug 27, 2023
a947f44
add conv2d quant kernel
hjchen2 Aug 27, 2023
7a38a9f
refactor quantization kernel and fix tuning warmup pass
hjchen2 Aug 28, 2023
e48f738
fix
hjchen2 Aug 28, 2023
63e92d3
Merge branch 'sdxl_dev' of https://github.com/Oneflow-Inc/oneflow int…
hjchen2 Aug 29, 2023
5500840
fix compilation error for cuda12.2
hjchen2 Aug 30, 2023
cc7c49c
eliminate scalar math op
hjchen2 Aug 30, 2023
f1a4397
refactor
hjchen2 Aug 30, 2023
a889fc5
fuse add to output completely
hjchen2 Aug 31, 2023
a80b11d
add_quant_matmul
clackhan Aug 31, 2023
cd745cc
fix unused-but-set-variable error
hjchen2 Sep 1, 2023
4e499e1
conv2d_quant support add_to_output
hjchen2 Sep 1, 2023
c792366
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflo…
hjchen2 Sep 1, 2023
63b449c
fix and reformat
hjchen2 Sep 1, 2023
eef55aa
update cutlass extension branch
hjchen2 Sep 1, 2023
d62f6e4
matmul_quant supports add_to_output
hjchen2 Sep 2, 2023
0fff5e0
Merge branch 'master' into dev_int8_conv
hjchen2 Sep 2, 2023
418598e
auto format by CI
oneflow-ci-bot Sep 2, 2023
9031d5c
support to find the fastest kernel for matmul_quant
hjchen2 Sep 3, 2023
f31f2b1
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflo…
hjchen2 Sep 3, 2023
a857441
fuse gelu quant
hjchen2 Sep 4, 2023
9ac9a1c
add_prune_redundant_quantization_op_pass
clackhan Sep 4, 2023
2d45cc6
auto format by CI
oneflow-ci-bot Sep 4, 2023
9af9b8d
optimize activation dynamic quantization
hjchen2 Sep 4, 2023
60b8611
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflo…
hjchen2 Sep 4, 2023
4a5a5c6
optimize
hjchen2 Sep 5, 2023
42c1ee9
impl mlir pass
clackhan Sep 5, 2023
af2f7d4
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflo…
clackhan Sep 5, 2023
7a5ba06
auto format by CI
oneflow-ci-bot Sep 5, 2023
d294ef2
refine
clackhan Sep 5, 2023
c384f31
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflo…
clackhan Sep 5, 2023
d414feb
update
hjchen2 Sep 5, 2023
1388490
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflo…
hjchen2 Sep 5, 2023
13c0669
fmt
jackalcooper Sep 6, 2023
14c4bba
prune_reduntant_quant_from_input_op
clackhan Sep 6, 2023
28ec052
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflo…
clackhan Sep 6, 2023
d30b33a
refine
clackhan Sep 6, 2023
78b51cd
fuse dynamic quant conv
hjchen2 Sep 6, 2023
0059708
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflo…
hjchen2 Sep 6, 2023
a102cc0
matmul quant with fiter scale
hjchen2 Sep 6, 2023
ef397d6
refine
hjchen2 Sep 7, 2023
f262013
fuse layer norm and dynamic quant
hjchen2 Sep 7, 2023
54b3dd5
add_grouped_matmul_quant
clackhan Sep 8, 2023
5c64169
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflo…
clackhan Sep 8, 2023
66f50e5
fix
hjchen2 Sep 8, 2023
bfd5319
optimize
hjchen2 Sep 8, 2023
89acc4c
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflo…
hjchen2 Sep 10, 2023
a153075
update xformers fmha
hjchen2 Sep 13, 2023
af2da94
Revert "update xformers fmha"
hjchen2 Sep 13, 2023
26dcc1d
optimize quant
hjchen2 Sep 13, 2023
397304f
add_cutlass_gemm_array_tuner
clackhan Sep 15, 2023
c1c859b
fix
hjchen2 Sep 15, 2023
4ba3cc5
fuse_min_max_observer_and_matmul_quant
clackhan Sep 22, 2023
7e154e1
refine
clackhan Sep 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ if(BUILD_PYTHON)
endif(BUILD_PYTHON)

set(CUTLASS_URL
https://github.com/Oneflow-Inc/cutlass/archive/e6f548d80bfdf1167d66adbbbcfc2ee3394f4777.zip)
https://github.com/Oneflow-Inc/cutlass/archive/d47b8883b5e3661b41cc8a7a6f4c240c5524647f.zip)
use_mirror(VARIABLE CUTLASS_URL URL ${CUTLASS_URL})
set(CUTLASS_MD5 425f8cf064ff47c81124e55490135f5c)
set(CUTLASS_MD5 7b417720240a443276ce4bb9ef169db1)

include(cuda)
add_subdirectory(external)
Expand Down
6 changes: 6 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ if(BUILD_CUDA)
endif()
include(nccl)
include(cutlass)
include(cutlass-extension)
include(trt_flash_attention)

list(APPEND oneflow_third_party_libs ${NCCL_LIBRARIES})
Expand All @@ -161,6 +162,11 @@ if(BUILD_CUDA)
list(APPEND oneflow_third_party_libs ${CUTLASS_LIBRARIES})
list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${CUTLASS_INCLUDE_DIR})
endif()
if(WITH_CUTLASS_EXTENSION)
list(APPEND oneflow_third_party_dependencies cutlass-extension)
list(APPEND oneflow_third_party_libs ${CUTLASS_EXTENSION_LIBRARIES})
list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${CUTLASS_EXTENSION_INCLUDE_DIR})
endif()
list(APPEND oneflow_third_party_dependencies trt_flash_attention)
list(APPEND oneflow_third_party_libs ${TRT_FLASH_ATTENTION_LIBRARIES})
list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${TRT_FLASH_ATTENTION_INCLUDE_DIR})
Expand Down
60 changes: 60 additions & 0 deletions cmake/third_party/cutlass-extension.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
include(ExternalProject)

set(WITH_CUTLASS_EXTENSION OFF CACHE BOOL "")

if(WITH_CUTLASS_EXTENSION)

add_definitions(-DWITH_CUTLASS_EXTENSION)

find_package(Threads)

set(CUTLASS_EXTENSION_PROJECT cutlass-extension)

set(CUTLASS_EXTENSION_INSTALL_DIR ${THIRD_PARTY_DIR}/cutlass-extension)

set(CUTLASS_EXTENSION_INCLUDE_DIR ${CUTLASS_EXTENSION_INSTALL_DIR}/include CACHE PATH "" FORCE)
set(CUTLASS_EXTENSION_LIBRARY_DIR ${CUTLASS_EXTENSION_INSTALL_DIR}/lib CACHE PATH "" FORCE)
set(CUTLASS_EXTENSION_LIBRARIES ${CUTLASS_EXTENSION_LIBRARY_DIR}/libcutlass_extension.so)
set(CUTLASS_EXTENSION_SOURCE_DIR
${CMAKE_CURRENT_BINARY_DIR}/cutlass-extension/src/cutlass-extension/)
set(CUTLASS_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cutlass/src/cutlass)

foreach(arch ${CUDA_REAL_ARCHS_LIST})
if(arch GREATER_EQUAL 70)
list(APPEND CUTLASS_REAL_ARCHS ${arch})
endif()
endforeach()

if(THIRD_PARTY)
ExternalProject_Add(
${CUTLASS_EXTENSION_PROJECT}
PREFIX cutlass-extension
GIT_REPOSITORY https://github.com/Oneflow-Inc/oneflow-cutlass-extension.git
GIT_TAG master
UPDATE_COMMAND ""
BUILD_BYPRODUCTS ${CUTLASS_EXTENSION_LIBRARIES}
CMAKE_ARGS -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}
-DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE}
CMAKE_CACHE_ARGS
-DCMAKE_CUDA_COMPILER:STRING=${CUDAToolkit_NVCC_EXECUTABLE}
-DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER}
-DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER}
-DCMAKE_INSTALL_PREFIX:PATH=${CUTLASS_EXTENSION_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${CUTLASS_EXTENSION_LIBRARY_DIR}
-DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE}
-DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}
-DCUTLASS_ENABLE_EXAMPLES:BOOL=OFF
-DCUTLASS_ENABLE_PROFILER:BOOL=OFF
-DCUTLASS_ENABLE_LIBRARY:BOOL=ON
-DCUTLASS_NVCC_ARCHS:STRING=${CUTLASS_REAL_ARCHS}
-DCUTLASS_ENABLE_TESTS:BOOL=OFF
-DCUTLASS_UNITY_BUILD_ENABLED:BOOL=ON
-DCUTLASS_LIBRARY_DEBUG_POSTFIX:STRING=
-DCUTLASS_NVCC_EMBED_PTX:BOOL=OFF
-DCUTLASS_DIR:STRING=${CUTLASS_SOURCE_DIR}
DEPENDS cutlass)

endif(THIRD_PARTY)
endif(WITH_CUTLASS_EXTENSION)
1 change: 1 addition & 0 deletions cmake/third_party/cutlass.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ if(WITH_CUTLASS)
"45_dual_gemm/test_run.h"
"45_dual_gemm/kernel/dual_gemm.h"
"45_dual_gemm/device/dual_gemm.h"
"45_dual_gemm/dual_gemm_common.h"
"45_dual_gemm/dual_gemm_run.h"
"45_dual_gemm/thread/left_silu_and_mul.h"
"45_dual_gemm/threadblock/dual_mma_multistage.h"
Expand Down
20 changes: 20 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,20 @@
String channel_pos="channels_first") => Conv2d'
bind_python: True

- name: "conv2d_quant"
signature:
'Tensor (Tensor input, Tensor weight, Tensor input_zero_point, Tensor scale=None, Tensor bias=None, Int32List[2] stride=1,
Int32List[2] padding=0, Int32List[2] dilation=1, Int32 groups=1,
String channel_pos="channels_first", DataType output_dtype=None) => Conv2dQuant'
bind_python: True

- name: "matmul_quant"
signature:
'Tensor (Tensor a, Tensor b, Tensor scale=None, Tensor bias=None,
Bool transpose_a=False, Bool transpose_b=False,
Double alpha=1.0, DataType output_dtype=None) => MatmulQuant'
bind_python: True

- name: "conv3d"
signature:
'Tensor (Tensor input, Tensor weight, Tensor bias=None, Int32List[3] stride=1,
Expand Down Expand Up @@ -1092,6 +1106,12 @@
signature: 'Tensor (Tensor x, Tensor w, Tensor w_scale, *, Tensor w_zero=None, Tensor b=None, Int32 num_bits=8, Bool symmetric=True, Int64 group_dim=-1, Int64 group_size=-1) => FusedLinearWithGroupwiseQuantizedWeight'
bind_python: True

- name: "fused_activation_min_max_observer"
signature:
"TensorTuple (Tensor in, Tensor weight_scale, Tensor weight_acc, Tensor bias=None, String quantization_formula, Int32 quantization_bit,
String quantization_scheme, Bool per_layer_quantization=True) => FusedActivationMinMaxObserver"
bind_python: True

- name: "conv_data_grad"
signature:
'Tensor (Tensor dy, Tensor weight, Tensor x, Int32 num_spatial_dims,
Expand Down
107 changes: 107 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,70 @@ class Conv3dFunctor : public ConvBaseFunctor {
}
};

class ConvQuantBaseFunctor {
public:
explicit ConvQuantBaseFunctor(const int& num_spatial_dims)
: num_spatial_dims_(num_spatial_dims) {}
virtual ~ConvQuantBaseFunctor() = default;

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& weight,
const std::shared_ptr<one::Tensor>& input_zero_point,
const Optional<one::Tensor>& scale, const Optional<one::Tensor>& bias,
const std::vector<int32_t>& stride, const std::vector<int32_t>& padding,
const std::vector<int32_t>& dilation, const int32_t& groups,
const std::string& channel_pos,
const Optional<Symbol<DType>>& output_dtype) const {
if (scale || bias) {
CHECK_OR_RETURN(scale && bias) << "scale and bias must both be given or not.";
}
std::vector<int32_t> kernel_size_vec(num_spatial_dims_);
int32_t kernel_idx_offset = 2;
if (channel_pos == "channels_last") { kernel_idx_offset = 1; }

for (int i = 0; i < num_spatial_dims_; i++) {
kernel_size_vec.at(i) = ((weight->shape())->At(i + kernel_idx_offset));
}
auto& conv_attrs =
THREAD_CACHED_MUTABLE_ATTR_MAP("filters", "kernel_size", "padding_before", "strides",
"dilation_rate", "groups", "data_format", "out_dtype");
conv_attrs.SetAllAttrs(static_cast<int32_t>(weight->shape()->At(0)), kernel_size_vec, padding,
stride, dilation, groups, channel_pos,
output_dtype.value_or(DType::Float())->data_type());
if (scale) {
return OpInterpUtil::Dispatch<Tensor>(
*conv_scale_bias_op_, {input, weight, input_zero_point, JUST(scale), JUST(bias)},
conv_attrs);
}
return OpInterpUtil::Dispatch<Tensor>(*conv_op_, {input, weight, input_zero_point}, conv_attrs);
}

protected:
std::shared_ptr<OpExpr> conv_op_;
std::shared_ptr<OpExpr> conv_scale_bias_op_;
int32_t num_spatial_dims_;
};

class Conv2dQuantFunctor : public ConvQuantBaseFunctor {
public:
Conv2dQuantFunctor() : ConvQuantBaseFunctor(/*num_spatial_dims_=*/2) {
conv_op_ = CHECK_JUST(one::OpBuilder("conv2d_quant")
.Input("in")
.Input("weight")
.Input("in_zero_point")
.Output("out")
.Build());
conv_scale_bias_op_ = CHECK_JUST(one::OpBuilder("conv2d_quant")
.Input("in")
.Input("weight")
.Input("in_zero_point")
.Input("scale")
.Input("bias")
.Output("out")
.Build());
}
};

class DeConvBaseFunctor {
public:
explicit DeConvBaseFunctor(const int& num_spatial_dims) : num_spatial_dims_(num_spatial_dims) {
Expand Down Expand Up @@ -314,6 +378,47 @@ class MatMulNoBroadCastFunctor {
}
};

class MatMulQuantFunctor {
public:
MatMulQuantFunctor() {
matmul_op_ =
CHECK_JUST(one::OpBuilder("matmul_quant").Input("a").Input("b").Output("out").Build());
matmul_scale_bias_op_ = CHECK_JUST(one::OpBuilder("matmul_quant")
.Input("a")
.Input("b")
.Input("scale")
.Input("bias")
.Output("out")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& a,
const std::shared_ptr<one::Tensor>& b,
const Optional<one::Tensor>& scale, const Optional<one::Tensor>& bias,
const bool& transpose_a, const bool& transpose_b, const double& alpha,
const Optional<Symbol<DType>>& output_dtype) const {
CHECK_OR_RETURN(!transpose_a)
<< "the first input should not be transposed for quantized matmul.";
CHECK_OR_RETURN(transpose_b) << "the second input should be transposed for quantized matmul.";
CHECK_EQ_OR_RETURN(alpha, 1) << "alpha should be 1 for quantized matmul.";
if (scale || bias) {
CHECK_OR_RETURN(scale && bias) << "scale and bias must both be given or not.";
}
auto& attrs =
THREAD_CACHED_MUTABLE_ATTR_MAP("transpose_a", "transpose_b", "alpha", "out_dtype");
attrs.SetAllAttrs(transpose_a, transpose_b, alpha,
output_dtype.value_or(DType::Float())->data_type());
if (scale) {
return OpInterpUtil::Dispatch<Tensor>(*matmul_scale_bias_op_, {a, b, JUST(scale), JUST(bias)},
attrs);
}
return OpInterpUtil::Dispatch<Tensor>(*matmul_op_, {a, b}, attrs);
}

private:
std::shared_ptr<OpExpr> matmul_op_;
std::shared_ptr<OpExpr> matmul_scale_bias_op_;
};

class MatMulFunctor {
public:
MatMulFunctor() {
Expand Down Expand Up @@ -5430,9 +5535,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::DeConv1dFunctor>("Deconv1d");
m.add_functor<impl::DeConv2dFunctor>("Deconv2d");
m.add_functor<impl::DeConv3dFunctor>("Deconv3d");
m.add_functor<impl::Conv2dQuantFunctor>("Conv2dQuant");
m.add_functor<impl::EmbeddingReNormFunctor>("EmbeddingReNorm");
m.add_functor<impl::EmbeddingFunctor>("Embedding");
m.add_functor<impl::MatMulFunctor>("MatMul");
m.add_functor<impl::MatMulQuantFunctor>("MatmulQuant");
m.add_functor<impl::MatMulNoBroadCastFunctor>("MatMulNoBroadCast");
m.add_functor<impl::BatchMatMulFunctor>("BatchMatMul");
m.add_functor<impl::MatrixVectorProductFunctor>("MatrixVectorProduct");
Expand Down
46 changes: 46 additions & 0 deletions oneflow/core/functional/impl/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,51 @@ class MovingAverageMinMaxObserverFunctor {
std::shared_ptr<OpExpr> op_;
};

class FusedActivationMinMaxObserverFunctor {
public:
FusedActivationMinMaxObserverFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fused_activation_min_max_observer")
.Input("in")
.Input("weight_scale")
.Input("weight_acc")
.Output("in_scale")
.Output("in_zero_point")
.Output("out_scale")
.Output("out_bias")
.Build());
op_with_bias_ = CHECK_JUST(one::OpBuilder("fused_activation_min_max_observer")
.Input("in")
.Input("weight_scale")
.Input("weight_acc")
.Input("bias")
.Output("in_scale")
.Output("in_zero_point")
.Output("out_scale")
.Output("out_bias")
.Build());
}
Maybe<TensorTuple> operator()(
const std::shared_ptr<one::Tensor>& in, const std::shared_ptr<one::Tensor>& weight_scale,
const std::shared_ptr<one::Tensor>& weight_acc, const Optional<one::Tensor>& bias,
const std::string& quantization_formula, const int32_t& quantization_bit,
const std::string& quantization_scheme, const bool& per_layer_quantization) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("quantization_formula", "quantization_bit",
"quantization_scheme", "per_layer_quantization");
attrs.SetAllAttrs(quantization_formula, quantization_bit, quantization_scheme,
per_layer_quantization);
if (bias) {
return OpInterpUtil::Dispatch<TensorTuple>(*op_with_bias_,
{in, weight_scale, weight_acc, JUST(bias)}, attrs);
} else {
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {in, weight_scale, weight_acc}, attrs);
}
}

private:
std::shared_ptr<OpExpr> op_;
std::shared_ptr<OpExpr> op_with_bias_;
};

class FakeQuantizationFunctor {
public:
FakeQuantizationFunctor() {
Expand Down Expand Up @@ -390,6 +435,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::GroupwiseDequantizeFunctor>("GroupwiseDequantize");
m.add_functor<impl::FusedLinearWithGroupwiseQuantizedWeightFunctor>(
"FusedLinearWithGroupwiseQuantizedWeight");
m.add_functor<impl::FusedActivationMinMaxObserverFunctor>("FusedActivationMinMaxObserver");
};

} // namespace functional
Expand Down
Loading