Skip to content

Commit

Permalink
fix/feat: Add Dynamo-only converter registry
Browse files Browse the repository at this point in the history
- Add Dynamo converter registry which functions as a superset of the
standard FX converter registry
- For use with new + experimental converters
- Uses custom decorator `dynamo_tensorrt_converter`
- Update references within Dynamo functions to use the converter
registry `DYNAMO_CONVERTERS`
  • Loading branch information
gs-olive committed Jul 7, 2023
1 parent 44e4ffa commit 4239a7b
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
5 changes: 5 additions & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from .converter_registry import (
DYNAMO_CONVERTERS,
dynamo_tensorrt_converter,
)

from torch_tensorrt.dynamo import fx_ts_compat
from .backend import compile
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.fx.node import _get_qualified_name
from torch.fx.passes.operator_support import OperatorSupport

from torch_tensorrt.fx.converter_registry import CONVERTERS
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS


logger = logging.getLogger(__name__)
Expand Down
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/converter_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Any, Callable, Dict

from torch.fx.node import Target
from torch_tensorrt.fx.converter_registry import CONVERTERS

DYNAMO_CONVERTERS: Dict[Target, Any] = dict(CONVERTERS)


def dynamo_tensorrt_converter(
key: Target,
enabled: bool = True,
) -> Callable[[Any], Any]:
def register_converter(converter):
DYNAMO_CONVERTERS[key] = converter
return converter

def disable_converter(converter):
return converter

if enabled:
return register_converter
else:
return disable_converter
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata

from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
from .input_tensor_spec import InputTensorSpec
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.fx.utils import (
Expand Down

0 comments on commit 4239a7b

Please sign in to comment.