Skip to content

Commit 9b1b44b

Browse files
cccclaifacebook-github-bot
authored andcommitted
Fix mobile bert fine tune (#9915)
Summary: As title, it's broken in #9643 Differential Revision: D72472098
1 parent 56c8dc2 commit 9b1b44b

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

examples/qualcomm/scripts/mobilebert_fine_tune.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717
generate_htp_compiler_spec,
1818
generate_qnn_executorch_compiler_spec,
1919
skip_annotation,
20+
to_edge_transform_and_lower_to_qnn,
2021
)
2122
from executorch.examples.qualcomm.utils import (
2223
build_executorch_binary,
2324
make_output_dir,
2425
make_quantizer,
2526
parse_skip_delegation_node,
26-
QnnPartitioner,
2727
setup_common_args_and_variables,
2828
SimpleADB,
2929
)
30-
from executorch.exir import to_edge
30+
from executorch.exir import ExecutorchBackendConfig
3131
from transformers import BertTokenizer, MobileBertForSequenceClassification
3232

3333

@@ -273,30 +273,42 @@ def calibrator(gm):
273273

274274
quantizer = make_quantizer(quant_dtype=quant_dtype)
275275
backend_options = generate_htp_compiler_spec(quant_dtype is not None)
276-
partitioner = QnnPartitioner(
277-
generate_qnn_executorch_compiler_spec(
278-
soc_model=getattr(QcomChipset, args.model),
279-
backend_options=backend_options,
280-
),
281-
skip_node_id_set=skip_node_id_set,
282-
skip_node_op_set=skip_node_op_set,
276+
# partitioner = QnnPartitioner(
277+
# generate_qnn_executorch_compiler_spec(
278+
# soc_model=getattr(QcomChipset, args.model),
279+
# backend_options=backend_options,
280+
# ),
281+
# skip_node_id_set=skip_node_id_set,
282+
# skip_node_op_set=skip_node_op_set,
283+
# )
284+
backend_options = generate_htp_compiler_spec(
285+
use_fp16=False,
286+
)
287+
compile_spec = generate_qnn_executorch_compiler_spec(
288+
soc_model=QcomChipset.SM8550,
289+
backend_options=backend_options,
283290
)
284291
# skip embedding layer cause it's quantization sensitive
285292
graph_module, _ = skip_annotation(
286293
nn_module=model,
287294
quantizer=quantizer,
288-
partitioner=partitioner,
295+
compiler_specs=compile_spec,
289296
sample_input=inputs[0],
290297
calibration_cb=calibrator,
291298
fp_node_op_set={torch.ops.aten.embedding.default},
292299
)
293300
# lower all graph again, the skipped operators will be left in CPU
294-
exec_prog = to_edge(
295-
torch.export.export(graph_module, inputs[0], strict=True),
296-
).to_executorch()
297-
301+
# exec_prog = to_edge(
302+
# torch.export.export(graph_module, inputs[0], strict=True),
303+
# ).to_executorch()
304+
delegated_program = to_edge_transform_and_lower_to_qnn(
305+
graph_module, inputs[0], compile_spec
306+
)
307+
executorch_program = delegated_program.to_executorch(
308+
config=ExecutorchBackendConfig(extract_delegate_segments=True)
309+
)
298310
with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
299-
file.write(exec_prog.buffer)
311+
file.write(executorch_program.buffer)
300312

301313
if args.compile_only:
302314
return

0 commit comments

Comments
 (0)