Skip to content

Commit

Permalink
Merge pull request #2189 from pytorch/py38_compatibility
Browse files Browse the repository at this point in the history
Py38 compatibility
  • Loading branch information
narendasan authored Aug 10, 2023
2 parents b3089bf + f53a823 commit 81d6bcc
Show file tree
Hide file tree
Showing 15 changed files with 47 additions and 15 deletions.
12 changes: 7 additions & 5 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -32,11 +34,11 @@ class _ShapeMode(Enum):
shape: Optional[
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
dtype: _enums.dtype = ( # type: ignore[name-defined]
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
_explicit_set_dtype: bool = False
format: _enums.TensorFormat = ( # type: ignore[name-defined]
format: _enums.TensorFormat = (
_enums.TensorFormat.contiguous
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)

Expand Down Expand Up @@ -208,7 +210,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
return False

@staticmethod
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
def _parse_dtype(dtype: Any) -> _enums.dtype:
if isinstance(dtype, torch.dtype):
if dtype == torch.long:
return _enums.dtype.long
Expand Down Expand Up @@ -236,7 +238,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
)

@staticmethod
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
if dtype == _enums.dtype.long:
return torch.long
elif dtype == _enums.dtype.int32:
Expand All @@ -255,7 +257,7 @@ def is_trt_dtype(self) -> bool:
return bool(self.dtype != _enums.dtype.long)

@staticmethod
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
def _parse_format(format: Any) -> _enums.TensorFormat:
if isinstance(format, torch.memory_format):
if format == torch.contiguous_format:
return _enums.TensorFormat.contiguous
Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Callable, List, Optional, Sequence, Set, TypeGuard
from typing import Any, Callable, List, Optional, Sequence, Set

import torch
import torch.fx
Expand All @@ -12,6 +14,7 @@
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.ts._compiler import compile as torchscript_compile
from typing_extensions import TypeGuard


def _non_fx_input_interface(
Expand Down
6 changes: 4 additions & 2 deletions py/torch_tensorrt/dynamo/aten_tracer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import copy
import sys
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import torch
import torch._dynamo as torchdynamo
Expand All @@ -22,7 +24,7 @@
)
from typing_extensions import TypeAlias

Value: TypeAlias = Tuple["Value", ...] | List["Value"] | Dict[str, "Value"]
Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]]


class DynamoConfig:
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from functools import partial
from typing import Any, Callable, Sequence
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import collections.abc
import logging
from typing import Any, List, Optional, Set, Tuple
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/conversion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import io
from typing import Sequence

Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from enum import Enum, auto
Expand Down Expand Up @@ -28,7 +30,7 @@
Dict[str, Argument],
str,
],
TRTTensor | Sequence[TRTTensor],
Union[TRTTensor, Sequence[TRTTensor]],
]


Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shape.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import List, Optional, Tuple

import numpy as np
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Optional, Sequence, Set

import torch
Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Type, TypeAlias
from typing import Any, Callable, Dict, Optional, Type

import torch
from torch._ops import OpOverload
from torch.fx import GraphModule, Node
from typing_extensions import TypeAlias

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, Tuple

import torch
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from typing import Any, List, Optional, Tuple

Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from dataclasses import fields, replace
from typing import Any, Callable, Dict, Optional, Sequence
Expand Down
12 changes: 7 additions & 5 deletions py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any, Dict, List, Optional, Set

Expand Down Expand Up @@ -39,7 +41,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
)


def _parse_op_precision(precision: Any) -> _enums.dtype: # type: ignore[name-defined]
def _parse_op_precision(precision: Any) -> _enums.dtype:
if isinstance(precision, torch.dtype):
if precision == torch.int8:
return _enums.dtype.int8
Expand All @@ -63,7 +65,7 @@ def _parse_op_precision(precision: Any) -> _enums.dtype: # type: ignore[name-de
)


def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: # type: ignore[name-defined]
def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]:
parsed_precisions = set()
if any(isinstance(precisions, type) for type in [list, tuple, set]):
for p in precisions:
Expand All @@ -73,7 +75,7 @@ def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: # type: ig
return parsed_precisions


def _parse_device_type(device: Any) -> _enums.DeviceType: # type: ignore[name-defined]
def _parse_device_type(device: Any) -> _enums.DeviceType:
if isinstance(device, torch.device):
if device.type == "cuda":
return _C.DeviceType.gpu
Expand Down Expand Up @@ -346,10 +348,10 @@ def TensorRTCompileSpec(
device: torch.device | Device = Device._current_device(),
disable_tf32: bool = False,
sparse_weights: bool = False,
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, # type: ignore[name-defined]
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
refit: bool = False,
debug: bool = False,
capability: _enums.EngineCapability = _enums.EngineCapability.default, # type: ignore[name-defined]
capability: _enums.EngineCapability = _enums.EngineCapability.default,
num_avg_timing_iters: int = 1,
workspace_size: int = 0,
dla_sram_size: int = 1048576,
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/ts/_compiler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any, List, Optional, Sequence, Set, Tuple

import torch
Expand Down

0 comments on commit 81d6bcc

Please sign in to comment.