Skip to content

Commit e9e7c23

Browse files
authored
[passes] Consider multiple DQ case (#169)
* [passes] Consider multiple DQ case This commit considers multiple DQ case in SimpleAdd. TICO-DCO-1.0-Signed-off-by: seongwoo <[email protected]> * rename variable.
1 parent 90a3650 commit e9e7c23

File tree

1 file changed

+54
-24
lines changed

1 file changed

+54
-24
lines changed

tico/experimental/quantization/passes/fold_quant_ops.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
import torch
2222
from torch.export import ExportedProgram
2323

24-
from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
24+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
2525
from tico.utils import logging
26-
from tico.utils.graph import create_node
2726
from tico.utils.passes import PassBase, PassResult
2827
from tico.utils.trace_decorators import trace_graph_diff_on_pass
2928
from tico.utils.utils import get_quant_dtype
@@ -42,11 +41,33 @@ class FoldQuantOps(PassBase):
4241
To export quantized circle, this pass removes (Q - DQ) nodes and saves those quantization info
4342
to previous op's metadata.
4443
45-
[BEFORE]
46-
op (float) - Quantize - Dequantize - (float)
47-
48-
[AFTER]
49-
op (float with meta[QPARAM_KEY])
44+
────────────────────────────────────────────────────────────────
45+
BEFORE AFTER
46+
────────────────────────────────────────────────────────────────
47+
op(float) ─ Q ─ DQ ─ … op(float, meta[QPARAM])
48+
49+
op ─ Q1 ─ DQ1 ─ Q2 ─ DQ2 op(meta[QPARAM]) ─ Q2
50+
▲ ▲
51+
│ (Q1, DQ1 folded) │ (re-quantization kept)
52+
53+
op ─ Q ─┬─ DQ0 op(meta[QPARAM])
54+
├─ DQ1 (each DQ* folded, Q dropped when orphaned)
55+
└─ DQ2
56+
────────────────────────────────────────────────────────────────
57+
58+
Algorithm
59+
---------
60+
1. Iterate over *all* Dequantize nodes.
61+
2. For each DQ, verify it is driven by a Quantize node `q` and that
62+
`q` and `dq` share identical (scale, zero-point, dtype).
63+
3. a) If the producer op has **no** QPARAM, attach one, then replace
64+
*this* DQ's usages with the producer op.
65+
b) If the producer is already quantized with a different dtype,
66+
this is a *re-quantization*: attach QPARAM to `q` and keep it,
67+
but still remove the DQ.
68+
4. After all replacements, run `graph.eliminate_dead_code()`.
69+
Any Quantize that became orphaned because *all* its DQs were folded
70+
is deleted automatically.
5071
"""
5172

5273
def __init__(self):
@@ -81,15 +102,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
81102
if q_args.dtype != dq_args.dtype:
82103
continue
83104

84-
# Case 1. op is not quantized
85-
# - Quantize op
86-
# Case 2. op is quantized
87-
# 2.1. op_dtype == qdq_dtype
88-
# - Just skip (NOTE Need requantization?)
89-
# 2.2. op_dtype != qdq_dtype
90-
# - Insert Quantize operator
91-
92-
# Case 1
105+
# ───────────────────────────────────────────
106+
# Case 1: op not yet quantized
107+
# ───────────────────────────────────────────
93108
if QPARAM_KEY not in op.meta:
94109
qparam = QuantParam()
95110
qparam.scale = [q_args.scale]
@@ -100,21 +115,36 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
100115
dq.replace_all_uses_with(op, propagate_meta=False)
101116

102117
logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.")
118+
# ───────────────────────────────────────────
119+
# Case 2: op already quantized
120+
# 2.1 same dtype → nothing to do
121+
# 2.2 diff dtype → leave Q in place
122+
# ───────────────────────────────────────────
103123
else:
104124
op_qparam: QuantParam = op.meta[QPARAM_KEY]
105125
qdq_dtype = get_quant_dtype(q_args.quant_min, q_args.quant_max)
106-
# Case 2.2
126+
107127
if op_qparam.dtype != qdq_dtype:
108-
# If op is already quantized with a dtype different from qdq, leave quantize
109-
qparam = QuantParam()
110-
qparam.scale = [q_args.scale]
111-
qparam.zero_point = [q_args.zero_p]
112-
qparam.dtype = qdq_dtype
113-
q.meta[QPARAM_KEY] = qparam
114-
assert len(q.users) == 1, "Fix me unless"
128+
# Attach QPARAM to Q once
129+
if QPARAM_KEY not in q.meta:
130+
qparam = QuantParam()
131+
qparam.scale = [q_args.scale]
132+
qparam.zero_point = [q_args.zero_p]
133+
qparam.dtype = qdq_dtype
134+
q.meta[QPARAM_KEY] = qparam
135+
assert len(q.users) == 1, "Fix me unless"
115136

116137
dq.replace_all_uses_with(q, propagate_meta=False)
117138
logger.debug(f"{dq.name} is folded ({q.name} is left).")
139+
else:
140+
# Same dtype → the Quantize–Dequantize pair is redundant.
141+
assert op_qparam.scale and op_qparam.scale[0] == q_args.scale
142+
assert (
143+
op_qparam.zero_point
144+
and op_qparam.zero_point[0] == q_args.zero_p
145+
)
146+
dq.replace_all_uses_with(op, propagate_meta=False)
147+
logger.debug(f"Removed redundant {dq.name}")
118148

119149
graph.eliminate_dead_code()
120150
graph.lint()

0 commit comments

Comments
 (0)