@@ -50,11 +50,16 @@ class QuantizationMode(Enum):
5050def get_exported_graph (
5151 model ,
5252 sample_inputs ,
53+ sample_kwargs = None ,
5354 dynamic_shapes = None ,
5455 qmode = QuantizationMode .NONE ,
5556) -> torch .fx .GraphModule :
5657 export_training_graph = export (
57- model , sample_inputs , dynamic_shapes = dynamic_shapes , strict = True
58+ model ,
59+ sample_inputs ,
60+ kwargs = sample_kwargs ,
61+ dynamic_shapes = dynamic_shapes ,
62+ strict = True ,
5863 ).module ()
5964
6065 if qmode == QuantizationMode .NONE :
@@ -82,6 +87,7 @@ def random_uniform_tensor(shape, low=0.0, high=1.0, device=None, dtype=None):
8287def export_model_to_vulkan (
8388 model ,
8489 sample_inputs ,
90+ sample_kwargs = None ,
8591 dynamic_shapes = None ,
8692 operator_blocklist = None ,
8793 operator_allowlist = None ,
@@ -91,11 +97,16 @@ def export_model_to_vulkan(
9197):
9298 compile_options = {}
9399 exported_graph = get_exported_graph (
94- model , sample_inputs , dynamic_shapes = dynamic_shapes , qmode = qmode
100+ model ,
101+ sample_inputs ,
102+ sample_kwargs = sample_kwargs ,
103+ dynamic_shapes = dynamic_shapes ,
104+ qmode = qmode ,
95105 )
96106 program = export (
97107 exported_graph ,
98108 sample_inputs ,
109+ kwargs = sample_kwargs ,
99110 dynamic_shapes = dynamic_shapes ,
100111 strict = True ,
101112 )
@@ -422,6 +433,7 @@ def save_bundled_program(
422433 sample_inputs : Tuple [torch .Tensor ],
423434 output_path : str ,
424435 method_name : str = "forward" ,
436+ sample_kwargs = None ,
425437 et_program : Optional [ExecutorchProgramManager ] = None ,
426438 dynamic_shapes = None ,
427439) -> str :
@@ -441,13 +453,21 @@ def save_bundled_program(
441453 """
442454 # If no ExecutorchProgramManager provided, export to Vulkan
443455 if et_program is None :
444- et_program = export_model_to_vulkan (model , sample_inputs , dynamic_shapes )
456+ et_program = export_model_to_vulkan (
457+ model ,
458+ sample_inputs ,
459+ sample_kwargs = sample_kwargs ,
460+ dynamic_shapes = dynamic_shapes ,
461+ )
462+
463+ if sample_kwargs is None :
464+ sample_kwargs = {}
445465
446466 # Generate expected outputs by running the model
447- expected_outputs = [getattr (model , method_name )(* sample_inputs )]
467+ expected_outputs = [getattr (model , method_name )(* sample_inputs , ** sample_kwargs )]
448468
449469 # Flatten sample inputs to match expected format
450- inputs_flattened , _ = tree_flatten (sample_inputs )
470+ inputs_flattened , _ = tree_flatten (( sample_inputs , sample_kwargs ) )
451471
452472 # Create test suite with the sample inputs and expected outputs
453473 test_suites = [
0 commit comments