21
21
import torch
22
22
from torch .export import ExportedProgram
23
23
24
- from tico .serialize .quant_param import QPARAM_KEY , QuantParam , to_qparam_dtype
24
+ from tico .serialize .quant_param import QPARAM_KEY , QuantParam
25
25
from tico .utils import logging
26
- from tico .utils .graph import create_node
27
26
from tico .utils .passes import PassBase , PassResult
28
27
from tico .utils .trace_decorators import trace_graph_diff_on_pass
29
28
from tico .utils .utils import get_quant_dtype
@@ -42,11 +41,33 @@ class FoldQuantOps(PassBase):
42
41
To export quantized circle, this pass removes (Q - DQ) nodes and saves those quantization info
43
42
to previous op's metadata.
44
43
45
- [BEFORE]
46
- op (float) - Quantize - Dequantize - (float)
47
-
48
- [AFTER]
49
- op (float with meta[QPARAM_KEY])
44
+ ────────────────────────────────────────────────────────────────
45
+ BEFORE AFTER
46
+ ────────────────────────────────────────────────────────────────
47
+ op(float) ─ Q ─ DQ ─ … op(float, meta[QPARAM])
48
+
49
+ op ─ Q1 ─ DQ1 ─ Q2 ─ DQ2 op(meta[QPARAM]) ─ Q2
50
+ ▲ ▲
51
+ │ (Q1, DQ1 folded) │ (re-quantization kept)
52
+
53
+ op ─ Q ─┬─ DQ0 op(meta[QPARAM])
54
+ ├─ DQ1 (each DQ* folded, Q dropped when orphaned)
55
+ └─ DQ2
56
+ ────────────────────────────────────────────────────────────────
57
+
58
+ Algorithm
59
+ ---------
60
+ 1. Iterate over *all* Dequantize nodes.
61
+ 2. For each DQ, verify it is driven by a Quantize node `q` and that
62
+ `q` and `dq` share identical (scale, zero-point, dtype).
63
+ 3. a) If the producer op has **no** QPARAM, attach one, then replace
64
+ *this* DQ's usages with the producer op.
65
+ b) If the producer is already quantized with a different dtype,
66
+ this is a *re-quantization*: attach QPARAM to `q` and keep it,
67
+ but still remove the DQ.
68
+ 4. After all replacements, run `graph.eliminate_dead_code()`.
69
+ Any Quantize that became orphaned because *all* its DQs were folded
70
+ is deleted automatically.
50
71
"""
51
72
52
73
def __init__ (self ):
@@ -81,15 +102,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
81
102
if q_args .dtype != dq_args .dtype :
82
103
continue
83
104
84
- # Case 1. op is not quantized
85
- # - Quantize op
86
- # Case 2. op is quantized
87
- # 2.1. op_dtype == qdq_dtype
88
- # - Just skip (NOTE Need requantization?)
89
- # 2.2. op_dtype != qdq_dtype
90
- # - Insert Quantize operator
91
-
92
- # Case 1
105
+ # ───────────────────────────────────────────
106
+ # Case 1: op not yet quantized
107
+ # ───────────────────────────────────────────
93
108
if QPARAM_KEY not in op .meta :
94
109
qparam = QuantParam ()
95
110
qparam .scale = [q_args .scale ]
@@ -100,21 +115,36 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
100
115
dq .replace_all_uses_with (op , propagate_meta = False )
101
116
102
117
logger .debug (f"{ q .name } and { dq .name } are folded to { op .name } ." )
118
+ # ───────────────────────────────────────────
119
+ # Case 2: op already quantized
120
+ # 2.1 same dtype → nothing to do
121
+ # 2.2 diff dtype → leave Q in place
122
+ # ───────────────────────────────────────────
103
123
else :
104
124
op_qparam : QuantParam = op .meta [QPARAM_KEY ]
105
125
qdq_dtype = get_quant_dtype (q_args .quant_min , q_args .quant_max )
106
- # Case 2.2
126
+
107
127
if op_qparam .dtype != qdq_dtype :
108
- # If op is already quantized with a dtype different from qdq, leave quantize
109
- qparam = QuantParam ()
110
- qparam .scale = [q_args .scale ]
111
- qparam .zero_point = [q_args .zero_p ]
112
- qparam .dtype = qdq_dtype
113
- q .meta [QPARAM_KEY ] = qparam
114
- assert len (q .users ) == 1 , "Fix me unless"
128
+ # Attach QPARAM to Q once
129
+ if QPARAM_KEY not in q .meta :
130
+ qparam = QuantParam ()
131
+ qparam .scale = [q_args .scale ]
132
+ qparam .zero_point = [q_args .zero_p ]
133
+ qparam .dtype = qdq_dtype
134
+ q .meta [QPARAM_KEY ] = qparam
135
+ assert len (q .users ) == 1 , "Fix me unless"
115
136
116
137
dq .replace_all_uses_with (q , propagate_meta = False )
117
138
logger .debug (f"{ dq .name } is folded ({ q .name } is left)." )
139
+ else :
140
+ # Same dtype → the Quantize–Dequantize pair is redundant.
141
+ assert op_qparam .scale and op_qparam .scale [0 ] == q_args .scale
142
+ assert (
143
+ op_qparam .zero_point
144
+ and op_qparam .zero_point [0 ] == q_args .zero_p
145
+ )
146
+ dq .replace_all_uses_with (op , propagate_meta = False )
147
+ logger .debug (f"Removed redundant { dq .name } " )
118
148
119
149
graph .eliminate_dead_code ()
120
150
graph .lint ()
0 commit comments