Skip to content

output_shardings of compiled function has wrong dimensionality #31888

@batterseapower

Description

@batterseapower

Description

#!/usr/bin/env python3
"""
Minimal reproducer for JAX bug where compiled.output_shardings.spec returns
an empty PartitionSpec() instead of a PartitionSpec with the correct dimensions.

Bug Description:
When compiling a JAX function that performs a matmul with sharded inputs,
the compiled.output_shardings.spec attribute returns an empty PartitionSpec()
instead of a PartitionSpec with elements corresponding to each output dimension.
"""

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec

# Create a 2x2 device mesh
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=("X", "Y"))

print("=" * 60)
print("CASE 1: Working case - simple sharding without reduction")
print("=" * 60)

with mesh:
    # Simple case where sharding works correctly
    a_simple = jnp.ones((8, 16))
    b_simple = jnp.ones((16, 32))

    # Shard with no reduction dimension conflict
    a_simple_sharding = NamedSharding(mesh, PartitionSpec("X", None))
    b_simple_sharding = NamedSharding(mesh, PartitionSpec(None, "Y"))

    a_simple = jax.device_put(a_simple, a_simple_sharding)
    b_simple = jax.device_put(b_simple, b_simple_sharding)

    @jax.jit
    def matmul_simple(a, b):
        return a @ b

    compiled_simple = matmul_simple.lower(a_simple, b_simple).compile()
    output_simple = matmul_simple(a_simple, b_simple)

    print(f"Output shape: {output_simple.shape}")
    print(f"compiled.output_shardings.spec: {compiled_simple.output_shardings.spec}")
    print(f"Spec length: {len(compiled_simple.output_shardings.spec)} (expected: {len(output_simple.shape)})")

    if len(compiled_simple.output_shardings.spec) == len(output_simple.shape):
        print("✓ PASS: Spec length matches output shape")
    else:
        print("✗ FAIL: Spec length does NOT match output shape")

print()
print("=" * 60)
print("CASE 2: Failing case - sharding on reduction dimension")
print("=" * 60)

with mesh:
    # Create input arrays for matmul: (8, 16) @ (16, 32) -> (8, 32)
    a = jnp.ones((8, 16))
    b = jnp.ones((16, 32))

    # Shard inputs on the reduction dimension
    # This causes the bug: reduction dimension (dim 1 of a, dim 0 of b) is sharded on X
    a_sharding = NamedSharding(mesh, PartitionSpec(None, "X"))
    b_sharding = NamedSharding(mesh, PartitionSpec("X", None))

    a = jax.device_put(a, a_sharding)
    b = jax.device_put(b, b_sharding)

    @jax.jit
    def matmul(a, b):
        return a @ b  # or jnp.einsum("ab,bc->ac", a, b)

    # Compile the function
    compiled = matmul.lower(a, b).compile()

    # Execute to get the output
    output = matmul(a, b)

    # The bug:
    print(f"Output shape: {output.shape}")  # (8, 32) - 2D array
    print(f"compiled.output_shardings.spec: {compiled.output_shardings.spec}")  # PartitionSpec() - empty!
    print(f"Spec length: {len(compiled.output_shardings.spec)} (expected: {len(output.shape)})")

    if len(compiled.output_shardings.spec) == len(output.shape):
        print("✓ PASS: Spec length matches output shape")
    else:
        print("✗ FAIL: Spec length does NOT match output shape")

print()
print("=" * 60)
print("SUMMARY")
print("=" * 60)
print("The bug occurs when the reduction dimension of a matmul is sharded.")
print("In Case 1, the inputs are sharded on non-reduction dimensions and it works.")
print("In Case 2, the reduction dimension is sharded and output_shardings.spec is empty.")
print()

# This assertion demonstrates the bug
assert len(compiled.output_shardings.spec) == len(output.shape), \
    f"BUG: Expected spec with {len(output.shape)} elements, got {len(compiled.output_shardings.spec)}"

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.7.2.dev20250910
jaxlib: 0.7.2.dev20250910
numpy:  2.3.2
python: 3.11.13 (main, Aug 16 2025, 02:17:37) [GCC 11.4.0]
device info: cpu-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='maxb-devbox-0', release='5.10.236-227.928.amzn2.x86_64', version='#1 SMP Sat Apr 19 16:54:57 UTC 2025', machine='x86_64')
JAX_PLATFORMS=cpu
XLA_FLAGS=--xla_force_host_platform_device_count=4

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtype:Bug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions