Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,6 @@ runtime.python_library(
],
)

runtime.python_library(
name = "replace_qdq",
srcs = ["replace_qdq.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan:utils_lib",
"//executorch/exir:pass_base",
],
)

runtime.python_library(
name = "fuse_patterns",
srcs = ["fuse_patterns.py"],
Expand Down Expand Up @@ -149,7 +136,6 @@ runtime.python_library(
":insert_prepack_nodes",
":remove_asserts",
":remove_redundant_ops",
":replace_qdq",
":squeeze_unsqueeze_inputs",
":tag_memory_meta_pass",
]
Expand Down
2 changes: 0 additions & 2 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from executorch.backends.vulkan._passes.remove_redundant_ops import (
RemoveRedundantOpsTransform,
)
from executorch.backends.vulkan._passes.replace_qdq import ReplaceQDQPass
from executorch.backends.vulkan._passes.squeeze_unsqueeze_inputs import (
SqueezeUnsqueezeInputs,
)
Expand All @@ -33,7 +32,6 @@
"remove_asserts",
"RemoveAssertsTransform",
"RemoveRedundantOpsTransform",
"ReplaceQDQPass",
"SqueezeUnsqueezeInputs",
"TagMemoryMetaPass",
]
93 changes: 0 additions & 93 deletions backends/vulkan/_passes/replace_qdq.py

This file was deleted.

36 changes: 0 additions & 36 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,42 +539,6 @@ def apply_rotary_emb_impl(
lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd")
apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)

#############################
## quantize/dequantize ops ##
#############################


def quantize_q8ta_for_conv2d_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
):
return torch.ops.quantized_decomposed.quantize_per_tensor(
input, scale, zero_point, -128, 127, torch.int8
)


name = "quantize_q8ta_for_conv2d"
lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor")
lib.impl(name, quantize_q8ta_for_conv2d_impl, "CompositeExplicitAutograd")
quantize_q8ta_for_conv2d_op = getattr(getattr(torch.ops, namespace), name)


def dequantize_q8to_from_conv2d_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
):
return torch.ops.quantized_decomposed.dequantize_per_tensor(
input, scale, zero_point, -128, 127, input.dtype
)


name = "dequantize_q8to_from_conv2d"
lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor")
lib.impl(name, dequantize_q8to_from_conv2d_impl, "CompositeExplicitAutograd")
dequantize_q8to_from_conv2d_op = getattr(getattr(torch.ops, namespace), name)

########################
## add_q8ta_q8ta_q8to ##
########################
Expand Down
18 changes: 7 additions & 11 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,9 @@ def register_ephemeral_op():

@update_features(
[
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_token.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_token.default,
]
)
Expand Down Expand Up @@ -630,35 +626,35 @@ def register_quantized_binary_op():

@update_features(
[
exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
]
)
def register_quantize_for_conv2d_op():
def register_quantize_op():
return OpFeatures(
inputs_storage=[
utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER,
],
outputs_storage=[
utils.PACKED_INT8_4W4C_BUFFER,
],
supports_resize=False,
)


@update_features(
[
exir_ops.edge.et_vk.dequantize_q8to_from_conv2d.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
]
)
def register_dequantize_for_conv2d_op():
def register_dequantize_op():
return OpFeatures(
inputs_storage=[
utils.PACKED_INT8_4W4C_BUFFER,
],
outputs_storage=[
utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER,
],
supports_resize=False,
)


Expand Down
58 changes: 41 additions & 17 deletions backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,30 +366,52 @@ void add_unpack_4w4c_and_dequantize_node(
// Operator Entrypoints
//

void quantize_q8ta_for_conv2d(
void quantize_per_tensor_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int32_t idx = 0;
const ValueRef fp_input = args.at(idx++);
const ValueRef scale = args.at(idx++);
const ValueRef zero_point = args.at(idx++);
const ValueRef packed_int8_input = args.at(idx++);
int32_t arg_idx = 0;
const ValueRef fp_input = args[arg_idx++];
const ValueRef scale = args[arg_idx++];
const ValueRef zero_point = args[arg_idx++];
const ValueRef quant_min = args[arg_idx++];
(void)quant_min;
const ValueRef quant_max = args[arg_idx++];
(void)quant_max;
const ValueRef dtype = args[arg_idx++];
(void)dtype;

const ValueRef int8_output = args[arg_idx++];

VK_CHECK_COND(
graph.estimate_memory_layout_of(int8_output) == utils::kPackedInt8_4W4C);

add_quantize_and_pack_4w4c_node(
graph, fp_input, scale, zero_point, packed_int8_input);
graph, fp_input, scale, zero_point, int8_output);
}

void dequantize_q8to_from_conv2d(
void dequantize_per_tensor_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int32_t idx = 0;
const ValueRef packed_int8_output = args.at(idx++);
const ValueRef scale = args.at(idx++);
const ValueRef zero_point = args.at(idx++);
const ValueRef fp_output = args.at(idx++);
int32_t arg_idx = 0;
const ValueRef int8_input = args[arg_idx++];
const ValueRef scale = args[arg_idx++];
const ValueRef zero_point = args[arg_idx++];
const ValueRef quant_min = args[arg_idx++];
(void)quant_min;
const ValueRef quant_max = args[arg_idx++];
(void)quant_max;
const ValueRef dtype = args[arg_idx++];
(void)dtype;
const ValueRef output_dtype = args[arg_idx++];
(void)output_dtype;

const ValueRef fp_output = args[arg_idx++];

VK_CHECK_COND(
graph.estimate_memory_layout_of(int8_input) == utils::kPackedInt8_4W4C);

add_unpack_4w4c_and_dequantize_node(
graph, packed_int8_output, scale, zero_point, fp_output);
graph, int8_input, scale, zero_point, fp_output);
}

void qdq8ta_conv2d_input(
Expand All @@ -416,11 +438,13 @@ void qdq8ta_conv2d_input(
}

REGISTER_OPERATORS {
VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input);
VK_REGISTER_OP(
et_vk.quantize_q8ta_for_conv2d.default, quantize_q8ta_for_conv2d);
quantized_decomposed.quantize_per_tensor.default,
quantize_per_tensor_impl);
VK_REGISTER_OP(
et_vk.dequantize_q8to_from_conv2d.default, dequantize_q8to_from_conv2d);
quantized_decomposed.dequantize_per_tensor.default,
dequantize_per_tensor_impl);
VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input);
}

} // namespace vkcompute
2 changes: 0 additions & 2 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
FuseQuantizedOpsTransform,
insert_prepack_nodes,
RemoveRedundantOpsTransform,
ReplaceQDQPass,
SqueezeUnsqueezeInputs,
TagMemoryMetaPass,
)
Expand Down Expand Up @@ -162,7 +161,6 @@ def preprocess( # noqa: C901
AddmmToLinearTransform(),
RemoveRedundantOpsTransform(),
FuseQuantizedOpsTransform(),
ReplaceQDQPass(),
FoldQDQPass(),
SqueezeUnsqueezeInputs(),
FuseViewCopyTransform(),
Expand Down
Loading