Skip to content

Commit

Permalink
Merge pull request #10711 from mattjj:closed-call
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 448720052
  • Loading branch information
jax authors committed May 14, 2022
2 parents ba0a2b3 + 05dda56 commit 86899ee
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 15 deletions.
21 changes: 18 additions & 3 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import namedtuple
from contextlib import contextmanager
import functools
from functools import partialmethod, total_ordering
from functools import partial, partialmethod, total_ordering
import gc
import itertools as it
import operator
Expand Down Expand Up @@ -207,8 +207,6 @@ def replace(self, *args, **kwargs):
return self._replace(*args, **kwargs)

def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
if primitive.call_primitive:
assert len(outvars) == len(params["call_jaxpr"].outvars)
source_info = source_info or source_info_util.new_source_info()
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)

Expand Down Expand Up @@ -1822,6 +1820,17 @@ def call_impl(f: lu.WrappedFun, *args, **params):
named_call_p.def_impl(call_impl)


class ClosedCallPrimitive(CallPrimitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
return [subfun], new_params

closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
closed_call_p.def_impl(call_impl)


outfeed_primitives: Set[Primitive] = set()
def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool:
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
Expand Down Expand Up @@ -2169,6 +2178,12 @@ class JaxprTypeError(TypeError): pass

custom_typechecks: Dict[Primitive, Callable] = {}

def _check_closed_call(*in_avals, call_jaxpr):
if list(in_avals) != list(call_jaxpr.in_avals):
raise JaxprTypeError("Closed call in_avals mismatch")
return call_jaxpr.out_avals, call_jaxpr.effects
custom_typechecks[closed_call_p] = _check_closed_call

def check_jaxpr(jaxpr: Jaxpr):
"""Checks well-formedness of a jaxpr.
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"reduce_precision",
"schur",
"name",
"closed_call",
"unreachable",
"bint",
"getslice",
Expand Down
10 changes: 10 additions & 0 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,16 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[core.named_call_p] = \
partial(call_transpose, core.named_call_p)


def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals, reduce_axes):
jaxpr_, consts = jaxpr.jaxpr, jaxpr.consts
jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_)
return call_transpose(core.closed_call_p, params, jaxpr_, (*consts, *args),
ct, cts_in_avals, reduce_axes)
primitive_transposes[core.closed_call_p] = _closed_call_transpose


def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
Expand Down
8 changes: 5 additions & 3 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,13 +986,13 @@ def f_lowered(ctx, *args, **params):

def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
avals_out, tokens_in, *args):
if isinstance(call_jaxpr, core.Jaxpr):
call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
xla.check_backend_matches(backend, ctx.platform)
output_types = map(aval_to_ir_types, avals_out)
flat_output_types = util.flatten(output_types)
effects = tokens_in.effects()
symbol_name = lower_jaxpr_to_fun(ctx, fn_name,
core.ClosedJaxpr(call_jaxpr, ()),
effects).name.value
symbol_name = lower_jaxpr_to_fun(ctx, fn_name, call_jaxpr, effects).name.value
args = [*tokens_in.tokens(), *args]
call = func_dialect.CallOp(flat_output_types,
ir.FlatSymbolRefAttr.get(symbol_name),
Expand Down Expand Up @@ -1024,6 +1024,8 @@ def _named_call_lowering(ctx, *args, name, backend=None,

register_lowering(core.named_call_p, _named_call_lowering)
register_lowering(core.call_p, partial(_named_call_lowering, name="core_call"))
register_lowering(core.closed_call_p,
partial(_named_call_lowering, name="core_closed_call"))


def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
Expand Down
30 changes: 25 additions & 5 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,17 @@ def process_call(self, primitive, f, tracers, params):
unknown_arg_tracers = [t for t in tracers if not t.is_known()]
# Adjust parameters (e.g. donated_invars) for the staged-out call's args.
num_new_args = len(const_tracers) + len(env_tracers)
staged_params = update_params(params, map(op.not_, in_knowns), num_new_args)
staged_params = dict(staged_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
staged_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
staged_params = update_params(staged_params, map(op.not_, in_knowns),
num_new_args)
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
for a in out_avals]
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params, jaxpr.effects, source)
out_tracers, primitive, staged_params, jaxpr.effects,
source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)

Expand Down Expand Up @@ -511,6 +513,12 @@ def partial_eval_wrapper_nounits(
call_partial_eval_rules: Dict[Primitive, Callable] = {}
call_param_updaters: Dict[Primitive, Callable] = {}

def _closed_call_param_updater(params, _, __):
jaxpr = params.get('call_jaxpr')
if jaxpr is None: return params
assert type(jaxpr) is core.Jaxpr
return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ()))
call_param_updaters[core.closed_call_p] = _closed_call_param_updater

def abstract_eval_fun(fun, *avals, debug_info=None, **params):
_, avals_out, _ = trace_to_jaxpr_dynamic(
Expand Down Expand Up @@ -666,8 +674,6 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
# assert len(invars) == len(params["call_jaxpr"].invars) # TODO constvars?
assert len(out_tracers) == len(params["call_jaxpr"].outvars)
assert ("donated_invars" not in params or
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
if primitive.map_primitive:
Expand Down Expand Up @@ -1254,6 +1260,20 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
dce_rules[remat_call_p] = dce_jaxpr_call_rule


def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn
) -> Tuple[List[bool], JaxprEqn]:
# TODO(mattjj): de-duplicate with above rule?
jaxpr_ = eqn.params['call_jaxpr']
jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts
new_jaxpr, used_inputs = dce_jaxpr(jaxpr, used_outputs)
new_params = dict(eqn.params, call_jaxpr=core.ClosedJaxpr(new_jaxpr, consts))
new_eqn = new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule

def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
) -> ClosedJaxpr:
"""Reorder `invars` by moving those indicated in `to_move` to the front."""
Expand Down
4 changes: 0 additions & 4 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,3 @@ def f(*args, **kw):
"Add an MLIR (MHLO) lowering via jax.interpreters.mlir "
"instead.")
return f


ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose,
core.named_call_p)
10 changes: 10 additions & 0 deletions tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def core_call(f, *args):
out = core.call_p.bind(f, *args)
return tree_unflatten(out_tree(), out)

@util.curry
def core_closed_call(f, *args):
args, in_tree = tree_flatten(args)
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
out = core.closed_call_p.bind(f, *args)
return tree_unflatten(out_tree(), out)

def simple_fun(x, y):
return jnp.sin(x * y)

Expand Down Expand Up @@ -147,6 +154,9 @@ def jvp_unlinearized(f, primals, tangents):
test_specs.append(CallSpec(core_call(ts.fun), ts.args))
test_specs.append(CallSpec(core_call(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_call(core_call(ts.fun)), ts.args))
test_specs.append(CallSpec(core_closed_call(ts.fun), ts.args))
test_specs.append(CallSpec(core_closed_call(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_closed_call(core_closed_call(ts.fun)), ts.args))
test_specs.append(CallSpec(partial(jvp_unlinearized, ts.fun),
(ts.args, ts.args)))

Expand Down

0 comments on commit 86899ee

Please sign in to comment.