diff --git a/py/torch_tensorrt/dynamo/_DryRunTracker.py b/py/torch_tensorrt/dynamo/_DryRunTracker.py index 625db8753d..85cc23165a 100644 --- a/py/torch_tensorrt/dynamo/_DryRunTracker.py +++ b/py/torch_tensorrt/dynamo/_DryRunTracker.py @@ -1,9 +1,9 @@ import logging import math from dataclasses import dataclass, field -from typing import List, Tuple +from typing import Any, Dict, List -import torch +from torch_tensorrt.dynamo._settings import CompilationSettings logger = logging.getLogger(__name__) @@ -15,18 +15,18 @@ class PerSubgraphData: Args: subgraph_name (str): Name of the subgraph in the GraphModule subgraph_op_count (int): Number of operations in the subgraph - subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph - subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph - subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph - subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph + subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph + subgraph_input_dtypes (Any): Input data types of the subgraph + subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph + subgraph_output_dtypes (Any): Output data types of the subgraph """ subgraph_name: str = "" subgraph_op_count: int = 0 - subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list) - subgraph_input_dtypes: List[torch.device] = field(default_factory=list) - subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list) - subgraph_output_dtypes: List[torch.device] = field(default_factory=list) + subgraph_input_shapes: Any = field(default_factory=list) + subgraph_input_dtypes: Any = field(default_factory=list) + subgraph_output_shapes: Any = field(default_factory=list) + subgraph_output_dtypes: Any = field(default_factory=list) @dataclass @@ -36,81 +36,72 @@ class DryRunTracker: Args: total_ops_in_graph (int): Total number of operators in graph supported_ops_in_graph (int): Number of supported operators in graph - graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph - graph_input_dtypes (List[torch.device]): Input data types of the graph - graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph - graph_output_dtypes (List[torch.device]): Output data types of the graph + graph_input_shapes (Any): Shapes of input Tensors of the graph + graph_input_dtypes (Any): Input data types of the graph + graph_output_shapes (Any): Shapes of output Tensors of the graph + graph_output_dtypes (Any): Output data types of the graph per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class tensorrt_graph_count (int): Number of TensorRT engines to be generated - truncated_long_and_double (bool): Whether truncate_long_and_double was enabled + compilation_settings (CompilationSettings): User Compilation Settings + unsupported_ops (Dict[str, int]): Set of operators not supported in TRT """ total_ops_in_graph: int = 0 supported_ops_in_graph: int = 0 - graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list) - graph_input_dtypes: List[torch.device] = field(default_factory=list) - graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list) - graph_output_dtypes: List[torch.device] = field(default_factory=list) + graph_input_shapes: Any = field(default_factory=list) + graph_input_dtypes: Any = field(default_factory=list) + graph_output_shapes: Any = field(default_factory=list) + graph_output_dtypes: Any = field(default_factory=list) per_subgraph_data: List[PerSubgraphData] = field(default_factory=list) tensorrt_graph_count: int = 0 - truncated_long_and_double: bool = False + compilation_settings: CompilationSettings = field( + default_factory=CompilationSettings + ) + unsupported_ops: Dict[str, int] = field(default_factory=dict) def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None: - """Displays statistics about the dryrun either to debug logs or info logs""" - # If user specified "dryrun=True", print to info logs, else debug - if dryrun_enabled: - dryrun_logger = logger.info - else: - dryrun_logger = logger.debug - + """Displays statistics about the dryrun either to debug logs or stdout""" formatted_stats = "\n" # Print overall stats about the graph, operator counts, etc. - formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n" + formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n\n" formatted_stats += ( f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, " f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, " - f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n" - ) - formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n" - formatted_stats += ( - f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n" + f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n" ) + formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n" + formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n" assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count # Print schematic of the graph structure, as in: # - # Inputs: [Tensor: (1, 3, 224, 224)@float32] + # Inputs: List[Tensor: (1, 3, 224, 224)@float32] # ... - # TRT Engine #1: _run_on_acc_0 - # Engine Inputs: [Tensor: (1, 3, 224, 224)@float32] - # Number of Operators in Engine: 1 - # Engine Outputs: [Tensor: (1, 64, 112, 112)@float32] + # TRT Engine #1 - Submodule name: _run_on_acc_0 + # Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32] + # Number of Operators in Engine: 1 + # Engine Outputs: Tensor: (1, 64, 112, 112)@float32 # ... - # Outputs: [Tensor: (1, 1000)@float32] + # Outputs: List[Tensor: (1, 1000)@float32] # formatted_stats += " " * 2 + "Graph Structure:\n\n" formatted_stats += ( " " * 3 - + f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n" + + f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n" ) for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data): - assert len(trt_subgraph_data.subgraph_input_dtypes) == len( - trt_subgraph_data.subgraph_input_shapes - ) - assert len(trt_subgraph_data.subgraph_output_dtypes) == len( - trt_subgraph_data.subgraph_output_shapes - ) formatted_stats += " " * 4 + "...\n" formatted_stats += ( - " " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n" + " " * 4 + + f"TRT Engine #{i+1} - Submodule name: {trt_subgraph_data.subgraph_name}\n" ) formatted_stats += ( " " * 5 - + f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n" + + f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n" ) formatted_stats += ( " " * 5 @@ -118,13 +109,13 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> ) formatted_stats += ( " " * 5 - + f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n" + + f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n" ) formatted_stats += " " * 4 + "...\n" formatted_stats += ( " " * 3 - + f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n" + + f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n" ) # Print aggregate statistics about the graph structure, including recommended "min_block_size" options @@ -167,7 +158,7 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> + " " * 3 + "- For minimal graph segmentation, select min_block_size=" + f"{most_ops_in_an_engine} which would generate " - + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines" + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engine(s)" ) if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine: formatted_stats += ( @@ -175,7 +166,7 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> + " " * 3 + "- For moderate graph segmentation, select min_block_size=" + f"{math.ceil(avg_ops_per_engine)} which would generate " - + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines" + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engine(s)" ) formatted_stats += ( @@ -183,7 +174,7 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> + " " * 3 + "- The current level of graph segmentation is equivalent to selecting min_block_size=" + f"{min_ops_in_an_engine} which generates " - + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines" + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engine(s)" ) else: formatted_stats += ( @@ -192,14 +183,45 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> + "Aggregate stats not available since no TRT Engines were generated." ) - dryrun_logger(formatted_stats) + # If user specified "dryrun=True", print to stdout, else debug + if dryrun_enabled: + print(formatted_stats) + else: + logger.debug(formatted_stats) -def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str: +def input_formatter(shapes: Any, dtypes: Any) -> str: """Format shapes and dtypes of input Tensors into a readable string""" - formatted_str = ", " - for shape, dtype in zip(shapes, dtypes): - formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, " + def input_formatter_helper(shapes: Any, dtypes: Any) -> str: + """Helper for input formatter""" + # Base case - single shape, single dtype + if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes): + return f"Tensor: {shapes}@{str(dtypes)[6:]}, " + + # Shapes is a sequence + elif isinstance(shapes, (list, tuple)): + formatted_str = "List[" if isinstance(shapes, list) else "Tuple(" + for shape, dtype in zip(shapes, dtypes): + formatted_str += input_formatter_helper(shape, dtype) + formatted_str = formatted_str[:-2] + ( + "], " if isinstance(shapes, list) else "), " + ) + return formatted_str + + # Shapes is a dictionary + elif isinstance(shapes, dict): + formatted_str = "Dict{" + + for key, shape in shapes.items(): + formatted_str += input_formatter_helper(shape, dtypes[key]) + + formatted_str = formatted_str[:-2] + "}, " + return formatted_str + + else: + raise ValueError( + f"Invalid input type {type(shapes)} encountered in parse_complex_tensor_structs parsing." + ) - return formatted_str[2:-2] + return input_formatter_helper(shapes, dtypes)[:-2] diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 78942086a0..d382f26db6 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -50,6 +50,7 @@ from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.utils import ( get_torch_inputs, + parse_complex_tensor_structs, prepare_inputs, set_log_level, to_torch_device, @@ -257,11 +258,13 @@ def compile_module( dryrun_tracker.total_ops_in_graph = total_ops dryrun_tracker.supported_ops_in_graph = num_supported_ops - dryrun_tracker.graph_input_shapes = [ - tuple(input_.shape) for input_ in sample_inputs - ] - dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs] - dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double + dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs( + sample_inputs, "shape", tuple + ) + dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs( + sample_inputs, "torch_dtype" + ) + dryrun_tracker.compilation_settings = settings if settings.dryrun and settings.min_block_size > 1: logger.info( @@ -290,7 +293,7 @@ def compile_module( # If specified, try using the fast partitioner and fall back to the global one on failure if settings.use_fast_partitioner: try: - partitioned_module = partitioning.fast_partition( + partitioned_module, supported_ops = partitioning.fast_partition( gm, verbose=settings.debug, min_block_size=settings.min_block_size, @@ -307,13 +310,15 @@ def compile_module( settings.use_fast_partitioner = False if not settings.use_fast_partitioner: - partitioned_module = partitioning.global_partition( + partitioned_module, supported_ops = partitioning.global_partition( gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, ) + dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators + # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated @@ -360,25 +365,23 @@ def compile_module( name, ) - subgraph_data.subgraph_input_dtypes = [ - submodule_input.torch_dtype for submodule_input in submodule_inputs - ] - subgraph_data.subgraph_input_shapes = [ - tuple(submodule_input.shape) for submodule_input in submodule_inputs - ] + subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs( + submodule_inputs, "shape", tuple + ) + subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs( + submodule_inputs, "torch_dtype" + ) submodule_outputs = submodule( *get_torch_inputs(submodule_inputs, to_torch_device(settings.device)) ) - if not isinstance(submodule_outputs, (list, tuple)): - submodule_outputs = [submodule_outputs] - subgraph_data.subgraph_output_dtypes = [ - submodule_output.dtype for submodule_output in submodule_outputs - ] - subgraph_data.subgraph_output_shapes = [ - tuple(submodule_output.shape) for submodule_output in submodule_outputs - ] + subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs( + submodule_outputs, "shape", tuple + ) + subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs( + submodule_outputs, "dtype" + ) dryrun_tracker.tensorrt_graph_count += 1 dryrun_tracker.per_subgraph_data.append(subgraph_data) @@ -401,10 +404,12 @@ def compile_module( if not isinstance(sample_outputs, (list, tuple)): sample_outputs = [sample_outputs] - dryrun_tracker.graph_output_shapes = [ - tuple(output_.shape) for output_ in sample_outputs - ] - dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs] + dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs( + sample_outputs, "shape", tuple + ) + dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs( + sample_outputs, "dtype" + ) # Replace all FX Modules with TRT Modules for name, trt_module in trt_modules.items(): diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 5bdbb8919b..5ec5293474 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -248,7 +248,7 @@ def partition( min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, -) -> torch.fx.GraphModule: +) -> Tuple[torch.fx.GraphModule, OpSupportTester]: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -259,7 +259,7 @@ def partition( torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage require_full_compilation: Require that all computational operators be run in TRT Returns: - torch.fx.GraphModule + torch.fx.GraphModule, OpSupportTester """ # Ensure graph is clean prior to partitioning gm.graph.eliminate_dead_code() @@ -280,4 +280,4 @@ def partition( if verbose: supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs) - return partitioned_graph + return partitioned_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 092bdabfd0..4c8efb234e 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -1,5 +1,5 @@ import logging -from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set +from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set, Tuple import torch from torch.fx.graph_module import GraphModule @@ -203,7 +203,7 @@ def partition( min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Optional[Set[str]] = None, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, -) -> torch.fx.GraphModule: +) -> Tuple[torch.fx.GraphModule, TorchTensorRTOperatorSupport]: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -214,7 +214,7 @@ def partition( torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage require_full_compilation: Whether to require that all operators be run in TRT Returns: - torch.fx.GraphModule + torch.fx.GraphModule, TorchTensorRTOperatorSupport """ supported_ops = TorchTensorRTOperatorSupport( torch_executed_ops=torch_executed_ops @@ -236,4 +236,4 @@ def partition( if verbose: supported_ops.print_support_overview(len(partitions)) - return fused_graph + return fused_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 26de1fcb27..22590fe73d 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -5,12 +5,12 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union import torch -import torch_tensorrt from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import PRECISION from torch_tensorrt.dynamo._settings import CompilationSettings +import torch_tensorrt from packaging import version logger = logging.getLogger(__name__) @@ -114,7 +114,7 @@ def prepare_inputs( inputs, disable_memory_format_check=disable_memory_format_check ) - elif isinstance(inputs, list): + elif isinstance(inputs, (list, tuple)): torchtrt_input_list = [] for input_obj in inputs: torchtrt_input = prepare_inputs( @@ -122,24 +122,62 @@ def prepare_inputs( ) torchtrt_input_list.append(torchtrt_input) - return torchtrt_input_list + return ( + torchtrt_input_list + if isinstance(inputs, list) + else tuple(torchtrt_input_list) + ) - elif isinstance(inputs, tuple): - torchtrt_inputs_tup = [] - for input_obj in inputs: + elif isinstance(inputs, dict): + torchtrt_inputs_dict: Dict[Any, Any] = dict() + + for key, input_obj in inputs.items(): torchtrt_input = prepare_inputs( input_obj, disable_memory_format_check=disable_memory_format_check ) - torchtrt_inputs_tup.append(torchtrt_input) + torchtrt_inputs_dict[key] = torchtrt_input + + return torchtrt_inputs_dict - return tuple(torchtrt_inputs_tup) + else: + raise ValueError( + f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" + ) + + +def parse_complex_tensor_structs( + inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], + attribute_to_extract: str, + apply_fn: Callable[[Any], Any] = lambda x: x, +) -> Any: + """Parses complex structures of Tensors and returns a mirrored structure + Extracts key attributes of each singular element, while reconstructing the struct + Optionally applies a function to each attribute before returning + """ + if isinstance(inputs, (torch.Tensor, Input)): + return apply_fn(getattr(inputs, attribute_to_extract, None)) + + elif isinstance(inputs, (list, tuple)): + torchtrt_input_list = [] + for input_obj in inputs: + torchtrt_input = parse_complex_tensor_structs( + input_obj, attribute_to_extract, apply_fn + ) + torchtrt_input_list.append(torchtrt_input) + + return ( + torchtrt_input_list + if isinstance(inputs, list) + else tuple(torchtrt_input_list) + ) elif isinstance(inputs, dict): torchtrt_inputs_dict: Dict[Any, Any] = dict() for key, input_obj in inputs.items(): - torchtrt_input = prepare_inputs( - input_obj, disable_memory_format_check=disable_memory_format_check + torchtrt_input = parse_complex_tensor_structs( + input_obj, attribute_to_extract, apply_fn ) torchtrt_inputs_dict[key] = torchtrt_input @@ -147,7 +185,7 @@ def prepare_inputs( else: raise ValueError( - f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " + f"Invalid input type {type(inputs)} encountered in parse_complex_tensor_structs parsing. " + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" ) diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 0038412c30..a958d03120 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -1,10 +1,11 @@ from copy import deepcopy import torch -import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo.partitioning import fast_partition +import torch_tensorrt + from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -19,7 +20,7 @@ def forward(self, x, y): return torch.mean(out, dim=1) fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = fast_partition(deepcopy(fx_graph), min_block_size=3) + partitioned_graph, _ = fast_partition(deepcopy(fx_graph), min_block_size=3) self.assertEquals( len( @@ -198,7 +199,7 @@ def forward(self, x, y): ) fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = fast_partition(deepcopy(fx_graph), min_block_size=3) + partitioned_graph, _ = fast_partition(deepcopy(fx_graph), min_block_size=3) self.assertEquals( len(list(partitioned_graph.named_children())), diff --git a/tests/py/dynamo/partitioning/test_fast_partitioning.py b/tests/py/dynamo/partitioning/test_fast_partitioning.py index a271df1e72..2ff5433b22 100644 --- a/tests/py/dynamo/partitioning/test_fast_partitioning.py +++ b/tests/py/dynamo/partitioning/test_fast_partitioning.py @@ -18,7 +18,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.fast_partition(deepcopy(fx_graph)) + partitioned_graph, _ = partitioning.fast_partition(deepcopy(fx_graph)) self.assertEquals( len( [ @@ -40,7 +40,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.fast_partition( + partitioned_graph, _ = partitioning.fast_partition( deepcopy(fx_graph), require_full_compilation=True ) self.assertEquals( @@ -68,7 +68,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partitioning.fast_partition( + partitioned_graph, _ = partitioning.fast_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( @@ -97,7 +97,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - partitioned_graph = partitioning.fast_partition( + partitioned_graph, _ = partitioning.fast_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( diff --git a/tests/py/dynamo/partitioning/test_global_partitioning.py b/tests/py/dynamo/partitioning/test_global_partitioning.py index 6701b33ccf..51d93a40f2 100644 --- a/tests/py/dynamo/partitioning/test_global_partitioning.py +++ b/tests/py/dynamo/partitioning/test_global_partitioning.py @@ -18,7 +18,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.global_partition(deepcopy(fx_graph)) + partitioned_graph, _ = partitioning.global_partition(deepcopy(fx_graph)) self.assertEquals( len(list(partitioned_graph.named_children())), 0, @@ -34,7 +34,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.global_partition( + partitioned_graph, _ = partitioning.global_partition( deepcopy(fx_graph), require_full_compilation=True ) self.assertEquals( @@ -56,7 +56,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partitioning.global_partition( + partitioned_graph, _ = partitioning.global_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( @@ -79,7 +79,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - partitioned_graph = partitioning.global_partition( + partitioned_graph, _ = partitioning.global_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index 9ec0fcf58e..742b9fc1a3 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -70,13 +70,13 @@ def compile_module_testing( ) -> torch.fx.GraphModule: """Helper compiler exclusively for testing""" if use_fast_partitioner: - partitioned_module = partitioning.fast_partition( + partitioned_module, _ = partitioning.fast_partition( gm, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, ) else: - partitioned_module = partitioning.global_partition( + partitioned_module, _ = partitioning.global_partition( gm, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops,