From 4dddabd27a1b11c4d9aadfa2b134488c18ffca51 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Tue, 23 Sep 2025 12:03:52 -0700 Subject: [PATCH 1/2] Custom where_Scalar op (#14470) Summary: Continued support of custom cadence ops Reviewed By: hsharma35 Differential Revision: D82703256 --- backends/cadence/aot/ref_implementations.py | 12 ++++++++++++ .../cadence/aot/tests/test_ref_implementations.py | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index fe012837870..89abb006f6f 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1092,3 +1092,15 @@ def rms_norm( eps: float, ) -> torch.Tensor: return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X) + + +@impl(m, "where_Scalar") +def where_Scalar( + condition: torch.Tensor, + if_true: float, + if_false: float, +) -> torch.Tensor: + if condition.dtype != torch.bool: + raise ValueError("condition must be a bool tensor") + + return torch.where(condition, if_true, if_false) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index bc025f4c894..26281b70216 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1145,3 +1145,14 @@ def test_quantized_relu( torch.equal(output, expected_output), f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", ) + + def test_where_Scalar(self) -> None: + input_tensor = torch.tensor([1, 2, 3, 4], dtype=torch.int8) + out = torch.ops.cadence.where_Scalar(input_tensor > 2, 1.0, 0.0) + self.assertTrue( + torch.equal(out, torch.tensor([0.0, 0.0, 1.0, 1.0], dtype=torch.float32)) + ) + with self.assertRaises(ValueError) as context: + torch.ops.cadence.where_Scalar(input_tensor, 1.0, 0.0) + + self.assertIn("condition must be a bool tensor", str(context.exception)) From 840a36ce0ed049c9bc9591950c11b18d16ae89ae Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Tue, 23 Sep 2025 12:03:52 -0700 Subject: [PATCH 2/2] Update requantize to requantize_per_tensor since we don't have a non-per-tensor variant (#14482) Summary: As titled Reviewed By: hsharma35 Differential Revision: D82995376 --- backends/cadence/aot/ref_implementations.py | 25 +++++++++------------ 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 89abb006f6f..781f04ae1da 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1041,13 +1041,13 @@ def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ... def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "requantize") -def requantize( +@impl(m, "requantize.per_tensor") +def requantize_per_tensor( input: torch.Tensor, - in_scale: torch.Tensor, - in_zero_point: torch.Tensor, - out_scale: torch.Tensor, - out_zero_point: torch.Tensor, + in_scale: float, + in_zero_point: int, + out_scale: float, + out_zero_point: int, dtype: ScalarType, ) -> torch.Tensor: if dtype in qdtype_map: @@ -1056,11 +1056,6 @@ def requantize( torch.dequantize(input), out_scale, out_zero_point, qdtype_map[dtype] ) - # For in_scale or out_scale other than scalar, it requires quant/dequant - # per channel, but the channel dimension value is missing - if in_scale.numel() > 1 or out_scale.numel() > 1: - raise NotImplementedError("Only scalar scales are supported") - quant_min = torch.iinfo(input.dtype).min quant_max = torch.iinfo(input.dtype).max # pyre-fixme[6]: This dtype is actually the right one. @@ -1070,14 +1065,14 @@ def requantize( return torch.ops.quantized_decomposed.quantize_per_tensor( torch.ops.quantized_decomposed.dequantize_per_tensor( input, - in_scale.flatten()[0], - in_zero_point.flatten()[0], + in_scale, + in_zero_point, quant_min, quant_max, input.dtype, ), - out_scale.flatten()[0], - out_zero_point.flatten()[0], + out_scale, + out_zero_point, out_quant_min, out_quant_max, dtype,