From 5af1e03dbe519673bac61d6657190b0a63c96f1d Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Tue, 23 Sep 2025 12:02:33 -0700 Subject: [PATCH] Custom where_Scalar op (#14470) Summary: Continued support of custom cadence ops Reviewed By: hsharma35 Differential Revision: D82703256 --- backends/cadence/aot/ref_implementations.py | 12 ++++++++++++ .../cadence/aot/tests/test_ref_implementations.py | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index fe012837870..89abb006f6f 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1092,3 +1092,15 @@ def rms_norm( eps: float, ) -> torch.Tensor: return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X) + + +@impl(m, "where_Scalar") +def where_Scalar( + condition: torch.Tensor, + if_true: float, + if_false: float, +) -> torch.Tensor: + if condition.dtype != torch.bool: + raise ValueError("condition must be a bool tensor") + + return torch.where(condition, if_true, if_false) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index bc025f4c894..26281b70216 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1145,3 +1145,14 @@ def test_quantized_relu( torch.equal(output, expected_output), f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", ) + + def test_where_Scalar(self) -> None: + input_tensor = torch.tensor([1, 2, 3, 4], dtype=torch.int8) + out = torch.ops.cadence.where_Scalar(input_tensor > 2, 1.0, 0.0) + self.assertTrue( + torch.equal(out, torch.tensor([0.0, 0.0, 1.0, 1.0], dtype=torch.float32)) + ) + with self.assertRaises(ValueError) as context: + torch.ops.cadence.where_Scalar(input_tensor, 1.0, 0.0) + + self.assertIn("condition must be a bool tensor", str(context.exception))