Skip to content

NotImplementedError: unreduced rule for transpose is not implemented #31895

@batterseapower

Description

@batterseapower

Description

This fails with NotImplementedError: unreduced rule for transpose is not implemented. Please file an issue at https://github.com/jax-ml/jax/issues:

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"

import dataclasses

import jax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec


@dataclasses.dataclass
class JAXComms:
    needs_comms: bool
    output_sharding: NamedSharding


def get_hlo(fn, *args: jax.Array, **kwargs) -> str:
    lowered = fn.lower(*args, **kwargs)
    compiled = lowered.compile()
    return compiled.as_text()


def get_hlo_for_einsum(
    shapes: list[tuple[int, ...]],
    specs: list[PartitionSpec],
    op: str,
    out_spec: PartitionSpec | None = None,
) -> str:
    devices = mesh_utils.create_device_mesh((2, 2))
    mesh = Mesh(
        devices,
        axis_names=("X", "Y"),
        axis_types=(AxisType.Auto if out_spec is None else AxisType.Explicit,) * 2,
    )
    with mesh:
        args = [
            jax.device_put(jnp.ones(shape), NamedSharding(mesh, spec))
            for shape, spec in zip(shapes, specs, strict=True)
        ]

        @jax.jit
        def f(*args):
            return jnp.einsum(
                op,
                *args,
                out_sharding=None
                if out_spec is None
                else NamedSharding(mesh, out_spec),
            )

        return get_hlo(f, *args)


print(
    get_hlo_for_einsum(
        shapes=[(6, 4, 8), (4, 8, 10)],  # [H, E, B] @ [E, B, A]
        specs=[
            PartitionSpec(None, "X", "Y"),  # [H, E{X}, B{Y}]
            PartitionSpec("X", "Y", None),  # [E{X}, B{Y}, A]
        ],
        op="heb,eba->hea",  # Reduce over B
        out_spec=PartitionSpec(None, "X", None, unreduced=frozenset(["Y"])),
    )
)

Full traceback:

  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 317, in jit_lower
    return jit_trace(jit_func, *args, **kwargs).lower()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 310, in jit_trace
    p, args_flat = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 635, in _infer_params
    return _infer_params_internal(fun, ji, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 659, in _infer_params_internal
    p, args_flat = _infer_params_impl(
                   ^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 556, in _infer_params_impl
    jaxpr, consts, out_avals = _create_pjit_jaxpr(
                               ^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 504, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 1343, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/profiler.py", line 364, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2387, in trace_to_jaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 429, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/root/src/tree/op-simulator/temp2.py", line 45, in f
    return jnp.einsum(
           ^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/numpy/einsum.py", line 318, in einsum
    return jit_einsum(operand_arrays, contractions, precision,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 268, in cache_miss
    executable, pgle_profiler, const_args) = _python_pjit_helper(
                                             ^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 137, in _python_pjit_helper
    p, args_flat = _infer_params(fun, jit_info, args, kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 635, in _infer_params
    return _infer_params_internal(fun, ji, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 659, in _infer_params_internal
    p, args_flat = _infer_params_impl(
                   ^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 556, in _infer_params_impl
    jaxpr, consts, out_avals = _create_pjit_jaxpr(
                               ^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 504, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/pjit.py", line 1343, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/profiler.py", line 364, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2387, in trace_to_jaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 212, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/api_util.py", line 318, in _argnames_partial
    return _fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/api_util.py", line 292, in _argnums_partial
    return _fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/linear_util.py", line 429, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/numpy/einsum.py", line 588, in _einsum
    operand = lax.transpose(operand, perm)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 2971, in transpose
    return transpose_p.bind(operand, permutation=permutation)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/core.py", line 634, in bind
    return self._true_bind(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/core.py", line 650, in _true_bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/core.py", line 662, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2109, in process_primitive
    return self.default_process_primitive(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2123, in default_process_primitive
    out_avals, effs = _cached_abstract_eval(primitive, *aval_qdds, **params)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/util.py", line 460, in wrapper
    return cached_call(_multi_weakref_placeholder,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/util.py", line 444, in cache_miss
    return call(*orig_args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 1940, in _cached_abstract_eval
    return primitive.abstract_eval(*aval_qdds, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/core.py", line 704, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/lax/utils.py", line 161, in standard_abstract_eval
    out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule(
                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/lax/utils.py", line 108, in call_shape_dtype_sharding_rule
    out_shardings = call_sharding_rule(
                    ^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/lax/utils.py", line 98, in call_sharding_rule
    out_sharding = call_unreduced_rule(prim, unreduced_rule, out_sharding,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.13/lib/python3.11/site-packages/jax/_src/lax/utils.py", line 77, in call_unreduced_rule
    raise NotImplementedError(

System info (python version, jaxlib version, accelerator, etc.)

learning/45eac/tfrc/runtime/libtpu_init_utils.cc:285
jax:    0.7.2
jaxlib: 0.7.2
numpy:  2.3.2
python: 3.11.13 (main, Sep 10 2025, 12:53:42) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='coder-maxb-maxb-0', release='6.1.119-129.201.amzn2023.x86_64', version='#1 SMP PREEMPT_DYNAMIC Tue Dec  3 21:07:35 UTC 2024', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions