Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class SymbolicRange:
stop: itir.Expr

def translate(self, distance: int) -> SymbolicRange:
return SymbolicRange(im.plus(self.start, distance), im.plus(self.stop, distance))
start = im.plus(self.start, distance)
# TODO(tehrengruber): temporary solution to avoid oob-access without concat_where
start = im.call("maximum")(0, start)
return SymbolicRange(start, im.plus(self.stop, distance))


@dataclasses.dataclass
Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,15 @@ def apply(
if not uids:
uids = eve_utils.UIDGenerator()

return cls(uids=uids, enabled_transformations=enabled_transformations).visit(
new_node = cls(uids=uids, enabled_transformations=enabled_transformations).visit(
node, within_set_at_expr=within_set_at_expr
)
new_node = type_inference.infer(
new_node,
offset_provider_type=offset_provider_type,
allow_undeclared_symbols=allow_undeclared_symbols,
)
return new_node

def transform_fuse_make_tuple(self, node: itir.Node, **kwargs):
if not cpm.is_call_to(node, "make_tuple"):
Expand Down Expand Up @@ -429,6 +435,7 @@ def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs):
return None

def transform_inline_let_vars_opcount_preserving(self, node: itir.Node, **kwargs):
return None
# when multiple `as_fieldop` calls are fused that use the same argument, this argument
# might become referenced once only. In order to be able to continue fusing such arguments
# try inlining here.
Expand Down
6 changes: 4 additions & 2 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ def _transform_by_pattern(
# able to eliminate all tuples, e.g., by propagating the scalar ifs to the top-level
# of a SetAt, the CollapseTuple pass will eliminate most of this cases.
if isinstance(domain, tuple):
flattened_domains: tuple[domain_utils.SymbolicDomain] = (
next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough
flattened_domains: tuple[domain_utils.SymbolicDomain] = tuple(
domain
for domain in next_utils.flatten_nested_tuple(domain)
if domain is not infer_domain.DomainAccessDescriptor.NEVER # type: ignore[assignment] # mypy not smart enough
)
if not all(d == flattened_domains[0] for d in flattened_domains):
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from gt4py.next.iterator.transforms import trace_shifts
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
from gt4py.next.iterator.transforms.inline_lifts import InlineLifts
from gt4py.next.iterator.type_system import inference as type_inference


def is_center_derefed_only(node: itir.Node) -> bool:
Expand Down Expand Up @@ -95,17 +96,21 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
for i, (param, arg) in enumerate(zip(node.fun.params, node.args)):
if cpm.is_applied_lift(arg) and is_center_derefed_only(param):
eligible_params[i] = True
bound_arg_evaluator = self.uids.sequential_id(prefix="_icdlv")
capture_lift = im.promote_to_const_iterator(im.call(bound_arg_evaluator)())
trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift)
new_args.append(capture_lift)
# since we deref an applied lift here we can (but don't need to) immediately
# inline
evaluators[bound_arg_evaluator] = im.lambda_()(
bound_arg_evaluator_name = self.uids.sequential_id(prefix="__icdlv")
bound_arg_evaluator = im.lambda_()(
InlineLifts(flags=InlineLifts.Flag.INLINE_DEREF_LIFT).visit(
im.deref(arg), recurse=False
)
)
capture_lift = im.promote_to_const_iterator(
im.call(im.ref(bound_arg_evaluator_name, bound_arg_evaluator.type))()
)
trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift)

new_args.append(capture_lift)
# since we deref an applied lift here we can (but don't need to) immediately
# inline
evaluators[bound_arg_evaluator_name] = bound_arg_evaluator
else:
new_args.append(arg)

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec) -> bool:
is_compatible &= is_compatible_type(arg_a, arg_b)
is_compatible &= is_compatible_type(type_a.returns, type_b.returns)
else:
is_compatible &= is_concretizable(type_a, type_b)
is_compatible &= is_concretizable(type_a, type_b) or is_concretizable(type_b, type_a)

return is_compatible

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def test_laplace(offset_provider):
im.deref(im.shift("Joff", -1)("arg0")),
)
)
domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)})
expected_domains = {"in_field1": {IDim: (-1, 12), JDim: (-1, 8)}}
domain = im.domain(common.GridType.CARTESIAN, {IDim: (1, 11), JDim: (1, 7)})
expected_domains = {"in_field1": {IDim: (0, 12), JDim: (0, 8)}}

testee, expected = setup_test_as_fieldop(stencil, domain)
run_test_expr(testee, expected, domain, expected_domains, offset_provider)
Expand All @@ -238,10 +238,10 @@ def test_shift_x_y_two_inputs(offset_provider):
im.deref(im.shift("Joff", 1)("arg1")),
)
)
domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)})
domain = im.domain(common.GridType.CARTESIAN, {IDim: (1, 11), JDim: (0, 7)})
expected_domains = {
"in_field1": {IDim: (-1, 10), JDim: (0, 7)},
"in_field2": {IDim: (0, 11), JDim: (1, 8)},
"in_field1": {IDim: (0, 10), JDim: (0, 7)},
"in_field2": {IDim: (1, 11), JDim: (1, 8)},
}
testee, expected = setup_test_as_fieldop(
stencil,
Expand All @@ -257,9 +257,9 @@ def test_shift_x_y_two_inputs_literal(offset_provider):
im.deref(im.shift("Joff", 1)("arg1")),
)
)
domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)})
domain = im.domain(common.GridType.CARTESIAN, {IDim: (1, 11), JDim: (0, 7)})
expected_domains = {
"in_field1": {IDim: (-1, 10), JDim: (0, 7)},
"in_field1": {IDim: (0, 10), JDim: (0, 7)},
}
testee, expected = setup_test_as_fieldop(
stencil,
Expand All @@ -279,11 +279,11 @@ def test_shift_x_y_z_three_inputs(offset_provider):
im.deref(im.shift("Koff", -1)("arg2")),
)
)
domain_dict = {IDim: (0, 11), JDim: (0, 7), KDim: (0, 3)}
domain_dict = {IDim: (0, 11), JDim: (0, 7), KDim: (1, 3)}
expected_domains = {
"in_field1": {IDim: (1, 12), JDim: (0, 7), KDim: (0, 3)},
"in_field2": {IDim: (0, 11), JDim: (1, 8), KDim: (0, 3)},
"in_field3": {IDim: (0, 11), JDim: (0, 7), KDim: (-1, 2)},
"in_field1": {IDim: (1, 12), JDim: (0, 7), KDim: (1, 3)},
"in_field2": {IDim: (0, 11), JDim: (1, 8), KDim: (1, 3)},
"in_field3": {IDim: (0, 11), JDim: (0, 7), KDim: (0, 2)},
}
testee, expected = setup_test_as_fieldop(
stencil,
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_nested_stencils(offset_provider):
tmp = im.as_fieldop(inner_stencil)(im.ref("in_field1"), im.ref("in_field2"))
testee = im.as_fieldop(stencil)(im.ref("in_field1"), tmp)

domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (0, 7)})
domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11), JDim: (1, 7)})
domain_inner = translate_domain(domain, {"Ioff": 0, "Joff": -1}, offset_provider)

expected_inner = im.as_fieldop(inner_stencil, domain_inner)(
Expand All @@ -338,7 +338,7 @@ def test_nested_stencils(offset_provider):
expected = im.as_fieldop(stencil, domain)(im.ref("in_field1"), expected_inner)

expected_domains = {
"in_field1": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (-1, 7)}),
"in_field1": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12), JDim: (0, 7)}),
"in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider),
}
actual_call, actual_domains = infer_domain.infer_expr(
Expand Down Expand Up @@ -515,9 +515,9 @@ def test_cond(offset_provider):

testee = im.if_(cond, field_1, field_2)

domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})
domain = im.domain(common.GridType.CARTESIAN, {IDim: (2, 13)})
domain_tmp = translate_domain(domain, {"Ioff": -1}, offset_provider)
expected_domains_dict = {"in_field1": {IDim: (0, 12)}, "in_field2": {IDim: (-2, 12)}}
expected_domains_dict = {"in_field1": {IDim: (2, 14)}, "in_field2": {IDim: (0, 14)}}
expected_tmp2 = im.as_fieldop(tmp_stencil2, domain_tmp)(
im.ref("in_field1"), im.ref("in_field2")
)
Expand Down Expand Up @@ -731,7 +731,7 @@ def test_nested_let_args(offset_provider):
),
)(premap_field("inner", "Ioff", -1))

domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})
domain = im.domain(common.GridType.CARTESIAN, {IDim: (2, 11)})
domain_m1 = translate_domain(domain, {"Ioff": -1}, offset_provider)
domain_m2 = translate_domain(domain, {"Ioff": -2}, offset_provider)

Expand All @@ -754,9 +754,9 @@ def test_program_let(offset_provider):
let_tmp = im.let("inner", premap_field("outer", "Ioff", -1))(premap_field("inner", "Ioff", -1))
as_fieldop = im.as_fieldop(stencil_tmp)(im.ref("tmp"))

domain_lm2_rm1 = im.domain(common.GridType.CARTESIAN, {IDim: (-2, 10)})
domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})
domain_lm1 = im.domain(common.GridType.CARTESIAN, {IDim: (-1, 11)})
domain_lm2_rm1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)})
domain = im.domain(common.GridType.CARTESIAN, {IDim: (2, 11)})
domain_lm1 = im.domain(common.GridType.CARTESIAN, {IDim: (1, 11)})

params = [im.sym(name) for name in ["in_field", "out_field", "outer"]]

Expand Down
Loading