Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,13 +1041,13 @@ def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...
def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "requantize")
def requantize(
@impl(m, "requantize.per_tensor")
def requantize_per_tensor(
input: torch.Tensor,
in_scale: torch.Tensor,
in_zero_point: torch.Tensor,
out_scale: torch.Tensor,
out_zero_point: torch.Tensor,
in_scale: float,
in_zero_point: int,
out_scale: float,
out_zero_point: int,
dtype: ScalarType,
) -> torch.Tensor:
if dtype in qdtype_map:
Expand All @@ -1056,11 +1056,6 @@ def requantize(
torch.dequantize(input), out_scale, out_zero_point, qdtype_map[dtype]
)

# For in_scale or out_scale other than scalar, it requires quant/dequant
# per channel, but the channel dimension value is missing
if in_scale.numel() > 1 or out_scale.numel() > 1:
raise NotImplementedError("Only scalar scales are supported")

quant_min = torch.iinfo(input.dtype).min
quant_max = torch.iinfo(input.dtype).max
# pyre-fixme[6]: This dtype is actually the right one.
Expand All @@ -1070,14 +1065,14 @@ def requantize(
return torch.ops.quantized_decomposed.quantize_per_tensor(
torch.ops.quantized_decomposed.dequantize_per_tensor(
input,
in_scale.flatten()[0],
in_zero_point.flatten()[0],
in_scale,
in_zero_point,
quant_min,
quant_max,
input.dtype,
),
out_scale.flatten()[0],
out_zero_point.flatten()[0],
out_scale,
out_zero_point,
out_quant_min,
out_quant_max,
dtype,
Expand All @@ -1092,3 +1087,15 @@ def rms_norm(
eps: float,
) -> torch.Tensor:
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)


@impl(m, "where_Scalar")
def where_Scalar(
condition: torch.Tensor,
if_true: float,
if_false: float,
) -> torch.Tensor:
if condition.dtype != torch.bool:
raise ValueError("condition must be a bool tensor")

return torch.where(condition, if_true, if_false)
11 changes: 11 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,3 +1145,14 @@ def test_quantized_relu(
torch.equal(output, expected_output),
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)

def test_where_Scalar(self) -> None:
input_tensor = torch.tensor([1, 2, 3, 4], dtype=torch.int8)
out = torch.ops.cadence.where_Scalar(input_tensor > 2, 1.0, 0.0)
self.assertTrue(
torch.equal(out, torch.tensor([0.0, 0.0, 1.0, 1.0], dtype=torch.float32))
)
with self.assertRaises(ValueError) as context:
torch.ops.cadence.where_Scalar(input_tensor, 1.0, 0.0)

self.assertIn("condition must be a bool tensor", str(context.exception))
Loading