Skip to content
17 changes: 8 additions & 9 deletions src/jace/translator/jaxpr_translator_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ class JaxprTranslationBuilder:
- it lacks the special `__return` variable,
- the `arg_names` parameter is not set,
- for all scalar values a `Scalar` SDFG variable is used, thus they cannot
be used for return values,
be used for returning values,
- for every transient there is exactly one access node that writes to it,
except the name of the array starts with `__jace_mutable_`, which can
be written to multiple times.
except if the name of the array starts with `__jace_mutable_`, in which case
it can be written to multiple times.

For these reasons the SDFG is not directly usable, and further manipulations
have to be performed. Especially, DaCe's validation function will fail and
Expand Down Expand Up @@ -502,7 +502,8 @@ def _allocate_translation_ctx(
@property
def _ctx(self) -> TranslationContext:
"""Returns the currently active translation context."""
assert len(self._ctx_stack) != 0, "No context is active."
if not self.is_allocated():
raise RuntimeError("The context is not allocated.")
return self._ctx_stack[-1]

def _clear_translation_ctx(self) -> TranslationContext | None:
Expand Down Expand Up @@ -580,10 +581,9 @@ def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None:
prev_terminal_state,
new_sdfg_term_state,
)
self._ctx.validate()

# Modify terminal root state of 'self'
self._ctx.terminal_state = new_sdfg_term_state
self._ctx.validate()

def _translate_jaxpr_internal(self, jaxpr: jax_core.ClosedJaxpr) -> TranslationContext:
"""
Expand Down Expand Up @@ -712,7 +712,7 @@ def _propagate_memlets_in_new_states(
]

while nodes_to_process:
currently_processing = nodes_to_process.pop(-1)
currently_processing = nodes_to_process.pop()
if (
self.sdfg.out_degree(currently_processing) == 0
and currently_processing != new_terminal_state
Expand Down Expand Up @@ -790,7 +790,7 @@ def __init__(self, name: str | None, jaxpr: jax_core.ClosedJaxpr) -> None:
self.terminal_state = self.start_state
self.jaxpr = jaxpr

def validate(self) -> bool:
def validate(self) -> None:
"""
Validate internal state of `self`.

Expand Down Expand Up @@ -829,4 +829,3 @@ def validate(self) -> bool:
self.sdfg,
None,
)
return True
86 changes: 43 additions & 43 deletions src/jace/translator/mapped_operation_base_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# SPDX-License-Identifier: BSD-3-Clause

"""Module containing all translators related to arithmetic logical operations."""
"""Module implementing the `MappedOperationTranslatorBase` helper class."""

from __future__ import annotations

Expand Down Expand Up @@ -37,11 +37,11 @@ class MappedOperationTranslatorBase(translator.PrimitiveTranslator):
```
where `__in*` are the connector names of the Tasklet and `__out` is the
output connector. For problems such as this, the SDFG API provides the
`SDFGState.add_mapped_tasklet()` function, however, in most cases it can not
be directly used, for various reasons. Thus this class acts like a
convenience wrapper around it.
`SDFGState.add_mapped_tasklet()` function. However, because the function
operates on a very low level and is very verbose to use, this class acts
as a convenience wrapper around it.

To use this class a user has to overwrite the `write_tasklet_code()` function.
To use this class a user has to define the abstract `write_tasklet_code()` method.
This function generates the entire code that should be put into the Tasklet,
include the assignment to `__out`. If needed the translator will perform
literal substitution on the returned code and broadcast the inputs to match
Expand All @@ -51,7 +51,7 @@ class MappedOperationTranslatorBase(translator.PrimitiveTranslator):
to generate custom input Memlets, such as adding an offset.

Args:
primitive_name: The name of the primitive `self` should bind to.
primitive_name: The name of the primitive `self` should bind to.

Note:
This class will always generate a mapped Tasklet, even if a scalar is handled.
Expand All @@ -78,7 +78,7 @@ def __call__(
"""
Create the mapped Tasklet.

The function will create the map ranges and based on the shape of the
The function will create the map ranges based on the shape of the
output array. It will then call `make_input_memlets()` to get the input
Memlets. After that it calls `write_tasklet_code()` to get the Tasklet
code and perform literal substitution by forwarding it to
Expand All @@ -88,7 +88,7 @@ def __call__(
For a description of the arguments see `PrimitiveTranslatorCallable`.
"""
assert len(out_var_names) == 1
if util.get_jax_var_shape(eqn.outvars[0]) != ():
if util.get_jax_var_shape(eqn.outvars[0]):
tskl_ranges: list[tuple[str, str]] = [
(f"__i{dim}", f"0:{N}")
for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0]))
Expand Down Expand Up @@ -130,20 +130,20 @@ def write_tasklet_code(
eqn: jax_core.JaxprEqn,
) -> str:
"""
Return the (Python) code that should be put inside the Tasklet.
Return the Python code that should be put inside the Tasklet.

This also includes the assignment statement, i.e. `__out`.
However, the base will do literal substitution on the returned object.

Args:
tskl_ranges: List of pairs used as map parameter, first element
tskl_ranges: List of pairs used as map parameter, first element
is the name iteration index of the dimension, second is its range.
in_var_names: The list of SDFG variables used as input, `None` if literal.
eqn: The equation.
in_var_names: The list of SDFG variables used as input, `None` if literal.
eqn: The equation.
"""
...

def make_input_memlets( # noqa: PLR6301 # Subclasses might need them.
def make_input_memlets( # noqa: PLR6301 [no-self-use] # Subclasses might need them.
self,
tskl_ranges: Sequence[tuple[str, str]],
in_var_names: Sequence[str | None],
Expand All @@ -156,13 +156,13 @@ def make_input_memlets( # noqa: PLR6301 # Subclasses might need them.
that is used to connect it to the Map entry node.

Args:
tskl_ranges: List of pairs used as map parameter, first element
tskl_ranges: List of pairs used as map parameter, first element
is the name iteration index of the dimension, second is its range
in_var_names: The list of SDFG variables used as input, `None` if literal.
eqn: The equation object.
in_var_names: The list of SDFG variables used as input, `None` if literal.
eqn: The equation object.
"""
out_shp = tuple(util.get_jax_var_shape(eqn.outvars[0])) # Shape of the output
out_rank = len(out_shp)
out_shape = tuple(util.get_jax_var_shape(eqn.outvars[0]))
out_rank = len(out_shape)
if any(len(util.get_jax_var_shape(invar)) not in {0, out_rank} for invar in eqn.invars):
raise NotImplementedError(
f"'MappedOperationTranslatorBase' Inputs must have the same rank as the output! "
Expand All @@ -171,44 +171,44 @@ def make_input_memlets( # noqa: PLR6301 # Subclasses might need them.

# Now we will generate the input Memlets.
tskl_inputs: dict[str, dace.Memlet] = {}
for i, (in_var_name, inp_shp) in enumerate(
for i, (in_var_name, in_shape) in enumerate(
zip(in_var_names, (util.get_jax_var_shape(invar) for invar in eqn.invars))
):
if in_var_name is None: # Input is a literal: No Memlet needed
continue

if inp_shp == (): # Scalars
tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0") # Scalar
continue

# We have to to broadcasting (combine yes and no together)
dims_to_bcast: Sequence[int] = [dim for dim in range(out_rank) if inp_shp[dim] == 1]
tskl_inputs[f"__in{i}"] = dace.Memlet.simple(
in_var_name,
", ".join(
("0" if i in dims_to_bcast else it_var)
for i, (it_var, _) in enumerate(tskl_ranges)
),
)
if in_var_name is None:
pass

elif in_shape == ():
tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0")

else:
dims_to_bcast = [
dim for dim in range(out_rank) if in_shape[dim] == 1 and out_shape[dim] != 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional suggestion:

Suggested change
dim for dim in range(out_rank) if in_shape[dim] == 1 and out_shape[dim] != 1
dim for dim in range(out_rank) if in_shape[dim] == 1 != out_shape[dim]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know that this is possible, however, I do not think that it is very readable the a < x < b is useful but here I do not like it that much,

]
tskl_inputs[f"__in{i}"] = dace.Memlet.simple(
in_var_name,
", ".join(
("0" if i in dims_to_bcast else it_var)
for i, (it_var, _) in enumerate(tskl_ranges)
),
)
return tskl_inputs

def literal_substitution( # noqa: PLR6301 # Subclasses might need it.
def literal_substitution( # noqa: PLR6301 [no-self-use] # Subclasses might need it.
self, tskl_code: str, in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn
) -> str:
"""
Perform literal substitution on the proto Tasklet code `tskl_code`.

Args:
tskl_code: The proto Tasklet code with literal.
in_var_names: The list of SDFG variables used as input.
eqn: The equation.
tskl_code: The proto Tasklet code with literal.
in_var_names: The list of SDFG variables used as input.
eqn: The equation.

Note:
It is allowed but not recommended to override this function.
"""
for i, in_var_name in enumerate(in_var_names):
if in_var_name is not None:
continue
t_val = util.get_jax_literal_value(eqn.invars[i])
tskl_code = tskl_code.replace(f"__in{i}", str(t_val))
if in_var_name is None:
t_val = util.get_jax_literal_value(eqn.invars[i])
tskl_code = tskl_code.replace(f"__in{i}", str(t_val))
return tskl_code
51 changes: 49 additions & 2 deletions src/jace/translator/post_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

if TYPE_CHECKING:
from dace.sdfg import nodes as dace_nodes
from jax import core as jax_core

from jace import translator

Expand Down Expand Up @@ -271,7 +272,8 @@ def add_nested_sdfg(
will first pass it to `finalize_translation_context()` and operates on the
return values. This means that `child_ctx` will be modified in place, and
a copy will be added to `parent_ctx`.
It is highly recommended that `state` is empty.
It is highly recommended that `state` is empty, this makes subsequent
inlining of the nested SDFG simpler.
"""
if child_ctx.sdfg.free_symbols:
raise NotImplementedError("Symbol Mapping is not implemented.")
Expand All @@ -298,7 +300,6 @@ def add_nested_sdfg(
nested_sdfg: dace_nodes.NestedSDFG = state.add_nested_sdfg(
sdfg=final_child_ctx.sdfg,
parent=parent_ctx.sdfg,
# Bug in DaCe must be a set.
inputs=set(final_child_ctx.input_names),
outputs=set(final_child_ctx.output_names),
)
Expand Down Expand Up @@ -326,3 +327,49 @@ def add_nested_sdfg(
)

return nested_sdfg


def promote_literals_to_constants(
builder: translator.JaxprTranslationBuilder,
var_names: Sequence[str | None],
jax_vars: Sequence[jax_core.Atom],
name_pattern: str,
) -> list[str]:
"""
Promotes all literals in `var_names` to DaCe constants and add them to the SDFG.

The function assumes that `var_names` are the SDFG variables equivalents of
`jax_vars`, as by convention `None` indicates a literal. The function will create
a constant for each literal and return `var_names` cleared of all literals.
For naming the variables the function will use `name_pattern`.

Args:
builder: The builder that is used for translation.
var_names: Names of the SDFG variables, `None` indicates a literal.
jax_vars: The JAX variables, in the same order than `var_names`.
name_pattern: A pattern to generate a unique name for the variables.

Todo:
Is a constant the right idea or should we generate a symbol?
"""
promoted_var_names: list[str] = []
for i, var_name in enumerate(var_names):
if var_name is None:
promoted_var_name = f"__const_{name_pattern}_literal_promotion_{i}"
jax_var = jax_vars[i]
promoted_jace_var = util.JaCeVar.from_atom(
jax_var=jax_var,
name=promoted_var_name,
)
builder.add_array(promoted_jace_var)
builder.sdfg.add_constant(
promoted_var_name,
util.get_jax_literal_value(jax_var),
builder.arrays[promoted_var_name],
)

else:
# Already an SDFG variable, so nothing to do.
promoted_var_name = var_name
promoted_var_names.append(promoted_var_name)
return promoted_var_names
2 changes: 1 addition & 1 deletion src/jace/translator/primitive_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __call__(
Args:
builder: The builder object of the translation.
in_var_names: List of the names of the arrays created inside the
SDFG for the inpts or `None` in case of a literal.
SDFG for the inputs or `None` in case of a literal.
out_var_names: List of the names of the arrays created inside the
SDFG for the outputs.
eqn: The JAX primitive that should be translated.
Expand Down
25 changes: 13 additions & 12 deletions src/jace/translator/primitive_translators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,34 @@
LogicalOperationTranslator,
)
from .broadcast_in_dim_translator import BroadcastInDimTranslator
from .concatenate_translator import ConcatenateTranslator
from .concatenate_translator import concatenate_translator
from .conditions import condition_translator
from .convert_element_type_translator import ConvertElementTypeTranslator
from .copy_translator import CopyTranslator, DevicePutTranslator
from .gather_translator import GatherTranslator
from .copy_translator import copy_translator, device_put_translator
from .gather_translator import gather_translator
from .iota_translator import IotaTranslator
from .pjit_translator import PJITTranslator
from .reshape_translator import ReshapeTranslator
from .pjit_translator import pjit_translator
from .reshape_translator import reshape_translator
from .select_n_translator import SelectNTranslator
from .slicing import SlicingTranslator
from .slicing import SlicingTranslator, dynamic_slicing_translator
from .squeeze_translator import SqueezeTranslator


__all__ = [
"ArithmeticOperationTranslator",
"BroadcastInDimTranslator",
"ConcatenateTranslator",
"ConvertElementTypeTranslator",
"CopyTranslator",
"DevicePutTranslator",
"GatherTranslator",
"IotaTranslator",
"LogicalOperationTranslator",
"PJITTranslator",
"ReshapeTranslator",
"SelectNTranslator",
"SlicingTranslator",
"SqueezeTranslator",
"concatenate_translator",
"condition_translator",
"copy_translator",
"device_put_translator",
"dynamic_slicing_translator",
"gather_translator",
"pjit_translator",
"reshape_translator",
]
Loading