Skip to content

Commit

Permalink
feat: Add ExportedProgram as an IR (#2191)
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 authored Aug 16, 2023
1 parent 08a2ee4 commit 6814350
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.fx
import torch_tensorrt.ts
from torch._export import ExportedProgram
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.compile import compile as dynamo_compile
Expand Down Expand Up @@ -43,6 +44,7 @@ class _IRType(Enum):
fx = 1
dynamo = 2
torch_compile = 3
exported_program = 4


class _ModuleType(Enum):
Expand All @@ -51,6 +53,7 @@ class _ModuleType(Enum):
nn = 0
ts = 1
fx = 2
ep = 3


def _parse_module_type(module: Any) -> _ModuleType:
Expand All @@ -61,6 +64,8 @@ def _parse_module_type(module: Any) -> _ModuleType:
return _ModuleType.ts
elif isinstance(module, torch.fx.GraphModule):
return _ModuleType.fx
elif isinstance(module, ExportedProgram):
return _ModuleType.ep
elif isinstance(module, torch.nn.Module):
return _ModuleType.nn
else:
Expand All @@ -70,6 +75,7 @@ def _parse_module_type(module: Any) -> _ModuleType:
def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
module_is_tsable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.ts])
module_is_fxable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.fx])
module_is_exportable = module_type == _ModuleType.ep

ir_targets_torchscript = any(ir == opt for opt in ["torchscript", "ts"])
ir_targets_fx = ir == "fx"
Expand All @@ -95,8 +101,16 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript"
)
return _IRType.ts
elif module_is_exportable:
raise ValueError(
"Input graph is an ExportedProgram which is not currently supported. Please provide torch.nn.Module or torch.fx.GraphModule as input."
)
else:
raise ValueError("Module was provided in an unsupported format")
elif ir == "exported_program":
raise ValueError(
"ir=exported_program is not currently supported. Supported ir options : ts|fx|dynamo"
)
else:
raise ValueError("Unknown ir was requested")

Expand Down

0 comments on commit 6814350

Please sign in to comment.