diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 903a4a92b8e..453b4814637 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -104,19 +104,6 @@ runtime.python_library( ], ) -runtime.python_library( - name = "replace_qdq", - srcs = ["replace_qdq.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - "//caffe2:torch", - "//executorch/backends/vulkan:utils_lib", - "//executorch/exir:pass_base", - ], -) - runtime.python_library( name = "fuse_patterns", srcs = ["fuse_patterns.py"], @@ -149,7 +136,6 @@ runtime.python_library( ":insert_prepack_nodes", ":remove_asserts", ":remove_redundant_ops", - ":replace_qdq", ":squeeze_unsqueeze_inputs", ":tag_memory_meta_pass", ] diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 8d305ababe4..d6a6823ca88 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -19,7 +19,6 @@ from executorch.backends.vulkan._passes.remove_redundant_ops import ( RemoveRedundantOpsTransform, ) -from executorch.backends.vulkan._passes.replace_qdq import ReplaceQDQPass from executorch.backends.vulkan._passes.squeeze_unsqueeze_inputs import ( SqueezeUnsqueezeInputs, ) @@ -33,7 +32,6 @@ "remove_asserts", "RemoveAssertsTransform", "RemoveRedundantOpsTransform", - "ReplaceQDQPass", "SqueezeUnsqueezeInputs", "TagMemoryMetaPass", ] diff --git a/backends/vulkan/_passes/replace_qdq.py b/backends/vulkan/_passes/replace_qdq.py deleted file mode 100644 index 2c5331eb213..00000000000 --- a/backends/vulkan/_passes/replace_qdq.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import executorch.backends.vulkan.utils as utils -import torch -from executorch.exir.dialects._ops import ops as exir_ops - -from executorch.exir.pass_base import ExportPass, PassResult - - -class ReplaceQDQPass(ExportPass): - """ - Replace standard quantize/dequantize ops with custom conv-specific ops when they - feed into/from quantized convolution operations. This optimization allows the - backend to handle quantization more efficiently for convolution operations. - """ - - def __init__(self): - super(ReplaceQDQPass, self).__init__() - - def call(self, graph_module: torch.fx.GraphModule): - # Track nodes that need to be replaced - nodes_to_replace = [] - - for node in graph_module.graph.nodes: - # Check if this is the custom quantized conv2d op - if node.target in [ - exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default, - exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, - exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default, - ]: - for quantized_input_node in node.args: - if isinstance( - quantized_input_node, torch.fx.Node - ) and utils.is_quant_node(quantized_input_node): - # Get the arguments from the original quantize node - input_tensor = quantized_input_node.args[0] - scale = quantized_input_node.args[1] - zero_point = quantized_input_node.args[2] - - nodes_to_replace.append( - { - "old_node": quantized_input_node, - "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, - "args": (input_tensor, scale, zero_point), - "node_type": "quantize_input", - } - ) - - # Find dequantize ops that consume the output of this conv2d - for user in node.users: - if utils.is_dequant_node(user): - # Get the arguments from the original dequantize node - scale = user.args[1] - zero_point = user.args[2] - - nodes_to_replace.append( - { - "old_node": user, - "new_target": exir_ops.edge.et_vk.dequantize_q8to_from_conv2d.default, - "args": ( - node, - scale, - zero_point, - ), # node is the conv2d output - "node_type": "dequantize_output", - } - ) - - # Apply the replacements - for replacement in nodes_to_replace: - old_node = replacement["old_node"] - new_target = replacement["new_target"] - new_args = replacement["args"] - - with graph_module.graph.inserting_before(old_node): - new_node = graph_module.graph.create_node( - "call_function", new_target, args=new_args - ) - new_node.meta = old_node.meta.copy() - old_node.replace_all_uses_with(new_node) - - # Clean up the graph - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - - # Re-trace to validate everything is ok - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 682087585ef..aed8b591fea 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -539,42 +539,6 @@ def apply_rotary_emb_impl( lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd") apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name) -############################# -## quantize/dequantize ops ## -############################# - - -def quantize_q8ta_for_conv2d_impl( - input: torch.Tensor, - scale: float, - zero_point: int, -): - return torch.ops.quantized_decomposed.quantize_per_tensor( - input, scale, zero_point, -128, 127, torch.int8 - ) - - -name = "quantize_q8ta_for_conv2d" -lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor") -lib.impl(name, quantize_q8ta_for_conv2d_impl, "CompositeExplicitAutograd") -quantize_q8ta_for_conv2d_op = getattr(getattr(torch.ops, namespace), name) - - -def dequantize_q8to_from_conv2d_impl( - input: torch.Tensor, - scale: float, - zero_point: int, -): - return torch.ops.quantized_decomposed.dequantize_per_tensor( - input, scale, zero_point, -128, 127, input.dtype - ) - - -name = "dequantize_q8to_from_conv2d" -lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor") -lib.impl(name, dequantize_q8to_from_conv2d_impl, "CompositeExplicitAutograd") -dequantize_q8to_from_conv2d_op = getattr(getattr(torch.ops, namespace), name) - ######################## ## add_q8ta_q8ta_q8to ## ######################## diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index e51bc8ea12a..461278500a6 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -144,13 +144,9 @@ def register_ephemeral_op(): @update_features( [ - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_token.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.quantized_decomposed.dequantize_per_token.default, ] ) @@ -630,10 +626,11 @@ def register_quantized_binary_op(): @update_features( [ - exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, ] ) -def register_quantize_for_conv2d_op(): +def register_quantize_op(): return OpFeatures( inputs_storage=[ utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, @@ -641,16 +638,16 @@ def register_quantize_for_conv2d_op(): outputs_storage=[ utils.PACKED_INT8_4W4C_BUFFER, ], - supports_resize=False, ) @update_features( [ - exir_ops.edge.et_vk.dequantize_q8to_from_conv2d.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, ] ) -def register_dequantize_for_conv2d_op(): +def register_dequantize_op(): return OpFeatures( inputs_storage=[ utils.PACKED_INT8_4W4C_BUFFER, @@ -658,7 +655,6 @@ def register_dequantize_for_conv2d_op(): outputs_storage=[ utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], - supports_resize=False, ) diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp index 81eaf9e02a5..ee8f8a1afb4 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp @@ -366,30 +366,52 @@ void add_unpack_4w4c_and_dequantize_node( // Operator Entrypoints // -void quantize_q8ta_for_conv2d( +void quantize_per_tensor_impl( ComputeGraph& graph, const std::vector& args) { - int32_t idx = 0; - const ValueRef fp_input = args.at(idx++); - const ValueRef scale = args.at(idx++); - const ValueRef zero_point = args.at(idx++); - const ValueRef packed_int8_input = args.at(idx++); + int32_t arg_idx = 0; + const ValueRef fp_input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + (void)quant_min; + const ValueRef quant_max = args[arg_idx++]; + (void)quant_max; + const ValueRef dtype = args[arg_idx++]; + (void)dtype; + + const ValueRef int8_output = args[arg_idx++]; + + VK_CHECK_COND( + graph.estimate_memory_layout_of(int8_output) == utils::kPackedInt8_4W4C); add_quantize_and_pack_4w4c_node( - graph, fp_input, scale, zero_point, packed_int8_input); + graph, fp_input, scale, zero_point, int8_output); } -void dequantize_q8to_from_conv2d( +void dequantize_per_tensor_impl( ComputeGraph& graph, const std::vector& args) { - int32_t idx = 0; - const ValueRef packed_int8_output = args.at(idx++); - const ValueRef scale = args.at(idx++); - const ValueRef zero_point = args.at(idx++); - const ValueRef fp_output = args.at(idx++); + int32_t arg_idx = 0; + const ValueRef int8_input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + (void)quant_min; + const ValueRef quant_max = args[arg_idx++]; + (void)quant_max; + const ValueRef dtype = args[arg_idx++]; + (void)dtype; + const ValueRef output_dtype = args[arg_idx++]; + (void)output_dtype; + + const ValueRef fp_output = args[arg_idx++]; + + VK_CHECK_COND( + graph.estimate_memory_layout_of(int8_input) == utils::kPackedInt8_4W4C); add_unpack_4w4c_and_dequantize_node( - graph, packed_int8_output, scale, zero_point, fp_output); + graph, int8_input, scale, zero_point, fp_output); } void qdq8ta_conv2d_input( @@ -416,11 +438,13 @@ void qdq8ta_conv2d_input( } REGISTER_OPERATORS { - VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input); VK_REGISTER_OP( - et_vk.quantize_q8ta_for_conv2d.default, quantize_q8ta_for_conv2d); + quantized_decomposed.quantize_per_tensor.default, + quantize_per_tensor_impl); VK_REGISTER_OP( - et_vk.dequantize_q8to_from_conv2d.default, dequantize_q8to_from_conv2d); + quantized_decomposed.dequantize_per_tensor.default, + dequantize_per_tensor_impl); + VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input); } } // namespace vkcompute diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 81ee67a596c..3a3f6cdf4fe 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -21,7 +21,6 @@ FuseQuantizedOpsTransform, insert_prepack_nodes, RemoveRedundantOpsTransform, - ReplaceQDQPass, SqueezeUnsqueezeInputs, TagMemoryMetaPass, ) @@ -162,7 +161,6 @@ def preprocess( # noqa: C901 AddmmToLinearTransform(), RemoveRedundantOpsTransform(), FuseQuantizedOpsTransform(), - ReplaceQDQPass(), FoldQDQPass(), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(),