Skip to content

Unnecessary all_reduce introduced to unary einsum #31889

@batterseapower

Description

@batterseapower

Description

An einsum reducing ab->b with only the a dimension sharded, and an explicit unreduced annotation on the result, emits an all-reduce. However, the more complex scaled reduction ab,a->b does not require an all_reduce. No communication should be necessary in either case because we can reduce 1/2 of the a axis purely locally and then mark the result as having an unreduced mesh axis.

Because the complex case is working but the simple case is not, my guess is that there is some special case code handling the simple case which hasn't been updated to know about unreduced?

"""
Minimal reproducer showing unnecessary all-reduce in JAX einsum with unreduced dimensions.

This demonstrates a clear bug by contrasting two nearly identical cases:
1. Simple sum (ab->b): INCORRECTLY introduces all-reduce
2. Scaled sum (ab,a->b): CORRECTLY avoids all-reduce

Both should behave the same way with unreduced dimensions, but they don't!
"""

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec

# Enable shardy partitioner
jax.config.update("jax_use_shardy_partitioner", True)


def check_for_all_reduce(fn, *args) -> bool:
    """Check if a JAX function introduces all-reduce operations."""
    lowered = fn.lower(*args)
    compiled = lowered.compile()
    hlo_text = compiled.as_text()
    return " all-reduce(" in hlo_text


# Create 2x2 device mesh
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(
    devices, axis_names=("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit)
)

print("CONTRASTING TEST CASES: Bug in Simple Sum vs Correct Scaled Sum")
print("=" * 70)
print()
print("Setup:")
print("  - Device mesh: 2x2 (axes X, Y)")
print("  - Matrix: 4x8 sharded on X axis (each X-device has 2 rows)")
print("  - Vector: 4 elements sharded on X axis (for scaled sum)")
print("  - Both operations sum over the first dimension")
print("  - Both specify unreduced=['X'] to keep partial sums local")
print()

with mesh:
    # Shared input: 4x8 matrix, sharded along X axis
    matrix = jax.device_put(
        jnp.ones((4, 8)), NamedSharding(mesh, PartitionSpec("X", None))
    )

    # Additional input for scaled sum: 4-element vector, also sharded on X
    vector = jax.device_put(jnp.ones((4,)), NamedSharding(mesh, PartitionSpec("X")))

    # CASE 1: Simple sum (ab->b) - THE BUG
    @jax.jit
    def simple_sum_unreduced(arr):
        """Simple sum over rows with unreduced X."""
        return jnp.einsum(
            "ab->b",
            arr,
            out_sharding=NamedSharding(
                mesh, PartitionSpec(None, unreduced=frozenset(["X"]))
            ),
        )

    # CASE 2: Scaled sum (ab,a->b) - WORKS CORRECTLY
    @jax.jit
    def scaled_sum_unreduced(arr, vec):
        """Scaled sum over rows with unreduced X."""
        return jnp.einsum(
            "ab,a->b",
            arr,
            vec,
            out_sharding=NamedSharding(
                mesh, PartitionSpec(None, unreduced=frozenset(["X"]))
            ),
        )

    # Check both cases
    simple_has_allreduce = check_for_all_reduce(simple_sum_unreduced, matrix)
    scaled_has_allreduce = check_for_all_reduce(scaled_sum_unreduced, matrix, vector)

    print("CASE 1: Simple Sum (ab->b)")
    print("-" * 40)
    print("  Operation: Sum matrix rows to get vector")
    print("  Inputs: Matrix sharded on X")
    print("  Output spec: PartitionSpec(None, unreduced=['X'])")
    print(
        f"  Has all-reduce: {simple_has_allreduce} {'❌ BUG!' if simple_has_allreduce else '✓'}"
    )
    print()

    print("CASE 2: Scaled Sum (ab,a->b)")
    print("-" * 40)
    print("  Operation: Weighted sum of matrix rows")
    print("  Inputs: Matrix sharded on X, vector sharded on X")
    print("  Output spec: PartitionSpec(None, unreduced=['X'])")
    print(
        f"  Has all-reduce: {scaled_has_allreduce} {'❌' if scaled_has_allreduce else '✓ CORRECT!'}"
    )
    print()

    print("ANALYSIS: Why Case 1 is Clearly a Bug")
    print("=" * 70)
    print()
    print("Both operations are fundamentally the same:")
    print("  1. Both reduce over dimension 'a' (rows)")
    print("  2. Both have inputs sharded on X axis")
    print("  3. Both specify unreduced=['X'] in output")
    print("  4. Both can compute partial sums locally without communication")
    print()
    print("The ONLY difference:")
    print("  - Simple sum: multiplies each row by implicit 1.0")
    print("  - Scaled sum: multiplies each row by explicit weight")
    print()
    print("Since the scaled sum correctly avoids all-reduce, the simple sum")
    print("should too. The fact that it doesn't is a clear compiler bug.")
    print()

    # Verify correctness
    simple_result = simple_sum_unreduced(matrix)
    scaled_result = scaled_sum_unreduced(matrix, vector)

    print("Output verification:")
    print(f"  Simple sum shape: {simple_result.shape}")
    print(f"  Scaled sum shape: {scaled_result.shape}")
    print()

    assert simple_has_allreduce and not scaled_has_allreduce

HLO for simple_sum:

HloModule jit_simple_sum_unreduced, is_scheduled=true, entry_computation_layout={(f32[2,8]{1,0})->f32[8]{0}}, num_partitions=4

%region_0.0 (reduce_sum.0: f32[], reduce_sum.1: f32[]) -> f32[] {
  %reduce_sum.0 = f32[] parameter(0), metadata={op_name="jit(simple_sum_unreduced)/ab->b/reduce_sum"}
  %reduce_sum.1 = f32[] parameter(1), metadata={op_name="jit(simple_sum_unreduced)/ab->b/reduce_sum"}
  ROOT %reduce_sum.2 = f32[] add(%reduce_sum.0, %reduce_sum.1), metadata={op_name="jit(simple_sum_unreduced)/ab->b/reduce_sum" source_file="/root/src/tree/op-simulator/sandbox/sandbox/maxb/simulator_sandbox/model2/minimal_reproducer.py" source_line=66 source_end_line=66 source_column=15 source_end_column=15}
}

%region_0.0.clone (reduce_sum.7: f32[], reduce_sum.12: f32[]) -> f32[] {
  %reduce_sum.7 = f32[] parameter(0), metadata={op_name="jit(simple_sum_unreduced)/ab->b/reduce_sum"}
  %reduce_sum.12 = f32[] parameter(1), metadata={op_name="jit(simple_sum_unreduced)/ab->b/reduce_sum"}
  ROOT %reduce_sum.13 = f32[] add(%reduce_sum.7, %reduce_sum.12), metadata={op_name="jit(simple_sum_unreduced)/ab->b/reduce_sum" source_file="/root/src/tree/op-simulator/sandbox/sandbox/maxb/simulator_sandbox/model2/minimal_reproducer.py" source_line=66 source_end_line=66 source_column=15 source_end_column=15}
}

ENTRY %main.0_spmd (param: f32[2,8]) -> f32[8] {
  %param = f32[2,8]{1,0} parameter(0), sharding={devices=[2,1,2]<=[4] last_tile_dim_replicate}, metadata={op_name="arr"}
  %constant.2 = f32[] constant(0)
  %reduce = f32[8]{0} reduce(%param, %constant.2), dimensions={0}, to_apply=%region_0.0, metadata={op_name="jit(simple_sum_unreduced)/ab->b/reduce_sum" source_file="/root/src/tree/op-simulator/sandbox/sandbox/maxb/simulator_sandbox/model2/minimal_reproducer.py" source_line=66 source_end_line=66 source_column=15 source_end_column=15}
  ROOT %all-reduce = f32[8]{0} all-reduce(%reduce), channel_id=1, replica_groups=[2,2]<=[2,2]T(1,0), use_global_device_ids=true, to_apply=%region_0.0.clone, metadata={op_name="jit(simple_sum_unreduced)/ab->b/reduce_sum" source_file="/root/src/tree/op-simulator/sandbox/sandbox/maxb/simulator_sandbox/model2/minimal_reproducer.py" source_line=66 source_end_line=66 source_column=15 source_end_column=15}
}

HLO for scaled_sum:

HloModule jit_scaled_sum_unreduced, is_scheduled=true, entry_computation_layout={(f32[2,8]{1,0}, f32[2]{0})->f32[8]{0}}, allow_spmd_sharding_propagation_to_parameters={false,false}, num_partitions=4

ENTRY %main.0_spmd (param: f32[2,8], param.1: f32[2]) -> f32[8] {
  %param = f32[2,8]{1,0} parameter(0), sharding={devices=[2,1,2]<=[4] last_tile_dim_replicate}, metadata={op_name="arr"}
  %param.1 = f32[2]{0} parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate}, metadata={op_name="vec"}
  ROOT %dot = f32[8]{0} dot(%param, %param.1), lhs_contracting_dims={0}, rhs_contracting_dims={0}, metadata={op_name="jit(scaled_sum_unreduced)/ab,a->b/dot_general" source_file="/root/src/tree/op-simulator/sandbox/sandbox/maxb/simulator_sandbox/model2/minimal_reproducer.py" source_line=78 source_end_line=78 source_column=15 source_end_column=15}
}

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.7.2
jaxlib: 0.7.2
numpy:  2.3.2
python: 3.11.13 (main, Sep 10 2025, 12:53:42) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='coder-maxb-maxb-0', release='6.1.119-129.201.amzn2023.x86_64', version='#1 SMP PREEMPT_DYNAMIC Tue Dec  3 21:07:35 UTC 2024', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtype:Bug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions