-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
This fails with NotImplementedError: unreduced rule for transpose is not implemented. Please file an issue at https://github.com/jax-ml/jax/issues
:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import dataclasses
import jax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec
@dataclasses.dataclass
class JAXComms:
needs_comms: bool
output_sharding: NamedSharding
def get_hlo(fn, *args: jax.Array, **kwargs) -> str:
lowered = fn.lower(*args, **kwargs)
compiled = lowered.compile()
return compiled.as_text()
def get_hlo_for_einsum(
shapes: list[tuple[int, ...]],
specs: list[PartitionSpec],
op: str,
out_spec: PartitionSpec | None = None,
) -> str:
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(
devices,
axis_names=("X", "Y"),
axis_types=(AxisType.Auto if out_spec is None else AxisType.Explicit,) * 2,
)
with mesh:
args = [
jax.device_put(jnp.ones(shape), NamedSharding(mesh, spec))
for shape, spec in zip(shapes, specs, strict=True)
]
@jax.jit
def f(*args):
return jnp.einsum(
op,
*args,
out_sharding=None
if out_spec is None
else NamedSharding(mesh, out_spec),
)
return get_hlo(f, *args)
print(
get_hlo_for_einsum(
shapes=[(6, 4, 8), (4, 8, 10)], # [H, E, B] @ [E, B, A]
specs=[
PartitionSpec(None, "X", "Y"), # [H, E{X}, B{Y}]
PartitionSpec("X", "Y", None), # [E{X}, B{Y}, A]
],
op="heb,eba->hea", # Reduce over B
out_spec=PartitionSpec(None, "X", None, unreduced=frozenset(["Y"])),
)
)
Full traceback:
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 317, in jit_lower
return jit_trace(jit_func, *args, **kwargs).lower()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 310, in jit_trace
p, args_flat = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 635, in _infer_params
return _infer_params_internal(fun, ji, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 659, in _infer_params_internal
p, args_flat = _infer_params_impl(
^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 556, in _infer_params_impl
jaxpr, consts, out_avals = _create_pjit_jaxpr(
^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 504, in memoized_fun
ans = call(fun, *args)
^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 1343, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/profiler.py", line 364, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2387, in trace_to_jaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped
return self.f_transformed(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
ans = f(*py_args, **py_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 429, in _get_result_paths_thunk
ans = _fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/src/tree/op-simulator/temp2.py", line 45, in f
return jnp.einsum(
^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/numpy/einsum.py", line 318, in einsum
return jit_einsum(operand_arrays, contractions, precision,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 268, in cache_miss
executable, pgle_profiler, const_args) = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 137, in _python_pjit_helper
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 635, in _infer_params
return _infer_params_internal(fun, ji, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 659, in _infer_params_internal
p, args_flat = _infer_params_impl(
^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 556, in _infer_params_impl
jaxpr, consts, out_avals = _create_pjit_jaxpr(
^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 504, in memoized_fun
ans = call(fun, *args)
^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 1343, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/profiler.py", line 364, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2387, in trace_to_jaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped
return self.f_transformed(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
ans = f(*py_args, **py_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/api_util.py", line 318, in _argnames_partial
return _fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/api_util.py", line 292, in _argnums_partial
return _fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 429, in _get_result_paths_thunk
ans = _fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/numpy/einsum.py", line 588, in _einsum
operand = lax.transpose(operand, perm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 2971, in transpose
return transpose_p.bind(operand, permutation=permutation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/core.py", line 634, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/core.py", line 650, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/core.py", line 662, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2109, in process_primitive
return self.default_process_primitive(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2123, in default_process_primitive
out_avals, effs = _cached_abstract_eval(primitive, *aval_qdds, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/util.py", line 460, in wrapper
return cached_call(_multi_weakref_placeholder,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/util.py", line 444, in cache_miss
return call(*orig_args, **orig_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 1940, in _cached_abstract_eval
return primitive.abstract_eval(*aval_qdds, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/core.py", line 704, in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/lax/utils.py", line 161, in standard_abstract_eval
out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/lax/utils.py", line 108, in call_shape_dtype_sharding_rule
out_shardings = call_sharding_rule(
^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/lax/utils.py", line 98, in call_sharding_rule
out_sharding = call_unreduced_rule(prim, unreduced_rule, out_sharding,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/lax/utils.py", line 77, in call_unreduced_rule
raise NotImplementedError(
System info (python version, jaxlib version, accelerator, etc.)
learning/45eac/tfrc/runtime/libtpu_init_utils.cc:285
jax: 0.7.2
jaxlib: 0.7.2
numpy: 2.3.2
python: 3.11.13 (main, Sep 10 2025, 12:53:42) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='coder-maxb-maxb-0', release='6.1.119-129.201.amzn2023.x86_64', version='#1 SMP PREEMPT_DYNAMIC Tue Dec 3 21:07:35 UTC 2024', machine='x86_64')
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working