Skip to content

Commit ddc7cdb

Browse files
authored
[ET-VK][ez] Accept sample_kwargs as an argument in several test util functions (#15314)
Title says it all! This makes it possible to export models that require kwargs to be defined instead of args. Differential Revision: [D84716455](https://our.internmc.facebook.com/intern/diff/D84716455/) [ghstack-poisoned]
1 parent 86c1a9a commit ddc7cdb

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

backends/vulkan/test/utils.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,16 @@ class QuantizationMode(Enum):
5050
def get_exported_graph(
5151
model,
5252
sample_inputs,
53+
sample_kwargs=None,
5354
dynamic_shapes=None,
5455
qmode=QuantizationMode.NONE,
5556
) -> torch.fx.GraphModule:
5657
export_training_graph = export(
57-
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
58+
model,
59+
sample_inputs,
60+
kwargs=sample_kwargs,
61+
dynamic_shapes=dynamic_shapes,
62+
strict=True,
5863
).module()
5964

6065
if qmode == QuantizationMode.NONE:
@@ -82,6 +87,7 @@ def random_uniform_tensor(shape, low=0.0, high=1.0, device=None, dtype=None):
8287
def export_model_to_vulkan(
8388
model,
8489
sample_inputs,
90+
sample_kwargs=None,
8591
dynamic_shapes=None,
8692
operator_blocklist=None,
8793
operator_allowlist=None,
@@ -91,11 +97,16 @@ def export_model_to_vulkan(
9197
):
9298
compile_options = {}
9399
exported_graph = get_exported_graph(
94-
model, sample_inputs, dynamic_shapes=dynamic_shapes, qmode=qmode
100+
model,
101+
sample_inputs,
102+
sample_kwargs=sample_kwargs,
103+
dynamic_shapes=dynamic_shapes,
104+
qmode=qmode,
95105
)
96106
program = export(
97107
exported_graph,
98108
sample_inputs,
109+
kwargs=sample_kwargs,
99110
dynamic_shapes=dynamic_shapes,
100111
strict=True,
101112
)
@@ -422,6 +433,7 @@ def save_bundled_program(
422433
sample_inputs: Tuple[torch.Tensor],
423434
output_path: str,
424435
method_name: str = "forward",
436+
sample_kwargs=None,
425437
et_program: Optional[ExecutorchProgramManager] = None,
426438
dynamic_shapes=None,
427439
) -> str:
@@ -441,13 +453,21 @@ def save_bundled_program(
441453
"""
442454
# If no ExecutorchProgramManager provided, export to Vulkan
443455
if et_program is None:
444-
et_program = export_model_to_vulkan(model, sample_inputs, dynamic_shapes)
456+
et_program = export_model_to_vulkan(
457+
model,
458+
sample_inputs,
459+
sample_kwargs=sample_kwargs,
460+
dynamic_shapes=dynamic_shapes,
461+
)
462+
463+
if sample_kwargs is None:
464+
sample_kwargs = {}
445465

446466
# Generate expected outputs by running the model
447-
expected_outputs = [getattr(model, method_name)(*sample_inputs)]
467+
expected_outputs = [getattr(model, method_name)(*sample_inputs, **sample_kwargs)]
448468

449469
# Flatten sample inputs to match expected format
450-
inputs_flattened, _ = tree_flatten(sample_inputs)
470+
inputs_flattened, _ = tree_flatten((sample_inputs, sample_kwargs))
451471

452472
# Create test suite with the sample inputs and expected outputs
453473
test_suites = [

0 commit comments

Comments
 (0)