Skip to content

Commit

Permalink
feat: support more elementwise and unary dynamo converters (#2429)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Nov 28, 2023
1 parent cd158b6 commit 53401dd
Show file tree
Hide file tree
Showing 14 changed files with 876 additions and 67 deletions.
256 changes: 253 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,9 +1739,177 @@ def aten_ops_logical_xor(
)


def bitwise_type_validator(node: Node) -> bool:
supported_type = [torch.bool, bool]

tensor_targets = [
torch.ops.aten.bitwise_and.Tensor,
torch.ops.aten.bitwise_or.Tensor,
torch.ops.aten.bitwise_xor.Tensor,
]
scalar_targets = [
torch.ops.aten.bitwise_and.Scalar,
torch.ops.aten.bitwise_or.Scalar,
torch.ops.aten.bitwise_xor.Scalar,
]
scalar_tensor_targets = [
torch.ops.aten.bitwise_and.Scalar_Tensor,
torch.ops.aten.bitwise_or.Scalar_Tensor,
torch.ops.aten.bitwise_xor.Scalar_Tensor,
]

if node.target in tensor_targets:
lhs_val = node.args[0]
rhs_val = node.args[1]
lhs_meta = lhs_val.meta.get("tensor_meta")
rhs_meta = rhs_val.meta.get("tensor_meta")
if lhs_meta is None or rhs_meta is None:
return False
return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type

elif node.target in scalar_targets:
lhs_val = node.args[0]
rhs_val = node.args[1]
lhs_meta = lhs_val.meta.get("tensor_meta")
if lhs_meta is None:
return False
return lhs_meta.dtype in supported_type and isinstance(rhs_val, bool)

elif node.target in scalar_tensor_targets:
lhs_val = node.args[0]
rhs_val = node.args[1]
rhs_meta = rhs_val.meta.get("tensor_meta")
if rhs_meta is None:
return False
return isinstance(lhs_val, bool) and rhs_meta.dtype in supported_type

else:
return False


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Scalar, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Scalar_Tensor,
capability_validator=bitwise_type_validator,
)
def aten_ops_bitwise_and(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.bitwise_and(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Scalar, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Scalar_Tensor, capability_validator=bitwise_type_validator
)
def aten_ops_bitwise_or(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.bitwise_or(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Scalar, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Scalar_Tensor,
capability_validator=bitwise_type_validator,
)
def aten_ops_bitwise_xor(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.bitwise_xor(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


def bitwise_not_type_validator(node: Node) -> bool:
val = node.args[0]
val_meta = val.meta.get("tensor_meta")

if val_meta is None:
return False

supported_type = [torch.bool, bool]
return val_meta.dtype in supported_type


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_bitwise_not(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.bitwise_not(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
def aten_ops_equal(
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_eq(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -1758,9 +1926,38 @@ def aten_ops_equal(
)


@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_ne(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.ne(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
def aten_ops_greater(
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_gt(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -1777,9 +1974,38 @@ def aten_ops_greater(
)


@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_ge(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.ge(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
def aten_ops_less(
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_lt(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -1796,6 +2022,30 @@ def aten_ops_less(
)


@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_le(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.le(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


def conv_param_validator(conv_node: Node) -> bool:
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])

Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def convert_binary_elementwise(
source_ir: Optional[SourceIR],
name: str,
op_type: trt.ElementWiseOperation,
lhs_val: Union[int, float, TRTTensor, torch.Tensor],
rhs_val: Union[int, float, TRTTensor, torch.Tensor],
lhs_val: Union[int, float, bool, TRTTensor, torch.Tensor],
rhs_val: Union[int, float, bool, TRTTensor, torch.Tensor],
) -> TRTTensor:
"""
This function adds a TensorRT elementwise layer. We allow both operands to be
Expand Down Expand Up @@ -120,11 +120,11 @@ def convert_binary_elementwise(
# Note that the dtype here is supposed to be the same as the scalar
# dtype but we don't have a way to detect whether it makes sense for the
# scalar to be float or half. Hence we go with the lhs dtype.
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)):
rhs_val = np.array(
[rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY)
)
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)):
lhs_val = np.array(
[lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY)
)
Expand Down
Loading

0 comments on commit 53401dd

Please sign in to comment.