From a2fd0250d8acf03589182c7cd5964687bd4d96d9 Mon Sep 17 00:00:00 2001 From: Mehrdad Malekmohammadi Date: Tue, 21 Oct 2025 14:26:44 -0400 Subject: [PATCH] extend StableHLO dialect and operations This commit migrates several new StableHLO dialect operations to the xdsl_jax project. Additionally, it includes necessary attributes and types as per the StableHLO specification, plus some general traits and constraints that can be upstreamed to xdsl in the xdsl-extras module Moreover, It adds pytest tests to test general format parse/print roundtrip tests for the new operations. --- pyproject.toml | 8 + .../{stablehlo.py => _stablehlo_upstream.py} | 0 src/xdsl_jax/dialects/stablehlo/__init__.py | 158 +++ src/xdsl_jax/dialects/stablehlo/attributes.py | 411 +++++++ .../dialects/stablehlo/control_flow.py | 166 +++ .../dialects/stablehlo/data_movement.py | 451 ++++++++ src/xdsl_jax/dialects/stablehlo/dialect.py | 206 ++++ src/xdsl_jax/dialects/stablehlo/dynamism.py | 202 ++++ .../dialects/stablehlo/elementwise_binary.py | 227 ++++ .../dialects/stablehlo/elementwise_other.py | 219 ++++ .../dialects/stablehlo/elementwise_unary.py | 581 ++++++++++ .../dialects/stablehlo/extensibility.py | 176 +++ src/xdsl_jax/dialects/stablehlo/reduction.py | 166 +++ src/xdsl_jax/dialects/stablehlo/types.py | 261 +++++ src/xdsl_jax/xdsl_extras/__init__.py | 35 + src/xdsl_jax/xdsl_extras/constraints.py | 84 ++ src/xdsl_jax/xdsl_extras/traits.py | 251 ++++ tests/conftest.py | 102 ++ tests/pytest/__init__.py | 1 + tests/pytest/test_stablehlo_dialect.py | 1018 +++++++++++++++++ 20 files changed, 4723 insertions(+) rename src/xdsl_jax/dialects/{stablehlo.py => _stablehlo_upstream.py} (100%) create mode 100644 src/xdsl_jax/dialects/stablehlo/__init__.py create mode 100644 src/xdsl_jax/dialects/stablehlo/attributes.py create mode 100644 src/xdsl_jax/dialects/stablehlo/control_flow.py create mode 100644 src/xdsl_jax/dialects/stablehlo/data_movement.py create mode 100644 src/xdsl_jax/dialects/stablehlo/dialect.py create mode 100644 src/xdsl_jax/dialects/stablehlo/dynamism.py create mode 100644 src/xdsl_jax/dialects/stablehlo/elementwise_binary.py create mode 100644 src/xdsl_jax/dialects/stablehlo/elementwise_other.py create mode 100644 src/xdsl_jax/dialects/stablehlo/elementwise_unary.py create mode 100644 src/xdsl_jax/dialects/stablehlo/extensibility.py create mode 100644 src/xdsl_jax/dialects/stablehlo/reduction.py create mode 100644 src/xdsl_jax/dialects/stablehlo/types.py create mode 100644 src/xdsl_jax/xdsl_extras/__init__.py create mode 100644 src/xdsl_jax/xdsl_extras/constraints.py create mode 100644 src/xdsl_jax/xdsl_extras/traits.py create mode 100644 tests/conftest.py create mode 100644 tests/pytest/__init__.py create mode 100644 tests/pytest/test_stablehlo_dialect.py diff --git a/pyproject.toml b/pyproject.toml index 6826013..6d1fda3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,14 @@ ignore = [ "PYI041", # https://docs.astral.sh/ruff/rules/redundant-numeric-union ] +[tool.ruff.lint.per-file-ignores] +# Allow long lines in StableHLO dialect files (contain MLIR examples in docstrings) +"src/xdsl_jax/dialects/stablehlo/**/*.py" = ["E501"] +# Allow long lines in xdsl_extras files (contain detailed docstrings) +"src/xdsl_jax/xdsl_extras/**/*.py" = ["E501"] +# Allow long lines in test files (contain MLIR test cases) +"tests/**/*.py" = ["E501"] + [tool.ruff.lint.mccabe] max-complexity = 10 diff --git a/src/xdsl_jax/dialects/stablehlo.py b/src/xdsl_jax/dialects/_stablehlo_upstream.py similarity index 100% rename from src/xdsl_jax/dialects/stablehlo.py rename to src/xdsl_jax/dialects/_stablehlo_upstream.py diff --git a/src/xdsl_jax/dialects/stablehlo/__init__.py b/src/xdsl_jax/dialects/stablehlo/__init__.py new file mode 100644 index 0000000..c99f392 --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/__init__.py @@ -0,0 +1,158 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +StableHLO dialect package for xdsl-jax. + +This package contains organized elementwise operations and other StableHLO-related +functionality. +""" + +# Import all elementwise operations explicitly +from .attributes import ( + CustomCallApiVersion, + CustomCallApiVersionAttr, + GatherDimensionNumbers, + OutputOperandAlias, + ResultAccuracyModeAttr, + ScatterDimensionNumbers, +) +from .control_flow import ( + IfOp, + OptimizationBarrierOp, + WhileOp, +) +from .data_movement import ( + BroadcastInDimOp, + ConcatenateOp, + DynamicSliceOp, + GatherOp, + ReshapeOp, + ScatterOp, + SliceOp, +) + +# Import the main StableHLO dialect +from .dialect import StableHLO +from .dynamism import ( + DynamicBroadcastInDimOp, +) +from .elementwise_binary import ( + ComplexOp, + DivideOp, + MaximumOp, + MinimumOp, + PowerOp, + RemainderOp, +) +from .elementwise_other import ( + ClampOp, + CompareOp, + MapOp, + ReducePrecisionOp, + SelectOp, +) +from .elementwise_unary import ( + ConvertOp, + CosineOp, + ExponentialMinusOneOp, + ExponentialOp, + FloorOp, + ImagOp, + IsFiniteOp, + LogisticOp, + LogOp, + LogPlusOneOp, + NegateOp, + RealOp, + RoundNearestAfzOp, + RoundNearestEvenOp, + RsqrtOp, + SignOp, + SineOp, + SqrtOp, + TanhOp, + TanOp, +) +from .extensibility import ( + CustomCallOp, +) +from .reduction import ( + ReduceOp, +) + +# Export all operations and the dialect for external use +__all__ = [ + # Main dialect + "StableHLO", + # Elementwise unary operations + "ConvertOp", + "CosineOp", + "ExponentialMinusOneOp", + "ExponentialOp", + "FloorOp", + "ImagOp", + "IsFiniteOp", + "LogOp", + "LogPlusOneOp", + "LogisticOp", + "NegateOp", + "RealOp", + "RoundNearestAfzOp", + "RoundNearestEvenOp", + "RsqrtOp", + "SignOp", + "SineOp", + "SqrtOp", + "TanOp", + "TanhOp", + # Elementwise binary operations + "ComplexOp", + "DivideOp", + "MaximumOp", + "MinimumOp", + "PowerOp", + "RemainderOp", + # Elementwise other operations + "ClampOp", + "CompareOp", + "MapOp", + "ReducePrecisionOp", + "SelectOp", + # Control flow operations + "IfOp", + "WhileOp", + "OptimizationBarrierOp", + # Data movement operations + "BroadcastInDimOp", + "ConcatenateOp", + "DynamicSliceOp", + "GatherOp", + "ReshapeOp", + "ScatterOp", + "SliceOp", + # Dynamism operations + "DynamicBroadcastInDimOp", + # Reduction operations + "ReduceOp", + # Extensibility operations + "CustomCallOp", + # Attributes + "GatherDimensionNumbers", + "ResultAccuracyModeAttr", + "ScatterDimensionNumbers", + "CustomCallApiVersion", + "CustomCallApiVersionAttr", + "OutputOperandAlias", +] diff --git a/src/xdsl_jax/dialects/stablehlo/attributes.py b/src/xdsl_jax/dialects/stablehlo/attributes.py new file mode 100644 index 0000000..d5cf62b --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/attributes.py @@ -0,0 +1,411 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +StableHLO attribute definitions for PennyLane's compiler infrastructure. + +This module provides attribute definitions based on the StableHLO specification +(https://github.com/openxla/stablehlo/blob/main/docs/spec.md), including +attributes for StableHLO operations. +""" + +# pylint: disable=too-few-public-methods + +from collections.abc import Sequence + +from xdsl.dialects.builtin import I64, ArrayAttr, IntegerAttr, i64 +from xdsl.ir import ( + Attribute, + EnumAttribute, + ParametrizedAttribute, + SpacedOpaqueSyntaxAttribute, + StrEnum, +) +from xdsl.irdl import irdl_attr_definition +from xdsl.parser import AttrParser +from xdsl.printer import Printer + + +# Utility functions for dimension array parsing/printing +def parse_dims(parser: AttrParser) -> ArrayAttr[IntegerAttr[I64]]: + """Parse dimension array in [1, 2, 3] format""" + value = parser.parse_comma_separated_list( + AttrParser.Delimiter.SQUARE, + lambda: IntegerAttr(parser.parse_integer(), i64), + ) + return ArrayAttr(value) + + +def print_dims(printer: Printer, dims: ArrayAttr[IntegerAttr[I64]]): + """Print dimension array in [1, 2, 3] format""" + printer.print_string("[") + printer.print_list( + dims.data, + lambda dim: printer.print_string(f"{dim.value.data}"), + ) + printer.print_string("]") + + +class ResultAccuracyMode(StrEnum): + """ + XLA result accuracy mode. + """ + + DEFAULT = "DEFAULT" + HIGH = "HIGHEST" + HIGHEST = "TOLERANCE" + + +@irdl_attr_definition +class ResultAccuracyModeAttr( + EnumAttribute[ResultAccuracyMode], SpacedOpaqueSyntaxAttribute +): + """ + XLA result accuracy mode. + + See external [documentation](https://github.com/openxla/stablehlo/blob/7c50d4efeaea30bff6aa5e46c7f71170f5aa06af/stablehlo/dialect/StablehloEnums.td#L49-L70). + """ + + name = "stablehlo.result_accuracy_mode" + + +@irdl_attr_definition +class GatherDimensionNumbers(ParametrizedAttribute): + """ + XLA gather dimension numbers. + + This attribute models the dimension information for gather operations. + See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L42). + """ + + name = "stablehlo.gather" + + offset_dims: ArrayAttr[IntegerAttr[I64]] + collapsed_slice_dims: ArrayAttr[IntegerAttr[I64]] + operand_batching_dims: ArrayAttr[IntegerAttr[I64]] + start_indices_batching_dims: ArrayAttr[IntegerAttr[I64]] + start_index_map: ArrayAttr[IntegerAttr[I64]] + index_vector_dim: IntegerAttr[I64] + + def print_parameters(self, printer: Printer) -> None: + """Print gather dimension numbers in structured format""" + with printer.in_angle_brackets(): + with printer.indented(): + # Print offset_dims + printer.print_string("\noffset_dims = ") + print_dims(printer, self.offset_dims) + printer.print_string(",") + + # Print collapsed_slice_dims + printer.print_string("\ncollapsed_slice_dims = ") + print_dims(printer, self.collapsed_slice_dims) + printer.print_string(",") + + # Print operand_batching_dims + printer.print_string("\noperand_batching_dims = ") + print_dims(printer, self.operand_batching_dims) + printer.print_string(",") + + # Print start_indices_batching_dims + printer.print_string("\nstart_indices_batching_dims = ") + print_dims(printer, self.start_indices_batching_dims) + printer.print_string(",") + + # Print start_index_map + printer.print_string("\nstart_index_map = ") + print_dims(printer, self.start_index_map) + printer.print_string(",") + + # Print index_vector_dim + printer.print_string( + f"\nindex_vector_dim = {self.index_vector_dim.value.data}" + ) + printer.print_string("\n") + + @classmethod + def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]: + """Parse gather dimension numbers from structured format""" + with parser.in_angle_brackets(): + # Initialize default values for all fields + offset_dims = ArrayAttr([]) + collapsed_slice_dims = ArrayAttr([]) + operand_batching_dims = ArrayAttr([]) + start_indices_batching_dims = ArrayAttr([]) + start_index_map = ArrayAttr([]) + index_vector_dim = IntegerAttr(0, i64) + + # Try to parse offset_dims + if parser.parse_optional_characters("offset_dims") is not None: + parser.parse_punctuation("=") + offset_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse collapsed_slice_dims + if parser.parse_optional_characters("collapsed_slice_dims") is not None: + parser.parse_punctuation("=") + collapsed_slice_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse operand_batching_dims + if parser.parse_optional_characters("operand_batching_dims") is not None: + parser.parse_punctuation("=") + operand_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse start_indices_batching_dims + if ( + parser.parse_optional_characters("start_indices_batching_dims") + is not None + ): + parser.parse_punctuation("=") + start_indices_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse start_index_map + if parser.parse_optional_characters("start_index_map") is not None: + parser.parse_punctuation("=") + start_index_map = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse index_vector_dim + if parser.parse_optional_characters("index_vector_dim") is not None: + parser.parse_punctuation("=") + index_vector_dim = IntegerAttr(parser.parse_integer(), i64) + + return ( + offset_dims, + collapsed_slice_dims, + operand_batching_dims, + start_indices_batching_dims, + start_index_map, + index_vector_dim, + ) + + +@irdl_attr_definition +class ScatterDimensionNumbers(ParametrizedAttribute): + """ + XLA scatter dimension numbers. + + This attribute models the dimension information for scatter operations. + See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L28). + """ + + name = "stablehlo.scatter" + + update_window_dims: ArrayAttr[IntegerAttr[I64]] + inserted_window_dims: ArrayAttr[IntegerAttr[I64]] + input_batching_dims: ArrayAttr[IntegerAttr[I64]] + scatter_indices_batching_dims: ArrayAttr[IntegerAttr[I64]] + scatter_dims_to_operand_dims: ArrayAttr[IntegerAttr[I64]] + index_vector_dim: IntegerAttr[I64] + + def print_parameters(self, printer: Printer) -> None: + """Print scatter dimension numbers in structured format""" + with printer.in_angle_brackets(): + with printer.indented(): + # Print update_window_dims + printer.print_string("\nupdate_window_dims = ") + print_dims(printer, self.update_window_dims) + printer.print_string(",") + + # Print inserted_window_dims + printer.print_string("\ninserted_window_dims = ") + print_dims(printer, self.inserted_window_dims) + printer.print_string(",") + + # Print input_batching_dims + printer.print_string("\ninput_batching_dims = ") + print_dims(printer, self.input_batching_dims) + printer.print_string(",") + + # Print scatter_indices_batching_dims + printer.print_string("\nscatter_indices_batching_dims = ") + print_dims(printer, self.scatter_indices_batching_dims) + printer.print_string(",") + + # Print scatter_dims_to_operand_dims + printer.print_string("\nscatter_dims_to_operand_dims = ") + print_dims(printer, self.scatter_dims_to_operand_dims) + printer.print_string(",") + + # Print index_vector_dim + printer.print_string( + f"\nindex_vector_dim = {self.index_vector_dim.value.data}" + ) + printer.print_string("\n") + + @classmethod + def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]: + """Parse scatter dimension numbers from structured format""" + with parser.in_angle_brackets(): + # Initialize default values for all fields + update_window_dims = ArrayAttr([]) + inserted_window_dims = ArrayAttr([]) + input_batching_dims = ArrayAttr([]) + scatter_indices_batching_dims = ArrayAttr([]) + scatter_dims_to_operand_dims = ArrayAttr([]) + index_vector_dim = IntegerAttr(0, i64) + + # Try to parse update_window_dims + if parser.parse_optional_characters("update_window_dims") is not None: + parser.parse_punctuation("=") + update_window_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse inserted_window_dims + if parser.parse_optional_characters("inserted_window_dims") is not None: + parser.parse_punctuation("=") + inserted_window_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse input_batching_dims + if parser.parse_optional_characters("input_batching_dims") is not None: + parser.parse_punctuation("=") + input_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse scatter_indices_batching_dims + if ( + parser.parse_optional_characters("scatter_indices_batching_dims") + is not None + ): + parser.parse_punctuation("=") + scatter_indices_batching_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse scatter_dims_to_operand_dims + if ( + parser.parse_optional_characters("scatter_dims_to_operand_dims") + is not None + ): + parser.parse_punctuation("=") + scatter_dims_to_operand_dims = parse_dims(parser) + parser.parse_optional_punctuation(",") + + # Try to parse index_vector_dim + if parser.parse_optional_characters("index_vector_dim") is not None: + parser.parse_punctuation("=") + index_vector_dim = IntegerAttr(parser.parse_integer(), i64) + + return ( + update_window_dims, + inserted_window_dims, + input_batching_dims, + scatter_indices_batching_dims, + scatter_dims_to_operand_dims, + index_vector_dim, + ) + + +# ===== CustomCall and layout-related attributes ===== + + +class CustomCallApiVersion(StrEnum): + """StableHLO CustomCall API version.""" + + API_VERSION_UNSPECIFIED = "API_VERSION_UNSPECIFIED" + API_VERSION_ORIGINAL = "API_VERSION_ORIGINAL" + API_VERSION_STATUS_RETURNING = "API_VERSION_STATUS_RETURNING" + API_VERSION_STATUS_RETURNING_UNIFIED = "API_VERSION_STATUS_RETURNING_UNIFIED" + API_VERSION_TYPED_FFI = "API_VERSION_TYPED_FFI" + + +@irdl_attr_definition +class CustomCallApiVersionAttr( + EnumAttribute[CustomCallApiVersion], SpacedOpaqueSyntaxAttribute +): + """StableHLO custom call API version attribute. + + Mirrors StableHLO enum for CustomCall API versions. + """ + + name = "stablehlo.custom_call_api_version" + + +@irdl_attr_definition +class OutputOperandAlias(ParametrizedAttribute): + """ + This attribute captures the alias relationship of the output to one of the + operands for a ``CustomCall`` op, denoted by ``operand_index``. The + ``output_tuple_indices`` and ``operand_tuple_indices`` are used to index into + output and operand types. These indices lists are empty if the corresponding + types are not tuple types, and can be arbitrarily long in case of + arbitrarily nested tuple types. + + See https://www.tensorflow.org/xla/aliasing. + + Example when used as array with in stablehlo.custom-call: + + ```mlir + %0 = "stablehlo.custom_call"(%arg0, %arg1) { + // other attributes + output_operand_alias = [ + #stablehlo.output_operand_alias + ] + } : (tuple, tensor<2x3xf32>>, tensor<5x5xf32>) -> tuple> + + The output and the 0th operand are both tuples. The aliasing shows the + relationship between the 0th element in output tuple with the 1st element in + the 0th operand. And both of them are of the same type: ``tensor<2x3xf32>``. + ``` + """ + + name = "stablehlo.output_operand_alias" + + output_tuple_indices: ArrayAttr[IntegerAttr[I64]] + operand_index: IntegerAttr[I64] + operand_tuple_indices: ArrayAttr[IntegerAttr[I64]] + + def print_parameters(self, printer: Printer) -> None: + """Print the OutputOperandAlias attribute.""" + with printer.in_angle_brackets(): + with printer.indented(): + printer.print_string("\noutput_tuple_indices = ") + print_dims(printer, self.output_tuple_indices) + printer.print_string(",") + + printer.print_string("\noperand_index = ") + printer.print_string(f"{self.operand_index.value.data}") + printer.print_string(",") + + printer.print_string("\noperand_tuple_indices = ") + print_dims(printer, self.operand_tuple_indices) + printer.print_string("\n") + + @classmethod + def parse_parameters(cls, parser: AttrParser): + """Parse the OutputOperandAlias attribute.""" + with parser.in_angle_brackets(): + output_tuple_indices = ArrayAttr([]) + operand_index = IntegerAttr(0, i64) + operand_tuple_indices = ArrayAttr([]) + + if parser.parse_optional_characters("output_tuple_indices") is not None: + parser.parse_punctuation("=") + output_tuple_indices = parse_dims(parser) + parser.parse_optional_punctuation(",") + + if parser.parse_optional_characters("operand_index") is not None: + parser.parse_punctuation("=") + operand_index = IntegerAttr(parser.parse_integer(), i64) + parser.parse_optional_punctuation(",") + + if parser.parse_optional_characters("operand_tuple_indices") is not None: + parser.parse_punctuation("=") + operand_tuple_indices = parse_dims(parser) + + return (output_tuple_indices, operand_index, operand_tuple_indices) diff --git a/src/xdsl_jax/dialects/stablehlo/control_flow.py b/src/xdsl_jax/dialects/stablehlo/control_flow.py new file mode 100644 index 0000000..c5fbcac --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/control_flow.py @@ -0,0 +1,166 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods +# pyright: reportUnknownArgumentType=false, reportUnknownVariableType=false + +""" +Control flow operations for the StableHLO dialect. +""" + +from typing import TypeVar + +from xdsl.dialects.builtin import AnyTensorType +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + region_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.traits import ( + Pure, + RecursivelySpeculatable, + RecursiveMemoryEffect, + SingleBlockImplicitTerminator, +) + +from xdsl_jax.dialects._stablehlo_upstream import ReturnOp + +# Import our custom StableHLO types +from .types import ( + HLO_PredTensor, + HLO_TensorOrPerAxisQuantizedTensorOrToken, + HLO_TensorOrToken, +) + +# Generic type variables for templating +T_IN = TypeVar("T_IN", bound=AnyTensorType) +T_OUT = TypeVar("T_OUT", bound=AnyTensorType) + + +@irdl_op_definition +class IfOp(IRDLOperation): + """ + Produces the output from executing exactly one branch from `true_branch` or + `false_branch` depending on the value of `pred`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#if + + Example: + %result = "stablehlo.if"(%pred) ({ + "stablehlo.return"(%result_true_branch) : (tensor) -> () + }, { + "stablehlo.return"(%result_false_branch) : (tensor) -> () + }) : (tensor) -> tensor + """ + + name = "stablehlo.if" + + pred = operand_def(HLO_PredTensor) + + res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken) + + true_branch = region_def("single_block") + + false_branch = region_def("single_block") + + traits = traits_def( + RecursiveMemoryEffect(), + RecursivelySpeculatable(), + SingleBlockImplicitTerminator(ReturnOp), + # TODO: InferTypeOpInterface + # TODO: OpAsmOpInterface + ) + + # TODO: Add custom assembly format + + +@irdl_op_definition +class WhileOp(IRDLOperation): + """ + Produces the output from executing `body` function 0 or more times while the + `cond` function outputs `true`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#while + + Example: + ```mlir + %results0, %results1 = stablehlo.while(%arg0 = %init_i, %arg1 = %init_sum) : tensor, tensor + cond { + %cond = stablehlo.compare LT, %arg0, %ten : (tensor, tensor) -> tensor + stablehlo.return %cond : tensor + } do { + %new_sum = stablehlo.add %arg1, %one : tensor + %new_i = stablehlo.add %arg0, %one : tensor + stablehlo.return %new_i, %new_sum : tensor, tensor + } + """ + + name = "stablehlo.while" + + operand = var_operand_def(HLO_TensorOrPerAxisQuantizedTensorOrToken) + + res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken) + + cond = region_def("single_block") + + body = region_def("single_block") + + traits = traits_def( + RecursiveMemoryEffect(), + RecursivelySpeculatable(), + SingleBlockImplicitTerminator(ReturnOp), + # TODO: InferTypeOpInterface + # TODO: OpAsmOpInterface + ) + + +@irdl_op_definition +class OptimizationBarrierOp(IRDLOperation): + """ + Ensures that the operations that produce the `operand` are executed before any + operations that depend on the `result` and prevents compiler transformations + from moving operations across the barrier. Other than that, the operation is + an identity, i.e. `result` = `operand`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier + + Example: + ```mlir + %result0, %result1 = stablehlo.optimization_barrier %operand0, %operand1 : tensor, tensor + ``` + """ + + name = "stablehlo.optimization_barrier" + + operand = var_operand_def(HLO_TensorOrToken) + + res = var_result_def(HLO_TensorOrToken) + + traits = traits_def( + Pure(), + # TODO: HLO_PairwiseSameOperandAndResultType + # TODO: InferTypeOpInterface + ) + + # TODO: Add custom assembly format + # assembly_format = """ + # attr-dict ($operand^ `:` custom(type($operand), type($result))):(`(` `)`)? + # """ diff --git a/src/xdsl_jax/dialects/stablehlo/data_movement.py b/src/xdsl_jax/dialects/stablehlo/data_movement.py new file mode 100644 index 0000000..c8437ca --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/data_movement.py @@ -0,0 +1,451 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods +# pyright: reportUnknownArgumentType=false, reportUnknownVariableType=false + +""" +Data movement operations for the StableHLO dialect. +""" + +from xdsl.dialects.builtin import ( + ArrayAttr, + BoolAttr, + DenseArrayBase, + IntAttrConstraint, + IntegerAttr, + TensorType, + i64, +) +from xdsl.irdl import ( + AnyInt, + EqIntConstraint, + IRDLOperation, + RangeLengthConstraint, + RangeOf, + irdl_op_definition, + operand_def, + opt_prop_def, + prop_def, + region_def, + result_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.irdl.attributes import eq +from xdsl.irdl.constraints import AtLeast +from xdsl.irdl.operations import SameVariadicOperandSize +from xdsl.traits import ( + ConditionallySpeculatable, + NoMemoryEffect, + Pure, + RecursiveMemoryEffect, +) +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.type import get_element_type_or_self + +from xdsl_jax.xdsl_extras import ( + AllMatchSameOperatorTrait, + SameOperandsAndResultElementType, +) + +from .attributes import GatherDimensionNumbers, ScatterDimensionNumbers +from .types import ( + HLO_AnyIntegerOrIndexTensor, + HLO_AnyTensor, + HLO_Int, + HLO_IntTensor, + HLO_Tensor, +) + + +@irdl_op_definition +class BroadcastInDimOp(IRDLOperation): + """ + Expands the dimensions and/or rank of an input tensor by duplicating the + data in the ``operand`` tensor and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim + + Example: + ```mlir + %result = stablehlo.broadcast_in_dim %operand, dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + ``` + """ + + name = "stablehlo.broadcast_in_dim" + operand = operand_def(HLO_AnyTensor) + broadcast_dimensions = prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_AnyTensor) + + assembly_format = """ + $operand `,` `dims` `=` $broadcast_dimensions + attr-dict `:` functional-type(operands, results) + """ + + traits = traits_def( + NoMemoryEffect(), + # TODO: HLO_SpeculatableIfAllInputsStatic, + # TODO: HLO_CompatibleOperandsAndResultElementType, + ) + + def verify_(self) -> None: + """Verify non-quantized broadcast_in_dim constraints.""" + o_type = self.operand_types[0] + r_type = self.result_types[0] + + # These are constrained to tensors by the op definition + assert isinstance(o_type, TensorType) + assert isinstance(r_type, TensorType) + + # broadcast_in_dim_c2: broadcast_dimensions size == operand rank + dims = tuple(self.broadcast_dimensions.get_values()) + operand_rank = o_type.get_num_dims() + if len(dims) != operand_rank: + raise VerifyException( + "broadcast_dimensions size (" + f"{len(dims)}" + ") does not match operand rank (" + f"{operand_rank}" + ")" + ) + + # broadcast_in_dim_c4: broadcast_dimensions should not have duplicates + if len(set(dims)) != len(dims): + raise VerifyException("broadcast_dimensions should not have duplicates") + + # Result rank and per-dimension checks + result_rank = r_type.get_num_dims() + o_shape = o_type.get_shape() + r_shape = r_type.get_shape() + + for i, dim_index in enumerate(dims): + # broadcast_in_dim_c3: each dim index in bounds of result rank + if dim_index < 0 or dim_index >= result_rank: + raise VerifyException( + "broadcast_dimensions contains invalid value " + f"{dim_index} for result with rank {result_rank}" + ) + + # If operand dim is static, enforce broadcast_in_dim_c5 + if o_shape[i] != -1: + dim_size = o_shape[i] + result_dim_size = r_shape[dim_index] + if dim_size not in (1, result_dim_size): + raise VerifyException( + "size of operand dimension " + f"{i} ({dim_size}) is not equal to 1 or size of result dimension " + f"{dim_index} ({result_dim_size})" + ) + + +@irdl_op_definition +class ConcatenateOp(IRDLOperation): + """ + Concatenates a variadic number of tensors in ``inputs`` along ``dimension`` + dimension in the same order as the given arguments and produces a ``result`` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#concatenate + + Example: + ```mlir + %result = stablehlo.concatenate %input0, %input1, dim = 0 : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> + ``` + """ + + name = "stablehlo.concatenate" + + inputs = var_operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + dimension = prop_def(IntegerAttr.constr(type=eq(i64), value=AtLeast(0))) + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + SameOperandsAndResultElementType(), + # InferTypeOpInterface(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # custom($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results) + # """ + + +@irdl_op_definition +class DynamicSliceOp(IRDLOperation): + """ + Extracts a slice from the ``operand`` using dynamically-computed starting + indices and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice + + Example: + ```mlir + %result = stablehlo.dynamic_slice %operand, %start_indices0, %start_indices1, sizes = [2, 2] + : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x2xi32> + ``` + """ + + name = "stablehlo.dynamic_slice" + operand = operand_def(HLO_Tensor) + start_indices = var_operand_def( + TensorType.constr( + element_type=HLO_Int, + shape=ArrayAttr.constr( + RangeLengthConstraint( + constraint=RangeOf(IntAttrConstraint(AnyInt())), + length=EqIntConstraint(0), + ) + ), + ) + ) + slice_sizes = prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_Tensor) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand `,` custom($start_indices) + # `sizes` `=` $slice_sizes attr-dict `:` functional-type(operands, results) + # """ + + traits = traits_def( + Pure(), + AllMatchSameOperatorTrait( + ("operand", "result"), + lambda x: get_element_type_or_self(x.type), + "element type", + ), + # TODO: InferTensorType(), + ) + + +@irdl_op_definition +class GatherOp(IRDLOperation): + """ + Gathers slices from ``operand`` tensor from offsets specified in + ``start_indices`` and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather + + Example: + ```mlir + %result = "stablehlo.gather"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi64> + ``` + """ + + name = "stablehlo.gather" + operand = operand_def(HLO_Tensor) + start_indices = operand_def(HLO_IntTensor) + dimension_numbers = prop_def(GatherDimensionNumbers) + slice_sizes = prop_def(DenseArrayBase.constr(i64)) + indices_are_sorted = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + result = result_def(HLO_Tensor) + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + AllMatchSameOperatorTrait( + ("operand", "result"), + lambda x: get_element_type_or_self(x.type), + "element type", + ), + # TODO: InferTensorTypeWithReify(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # custom($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results) + # """ + + +@irdl_op_definition +class ReshapeOp(IRDLOperation): + """ + Performs reshape of ``operand`` tensor to a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape + + Example: + ```mlir + %result = stablehlo.reshape %operand : (tensor<2xf32>) -> tensor<1x2xf32> + ``` + """ + + name = "stablehlo.reshape" + operand = operand_def(HLO_AnyTensor) + result = result_def(HLO_AnyTensor) + + assembly_format = """ + operands attr-dict `:` functional-type(operands, results) + """ + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + # TODO: HLO_CompatibleOperandsAndResultElementType, + ) + + def verify_(self) -> None: + """Verify that the operation has the same shape for all operands and results.""" + o_type = self.operand_types[0] + r_type = self.result_types[0] + + # These are constrained to tensors by the op definition + assert isinstance(o_type, TensorType) + assert isinstance(r_type, TensorType) + + # If o_type or r_type is dynamically shaped there is nothing to verify. + if not o_type.has_static_shape() or not r_type.has_static_shape(): + return + + # If the operand type is statically shaped (not required) the number of + # elements must match that of the result type. + num_operand_elements = 1 + for dim in o_type.get_shape(): + num_operand_elements *= dim + + num_result_elements = 1 + for dim in r_type.get_shape(): + num_result_elements *= dim + + if num_result_elements != num_operand_elements: + raise VerifyException( + "number of output elements (" + f"{num_result_elements}" + ") doesn't match expected number of elements (" + f"{num_operand_elements}" + ")" + ) + + +@irdl_op_definition +class ScatterOp(IRDLOperation): + """ + Produces ``results`` tensors which are equal to ``inputs`` tensors except that + several slices specified by ``scatter_indices`` are updated with the values + ``updates`` using ``update_computation``. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter + + Example: + ```mlir + %result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3, 4], + inserted_window_dims = [1], + input_batching_dims = [0], + scatter_indices_batching_dims = [1], + scatter_dims_to_operand_dims = [2, 1], + index_vector_dim = 3>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> + ``` + """ + + name = "stablehlo.scatter" + inputs = var_operand_def(HLO_Tensor) + scatter_indices = operand_def(HLO_AnyIntegerOrIndexTensor) + updates = var_operand_def(HLO_Tensor) + scatter_dimension_numbers = prop_def(ScatterDimensionNumbers) + indices_are_sorted = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + unique_indices = opt_prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + result = var_result_def(HLO_Tensor) + update_computation = region_def("single_block") + # TODO: The MLIR implementation doesn't have the SingleBlockImplicitTerminator trait, + # However, it is checked to have a terminator in the verifier, + # which does not specifically check the terminator to be stablehlo.return. + + traits = traits_def( + RecursiveMemoryEffect(), + ConditionallySpeculatable(), + # TODO: InferTypeOpInterface(), + ) + + irdl_options = [SameVariadicOperandSize()] + + # TODO: MLIR has a custom verifier for the scatter operation. + + +@irdl_op_definition +class SliceOp(IRDLOperation): + """ + Extracts a slice from the ``operand`` using statically-computed starting + indices and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice + + Example: + ```mlir + %result = stablehlo.slice %operand [1:3, 4:8:2] + : (tensor<3x8xi64>) -> tensor<2x2xi64> + + // Same in generic form: the `1:3` above is mapped to the first entry in + // `start_indices` and `limit_indices`, while `strides` is implicitly 1. + // The `4:8:2` above is parsed into the second entry of `start_indices`, + // `limit_indices` and `strides` respectively. + %result = "stablehlo.slice" (%operand) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x8xi64>) -> tensor<2x2xi64> + ``` + """ + + name = "stablehlo.slice" + + operand = operand_def(HLO_Tensor) + start_indices = prop_def(DenseArrayBase.constr(i64)) + limit_indices = prop_def(DenseArrayBase.constr(i64)) + strides = prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_Tensor) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand custom($start_indices, $limit_indices, $strides) + # attr-dict `:` functional-type(operands, results) + # """ + + traits = traits_def( + NoMemoryEffect(), + ConditionallySpeculatable(), + AllMatchSameOperatorTrait( + ("start_indices", "limit_indices", "strides"), len, "size" + ), + SameOperandsAndResultElementType(), + ) diff --git a/src/xdsl_jax/dialects/stablehlo/dialect.py b/src/xdsl_jax/dialects/stablehlo/dialect.py new file mode 100644 index 0000000..ea42ec0 --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/dialect.py @@ -0,0 +1,206 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyright: reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownMemberType=false + +""" +Extended StableHLO dialect that dynamically includes all upstream operations +plus custom operations for PennyLane's compiler infrastructure. + +This module automatically imports all operations and attributes from the upstream +xdsl_jax.dialects.stablehlo and adds custom ones without needing to hardcode +the upstream operation list. +""" + +from xdsl.ir import Dialect + +import xdsl_jax.dialects._stablehlo_upstream as xstablehlo + +from .attributes import ( + CustomCallApiVersionAttr, + GatherDimensionNumbers, + OutputOperandAlias, + ResultAccuracyModeAttr, + ScatterDimensionNumbers, +) +from .control_flow import ( + IfOp, + OptimizationBarrierOp, + WhileOp, +) +from .data_movement import ( + BroadcastInDimOp, + ConcatenateOp, + DynamicSliceOp, + GatherOp, + ReshapeOp, + ScatterOp, + SliceOp, +) +from .dynamism import ( + DynamicBroadcastInDimOp, +) +from .elementwise_binary import ( + ComplexOp, + DivideOp, + MaximumOp, + MinimumOp, + PowerOp, + RemainderOp, +) +from .elementwise_other import ( + ClampOp, + CompareOp, + MapOp, + ReducePrecisionOp, + SelectOp, +) + +# Import all elementwise operations from organized files +from .elementwise_unary import ( + ConvertOp, + CosineOp, + ExponentialMinusOneOp, + ExponentialOp, + FloorOp, + ImagOp, + IsFiniteOp, + LogisticOp, + LogOp, + LogPlusOneOp, + NegateOp, + RealOp, + RoundNearestAfzOp, + RoundNearestEvenOp, + RsqrtOp, + SignOp, + SineOp, + SqrtOp, + TanhOp, + TanOp, +) +from .extensibility import ( + CustomCallOp, +) +from .reduction import ( + ReduceOp, +) +from .types import UniformQuantizedPerAxisType, UniformQuantizedType + +# Operations to add to the dialect +OPERATIONS = [ + ClampOp, + CompareOp, + ComplexOp, + ConvertOp, + CosineOp, + DivideOp, + ExponentialMinusOneOp, + ExponentialOp, + FloorOp, + ImagOp, + IsFiniteOp, + LogOp, + LogPlusOneOp, + LogisticOp, + MapOp, + MaximumOp, + MinimumOp, + NegateOp, + PowerOp, + RealOp, + ReducePrecisionOp, + RemainderOp, + RoundNearestAfzOp, + RoundNearestEvenOp, + RsqrtOp, + SelectOp, + SignOp, + SineOp, + SqrtOp, + TanOp, + TanhOp, + # Data movement operations + BroadcastInDimOp, + ConcatenateOp, + DynamicSliceOp, + GatherOp, + ReshapeOp, + ScatterOp, + SliceOp, + # Control flow operations + IfOp, + WhileOp, + OptimizationBarrierOp, + # Dynamism operations + DynamicBroadcastInDimOp, + # Reduction operations + ReduceOp, + # Extensibility operations + CustomCallOp, +] + +# Attributes to add to the dialect +ATTRIBUTES = [ + CustomCallApiVersionAttr, + GatherDimensionNumbers, + ResultAccuracyModeAttr, + OutputOperandAlias, + ScatterDimensionNumbers, + UniformQuantizedPerAxisType, + UniformQuantizedType, +] + +# Operations/attributes from upstream that should be deleted/replaced in the local version +UPSTREAM_OPERATIONS_TO_DELETE = [] +UPSTREAM_ATTRIBUTES_TO_DELETE = [] + + +def filter_and_extend_upstream(upstream_list, to_delete, to_add): + """Filter out operations/attributes from upstream list and add new ones. + + Args: + upstream_list: List of operations/attributes to filter + to_delete: List of operations/attributes to remove + to_add: List of operations/attributes to add + + Returns: + Modified list of operations/attributes + """ + filtered_ops = list(upstream_list) + + # Remove operations that should be deleted + for op_to_delete in to_delete: + if op_to_delete in filtered_ops: + filtered_ops.remove(op_to_delete) + + # Add new operations + filtered_ops.extend(to_add) + + return filtered_ops + + +all_operations = filter_and_extend_upstream( + xstablehlo.StableHLO.operations, UPSTREAM_OPERATIONS_TO_DELETE, OPERATIONS +) +all_attributes = filter_and_extend_upstream( + xstablehlo.StableHLO.attributes, UPSTREAM_ATTRIBUTES_TO_DELETE, ATTRIBUTES +) + +# Create the extended StableHLO dialect by dynamically getting upstream components +StableHLO = Dialect( + "stablehlo", + all_operations, + all_attributes, +) diff --git a/src/xdsl_jax/dialects/stablehlo/dynamism.py b/src/xdsl_jax/dialects/stablehlo/dynamism.py new file mode 100644 index 0000000..08c41ea --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/dynamism.py @@ -0,0 +1,202 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods +# pyright: reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownMemberType=false + +""" +Dynamism operations for the StableHLO dialect. +""" + +from xdsl.dialects.builtin import ( + ArrayAttr, + DenseArrayBase, + IntAttrConstraint, + TensorType, + i64, +) +from xdsl.irdl import ( + AnyInt, + EqIntConstraint, + IRDLOperation, + ParsePropInAttrDict, + RangeLengthConstraint, + RangeOf, + irdl_op_definition, + operand_def, + opt_prop_def, + prop_def, + result_def, + traits_def, +) +from xdsl.traits import ( + ConditionallySpeculatable, + NoMemoryEffect, +) +from xdsl.utils.exceptions import VerifyException + +from .types import HLO_AnyTensor, HLO_DimensionValue + + +@irdl_op_definition +class DynamicBroadcastInDimOp(IRDLOperation): + """ + This operation is functionally identical to + [broadcast_in_dim](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim) + op, but the result shape is specified dynamically via ``output_dimensions``. + + It also accepts optional attributes to express static knowledge about the + expanding behavior of dimensions. If not specified, all dimensions are + assumed to be possibly expanding. The sets of dimensions that are known to + be expanding and the set of dimensions that are known to be non-expanding + must be disjoint and they must be a subset of the operand's dimensions. + + See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_broadcast_in_dim + + Example: + ```mlir + %operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64> + %output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64> + %result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array + } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + ``` + """ + + name = "stablehlo.dynamic_broadcast_in_dim" + + operand = operand_def(HLO_AnyTensor) + output_dimensions = operand_def( + TensorType.constr( + element_type=HLO_DimensionValue, + shape=ArrayAttr.constr( + RangeLengthConstraint( + constraint=RangeOf(IntAttrConstraint(AnyInt())), + length=EqIntConstraint(1), + ) + ), + ) + ) + broadcast_dimensions = prop_def(DenseArrayBase.constr(i64)) + known_expanding_dimensions = opt_prop_def(DenseArrayBase.constr(i64)) + known_nonexpanding_dimensions = opt_prop_def(DenseArrayBase.constr(i64)) + result = result_def(HLO_AnyTensor) + + assembly_format = ( + "$operand `,` $output_dimensions `,` `dims` `=` $broadcast_dimensions " + "attr-dict `:` functional-type(operands, results)" + ) + + traits = traits_def( + ConditionallySpeculatable(), + NoMemoryEffect(), + # TODO: InferShapedTypeOpInterface(), + ) + + irdl_options = [ParsePropInAttrDict()] + + # pylint: disable=too-many-branches + def verify_(self): + """Verify the operation.""" + # Operand and result must be tensors + operand_ty = self.operand_types[0] + result_ty = self.result_types[0] + assert isinstance(operand_ty, TensorType) + assert isinstance(result_ty, TensorType) + + # dynamic_broadcast_in_dim_c2: broadcast_dimensions size == operand rank + bcast_dims = tuple(self.broadcast_dimensions.get_values()) # pylint: disable=no-member + operand_rank = operand_ty.get_num_dims() + if len(bcast_dims) != operand_rank: + raise VerifyException( + "broadcast_dimensions size (" + f"{len(bcast_dims)}" + ") does not match operand rank (" + f"{operand_rank}" + ")" + ) + + # dynamic_broadcast_in_dim_c3: result rank >= operand rank + result_rank = result_ty.get_num_dims() + if result_rank < operand_rank: + raise VerifyException( + "result rank (" + f"{result_rank}" + ") is less than operand rank (" + f"{operand_rank}" + ")" + ) + + # dynamic_broadcast_in_dim_c4: broadcast_dimensions should not have duplicates + if len(set(bcast_dims)) != len(bcast_dims): + raise VerifyException("broadcast_dimensions should not have duplicates") + + # dynamic_broadcast_in_dim_c5: bounds and per-dimension compatibility + operand_shape = operand_ty.get_shape() + result_shape = result_ty.get_shape() + for i, dim_index in enumerate(bcast_dims): + if dim_index < 0 or dim_index >= result_rank: + raise VerifyException( + "broadcast_dimensions contains invalid value " + f"{dim_index} for result with rank {result_rank}" + ) + op_dim = operand_shape[i] + res_dim = result_shape[dim_index] + # If operand dim is static and not size-1, require compatibility with result dim + if op_dim not in (-1, 1): + if res_dim not in (-1, op_dim): + raise VerifyException( + "size of operand dimension " + f"{i} ({op_dim}) is not compatible with size of result dimension " + f"{dim_index} ({res_dim})" + ) + + # dynamic_broadcast_in_dim_c7: output_dimensions shape compatible with result rank + out_dims_ty = self.output_dimensions.type # pylint: disable=no-member + assert isinstance(out_dims_ty, TensorType) + # Must be rank-1 tensor (enforced by type constraint), and length must match result rank when statically known + out_shape = out_dims_ty.get_shape() + if len(out_shape) != 1: + raise VerifyException("output_dimensions must be a 1D tensor") + if out_shape[0] != -1 and out_shape[0] != result_rank: + raise VerifyException( + "length of output_dimensions (" + f"{out_shape[0]}" + ") is not compatible with result rank (" + f"{result_rank}" + ")" + ) + + # dynamic_broadcast_in_dim_c8: no duplicate expansion hints across both lists + hints = [] + if self.known_expanding_dimensions is not None: + hints.extend(self.known_expanding_dimensions.get_values()) # pylint: disable=no-member + if self.known_nonexpanding_dimensions is not None: + hints.extend( + self.known_nonexpanding_dimensions.get_values() # pylint: disable=no-member + ) + if len(set(hints)) != len(hints): + raise VerifyException( + "duplicate expansion hint for at least one operand dimension" + ) + + # dynamic_broadcast_in_dim_c9/c10: each hint must reference a valid operand dimension + for h in set(hints): + if h < 0 or h >= operand_rank: + raise VerifyException( + "hint for expanding dimension " + f"{h} does not refer to a valid operand dimension" + ) diff --git a/src/xdsl_jax/dialects/stablehlo/elementwise_binary.py b/src/xdsl_jax/dialects/stablehlo/elementwise_binary.py new file mode 100644 index 0000000..d28d88d --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/elementwise_binary.py @@ -0,0 +1,227 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Binary elementwise operations for the StableHLO dialect. +""" + +# pylint: disable=too-few-public-methods +# pyright: reportUnknownVariableType=false + +import abc +from typing import Generic, TypeVar + +from xdsl.dialects.builtin import ( + AnyTensorType, + ComplexType, + Float32Type, + Float64Type, + TensorType, +) +from xdsl.ir import Attribute, SSAValue +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + result_def, + traits_def, +) +from xdsl.traits import NoMemoryEffect + +from xdsl_jax.xdsl_extras import ( + Elementwise, + SameOperandsAndResultShape, + SameOperandsElementType, +) + +from .types import ( + HLO_ComplexTensor, + HLO_Fp32Or64Tensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_Tensor, +) + +# Type aliases +F32Or64Type = Float32Type | Float64Type +F32Or64TensorType = TensorType[F32Or64Type] +ComplexTensorType = TensorType[ComplexType] + +# Generic type variables for templating +T_LHS = TypeVar("T_LHS", bound=AnyTensorType) +T_RHS = TypeVar("T_RHS", bound=AnyTensorType) +T_OUT = TypeVar("T_OUT", bound=AnyTensorType) + + +class ElementwiseBinaryOperation(IRDLOperation, abc.ABC, Generic[T_LHS, T_RHS, T_OUT]): + """ + Templated base class for elementwise binary operations. + + This class provides a flexible template for binary operations that can work + with different tensor types. + + For more information about the semantics, see: + https://openxla.org/xla/operation_semantics#element-wise_binary_arithmetic_operations + """ + + lhs = operand_def(T_LHS) + rhs = operand_def(T_RHS) + result = result_def(T_OUT) + + traits = traits_def( + NoMemoryEffect(), + SameOperandsAndResultShape(), + Elementwise(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $lhs `,` $rhs attr-dict + # `:` custom(type($lhs), type($rhs), type($result)) + # """ + + def __init__( + self, lhs: SSAValue, rhs: SSAValue, result_type: Attribute | None = None + ): + if result_type is None: + result_type = lhs.type + super().__init__(operands=(lhs, rhs), result_types=(result_type,)) + + +@irdl_op_definition +class ComplexOp( + ElementwiseBinaryOperation[ + HLO_Fp32Or64Tensor, HLO_Fp32Or64Tensor, HLO_ComplexTensor + ] +): + """ + Performs element-wise conversion to a complex value from a pair of real and + imaginary values, `lhs` and `rhs`, and produces a `result` tensor. + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex + Example: + ```mlir + %result = stablehlo.complex %lhs, %rhs : tensor<2xcomplex> + ``` + """ + + name = "stablehlo.complex" + + # assembly_format = """ + # operands attr-dict + # `:` custom(type($lhs), type($rhs), type($result)) + # """ + + traits = traits_def( + NoMemoryEffect(), + SameOperandsElementType(), + SameOperandsAndResultShape(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + ) + + +@irdl_op_definition +class DivideOp( + ElementwiseBinaryOperation[ + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + ] +): + """ + Performs element-wise division of dividend `lhs` and divisor `rhs` tensors + and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#divide + + Example: + ```mlir + %result = stablehlo.divide %lhs, %rhs : tensor<4xf32> + ``` + """ + + name = "stablehlo.divide" + + +@irdl_op_definition +class MaximumOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise max operation on tensors `lhs` and `rhs` and produces + a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#maximum + + Example: + ```mlir + %result = stablehlo.maximum %lhs, %rhs : tensor<4xf32> + ``` + """ + + name = "stablehlo.maximum" + + +@irdl_op_definition +class MinimumOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise min operation on tensors `lhs` and `rhs` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#minimum + + Example: + ```mlir + %result = stablehlo.minimum %lhs, %rhs : tensor<4xf32> + ``` + """ + + name = "stablehlo.minimum" + + +@irdl_op_definition +class PowerOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power + + Example: + ```mlir + %result = stablehlo.power %lhs, %rhs : tensor<6xf64> + ``` + """ + + name = "stablehlo.power" + + +@irdl_op_definition +class RemainderOp(ElementwiseBinaryOperation[HLO_Tensor, HLO_Tensor, HLO_Tensor]): + """ + Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors + and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#remainder + + Example: + ```mlir + %result = stablehlo.remainder %lhs, %rhs : tensor<4xi64> + ``` + """ + + name = "stablehlo.remainder" diff --git a/src/xdsl_jax/dialects/stablehlo/elementwise_other.py b/src/xdsl_jax/dialects/stablehlo/elementwise_other.py new file mode 100644 index 0000000..c4b8f2f --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/elementwise_other.py @@ -0,0 +1,219 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Other elementwise operations for the StableHLO dialect. +""" + +# pylint: disable=too-few-public-methods +# pyright: reportUnknownArgumentType=false, reportUnknownVariableType=false + +from xdsl.dialects.builtin import ( + AnyFloat, + DenseArrayBase, + IntegerAttr, + TensorType, + i32, + i64, +) +from xdsl.irdl import ( + IRDLOperation, + attr_def, + irdl_op_definition, + operand_def, + opt_attr_def, + result_def, + traits_def, + var_operand_def, + var_region_def, +) +from xdsl.irdl.attributes import eq +from xdsl.irdl.constraints import AtLeast +from xdsl.traits import ( + NoMemoryEffect, + RecursiveMemoryEffect, + SingleBlockImplicitTerminator, +) + +import xdsl_jax.dialects._stablehlo_upstream as xstablehlo +from xdsl_jax.xdsl_extras import Elementwise, SameOperandsAndResultShape + +from .types import HLO_FpOrQuantizedIntTensor, HLO_PredTensor, HLO_Tensor + +# Type aliases +FloatTensorType = TensorType[AnyFloat] + + +@irdl_op_definition +class ClampOp(IRDLOperation): + """Element-wise clamp with min and max bounds. + + See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#clamp + """ + + name = "stablehlo.clamp" + + min = operand_def(HLO_Tensor) + operand = operand_def(HLO_Tensor) + max = operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $min `,` $operand `,` $max attr-dict + # `:` custom(type($min), type($operand), type($max), type($result)) + # """ + + traits = traits_def( + NoMemoryEffect(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + # TODO: HLO_CompatibleOperandsAndResultElementType(), + # TODO: HLO_BroadcastingElementwise(), + # TODO: InferTensorType(), + # TODO: InferShapedTypeOpInterface(), + ) + + +@irdl_op_definition +class CompareOp(IRDLOperation): + """Element-wise compare with direction and type attributes.""" + + name = "stablehlo.compare" + + assembly_format = """ + $comparison_direction `,` $lhs `,` $rhs (`,` $comparison_type^)? attr-dict `:` functional-type(operands, results) + """ + + lhs = operand_def(HLO_Tensor) + rhs = operand_def(HLO_Tensor) + result = result_def(HLO_PredTensor) + comparison_direction = attr_def(xstablehlo.ComparisonDirectionAttr) + comparison_type = opt_attr_def(xstablehlo.ComparisonTypeAttr) + + traits = traits_def( + NoMemoryEffect(), + Elementwise(), + SameOperandsAndResultShape(), + # TODO: HLO_SpeculatableIfAllInputsStatic(), + # TODO: HLO_CompatibleOperandsElementType(), + # TODO: InferTensorTypeWithReify(), + ) + + +@irdl_op_definition +class MapOp(IRDLOperation): + """ + Applies a map function `computation` to `inputs` along the `dimensions` and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#map + + Example: + ```mlir + %result = "stablehlo.map"(%input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.multiply %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + dimensions = array + } : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> + ``` + """ + + name = "stablehlo.map" + + inputs = var_operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + dimensions = attr_def(DenseArrayBase.constr(i64)) + computation = var_region_def("single_block") + + traits = traits_def( + RecursiveMemoryEffect(), + SameOperandsAndResultShape(), + SingleBlockImplicitTerminator(xstablehlo.ReturnOp), + # TODO: HLO_RecursivelySpeculatableIfAllInputsStatic(), + # TODO: InferTypeOpInterface + # TODO: InferShapedTypeOpInterface(), + ) + + +@irdl_op_definition +class ReducePrecisionOp(IRDLOperation): + """ + Performs element-wise conversion of `operand` to another floating-point type + that uses `exponent_bits` and `mantissa_bits` and back to the original + floating-point type and produces an `output` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_precision + + Example: + ```mlir + %output = stablehlo.reduce_precision %operand, format = e5m10 : tensor<6xf64> + ``` + """ + + name = "stablehlo.reduce_precision" + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand `,` `format` `=` custom($exponent_bits, $mantissa_bits) + # attr-dict `:` custom(type($operand), type($output)) + # """ + + operand = operand_def(HLO_FpOrQuantizedIntTensor) + result = result_def(HLO_FpOrQuantizedIntTensor) + + exponent_bits = attr_def(IntegerAttr.constr(type=eq(i32), value=AtLeast(1))) + mantissa_bits = attr_def(IntegerAttr.constr(type=eq(i32), value=AtLeast(0))) + + traits = traits_def( + NoMemoryEffect(), + Elementwise(), + # TODO: HLO_CompatibleOperandsAndResultType(), + # TODO: HLO_SpeculatableIfStaticDimInOutputIsStaticInInput(), + ) + + +@irdl_op_definition +class SelectOp(IRDLOperation): + """ + Produces a `result` tensor where each element is selected from `on_true` or + `on_false` tensor based on the value of the corresponding element of `pred`. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select + + Example: + ```mlir + %result = stablehlo.select %pred, %on_true, %on_false : tensor<2x2xi1>, tensor<2x2xi32> + ``` + """ + + name = "stablehlo.select" + + # assembly_format = """ + # operands attr-dict `:` + # custom(type($pred), type($on_true), type($on_false), type($result)) + # """ + + pred = operand_def(HLO_PredTensor) + on_true = operand_def(HLO_Tensor) + on_false = operand_def(HLO_Tensor) + result = result_def(HLO_Tensor) + + traits = traits_def( + NoMemoryEffect(), + ) diff --git a/src/xdsl_jax/dialects/stablehlo/elementwise_unary.py b/src/xdsl_jax/dialects/stablehlo/elementwise_unary.py new file mode 100644 index 0000000..e1c1e2e --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/elementwise_unary.py @@ -0,0 +1,581 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods +# pyright: reportUnknownVariableType=false + +""" +Unary elementwise operations for the StableHLO dialect. +""" + +import abc +from typing import Generic, TypeVar + +from xdsl.dialects.builtin import ( + I1, + AnyFloat, + AnyTensorType, + ComplexType, + TensorType, +) +from xdsl.ir import Attribute, SSAValue +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + operand_def, + opt_attr_def, + result_def, + traits_def, +) +from xdsl.traits import NoMemoryEffect + +from xdsl_jax.xdsl_extras import Elementwise, SameOperandsAndResultShape + +from .attributes import ResultAccuracyMode, ResultAccuracyModeAttr +from .types import ( + HLO_FloatTensor, + HLO_FpComplexOrQuantizedIntTensor, + HLO_FpOrComplexTensor, + HLO_FpOrQuantizedIntTensor, + HLO_IntFpOrComplexOrQuantizedIntTensor, + HLO_NonQuantizedTensor, + HLO_PredTensor, + HLO_SIntFpComplexOrQuantizedIntTensor, +) + +# Type aliases +I1TensorType = TensorType[I1] +FloatTensorType = TensorType[AnyFloat] +FloatOrComplexType = AnyFloat | ComplexType +FloatOrComplexTensorType = TensorType[FloatOrComplexType] +ComplexTensorType = TensorType[ComplexType] + +# Generic type variables for templating +T_IN = TypeVar("T_IN", bound=AnyTensorType) +T_OUT = TypeVar("T_OUT", bound=AnyTensorType) + + +class ElementwiseUnaryOperation(IRDLOperation, abc.ABC, Generic[T_IN, T_OUT]): + """ + Templated base class for elementwise unary operations. + + This class provides a flexible template for unary operations that can work + with different tensor types. + + For more informtation about the semantics, see: + https://openxla.org/xla/operation_semantics#element-wise_unary_functions + """ + + operand = operand_def(T_IN) + result = result_def(T_OUT) + + # TODO: Implement CustomDirective + # assembly_format = """ + # $operand attr-dict `:` custom(type($operand), type($result)) + # """ + + traits = traits_def( + NoMemoryEffect(), + SameOperandsAndResultShape(), + Elementwise(), + # TODO: InferShapedTypeOpInterface(), + # TODO: HLO_SpeculatableIfStaticDimInOutputIsStaticInInput(), + ) + + def __init__(self, operand: SSAValue, result_type: Attribute | None = None): + if result_type is None: + result_type = operand.type + super().__init__(operands=(operand,), result_types=(result_type,)) + + +@irdl_op_definition +class ConvertOp( + ElementwiseUnaryOperation[HLO_NonQuantizedTensor, HLO_NonQuantizedTensor] +): + """ + Performs an element-wise conversion from one element type to another on + `operand` tensor and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convert + + Example: + ```mlir + %result = stablehlo.convert %operand : (tensor<3xi64>) -> tensor<3xcomplex> + ``` + """ + + name = "stablehlo.convert" + + traits = traits_def(SameOperandsAndResultShape()) + + +@irdl_op_definition +class CosineOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise cosine operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine + + Example: + ```mlir + %result = stablehlo.cosine %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.cosine" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + # TODO: implement HLO_CompatibleOperandsAndResultType() + # traits = traits_def( + # HLO_CompatibleOperandsAndResultType() + # ) + + +@irdl_op_definition +class ExponentialMinusOneOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise exponential minus one operation on `operand` tensor + and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential_minus_one + + Example: + ```mlir + %result = stablehlo.exponential_minus_one %operand : tensor<2xf64> + ``` + """ + + name = "stablehlo.exponential_minus_one" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + # TODO: implement HLO_CompatibleOperandsAndResultType() + # traits = traits_def( + # HLO_CompatibleOperandsAndResultType() + # ) + + +@irdl_op_definition +class ExponentialOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise exponential operation on `operand` tensor and produces + a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential + + Example: + ```mlir + %result = stablehlo.exponential %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.exponential" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + # TODO: implement HLO_CompatibleOperandsAndResultType() + # traits = traits_def( + # HLO_CompatibleOperandsAndResultType() + # ) + + +@irdl_op_definition +class FloorOp( + ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor] +): + """ + Performs element-wise floor of `operand` tensor and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#floor + + Example: + ```mlir + %result = stablehlo.floor %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.floor" + + +@irdl_op_definition +class ImagOp(ElementwiseUnaryOperation[HLO_FpOrComplexTensor, HLO_FloatTensor]): + """ + Extracts the imaginary part, element-wise, from the `operand` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag + + Example: + ```mlir + %result = stablehlo.imag %operand : (tensor<2xcomplex>) -> tensor<2xf32> + ``` + """ + + name = "stablehlo.imag" + + +@irdl_op_definition +class IsFiniteOp(ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_PredTensor]): + """ + Performs element-wise check whether the value in `x` is finite (i.e. is + neither +Inf, -Inf, nor NaN) and produces a `y` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#is_finite + + Example: + ```mlir + %y = stablehlo.is_finite %x : (tensor<7xf64>) -> tensor<7xi1> + ``` + """ + + name = "stablehlo.is_finite" + + +@irdl_op_definition +class LogOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise logarithm operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log + + Example: + ```mlir + %result = stablehlo.log %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.log" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class LogPlusOneOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise logarithm plus one operation on `operand` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log_plus_one + + Example: + ```mlir + %result = stablehlo.log_plus_one %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.log_plus_one" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class LogisticOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise logistic operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#logistic + + Example: + ```mlir + %result = stablehlo.logistic %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.logistic" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class NegateOp( + ElementwiseUnaryOperation[ + HLO_IntFpOrComplexOrQuantizedIntTensor, HLO_IntFpOrComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise negation of `operand` tensor and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#negate + + Example: + ```mlir + %result = stablehlo.negate %operand : tensor<2x3xi32> + ``` + """ + + name = "stablehlo.negate" + + +@irdl_op_definition +class RealOp(ElementwiseUnaryOperation[HLO_FpOrComplexTensor, HLO_FloatTensor]): + """ + Extracts the real part, element-wise, from the `operand` and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real + + Example: + ```mlir + %result = stablehlo.real %operand : tensor<2xcomplex> : tensor<2xf32> + ``` + """ + + name = "stablehlo.real" + + +@irdl_op_definition +class RoundNearestAfzOp( + ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor] +): + """ + Performs element-wise rounding towards the nearest integer, breaking ties + away from zero, on the `operand` tensor and produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_afz + + Example: + ```mlir + %result = stablehlo.round_nearest_afz %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.round_nearest_afz" + + +@irdl_op_definition +class RoundNearestEvenOp( + ElementwiseUnaryOperation[HLO_FpOrQuantizedIntTensor, HLO_FpOrQuantizedIntTensor] +): + """ + Performs element-wise rounding towards the nearest integer, breaking ties + towards the even integer, on the `operand` tensor and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even + + Example: + ```mlir + %result = stablehlo.round_nearest_even %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.round_nearest_even" + + +@irdl_op_definition +class RsqrtOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise reciprocal square root operation on `operand` tensor + and produces a `result` tensor, implementing the `rSqrt` operation from the + IEEE-754 specification. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rsqrt + + Example: + ```mlir + %result = stablehlo.rsqrt %operand : tensor<2x2xf32> + ``` + """ + + name = "stablehlo.rsqrt" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class SignOp( + ElementwiseUnaryOperation[ + HLO_SIntFpComplexOrQuantizedIntTensor, HLO_SIntFpComplexOrQuantizedIntTensor + ] +): + """ + Returns the sign of the `operand` element-wise and produces a `result` + tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign + + Example: + ```mlir + %result = stablehlo.sign %operand : tensor<5xf64> + ``` + """ + + name = "stablehlo.sign" + + +@irdl_op_definition +class SineOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise sine operation on `operand` tensor and produces a + `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine + + Example: + ```mlir + %result = stablehlo.sine %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.sine" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class SqrtOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise square root operation on `operand` tensor and produces + a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sqrt + + Example: + ```mlir + %result = stablehlo.sqrt %operand : tensor<2x2xf32> + ``` + """ + + name = "stablehlo.sqrt" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class TanOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise tangent operation on `operand` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tan + + Example: + ```mlir + %result = stablehlo.tan %operand : tensor<2x2xf64> + ``` + """ + + name = "stablehlo.tan" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) + + +@irdl_op_definition +class TanhOp( + ElementwiseUnaryOperation[ + HLO_FpComplexOrQuantizedIntTensor, HLO_FpComplexOrQuantizedIntTensor + ] +): + """ + Performs element-wise hyperbolic tangent operation on `operand` tensor and + produces a `result` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tanh + + Example: + ```mlir + %result = stablehlo.tanh %operand : tensor<2xf32> + ``` + """ + + name = "stablehlo.tanh" + + result_accuracy = opt_attr_def( + ResultAccuracyModeAttr, ResultAccuracyModeAttr(ResultAccuracyMode.DEFAULT) + ) diff --git a/src/xdsl_jax/dialects/stablehlo/extensibility.py b/src/xdsl_jax/dialects/stablehlo/extensibility.py new file mode 100644 index 0000000..680322c --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/extensibility.py @@ -0,0 +1,176 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods +# pyright: reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportAttributeAccessIssue=false + +""" +Extensibility operations for the StableHLO dialect. +""" + +from xdsl.dialects.builtin import ( + ArrayAttr, + BoolAttr, + DenseIntElementsAttr, + DictionaryAttr, + FlatSymbolRefAttr, + StringAttr, + TensorType, + TupleType, +) +from xdsl.ir import Attribute +from xdsl.irdl import ( + AnyAttr, + IRDLOperation, + irdl_op_definition, + opt_prop_def, + prop_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.traits import ( + MemoryEffect, +) +from xdsl.utils.exceptions import VerifyException + +from .attributes import ( + CustomCallApiVersion, + CustomCallApiVersionAttr, + OutputOperandAlias, +) + + +@irdl_op_definition +class CustomCallOp(IRDLOperation): + """ + Encapsulates an implementation-defined operation ``call_target_name`` that + takes ``inputs`` and ``called_computations`` and produces ``results``. + + Depending on the API version there are two ways to pass extra bits of static + information to the external function: + 1. Use ``API_VERSION_TYPED_FFI`` which allows passing a dictionary attribute. + 2. Use a previous API version with a ``StringAttr`` to encode backend config. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call + + Example: + ```mlir + %results = stablehlo.custom_call @foo(%input0) { + backend_config = {bar = 42 : i32}, + api_version = 4 : i32, + called_computations = [@foo] + } : (tensor) -> tensor + ``` + """ + + name = "stablehlo.custom_call" + + inputs = var_operand_def(AnyAttr()) + call_target_name = prop_def(StringAttr) + has_side_effect = prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + backend_config = opt_prop_def(DictionaryAttr | StringAttr) + api_version = prop_def( + CustomCallApiVersionAttr, + default_value=CustomCallApiVersionAttr( + CustomCallApiVersion.API_VERSION_ORIGINAL + ), + ) + called_computations = opt_prop_def( + ArrayAttr[FlatSymbolRefAttr], default_value=ArrayAttr([]) + ) + operand_layouts = opt_prop_def(ArrayAttr[DenseIntElementsAttr]) + result_layouts = opt_prop_def(ArrayAttr[DenseIntElementsAttr]) + output_operand_aliases = prop_def(ArrayAttr[OutputOperandAlias]) + + result = var_result_def(AnyAttr()) + + traits = traits_def( + MemoryEffect(), + ) + + # TODO: Implement CustomDirective + # assembly_format = """ + # custom($call_target_name) `(` $inputs `)` + # attr-dict `:` functional-type(operands, results) + # """ + + def verify_(self) -> None: + """Verify the CustomCallOp.""" + # If both operand and result layout attributes are not specified then nothing to verify. + if self.operand_layouts is None and self.result_layouts is None: + return + + # Layout constraints for either both operands & results or none should be specified. + if (self.operand_layouts is None) != (self.result_layouts is None): + raise VerifyException( + "Layout attributes should be specified for either both operands and results or none." + ) + + assert self.operand_layouts is not None + assert self.result_layouts is not None + + def verify_types_and_layouts( + types: tuple[Attribute, ...], layouts: ArrayAttr, value_name: str + ): + if len(types) != len(layouts.data): + raise VerifyException( + "Number of " + f"{value_name}s must match the number of {value_name} layouts, " + f"{len(types)} != {len(layouts.data)}" + ) + + for index, (ty, layout_attr) in enumerate(zip(types, layouts.data)): + # Tuple types are not fully supported with layout constraints yet + if isinstance(ty, TupleType): + raise VerifyException( + "Tuple types are not fully supported with layout constraints yet" + ) + + try: + dims = list(layout_attr.get_values()) + except Exception as exc: + raise VerifyException("invalid layout attribute") from exc + + # For non-tensor types, layout must be empty + if not isinstance(ty, TensorType): + if len(dims) == 0: + continue + raise VerifyException( + "Only tensor types can have non-empty layout: " + f"{value_name} #{index} of type {ty} has layout {dims}" + ) + + # For ranked tensors, require permutation of [0, rank) + rank = ty.get_num_dims() + if rank != len(dims) or sorted(dims) != list(range(rank)): + raise VerifyException( + f"incorrect layout {dims} for type {ty}, layout must be a permutation of [0, {rank})" + ) + + # Operand types + operand_types: tuple[Attribute, ...] = tuple(op.type for op in self.operands) + + # Result types: if single tuple result, use its element types + if len(self.result_types) == 1 and isinstance(self.result_types[0], TupleType): + tuple_ty: TupleType = self.result_types[0] + result_types = tuple(tuple_ty.types.data) + else: + result_types = tuple(self.result_types) + + # Verify that operands and operand layouts match. + verify_types_and_layouts(operand_types, self.operand_layouts, "operand") + # Verify that results and result layouts match. + verify_types_and_layouts(result_types, self.result_layouts, "result") diff --git a/src/xdsl_jax/dialects/stablehlo/reduction.py b/src/xdsl_jax/dialects/stablehlo/reduction.py new file mode 100644 index 0000000..fa60537 --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/reduction.py @@ -0,0 +1,166 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=too-few-public-methods +# pyright: reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportAttributeAccessIssue=false, reportUnnecessaryComparison=false, reportOptionalMemberAccess=false + +""" +Reduction operations for the StableHLO dialect. +""" + +from xdsl.dialects.builtin import DenseArrayBase, i64 +from xdsl.irdl import ( + IRDLOperation, + irdl_op_definition, + prop_def, + region_def, + traits_def, + var_operand_def, + var_result_def, +) +from xdsl.irdl.operations import SameVariadicOperandSize +from xdsl.traits import ( + RecursiveMemoryEffect, + SingleBlockImplicitTerminator, +) +from xdsl.utils.exceptions import VerifyException + +from xdsl_jax.dialects import _stablehlo_upstream as xstablehlo + +from .types import HLO_Tensor + + +@irdl_op_definition +class ReduceOp(IRDLOperation): + """ + Applies a reduction function ``body`` to ``inputs`` and ``init_values`` along the + ``dimensions`` and produces a ``result`` tensor. + + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce + + Example: + ```mlir + %result = "stablehlo.reduce"(%input, %init_value) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = stablehlo.add %arg0, %arg1 : tensor + stablehlo.return %0 : tensor + }) { + dimensions = array + } : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + ``` + """ + + name = "stablehlo.reduce" + + inputs = var_operand_def(HLO_Tensor) + init_values = var_operand_def(HLO_Tensor) + dimensions = prop_def(DenseArrayBase.constr(i64)) + result = var_result_def(HLO_Tensor) + body = region_def("single_block") + + irdl_options = [SameVariadicOperandSize()] + + traits = traits_def( + RecursiveMemoryEffect(), + # TODO: InferShapedTypeOpInterface(), + # TODO: HLO_RecursivelySpeculatableIfAllInputsStatic, + # TODO: InferTensorTypeWithReify(), + SingleBlockImplicitTerminator(xstablehlo.ReturnOp), + ) + + # pylint: disable=no-member + # pylint: disable=too-many-branches + def verify_(self): + """Verify the ReduceOp.""" + # Gather shaped operand/result types + input_types = [op.type for op in self.inputs] + init_types = [op.type for op in self.init_values] + + # reduce_c1/c4/c5/i3: verify inputs and infer shape compatibility + dims_attr = self.dimensions + dims = tuple(dims_attr.get_values()) if dims_attr is not None else tuple() + + # Basic structural checks mirroring verifyReduceOpInputsAndInferShape + if len(input_types) == 0: + raise VerifyException("expected at least 1 input for reduce") + if len(input_types) != len(init_types): + raise VerifyException("number of inputs must match number of init_values") + + # All inputs must have equal rank; dimensions must be within rank and unique + # and not empty. + ranks = [] + for t in input_types: + # Tensors by op definition + assert hasattr(t, "get_num_dims") + ranks.append(t.get_num_dims()) + rank0 = ranks[0] + if any(r != rank0 for r in ranks): + raise VerifyException("all inputs must have the same rank") + + if len(dims) == 0: + raise VerifyException("dimensions cannot be empty for reduce") + if len(set(dims)) != len(dims): + raise VerifyException("dimensions should not have duplicates") + if any(d < 0 or d >= rank0 for d in dims): + raise VerifyException("dimensions contains an invalid value") + + # Element type compatibility between each input and its init value + for it, iv in zip(input_types, init_types): + it_elem = it.get_element_type() + iv_elem = iv.get_element_type() + if it_elem != iv_elem: + raise VerifyException( + "input and init_value must have the same element type" + ) + + # reduce_c2/c6: verify reducer region shape + # Expect block with arity 2 * number of inputs, with matching tensor element types and 0D tensors + if len(self.body.blocks) != 1: + raise VerifyException("reducer must have a single block") + block = self.body.blocks[0] + + expected_args = 2 * len(input_types) + if len(block.args) != expected_args: + raise VerifyException( + f"reducer must take {expected_args} arguments, got {len(block.args)}" + ) + + # Each pair (arg_i, arg_{i+N}) must be 0D tensors of the input element type + for i, it in enumerate(input_types): + it_elem = it.get_element_type() + acc = block.args[i] + val = block.args[i + len(input_types)] + for a in (acc, val): + a_ty = a.type + if not hasattr(a_ty, "get_num_dims") or a_ty.get_num_dims() != 0: + raise VerifyException("reducer arguments must be rank-0 tensors") + if a_ty.get_element_type() != it_elem: + raise VerifyException( + "reducer argument element types must match input element type" + ) + + # Region must terminate with exactly len(inputs) results + ret = block.ops.last + if len(ret.operands) != len(input_types): + raise VerifyException("reducer must return exactly one value per input") + for i, it in enumerate(input_types): + it_elem = it.get_element_type() + rty = ret.operands[i].type + if not hasattr(rty, "get_num_dims") or rty.get_num_dims() != 0: + raise VerifyException("reducer return values must be rank-0 tensors") + if rty.get_element_type() != it_elem: + raise VerifyException( + "reducer return element types must match input element type" + ) diff --git a/src/xdsl_jax/dialects/stablehlo/types.py b/src/xdsl_jax/dialects/stablehlo/types.py new file mode 100644 index 0000000..25dd588 --- /dev/null +++ b/src/xdsl_jax/dialects/stablehlo/types.py @@ -0,0 +1,261 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyright: reportPrivateImportUsage=false, reportUnknownParameterType=false, reportUnknownVariableType=false, reportInvalidTypeForm=false, reportUnknownArgumentType=false, reportArgumentType=false + +""" +StableHLO type definitions for PennyLane's compiler infrastructure. + +This module provides type definitions based on the StableHLO specification +(https://github.com/openxla/stablehlo/blob/main/docs/spec.md), including +token types and other necessary type definitions for StableHLO operations. +""" + +# pylint: disable=too-few-public-methods + +from typing import TypeAlias + +from xdsl.dialects.builtin import ( + AnyFloatConstr, + ComplexType, + Float32Type, + Float64Type, + IndexType, + IntAttr, + IntAttrConstraint, + IntegerType, + ParametrizedAttribute, + Signedness, + SignednessAttr, + TensorType, + i1, +) +from xdsl.irdl import eq, irdl_attr_definition +from xdsl.irdl.attributes import EqAttrConstraint, ParamAttrConstraint +from xdsl.irdl.constraints import IntSetConstraint + +from xdsl_jax.dialects._stablehlo_upstream import TokenType +from xdsl_jax.xdsl_extras.constraints import ( + NestedTupleOfConstraint, +) + + +def _create_param_constrained_type( + base_attr: type, widths: list[int], signedness: Signedness | None = None +): + """Create an integer type constrained using ParamAttrConstraint with IntSetConstraint.""" + width_constraint = IntAttrConstraint(IntSetConstraint(frozenset(widths))) + + if signedness is None: + signedness_constraint = None + else: + signedness_constraint = EqAttrConstraint(SignednessAttr(signedness)) + + return ParamAttrConstraint(base_attr, [width_constraint, signedness_constraint]) + + +# ============================================================================= +# Core StableHLO types constraints +# ============================================================================= + +HLO_Pred = eq(i1) +HLO_PredTensor: TypeAlias = TensorType[HLO_Pred] + +# NOTE: IntegerType is defined in the StableHLO spec as: +# IntegerType ::= SignedIntegerType | UnsignedIntegerType, +# but the MLIR implementation is using signless integers instead of signed, +# and there is a TODO to fix it. + +_HLO_INT_WIDTHS = [2, 4, 8, 16, 32, 64] +HLO_SignedInt = _create_param_constrained_type( + IntegerType, _HLO_INT_WIDTHS, Signedness.SIGNED +) +HLO_UnsignedInt = _create_param_constrained_type( + IntegerType, _HLO_INT_WIDTHS, Signedness.UNSIGNED +) +HLO_SignlessInt = _create_param_constrained_type(IntegerType, _HLO_INT_WIDTHS, None) + +HLO_Int: TypeAlias = HLO_UnsignedInt | HLO_SignlessInt +HLO_IntTensor: TypeAlias = TensorType[HLO_Int] + +_HLO_INT_OR_PRED_WIDTHS = [1, 2, 4, 8, 16, 32, 64] +HLO_IntOrPred = _create_param_constrained_type( + IntegerType, _HLO_INT_OR_PRED_WIDTHS, None +) + + +HLO_AnyIntegerOrIndex: TypeAlias = IntegerType | IndexType +HLO_AnyIntegerOrIndexTensor: TypeAlias = TensorType.constr(HLO_AnyIntegerOrIndex) + +HLO_DimensionValue: TypeAlias = HLO_Int | IndexType + +# Constraint variants for use in unions with ParamAttrConstraint +HLO_Float: TypeAlias = AnyFloatConstr +HLO_Float32Or64: TypeAlias = Float32Type | Float64Type +HLO_FloatTensor: TypeAlias = TensorType.constr(HLO_Float) +HLO_Fp32Or64Tensor: TypeAlias = TensorType.constr(HLO_Float32Or64) + +# Complex as a constraint over element types {f32,f64} +HLO_Complex: TypeAlias = ComplexType[HLO_Float32Or64] +HLO_ComplexTensor: TypeAlias = TensorType.constr(HLO_Complex) + +# ============================================================================= +# Quantized element type definitions +# ============================================================================= + + +@irdl_attr_definition +class UniformQuantizedType(ParametrizedAttribute): + """ + Placeholder for StableHLO per-tensor uniform quantized types. + + Parameterized by width to support different quantized integer widths + (e.g., 8-bit, 16-bit quantization). + """ + + name = "stablehlo.uniform_quantized" + width: IntAttr + signedness: SignednessAttr + + +@irdl_attr_definition +class UniformQuantizedPerAxisType(ParametrizedAttribute): + """ + Placeholder for StableHLO per-axis uniform quantized types. + + Parameterized by width to support different quantized integer widths + (e.g., 8-bit, 16-bit quantization). + """ + + name = "stablehlo.uniform_quantized_per_axis" + width: IntAttr + signedness: SignednessAttr + + +# ============================================================================= +# StableHLO quantized type aliases +# ============================================================================= + +_HLO_QUANTIZED_WIDTHS = [2, 4, 8, 16, 32] + +# Constraint-based types for operation definitions +HLO_QuantizedSignedInt = _create_param_constrained_type( + UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, Signedness.SIGNED +) +HLO_QuantizedUnsignedInt = _create_param_constrained_type( + UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, Signedness.UNSIGNED +) +HLO_QuantizedAnySignednessInt = _create_param_constrained_type( + UniformQuantizedType, _HLO_QUANTIZED_WIDTHS, None +) +HLO_QuantizedInt: TypeAlias = HLO_QuantizedSignedInt | HLO_QuantizedUnsignedInt + +HLO_PerAxisQuantizedSignedInt = _create_param_constrained_type( + UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, Signedness.SIGNED +) +HLO_PerAxisQuantizedUnsignedInt = _create_param_constrained_type( + UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, Signedness.UNSIGNED +) +HLO_PerAxisQuantizedAnySignednessInt = _create_param_constrained_type( + UniformQuantizedPerAxisType, _HLO_QUANTIZED_WIDTHS, None +) +HLO_PerAxisQuantizedInt: TypeAlias = ( + HLO_PerAxisQuantizedSignedInt | HLO_PerAxisQuantizedUnsignedInt +) + +# ============================================================================= +# Main tensor type definitions +# ============================================================================= + +HLO_Tensor: TypeAlias = TensorType[ + HLO_Float | HLO_Complex | HLO_IntOrPred | HLO_QuantizedInt +] +HLO_NonQuantizedTensor: TypeAlias = TensorType[HLO_Float | HLO_Complex | HLO_IntOrPred] + +# Note: There is a discrepancy between the StableHLO spec and the MLIR implementation. +# The spec does not allow unranked tensors, but the MLIR implementation +# defines it as a tensor of any type and rank. There is a TODO to fix this in MLIR. +# Therefore, we use the correct ranked tensor type. +HLO_AnyTensor: TypeAlias = TensorType[ + HLO_Float | HLO_Complex | HLO_IntOrPred | HLO_QuantizedInt | HLO_PerAxisQuantizedInt +] +HLO_TensorOrToken: TypeAlias = HLO_Tensor | TokenType +HLO_TensorOrPerAxisQuantizedTensorOrToken: TypeAlias = HLO_AnyTensor | TokenType + +# HLO_AnyTuple : NestedTupleOf<[HLO_AnyTensor, HLO_Token]> +HLO_AnyTuple = NestedTupleOfConstraint([HLO_AnyTensor, TokenType]) + +HLO_CustomCallValue: TypeAlias = HLO_Tensor | TokenType | HLO_AnyTuple + +# ============================================================================= +# HLO combined type definitions +# ============================================================================= + +HLO_PredOrIntTensor: TypeAlias = TensorType.constr(HLO_IntOrPred) + +HLO_FpOrComplexTensor: TypeAlias = TensorType.constr(HLO_Float | HLO_Complex) +HLO_FpOrQuantizedIntTensor: TypeAlias = TensorType.constr(HLO_Float | HLO_QuantizedInt) +HLO_FpComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr( + HLO_Float | HLO_Complex | HLO_QuantizedInt +) +HLO_IntFpOrComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr( + HLO_Int | HLO_Float | HLO_Complex | HLO_QuantizedInt +) +HLO_SIntFpComplexOrQuantizedIntTensor: TypeAlias = TensorType.constr( + HLO_SignedInt | HLO_Float | HLO_Complex | HLO_QuantizedInt +) + + +__all__ = [ + # Core types + "HLO_Pred", + "HLO_PredTensor", + "HLO_Int", + "HLO_IntTensor", + "HLO_AnyIntegerOrIndex", + "HLO_AnyIntegerOrIndexTensor", + "HLO_DimensionValue", + "HLO_Float", + "HLO_Float32Or64", + "HLO_FloatTensor", + "HLO_Fp32Or64Tensor", + "HLO_ComplexTensor", + "HLO_SignedInt", + "HLO_UnsignedInt", + "HLO_SignlessInt", + "HLO_QuantizedSignedInt", + "HLO_QuantizedUnsignedInt", + "HLO_QuantizedAnySignednessInt", + "HLO_QuantizedInt", + "HLO_PerAxisQuantizedSignedInt", + "HLO_PerAxisQuantizedUnsignedInt", + "HLO_PerAxisQuantizedAnySignednessInt", + "HLO_PerAxisQuantizedInt", + # Quantized types + "UniformQuantizedType", + "UniformQuantizedPerAxisType", + "HLO_Tensor", + "HLO_NonQuantizedTensor", + "HLO_AnyTensor", + "HLO_TensorOrToken", + "HLO_TensorOrPerAxisQuantizedTensorOrToken", + "HLO_CustomCallValue", + # Combined types + "HLO_PredOrIntTensor", + "HLO_FpOrComplexTensor", + "HLO_FpOrQuantizedIntTensor", + "HLO_FpComplexOrQuantizedIntTensor", + "HLO_IntFpOrComplexOrQuantizedIntTensor", + "HLO_SIntFpComplexOrQuantizedIntTensor", +] diff --git a/src/xdsl_jax/xdsl_extras/__init__.py b/src/xdsl_jax/xdsl_extras/__init__.py new file mode 100644 index 0000000..9fe36c9 --- /dev/null +++ b/src/xdsl_jax/xdsl_extras/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains additional utilities and functionality not available upstream in xDSL.""" + +from .constraints import NestedTupleOfConstraint +from .traits import ( + AllMatchSameOperatorTrait, + Elementwise, + SameOperandsAndResultElementType, + SameOperandsAndResultShape, + SameOperandsElementType, +) + +__all__ = [ + # Constraints + "NestedTupleOfConstraint", + # Traits + "AllMatchSameOperatorTrait", + "Elementwise", + "SameOperandsAndResultElementType", + "SameOperandsAndResultShape", + "SameOperandsElementType", +] diff --git a/src/xdsl_jax/xdsl_extras/constraints.py b/src/xdsl_jax/xdsl_extras/constraints.py new file mode 100644 index 0000000..67816f8 --- /dev/null +++ b/src/xdsl_jax/xdsl_extras/constraints.py @@ -0,0 +1,84 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains additional type and attribute constraints that are currently not available +upstream in xDSL.""" + +# pyright: reportGeneralTypeIssues=false, reportIncompatibleMethodOverride=false, reportUnknownParameterType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false, reportUnknownMemberType=false, reportAttributeAccessIssue=false, reportAssignmentType=false, reportArgumentType=false + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from typing import TypeVar + +from xdsl.dialects.builtin import TupleType +from xdsl.ir import Attribute +from xdsl.irdl import ( + AttrConstraint, + ConstraintContext, + IntConstraint, + irdl_to_attr_constraint, +) +from xdsl.utils.exceptions import VerifyException + + +@dataclass(frozen=True, init=False) +class NestedTupleOfConstraint(AttrConstraint[TupleType]): + """Constrain a nested tuple whose flattened leaves all match any allowed constraints.""" + + elem_constraints: tuple[AttrConstraint, ...] + + def __init__(self, elem_constraints: Sequence[object]): + object.__setattr__( + self, + "elem_constraints", + tuple(irdl_to_attr_constraint(c) for c in elem_constraints), + ) + + def get_flattened(self, a: Attribute): + """Get the flattened leaves of a tuple.""" + if isinstance(a, TupleType): + for t in a.types.data: + yield from self.get_flattened(t) + else: + yield a + + def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: + """Verify that the attribute is a tuple of allowed types.""" + if not isinstance(attr, TupleType): + raise VerifyException(f"expected TupleType, got {type(attr)}") + + leaves = list(self.get_flattened(attr)) + + for i, leaf in enumerate(leaves): + matched = False + for constr in self.elem_constraints: + try: + constr.verify(leaf, constraint_context) + matched = True + break + except VerifyException: + # Try next allowed constraint + pass + if not matched: + raise VerifyException( + f"tuple leaf {i} failed all allowed constraints: {leaf}" + ) + + def mapping_type_vars( + self, + type_var_mapping: Mapping[TypeVar, AttrConstraint | IntConstraint], + ) -> AttrConstraint: + """Map type variables to constraints.""" + # pylint: disable=unused-argument + return self diff --git a/src/xdsl_jax/xdsl_extras/traits.py b/src/xdsl_jax/xdsl_extras/traits.py new file mode 100644 index 0000000..0d341aa --- /dev/null +++ b/src/xdsl_jax/xdsl_extras/traits.py @@ -0,0 +1,251 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Traits for xDSL operations. + +This module provides operation traits that can be used to define operation invariants, +additional semantic information, or to group operations that have similar properties. +""" + +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from xdsl.dialects.builtin import TensorType, VectorType +from xdsl.ir import Attribute, Operation +from xdsl.traits import OpTrait +from xdsl.utils.exceptions import VerifyException +from xdsl.utils.type import get_element_type_or_self, have_compatible_shape + + +@dataclass(frozen=True) +class SameOperandsAndResultShape(OpTrait): + """Constrain the operation to have the same shape for all operands and results.""" + + # TODO: This trait should be added to ElementwiseBinaryOperation and + # ElementwiseUnaryOperation operations when upstreaming to xdsl. + + def verify(self, op: Operation) -> None: + """Verify that the operation has the same shape for all operands and results.""" + + if len(op.results) < 1 or len(op.operands) < 1: + raise VerifyException( + f"'{op.name}' requires at least one result or operand" + ) + + # Get all types (operands and results) to check for compatible shapes + all_types = list(op.operand_types) + list(op.result_types) + + # Check that all types have compatible shapes + for type_to_check in all_types[1:]: + if not have_compatible_shape(all_types[0], type_to_check): + raise VerifyException( + f"'{op.name}' requires the same shape for all operands and results" + ) + + +@dataclass(frozen=True) +class SameOperandsElementType(OpTrait): + """Constrain the operation to have the same element type for all operands.""" + + # TODO: This trait should be added to ElementwiseBinaryOperation and + # ElementwiseUnaryOperation operations when upstreaming to xdsl. + + def verify(self, op: Operation) -> None: + """Verify that the operation has the same element type for all operands.""" + + if len(op.operands) <= 1: + return + + # Get the element type of the first operand + first_elem_type = get_element_type_or_self(op.operand_types[0]) + + # Check that all other operands have the same element type + for operand_type in op.operand_types[1:]: + elem_type = get_element_type_or_self(operand_type) + if elem_type != first_elem_type: + raise VerifyException( + f"'{op.name}' requires the same element type for all operands" + ) + + +@dataclass(frozen=True) +class SameOperandsAndResultElementType(OpTrait): + """Constrain the operation to have the same element type for all operands and results.""" + + def verify(self, op: Operation) -> None: + """Verify that the operation has the same element type for all operands and results.""" + + if len(op.results) < 1 or len(op.operands) < 1: + raise VerifyException( + f"'{op.name}' requires at least one result or operand" + ) + + # Get the element type of the first operand + first_elem_type = get_element_type_or_self(op.operand_types[0]) + + all_types = list(op.operand_types) + list(op.result_types) + + # Check that all other operands have the same element type + for type_to_check in all_types[1:]: + elem_type = get_element_type_or_self(type_to_check) + if elem_type != first_elem_type: + raise VerifyException( + f"'{op.name}' requires the same element type for all operands and results" + ) + + +@dataclass(frozen=True) +class Elementwise(OpTrait): + """ + The following is the definition of the `Elementwise` trait from MLIR: + + https://github.com/llvm/llvm-project/blob/f8cb7987c64dcffb72414a40560055cb717dbf74/mlir/include/mlir/IR/OpDefinition.h#L1378-L1409 + + TODO: Add this trait to all the elementwise operations in xdsl when upstreaming. + + Tags elementwise operations on vectors or tensors. + + NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this trait. + In particular, broadcasting behavior is not allowed. + + An `Elementwise` op must satisfy the following properties: + + 1. If any result is a vector/tensor then at least one operand must also be a + vector/tensor. + 2. If any operand is a vector/tensor then there must be at least one result + and all results must be vectors/tensors. + 3. All operand and result vector/tensor types must be of the same shape. The + shape may be dynamic in which case the op's behaviour is undefined for + non-matching shapes. + 4. The operation must be elementwise on its vector/tensor operands and + results. When applied to single-element vectors/tensors, the result must + be the same per element. + + Rationale: + - 1. and 2. guarantee a well-defined iteration space and exclude the cases + of 0 non-scalar operands or 0 non-scalar results, which complicate a + generic definition of the iteration space. + - 3. guarantees that folding can be done across scalars/vectors/tensors with + the same pattern, as otherwise lots of special handling for type + mismatches would be needed. + - 4. guarantees that no error handling is needed. Higher-level dialects + should reify any needed guards or error handling code before lowering to + an Elementwise op. + """ + + def verify(self, op: Operation) -> None: + """Verify that the operation is elementwise.""" + + # Filter mappable types from results and operands (vectors/tensors only) + result_mappable_types = [ + t for t in op.result_types if Elementwise.is_mappable_type(t) + ] + operand_mappable_types = [ + t for t in op.operand_types if Elementwise.is_mappable_type(t) + ] + + # If the op only has scalar operand/result types, then we have nothing to check + if not result_mappable_types and not operand_mappable_types: + return + + # If a result is non-scalar, then at least one operand must be non-scalar + if result_mappable_types and not operand_mappable_types: + raise VerifyException( + f"'{op.name}': if a result is non-scalar, then at least one " + "operand must be non-scalar" + ) + + # At this point, operand_mappable_types should not be empty + assert operand_mappable_types, "At least one operand must be a vector or tensor" + + # If an operand is non-scalar, then there must be at least one non-scalar result + if not result_mappable_types: + raise VerifyException( + f"'{op.name}': if an operand is non-scalar, then there must be at " + "least one non-scalar result" + ) + + # If an operand is non-scalar, then all results must be non-scalar + if len(result_mappable_types) != len(op.results): + raise VerifyException( + f"'{op.name}': if an operand is non-scalar, then all results must be non-scalar" + ) + + # All non-scalar operands/results must have the same shape and base type + all_types = operand_mappable_types + result_mappable_types + + # Check that all types have compatible shapes + for type_to_check in all_types[1:]: + if not have_compatible_shape(all_types[0], type_to_check): + raise VerifyException( + f"'{op.name}': all non-scalar operands/results must have the " + "same shape and base type" + ) + + @staticmethod + def is_mappable_type(attr_type: Attribute) -> bool: + """Return True if the type is elementwise-mappable (vector or tensor). + + There is a TODO in MLIR to generalize this trait to avoid hardcoding vector/tensor. + We should update this when the TODO is resolved. + """ + return isinstance(attr_type, (VectorType, TensorType)) + + +@dataclass(frozen=True) +class AllMatchSameOperatorTrait(OpTrait): + """ + Verify that a list of operation attributes all match under the same operator + (e.g., size, rank, type, shape, element type). + + Parameters: + - attr_names: attribute names on the op to compare + - operator: callable taking the attribute value and returning a comparable value + - summary: human-readable name of the property used in error messages + """ + + attr_names: tuple[str, ...] + operator: Callable[[Any], Any] + summary: str + + def verify(self, op: Operation) -> None: + """Verify that the operation attributes all match under the same operator.""" + attributes = [] + for name in self.attr_names: + value = getattr(op, name, None) + if value is None: + return + attributes.append(value) + + if len(attributes) <= 1: + return + + names_str = ", ".join(self.attr_names) + try: + results = [self.operator(attr) for attr in attributes] + except (TypeError, ValueError, AttributeError) as e: + raise VerifyException( + f"cannot compute {self.summary} for {{{names_str}}}: {e}" + ) from e + + first = results[0] + if any(res != first for res in results[1:]): + results_str = ", ".join(str(r) for r in results) + raise VerifyException( + f"all of {{{names_str}}} must have the same {self.summary}: got {self.summary}s {results_str}" + ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..94dc452 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,102 @@ +"""Pytest configuration for xdsl-jax tests.""" +# pyright: reportUnknownMemberType=false, reportAttributeAccessIssue=false, reportUntypedFunctionDecorator=false + +from io import StringIO + +from filecheck.finput import FInput +from filecheck.matcher import Matcher +from filecheck.options import parse_argv_options +from filecheck.parser import Parser, pattern_for_opts +from xdsl.context import Context +from xdsl.dialects import test +from xdsl.parser import Parser as XDSLParser +from xdsl.passes import ModulePass, PassPipeline +from xdsl.printer import Printer + +import pytest + + +def _run_filecheck_impl( + program_str: str, + pipeline: tuple[ModulePass, ...] = (), + verify: bool = False, + roundtrip: bool = False, +) -> None: + """Run filecheck on an xDSL module, comparing it to a program string containing + filecheck directives.""" + ctx = Context() + # Register all needed dialects + ctx.load_dialect(test.Test) + + # Load StableHLO dialect from xdsl-jax + from xdsl_jax.dialects.stablehlo import StableHLO + + ctx.load_dialect(StableHLO) + + # Load commonly used dialects + from xdsl.dialects import arith, builtin, func + + ctx.load_dialect(builtin.Builtin) + ctx.load_dialect(func.Func) + ctx.load_dialect(arith.Arith) + + parser = XDSLParser(ctx, program_str) + xdsl_module = parser.parse_module() + + if roundtrip: + # Print generic format + stream = StringIO() + Printer(stream=stream, print_generic_format=True).print_op(xdsl_module) + parser = XDSLParser(ctx, stream.getvalue()) + xdsl_module = parser.parse_module() + + if verify: + xdsl_module.verify() + + pass_pipeline = PassPipeline(pipeline) + pass_pipeline.apply(ctx, xdsl_module) + + if verify: + xdsl_module.verify() + + stream = StringIO() + Printer(stream).print_op(xdsl_module) + opts = parse_argv_options(["filecheck", __file__]) + matcher = Matcher( + opts, + FInput("no-name", stream.getvalue()), + Parser(opts, StringIO(program_str), *pattern_for_opts(opts)), + ) + + exit_code = matcher.run() + assert exit_code == 0, f""" + filecheck failed with exit code {exit_code}. + + Original program string: + {program_str} + + Parsed module: + {stream.getvalue()} + """ + + +@pytest.fixture +def run_filecheck(): + """Fixture to run filecheck on an xDSL module. + + This fixture uses FileCheck to verify the correctness of a parsed MLIR string. Testers + can provide a pass pipeline to transform the IR, and verify correctness by including + FileCheck directives as comments in the input program string. + + Args: + program_str (str): The MLIR string containing the input program and FileCheck directives + pipeline (tuple[ModulePass]): A sequence containing all passes that should be applied + before running FileCheck + verify (bool): Whether or not to verify the IR after parsing and transforming. + ``False`` by default. + roundtrip (bool): Whether or not to use round-trip testing. This is useful for dialect + tests to verify that xDSL both parses and prints the IR correctly. If ``True``, we parse + the program string into an xDSL module, print it in generic format, and then parse the + generic program string back to an xDSL module. ``False`` by default. + """ + return _run_filecheck_impl diff --git a/tests/pytest/__init__.py b/tests/pytest/__init__.py new file mode 100644 index 0000000..536a6df --- /dev/null +++ b/tests/pytest/__init__.py @@ -0,0 +1 @@ +"""Pytest tests for xdsl-jax.""" diff --git a/tests/pytest/test_stablehlo_dialect.py b/tests/pytest/test_stablehlo_dialect.py new file mode 100644 index 0000000..cb5a4c2 --- /dev/null +++ b/tests/pytest/test_stablehlo_dialect.py @@ -0,0 +1,1018 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test module for xdsl_jax.dialects.stablehlo.""" +# pyright: reportUnknownVariableType=false, reportUnknownMemberType=false, reportAttributeAccessIssue=false + +from typing import Any + +import pytest + +# pylint: disable=wrong-import-position + +xdsl = pytest.importorskip("xdsl") +filecheck = pytest.importorskip("filecheck") + + +def test_all_unary_operations(run_filecheck: Any) -> None: + """Test all unary elementwise operations.""" + program = r""" + // CHECK: %[[tf32:.*]] = "test.op"() : () -> tensor + %tf32 = "test.op"() : () -> tensor + + // CHECK: %[[tf64:.*]] = "test.op"() : () -> tensor + %tf64 = "test.op"() : () -> tensor + + // CHECK: %[[tcomplex:.*]] = "test.op"() : () -> tensor> + %tcomplex = "test.op"() : () -> tensor> + + // CHECK: %convert = "stablehlo.convert"(%[[tf32]]) : (tensor) -> tensor + %convert = "stablehlo.convert"(%tf32) : (tensor) -> tensor + + // CHECK: %cos = "stablehlo.cosine"(%[[tf32]]) : (tensor) -> tensor + %cos = "stablehlo.cosine"(%tf32) : (tensor) -> tensor + + // CHECK: %exp = "stablehlo.exponential"(%[[tf32]]) : (tensor) -> tensor + %exp = "stablehlo.exponential"(%tf32) : (tensor) -> tensor + + // CHECK: %exponential_minus_one = "stablehlo.exponential_minus_one"(%[[tf32]]) : (tensor) -> tensor + %exponential_minus_one = "stablehlo.exponential_minus_one"(%tf32) : (tensor) -> tensor + + // CHECK: %floor = "stablehlo.floor"(%[[tf64]]) : (tensor) -> tensor + %floor = "stablehlo.floor"(%tf64) : (tensor) -> tensor + + // CHECK: %imag = "stablehlo.imag"(%[[tcomplex]]) : (tensor>) -> tensor + %imag = "stablehlo.imag"(%tcomplex) : (tensor>) -> tensor + + // CHECK: %is_finite = "stablehlo.is_finite"(%[[tf32]]) : (tensor) -> tensor + %is_finite = "stablehlo.is_finite"(%tf32) : (tensor) -> tensor + + // CHECK: %log = "stablehlo.log"(%[[tf32]]) : (tensor) -> tensor + %log = "stablehlo.log"(%tf32) : (tensor) -> tensor + + // CHECK: %log_plus_one = "stablehlo.log_plus_one"(%[[tf64]]) : (tensor) -> tensor + %log_plus_one = "stablehlo.log_plus_one"(%tf64) : (tensor) -> tensor + + // CHECK: %logistic = "stablehlo.logistic"(%[[tf32]]) : (tensor) -> tensor + %logistic = "stablehlo.logistic"(%tf32) : (tensor) -> tensor + + // CHECK: %negate = "stablehlo.negate"(%[[tf32]]) : (tensor) -> tensor + %negate = "stablehlo.negate"(%tf32) : (tensor) -> tensor + + // CHECK: %real = "stablehlo.real"(%[[tcomplex]]) : (tensor>) -> tensor + %real = "stablehlo.real"(%tcomplex) : (tensor>) -> tensor + + // CHECK: %round_afz = "stablehlo.round_nearest_afz"(%[[tf64]]) : (tensor) -> tensor + %round_afz = "stablehlo.round_nearest_afz"(%tf64) : (tensor) -> tensor + + // CHECK: %round_even = "stablehlo.round_nearest_even"(%[[tf64]]) : (tensor) -> tensor + %round_even = "stablehlo.round_nearest_even"(%tf64) : (tensor) -> tensor + + // CHECK: %rsqrt = "stablehlo.rsqrt"(%[[tf32]]) : (tensor) -> tensor + %rsqrt = "stablehlo.rsqrt"(%tf32) : (tensor) -> tensor + + // CHECK: %sign = "stablehlo.sign"(%[[tf32]]) : (tensor) -> tensor + %sign = "stablehlo.sign"(%tf32) : (tensor) -> tensor + + // CHECK: %sin = "stablehlo.sine"(%[[tf32]]) : (tensor) -> tensor + %sin = "stablehlo.sine"(%tf32) : (tensor) -> tensor + + // CHECK: %sqrt = "stablehlo.sqrt"(%[[tf32]]) : (tensor) -> tensor + %sqrt = "stablehlo.sqrt"(%tf32) : (tensor) -> tensor + + // CHECK: %tan = "stablehlo.tan"(%[[tf64]]) : (tensor) -> tensor + %tan = "stablehlo.tan"(%tf64) : (tensor) -> tensor + + // CHECK: %tanh = "stablehlo.tanh"(%[[tf32]]) : (tensor) -> tensor + %tanh = "stablehlo.tanh"(%tf32) : (tensor) -> tensor + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_all_binary_operations( + run_filecheck: Any, +) -> None: + """Test all binary elementwise operations.""" + program = r""" + // CHECK: %[[tf32_1:.*]] = "test.op"() : () -> tensor + %tf32_1 = "test.op"() : () -> tensor + + // CHECK: %[[tf32_2:.*]] = "test.op"() : () -> tensor + %tf32_2 = "test.op"() : () -> tensor + + // CHECK: %[[tf64_1:.*]] = "test.op"() : () -> tensor + %tf64_1 = "test.op"() : () -> tensor + + // CHECK: %[[tf64_2:.*]] = "test.op"() : () -> tensor + %tf64_2 = "test.op"() : () -> tensor + + // CHECK: %complex = "stablehlo.complex"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor> + %complex = "stablehlo.complex"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor> + + // CHECK: %divide = "stablehlo.divide"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %divide = "stablehlo.divide"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + + // CHECK: %maximum = "stablehlo.maximum"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %maximum = "stablehlo.maximum"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + + // CHECK: %minimum = "stablehlo.minimum"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %minimum = "stablehlo.minimum"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + + // CHECK: %power = "stablehlo.power"(%[[tf64_1]], %[[tf64_2]]) : (tensor, tensor) -> tensor + %power = "stablehlo.power"(%tf64_1, %tf64_2) : (tensor, tensor) -> tensor + + // CHECK: %remainder = "stablehlo.remainder"(%[[tf32_1]], %[[tf32_2]]) : (tensor, tensor) -> tensor + %remainder = "stablehlo.remainder"(%tf32_1, %tf32_2) : (tensor, tensor) -> tensor + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_all_other_operations(run_filecheck: Any) -> None: + """Test all other elementwise operations.""" + program = r""" + // CHECK: %[[tf32:.*]] = "test.op"() : () -> tensor + %tf32 = "test.op"() : () -> tensor + + // CHECK: %[[tf64:.*]] = "test.op"() : () -> tensor + %tf64 = "test.op"() : () -> tensor + + // CHECK: %[[ti1:.*]] = "test.op"() : () -> tensor + %ti1 = "test.op"() : () -> tensor + + // CHECK: %clamp = "stablehlo.clamp"(%[[tf32]], %[[tf32]], %[[tf32]]) : (tensor, tensor, tensor) -> tensor + %clamp = "stablehlo.clamp"(%tf32, %tf32, %tf32) : (tensor, tensor, tensor) -> tensor + + // CHECK: %compare = stablehlo.compare EQ, %[[tf32]], %[[tf32]] : (tensor, tensor) -> tensor + %compare = "stablehlo.compare"(%tf32, %tf32) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + + // CHECK: %map = "stablehlo.map"(%[[tf32]], %[[tf32]]) ({ + // CHECK: ^[[bb0:.*]](%arg0 : tensor, %arg1 : tensor): + // CHECK: %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%0) : (tensor) -> () + // CHECK: }) {dimensions = array} : (tensor, tensor) -> tensor + %map = "stablehlo.map"(%tf32, %tf32) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + dimensions = array + } : (tensor, tensor) -> tensor + + // CHECK: %reduce_precision = "stablehlo.reduce_precision"(%[[tf64]]) {exponent_bits = 5 : i32, mantissa_bits = 10 : i32} : (tensor) -> tensor + %reduce_precision = "stablehlo.reduce_precision"(%tf64) {exponent_bits = 5 : i32, mantissa_bits = 10 : i32} : (tensor) -> tensor + + // CHECK: %select = "stablehlo.select"(%[[ti1]], %[[tf32]], %[[tf32]]) : (tensor, tensor, tensor) -> tensor + %select = "stablehlo.select"(%ti1, %tf32, %tf32) : (tensor, tensor, tensor) -> tensor + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_shape_mismatch( + run_filecheck: Any, +) -> None: + """Test that operations with shape mismatches are properly rejected.""" + program = r""" + %tf32_2x3 = "test.op"() : () -> tensor<2x3xf32> + %tf64_3x2 = "test.op"() : () -> tensor<3x2xf64> + + // This should fail verification due to shape mismatch + %convert = "stablehlo.convert"(%tf32_2x3) : (tensor<2x3xf32>) -> tensor<3x2xf64> + """ + + with pytest.raises( + Exception, + match="all non-scalar operands/results must have the same shape and base type", + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_type_mismatch( + run_filecheck: Any, +) -> None: + """Test that operations with type mismatches are properly rejected.""" + program = r""" + %ti32 = "test.op"() : () -> tensor<2x3xi32> + + // This should fail verification due to type mismatch (cosine expects float/complex) + %cos = "stablehlo.cosine"(%ti32) : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises(Exception, match="operand at position 0 does not verify"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_missing_operands( + run_filecheck: Any, +) -> None: + """Test that operations with missing operands are properly rejected.""" + program = r""" + %result = "stablehlo.convert"() : () -> tensor<2x3xf64> + """ + + with pytest.raises(Exception, match="Expected 1 operand"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_trait_verification_failure( + run_filecheck: Any, +) -> None: + """Test that operations that violate trait constraints are properly rejected.""" + program = r""" + %tf32_2x3 = "test.op"() : () -> tensor<2x3xf32> + %tf64_3x2 = "test.op"() : () -> tensor<3x2xf64> + + // This should fail verification due to shape mismatch between operands + %complex = "stablehlo.complex"(%tf32_2x3, %tf64_3x2) : (tensor<2x3xf32>, tensor<3x2xf64>) -> tensor<2x3xcomplex> + """ + + with pytest.raises(Exception, match="requires the same shape"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_ir_operand_result_shape_mismatch( + run_filecheck: Any, +) -> None: + """Test that operations with operand vs result shape mismatches are properly rejected.""" + program = r""" + %tf32_2x3 = "test.op"() : () -> tensor<2x3xf32> + + // This should fail verification due to shape mismatch between operand and result + %convert = "stablehlo.convert"(%tf32_2x3) : (tensor<2x3xf32>) -> tensor<3x2xf64> + """ + + with pytest.raises( + Exception, + match="all non-scalar operands/results must have the same shape and base type", + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_control_flow_operations( + run_filecheck: Any, +) -> None: + """Test the IfOp operation.""" + program = r""" + // Test IfOp: + + // CHECK: %[[pred:.*]] = "test.op"() : () -> tensor + %pred = "test.op"() : () -> tensor + + // CHECK: %[[result:.*]] = "stablehlo.if"(%[[pred]]) ({ + // CHECK: "stablehlo.return"(%[[pred]]) : (tensor) -> () + // CHECK: }, { + // CHECK: "stablehlo.return"(%[[pred]]) : (tensor) -> () + // CHECK: }) : (tensor) -> tensor + %result = "stablehlo.if"(%pred) ({ + "stablehlo.return"(%pred) : (tensor) -> () + }, { + "stablehlo.return"(%pred) : (tensor) -> () + }) : (tensor) -> tensor + + // Test WhileOp: + + // CHECK: %[[init_i:.*]] = "test.op"() : () -> tensor + %init_i = "test.op"() : () -> tensor + + // CHECK: %[[init_sum:.*]] = "test.op"() : () -> tensor + %init_sum = "test.op"() : () -> tensor + + // CHECK: %[[ten:.*]] = "test.op"() : () -> tensor + %ten = "test.op"() : () -> tensor + + // CHECK: %[[one:.*]] = "test.op"() : () -> tensor + %one = "test.op"() : () -> tensor + + // CHECK: %[[results:.*]], %[[results_1:.*]] = "stablehlo.while"(%[[init_i]], %[[init_sum]]) ({ + // CHECK: ^{{.*}}(%[[arg0:.*]] : tensor, %[[arg1:.*]] : tensor): + // CHECK: %[[cond:.*]] = stablehlo.compare LT, %[[arg0]], %[[ten]] : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%[[cond]]) : (tensor) -> () + // CHECK: }, { + // CHECK: ^{{.*}}(%[[arg0_1:.*]] : tensor, %[[arg1_1:.*]] : tensor): + // CHECK: %[[new_sum:.*]] = "stablehlo.add"(%[[arg1_1]], %[[one]]) : (tensor, tensor) -> tensor + // CHECK: %[[new_i:.*]] = "stablehlo.add"(%[[arg0_1]], %[[one]]) : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%[[new_i]], %[[new_sum]]) : (tensor, tensor) -> () + // CHECK: }) : (tensor, tensor) -> (tensor, tensor) + %results:2 = "stablehlo.while"(%init_i, %init_sum) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %cond = "stablehlo.compare"(%arg0, %ten) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%cond) : (tensor) -> () + }, { + ^bb0(%arg0: tensor, %arg1: tensor): + %new_sum = "stablehlo.add"(%arg1, %one) : (tensor, tensor) -> tensor + %new_i = "stablehlo.add"(%arg0, %one) : (tensor, tensor) -> tensor + "stablehlo.return"(%new_i, %new_sum) : (tensor, tensor) -> () + }) : (tensor, tensor) -> (tensor, tensor) + + // Test OptimizationBarrierOp: + + // CHECK: %[[operand:.*]] = "test.op"() : () -> tensor + %operand = "test.op"() : () -> tensor + + // CHECK: %[[result2:.*]] = "stablehlo.optimization_barrier"(%[[operand]]) : (tensor) -> tensor + %result2 = "stablehlo.optimization_barrier"(%operand) : (tensor) -> tensor + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_data_movement_operations( + run_filecheck: Any, +) -> None: + """Test all data movement operations.""" + program = r""" + ////////////////// Setup test operations ////////////////// + // CHECK: %[[input1:.*]] = "test.op"() : () -> tensor<3x2xi64> + %input1 = "test.op"() : () -> tensor<3x2xi64> + + // CHECK: %[[input2:.*]] = "test.op"() : () -> tensor<1x2xi64> + %input2 = "test.op"() : () -> tensor<1x2xi64> + + // CHECK: %[[operand:.*]] = "test.op"() : () -> tensor<2x3x4x2xi32> + %operand = "test.op"() : () -> tensor<2x3x4x2xi32> + + // CHECK: %[[start_indices:.*]] = "test.op"() : () -> tensor<2x2x3x2xi64> + %start_indices = "test.op"() : () -> tensor<2x2x3x2xi64> + + // CHECK: %[[reshape_input:.*]] = "test.op"() : () -> tensor<2xf32> + %reshape_input = "test.op"() : () -> tensor<2xf32> + + // CHECK: %[[scatter_input:.*]] = "test.op"() : () -> tensor<2x3x4x2xi64> + %scatter_input = "test.op"() : () -> tensor<2x3x4x2xi64> + + // CHECK: %[[scatter_indices:.*]] = "test.op"() : () -> tensor<2x2x3x2xi64> + %scatter_indices = "test.op"() : () -> tensor<2x2x3x2xi64> + + // CHECK: %[[scatter_updates:.*]] = "test.op"() : () -> tensor<2x2x3x2x2xi64> + %scatter_updates = "test.op"() : () -> tensor<2x2x3x2x2xi64> + + // CHECK: %[[slice_input:.*]] = "test.op"() : () -> tensor<3x4xi64> + %slice_input = "test.op"() : () -> tensor<3x4xi64> + + // CHECK: %[[broadcast_input:.*]] = "test.op"() : () -> tensor<1x3xi32> + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + ////////////////// Test ConcatenateOp ////////////////// + // CHECK: %concatenate = "stablehlo.concatenate"(%[[input1]], %[[input2]]) <{dimension = 0 : i64}> : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> + %concatenate = "stablehlo.concatenate"(%input1, %input2) {dimension = 0 : i64} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> + + ////////////////// Test GatherOp ////////////////// + // CHECK: %gather = "stablehlo.gather"(%[[operand]], %[[start_indices]]) + // CHECK-SAME: dimension_numbers = #stablehlo.gather< + // CHECK-NEXT: offset_dims = [3, 4], + // CHECK-NEXT: collapsed_slice_dims = [1], + // CHECK-NEXT: operand_batching_dims = [0], + // CHECK-NEXT: start_indices_batching_dims = [1], + // CHECK-NEXT: start_index_map = [2, 1], + // CHECK-NEXT: index_vector_dim = 3 + // CHECK-NEXT: slice_sizes = array, indices_are_sorted = false + %gather = "stablehlo.gather"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32> + + ////////////////// Test ReshapeOp ////////////////// + // CHECK: %reshape = stablehlo.reshape %[[reshape_input]] : (tensor<2xf32>) -> tensor<1x2xf32> + %reshape = "stablehlo.reshape"(%reshape_input) : (tensor<2xf32>) -> tensor<1x2xf32> + + ////////////////// Test ScatterOp ////////////////// + // CHECK: %scatter = "stablehlo.scatter"(%[[scatter_input]], %[[scatter_indices]], %[[scatter_updates]]) + // CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter< + // CHECK-NEXT: update_window_dims = [3, 4], + // CHECK-NEXT: inserted_window_dims = [1], + // CHECK-NEXT: input_batching_dims = [0], + // CHECK-NEXT: scatter_indices_batching_dims = [1], + // CHECK-NEXT: scatter_dims_to_operand_dims = [2, 1], + // CHECK-NEXT: index_vector_dim = 3 + // CHECK-NEXT: indices_are_sorted = false, unique_indices = false + // CHECK-NEXT: ^[[bb0:.*]](%arg0 : tensor, %arg1 : tensor): + // CHECK-NEXT: %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + // CHECK-NEXT: "stablehlo.return"(%0) : (tensor) -> () + %scatter = "stablehlo.scatter"(%scatter_input, %scatter_indices, %scatter_updates) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3, 4], + inserted_window_dims = [1], + input_batching_dims = [0], + scatter_indices_batching_dims = [1], + scatter_dims_to_operand_dims = [2, 1], + index_vector_dim = 3>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> + + ////////////////// Test SliceOp ////////////////// + // CHECK: %slice = "stablehlo.slice"(%[[slice_input]]) + // CHECK-SAME: start_indices = array, + // CHECK-SAME: limit_indices = array, + // CHECK-SAME: strides = array + // CHECK-SAME: : (tensor<3x4xi64>) -> tensor<2x2xi64> + %slice = "stablehlo.slice"(%slice_input) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x4xi64>) -> tensor<2x2xi64> + + ////////////////// Test BroadcastInDimOp ////////////////// + // CHECK: %broadcast = stablehlo.broadcast_in_dim %[[broadcast_input]], dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + %broadcast = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + + ////////////////// Test DynamicSliceOp ////////////////// + // CHECK: %[[dyn_operand:.*]] = "test.op"() : () -> tensor<4x4xi32> + %dyn_operand = "test.op"() : () -> tensor<4x4xi32> + + // CHECK: %[[start0:.*]] = "test.op"() : () -> tensor + %start0 = "test.op"() : () -> tensor + + // CHECK: %[[start1:.*]] = "test.op"() : () -> tensor + %start1 = "test.op"() : () -> tensor + + // CHECK: %dynamic_slice = "stablehlo.dynamic_slice"(%[[dyn_operand]], %[[start0]], %[[start1]]) + // CHECK-SAME: slice_sizes = array + // CHECK-SAME: : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x3xi32> + %dynamic_slice = "stablehlo.dynamic_slice"(%dyn_operand, %start0, %start1) { + slice_sizes = array + } : (tensor<4x4xi32>, tensor, tensor) -> tensor<2x3xi32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_slice_operations( + run_filecheck: Any, +) -> None: + """Test invalid slice operations that should fail verification.""" + program_slice_mismatch = r""" + // CHECK: %input = "test.op"() : () -> tensor<3x8xi64> + %input = "test.op"() : () -> tensor<3x8xi64> + + // This should fail verification due to mismatched array sizes + // CHECK: %slice = "stablehlo.slice"(%input) {start_indices = array, limit_indices = array, strides = array} : (tensor<3x8xi64>) -> tensor<2x2xi64> + %slice = "stablehlo.slice"(%input) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x8xi64>) -> tensor<2x2xi64> + """ + + with pytest.raises( + Exception, + match="all of \\{start_indices, limit_indices, strides\\} must have the same size: got sizes 2, 3, 2", + ): + run_filecheck(program_slice_mismatch, roundtrip=True, verify=True) + + +def test_invalid_slice_element_type_mismatch( + run_filecheck: Any, +) -> None: + """Test that SliceOp rejects mismatched operand/result element types.""" + program = r""" + %slice_input = "test.op"() : () -> tensor<3x4xi64> + // CHECK: %slice_input = "test.op"() : () -> tensor<3x4xi64> + // Mismatched element type: operand is i64, result is f32 + %slice = "stablehlo.slice"(%slice_input) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<3x4xi64>) -> tensor<2x2xf32> + """ + + # Expect verification failure due to element type mismatch + with pytest.raises( + Exception, match="requires the same element type for all operands and results" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_gather_element_type_mismatch( + run_filecheck: Any, +) -> None: + """Test that GatherOp rejects mismatched operand/result element types.""" + program = r""" + %operand = "test.op"() : () -> tensor<2x3x4x2xi32> + %start_indices = "test.op"() : () -> tensor<2x2x3x2xi64> + + // Mismatched element type: operand is i32, result is f32 + %gather_bad = "stablehlo.gather"(%operand, %start_indices) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3, 4], + collapsed_slice_dims = [1], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [2, 1], + index_vector_dim = 3>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xf32> + """ + + # Expect verification failure due to element type mismatch between operand and result + with pytest.raises( + Exception, match=r"all of \{operand, result\} must have the same element type" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_reshape_operations( + run_filecheck: Any, +) -> None: + """Test invalid reshape operations that should fail verification.""" + program_reshape_mismatch = r""" + %reshape_input = "test.op"() : () -> tensor<2xf32> + + // This should fail verification due to element count mismatch (2 != 4) + %reshape_bad = "stablehlo.reshape"(%reshape_input) : (tensor<2xf32>) -> tensor<2x2xf32> + """ + + with pytest.raises(Exception, match="number of output elements"): + run_filecheck(program_reshape_mismatch, roundtrip=True, verify=True) + + +def test_invalid_broadcast_in_dim_operations( + run_filecheck: Any, +) -> None: + """Test invalid broadcast_in_dim operations that should fail verification.""" + # Test dims size mismatch. + program_broadcast_dims_size_mismatch = r""" + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + // dims has size 1, but operand rank is 2 + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + """ + + with pytest.raises( + Exception, match="broadcast_dimensions size .* does not match operand rank" + ): + run_filecheck(program_broadcast_dims_size_mismatch, roundtrip=True, verify=True) + + # Test duplicate dims. + program_broadcast_duplicate_dims = r""" + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + // duplicate entries in broadcast_dimensions are not allowed + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3x2xi32> + """ + + with pytest.raises( + Exception, match="broadcast_dimensions should not have duplicates" + ): + run_filecheck(program_broadcast_duplicate_dims, roundtrip=True, verify=True) + + # Test dim index out of bounds. + program_broadcast_dim_oob = r""" + %broadcast_input = "test.op"() : () -> tensor<1x3xi32> + + // result rank is 2, but dims contains 2 (out of bounds) + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<1x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises(Exception, match="broadcast_dimensions contains invalid value"): + run_filecheck(program_broadcast_dim_oob, roundtrip=True, verify=True) + + # Test operand dim not 1 and not equal to result dim. + program_broadcast_dim_mismatch = r""" + %broadcast_input = "test.op"() : () -> tensor<2x3xi32> + + // operand[0] = 2, result[0] = 4; dims = [0, 2] -> mismatch on dim 0 + %broadcast_bad = "stablehlo.broadcast_in_dim"(%broadcast_input) {broadcast_dimensions = array} : (tensor<2x3xi32>) -> tensor<4x3x2xi32> + """ + + with pytest.raises( + Exception, + match="size of operand dimension .* is not equal to 1 or size of result dimension", + ): + run_filecheck(program_broadcast_dim_mismatch, roundtrip=True, verify=True) + + +def test_dynamism_operations(run_filecheck: Any) -> None: + """Test all dynamism operations.""" + program = r""" + ////////////////// Setup ////////////////// + // CHECK: %[[operand:.*]] = "test.op"() : () -> tensor<1x3xi64> + %operand = "test.op"() : () -> tensor<1x3xi64> + + // CHECK: %[[out_dims:.*]] = "test.op"() : () -> tensor<3xi64> + %out_dims = "test.op"() : () -> tensor<3xi64> + + ////////////////// Test DynamicBroadcastInDimOp ////////////////// + // CHECK: %dynamic_bcast = stablehlo.dynamic_broadcast_in_dim %[[operand]], %[[out_dims]], dims = [2, 1] : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + %dynamic_bcast = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out_dims) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_reduction_operations(run_filecheck: Any) -> None: + """Test all reduction operations.""" + program = r""" + ////////////////// Setup ////////////////// + // CHECK: %[[input:.*]] = "test.op"() : () -> tensor<1x6xi64> + %input = "test.op"() : () -> tensor<1x6xi64> + + // CHECK: %[[init:.*]] = "test.op"() : () -> tensor + %init = "test.op"() : () -> tensor + + ////////////////// Test ReduceOp ////////////////// + // CHECK: %reduce = "stablehlo.reduce"(%[[input]], %[[init]]) <{dimensions = array}> ({ + // CHECK: ^[[bb0:.*]](%arg0 : tensor, %arg1 : tensor): + // CHECK: %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + // CHECK: "stablehlo.return"(%0) : (tensor) -> () + // CHECK: }) : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_reduction_operations( + run_filecheck: Any, +) -> None: + """Test invalid cases for ReduceOp verifier.""" + + # Duplicate dimensions + program_dup_dims = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"dimensions should not have duplicates"): + run_filecheck(program_dup_dims, roundtrip=True, verify=True) + + # Dimension out of range + program_dim_oob = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"dimensions contains an invalid value"): + run_filecheck(program_dim_oob, roundtrip=True, verify=True) + + # Input/init element type mismatch + program_elem_mismatch = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises( + Exception, match=r"input and init_value must have the same element type" + ): + run_filecheck(program_elem_mismatch, roundtrip=True, verify=True) + + # Reducer wrong arity (expects 2 args per input; give 1) + program_wrong_arity = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%acc: tensor): + "stablehlo.return"(%acc) : (tensor) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"reducer must take 2 arguments, got 1"): + run_filecheck(program_wrong_arity, roundtrip=True, verify=True) + + # Reducer arg wrong rank (should be 0D) + program_arg_rank = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor<2xi64>, %arg1: tensor<2xi64>): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64> + "stablehlo.return"(%0) : (tensor<2xi64>) -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises(Exception, match=r"reducer arguments must be rank-0 tensors"): + run_filecheck(program_arg_rank, roundtrip=True, verify=True) + + # Reducer return wrong count + program_return_count = r""" + %input = "test.op"() : () -> tensor<1x6xi64> + %init = "test.op"() : () -> tensor + + %reduce = "stablehlo.reduce"(%input, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + "stablehlo.return"() : () -> () + }) {dimensions = array} : (tensor<1x6xi64>, tensor) -> tensor<1xi64> + """ + + with pytest.raises( + Exception, match=r"reducer must return exactly one value per input" + ): + run_filecheck(program_return_count, roundtrip=True, verify=True) + + +def test_custom_call_basic(run_filecheck: Any) -> None: + """CustomCallOp minimal form without layouts should verify.""" + program = r""" + // CHECK: %[[ARG:.*]] = "test.op"() : () -> tensor<2x3xi32> + %arg = "test.op"() : () -> tensor<2x3xi32> + + // CHECK: %[[RES:.*]] = "stablehlo.custom_call"(%[[ARG]]) + // CHECK-SAME: call_target_name = "foo" + // CHECK-SAME: api_version = #stablehlo + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_with_layouts( + run_filecheck: Any, +) -> None: + """CustomCallOp with matching operand/result layouts should verify.""" + program = r""" + // CHECK: %[[ARG:.*]] = "test.op"() : () -> tensor<2x3xi32> + %arg = "test.op"() : () -> tensor<2x3xi32> + + // CHECK: %[[RES:.*]] = "stablehlo.custom_call"(%[[ARG]]) + // CHECK-SAME: operand_layouts = [dense<[1, 0]> : tensor<2xindex>] + // CHECK-SAME: result_layouts = [dense<[1, 0]> : tensor<2xindex>] + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + result_layouts = [dense<[1, 0]> : tensor<2xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_missing_result_layouts( + run_filecheck: Any, +) -> None: + """Providing only operand_layouts should fail (must provide both or none).""" + program = r""" + %arg = "test.op"() : () -> tensor<2x3xi32> + + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises( + Exception, + match=r"either both operands and results or none", + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_layouts_mismatch( + run_filecheck: Any, +) -> None: + """Number of layouts must match number of operands/results.""" + program = r""" + %arg0 = "test.op"() : () -> tensor<2x3xi32> + %arg1 = "test.op"() : () -> tensor<2x3xi32> + + %res = "stablehlo.custom_call"(%arg0, %arg1) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + result_layouts = [dense<[1, 0]> : tensor<2xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises( + Exception, match=r"Number of operands must match the number of operand layouts" + ): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_incorrect_layout_perm( + run_filecheck: Any, +) -> None: + """Layout must be a permutation of [0, rank).""" + program = r""" + %arg = "test.op"() : () -> tensor<2x3xi32> + + %res = "stablehlo.custom_call"(%arg) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[0]> : tensor<1xindex>], + result_layouts = [dense<[0]> : tensor<1xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tensor<2x3xi32> + """ + + with pytest.raises(Exception, match=r"layout must be a permutation of \[0, 2\)"): + run_filecheck(program, roundtrip=True, verify=True) + + +def test_custom_call_single_tuple_result_with_element_layouts( + run_filecheck: Any, +) -> None: + """Single tuple result with element-wise layouts should verify (common case).""" + program = r""" + // CHECK: %[[ARG0:.*]] = "test.op"() : () -> tensor<2x3xi32> + // CHECK: %[[ARG1:.*]] = "test.op"() : () -> tensor<1xi32> + %arg0 = "test.op"() : () -> tensor<2x3xi32> + %arg1 = "test.op"() : () -> tensor<1xi32> + + // CHECK: %[[RES:.*]] = "stablehlo.custom_call"(%[[ARG0]]) + // CHECK-SAME: call_target_name = "foo" + // CHECK-SAME: api_version = #stablehlo + // CHECK-SAME: operand_layouts = [dense<[1, 0]> : tensor<2xindex>] + // CHECK-SAME: result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>] + %res = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + api_version = #stablehlo, + operand_layouts = [dense<[1, 0]> : tensor<2xindex>], + result_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[0]> : tensor<1xindex>], + output_operand_aliases = [] + } : (tensor<2x3xi32>) -> tuple, tensor<1xi32>> + """ + run_filecheck(program, roundtrip=True, verify=True) + + +def test_invalid_dynamic_broadcast_in_dim_operations( + run_filecheck: Any, +) -> None: + """Test invalid dynamic_broadcast_in_dim cases that should fail verification.""" + + # dims size mismatch (c2) + program_dims_size_mismatch = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<3xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64> + """ + + with pytest.raises( + Exception, + match=r"broadcast_dimensions size \(1\) does not match operand rank \(2\)", + ): + run_filecheck(program_dims_size_mismatch, roundtrip=True, verify=True) + + # result rank < operand rank (c3) + program_result_rank_too_small = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<1xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<1xi64>) -> tensor<3xi64> + """ + + with pytest.raises( + Exception, match=r"result rank \(1\) is less than operand rank \(2\)" + ): + run_filecheck(program_result_rank_too_small, roundtrip=True, verify=True) + + # duplicate dims (c4) + program_duplicate_dims = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3xi64> + """ + + with pytest.raises( + Exception, match=r"broadcast_dimensions should not have duplicates" + ): + run_filecheck(program_duplicate_dims, roundtrip=True, verify=True) + + # dim index out of bounds (c5 bounds) + program_dim_oob = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3xi64> + """ + + with pytest.raises( + Exception, + match=r"broadcast_dimensions contains invalid value 2 for result with rank 2", + ): + run_filecheck(program_dim_oob, roundtrip=True, verify=True) + + # per-dimension size compatibility (c5 compatibility) + program_dim_incompatible = r""" + %operand = "test.op"() : () -> tensor<2x3xi32> + %out = "test.op"() : () -> tensor<3xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<2x3xi32>, tensor<3xi64>) -> tensor<4x3x2xi32> + """ + + with pytest.raises( + Exception, + match=r"size of operand dimension 0 \(2\) is not compatible with size of result dimension 0 \(4\)", + ): + run_filecheck(program_dim_incompatible, roundtrip=True, verify=True) + + # output_dimensions length incompatible with result rank when static (c7) + program_outlen_mismatch = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3x2xi64> + """ + + with pytest.raises( + Exception, + match=r"length of output_dimensions \(2\) is not compatible with result rank \(3\)", + ): + run_filecheck(program_outlen_mismatch, roundtrip=True, verify=True) + + # duplicate expansion hints across both lists (c8) + program_dup_hints = r""" + %operand = "test.op"() : () -> tensor<1x1xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array + } : (tensor<1x1xi64>, tensor<2xi64>) -> tensor<2x1xi64> + """ + + with pytest.raises( + Exception, match=r"duplicate expansion hint for at least one operand dimension" + ): + run_filecheck(program_dup_hints, roundtrip=True, verify=True) + + # hint refers to invalid operand dimension (c9/c10) + program_hint_oob = r""" + %operand = "test.op"() : () -> tensor<1x3xi64> + %out = "test.op"() : () -> tensor<2xi64> + + %bad = "stablehlo.dynamic_broadcast_in_dim"(%operand, %out) { + broadcast_dimensions = array, + known_expanding_dimensions = array + } : (tensor<1x3xi64>, tensor<2xi64>) -> tensor<2x3xi64> + """ + + with pytest.raises( + Exception, + match=r"hint for expanding dimension 5 does not refer to a valid operand dimension", + ): + run_filecheck(program_hint_oob, roundtrip=True, verify=True)