diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index a7062f2e1c..7604cadf7c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -212,12 +212,15 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: assert isinstance(node.args[0], itir.FunCall) first_axis_literal = node.args[0].args[0] assert isinstance(first_axis_literal, itir.AxisLiteral) - if first_axis_literal.kind == itir.DimensionKind.VERTICAL: - assert len(node.args) == 2 - assert isinstance(node.args[1], itir.FunCall) - assert isinstance(node.args[1].args[0], itir.AxisLiteral) - assert node.args[1].args[0].kind == itir.DimensionKind.HORIZONTAL - return itir.FunCall(fun=node.fun, args=[node.args[1], node.args[0]]) + if len(node.args) <= 2: + if len(node.args) == 2 and first_axis_literal.kind == itir.DimensionKind.VERTICAL: + assert isinstance(node.args[1], itir.FunCall) + assert isinstance(node.args[1].args[0], itir.AxisLiteral) + assert node.args[1].args[0].kind == itir.DimensionKind.HORIZONTAL + return itir.FunCall(fun=node.fun, args=[node.args[1], node.args[0]]) + return node + else: + raise NotImplementedError("Only up to two dimensional domains are supported.") return node @classmethod diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 647a94c8a2..bc0d84f34c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -64,6 +64,14 @@ def testee(a: cases.IJKField) -> cases.IJKField: cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) +def test_copy_vertical(unstructured_case_3d): + @gtx.field_operator + def testee(a: cases.KField) -> cases.KField: + return a + + cases.verify_with_default_data(unstructured_case_3d, testee, ref=lambda a: a) + + @pytest.mark.uses_tuple_returns def test_multicopy(cartesian_case): @gtx.field_operator