Skip to content

Commit

Permalink
feat: Improve Dynamo partitioning System Performance on Large Models (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Aug 15, 2023
1 parent 32d905b commit b57d83e
Show file tree
Hide file tree
Showing 15 changed files with 634 additions and 146 deletions.
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
OPTIMIZATION_LEVEL = None
TRUNCATE_LONG_AND_DOUBLE = False
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
Expand All @@ -29,3 +30,4 @@ class CompilationSettings:
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
use_fast_partitioner: bool = USE_FAST_PARTITIONER
65 changes: 56 additions & 9 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import torch
import torch._dynamo as td
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo import CompilationSettings, partitioning
from torch_tensorrt.dynamo.conversion import (
convert_module,
repair_long_or_double_inputs,
)
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs

Expand Down Expand Up @@ -111,24 +110,68 @@ def _compile_module(
Returns:
Compiled FX GraphModule
"""
# Partition module into components that can be TRT-accelerated
partitioned_module = partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
gm, settings.debug, settings.torch_executed_ops
)

# If the number of supported operations is 0 or less than the block size, skip the subgraph
# TODO: Add condition to second expression below when require_full_compilation is added
if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size):
logger.warning(
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
)
return gm
else:
logger.debug(
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
partitioned_module = partitioning.fast_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
logger.error(
"Partitioning failed on the subgraph with fast partition. See trace above. "
+ "Retrying with global partition.",
exc_info=True,
)

fast_partitioner_failed = True
settings.use_fast_partitioner = False

if not settings.use_fast_partitioner:
partitioned_module = partitioning.global_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)

# Store TRT replicas of Torch subgraphs
trt_modules = {}

# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
for name, _ in partitioned_module.named_children():
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
continue

submodule = getattr(partitioned_module, name)

# Get submodule inputs
submodule_inputs = get_submod_inputs(
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module, submodule, sample_inputs
)

Expand All @@ -153,4 +196,8 @@ def _compile_module(
for name, trt_mod in trt_modules.items():
setattr(partitioned_module, name, trt_mod)

# Reset settings object to user specification after fallback to global partitioning mode
if fast_partitioner_failed:
settings.use_fast_partitioner = True

return partitioned_module
54 changes: 5 additions & 49 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch_tensorrt
from torch.fx.passes.pass_manager import PassManager
from torch.fx.passes.splitter_base import SplitResult
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
EngineCapability,
Expand All @@ -21,18 +20,17 @@
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
)
from torch_tensorrt.dynamo.backend.backends import _compile_module
from torch_tensorrt.dynamo.conversion import convert_module
from torch_tensorrt.dynamo.lowering._fusers import (
fuse_permute_linear,
fuse_permute_matmul,
)
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,6 +62,7 @@ def compile(
version_compatible: bool = VERSION_COMPATIBLE,
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_python_runtime: bool = USE_PYTHON_RUNTIME,
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
**kwargs: Any,
) -> torch.fx.GraphModule:
if debug:
Expand All @@ -75,7 +74,7 @@ def compile(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
+ "torch_executed_ops, pass_through_build_failures}"
+ "torch_executed_ops, pass_through_build_failures, use_fast_partitioner}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -115,55 +114,12 @@ def compile(
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"use_fast_partitioner": use_fast_partitioner,
}

settings = CompilationSettings(**compilation_options)
if kwargs.get("use_capability_partitioner", None):
model = lower_model(gm, torch_inputs)
return _compile_module(model, torch_inputs, settings)
else:
split_result = lower_model_using_trt_splitter(gm, torch_inputs)
trt_module = _compile_graph(split_result, torch_inputs, settings)

return trt_module


def _compile_graph(
split_result: SplitResult,
inputs: Any,
settings: CompilationSettings = CompilationSettings(),
**kwargs: Any,
) -> torch.fx.GraphModule:
for submod_name, submod_inputs in split_result.submodule_inputs.items():
submod = getattr(split_result.split_module, submod_name)
# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
# Create TRT Module from submodule
trt_mod = convert_module(
submod,
submod_inputs,
settings=settings,
name=submod_name,
)
setattr(split_result.split_module, submod_name, trt_mod)

return split_result.split_module


def lower_model_using_trt_splitter(
model: torch.nn.Module, inputs: Any, **kwargs: Any
) -> SplitResult:
# Perform basic lowering
model = lower_model(model, inputs)
splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = False
splitter_setting.min_acc_module_size = 1
splitter_setting.use_experimental_rt = False
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
splitter.node_support_preview()
split_result = splitter.generate_split_results()

return split_result
return _compile_module(gm, torch_inputs, settings)


def lower_model(
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/converter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ def unique_targets(self) -> Set[Target]:
"""Returns the set of unique converter targets stored across all registries"""
return set.union(*[set(registry.keys()) for registry in self.registries])

# TODO: Make this a static method since it does not need state
def qualified_name_or_str(self, target: Target) -> str:
@staticmethod
def qualified_name_or_str(target: Target) -> str:
"""Returns string representation of an FX Node target"""
if isinstance(target, str):
return target
Expand Down
9 changes: 2 additions & 7 deletions py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from ._decompositions import get_decompositions # noqa: F401
from ._fusers import * # noqa: F403
from ._partition import ( # noqa: F401
DEFAULT_SINGLE_NODE_PARTITIONS,
get_submod_inputs,
partition,
)
from ._fusers import * # noqa: F401
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
from ._pre_aot_lowering import register_substitution # noqa: F401
from .substitutions import * # noqa: F403
from .substitutions import * # noqa: F401
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/partitioning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._adjacency_partitioner import partition as fast_partition
from ._global_partitioner import partition as global_partition
from .common import get_graph_converter_support, get_submod_inputs
Loading

0 comments on commit b57d83e

Please sign in to comment.