-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
Description
Description
#!/usr/bin/env python3
"""
Minimal reproducer for JAX bug where compiled.output_shardings.spec returns
an empty PartitionSpec() instead of a PartitionSpec with the correct dimensions.
Bug Description:
When compiling a JAX function that performs a matmul with sharded inputs,
the compiled.output_shardings.spec attribute returns an empty PartitionSpec()
instead of a PartitionSpec with elements corresponding to each output dimension.
"""
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec
# Create a 2x2 device mesh
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=("X", "Y"))
print("=" * 60)
print("CASE 1: Working case - simple sharding without reduction")
print("=" * 60)
with mesh:
# Simple case where sharding works correctly
a_simple = jnp.ones((8, 16))
b_simple = jnp.ones((16, 32))
# Shard with no reduction dimension conflict
a_simple_sharding = NamedSharding(mesh, PartitionSpec("X", None))
b_simple_sharding = NamedSharding(mesh, PartitionSpec(None, "Y"))
a_simple = jax.device_put(a_simple, a_simple_sharding)
b_simple = jax.device_put(b_simple, b_simple_sharding)
@jax.jit
def matmul_simple(a, b):
return a @ b
compiled_simple = matmul_simple.lower(a_simple, b_simple).compile()
output_simple = matmul_simple(a_simple, b_simple)
print(f"Output shape: {output_simple.shape}")
print(f"compiled.output_shardings.spec: {compiled_simple.output_shardings.spec}")
print(f"Spec length: {len(compiled_simple.output_shardings.spec)} (expected: {len(output_simple.shape)})")
if len(compiled_simple.output_shardings.spec) == len(output_simple.shape):
print("✓ PASS: Spec length matches output shape")
else:
print("✗ FAIL: Spec length does NOT match output shape")
print()
print("=" * 60)
print("CASE 2: Failing case - sharding on reduction dimension")
print("=" * 60)
with mesh:
# Create input arrays for matmul: (8, 16) @ (16, 32) -> (8, 32)
a = jnp.ones((8, 16))
b = jnp.ones((16, 32))
# Shard inputs on the reduction dimension
# This causes the bug: reduction dimension (dim 1 of a, dim 0 of b) is sharded on X
a_sharding = NamedSharding(mesh, PartitionSpec(None, "X"))
b_sharding = NamedSharding(mesh, PartitionSpec("X", None))
a = jax.device_put(a, a_sharding)
b = jax.device_put(b, b_sharding)
@jax.jit
def matmul(a, b):
return a @ b # or jnp.einsum("ab,bc->ac", a, b)
# Compile the function
compiled = matmul.lower(a, b).compile()
# Execute to get the output
output = matmul(a, b)
# The bug:
print(f"Output shape: {output.shape}") # (8, 32) - 2D array
print(f"compiled.output_shardings.spec: {compiled.output_shardings.spec}") # PartitionSpec() - empty!
print(f"Spec length: {len(compiled.output_shardings.spec)} (expected: {len(output.shape)})")
if len(compiled.output_shardings.spec) == len(output.shape):
print("✓ PASS: Spec length matches output shape")
else:
print("✗ FAIL: Spec length does NOT match output shape")
print()
print("=" * 60)
print("SUMMARY")
print("=" * 60)
print("The bug occurs when the reduction dimension of a matmul is sharded.")
print("In Case 1, the inputs are sharded on non-reduction dimensions and it works.")
print("In Case 2, the reduction dimension is sharded and output_shardings.spec is empty.")
print()
# This assertion demonstrates the bug
assert len(compiled.output_shardings.spec) == len(output.shape), \
f"BUG: Expected spec with {len(output.shape)} elements, got {len(compiled.output_shardings.spec)}"
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.7.2.dev20250910
jaxlib: 0.7.2.dev20250910
numpy: 2.3.2
python: 3.11.13 (main, Aug 16 2025, 02:17:37) [GCC 11.4.0]
device info: cpu-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='maxb-devbox-0', release='5.10.236-227.928.amzn2.x86_64', version='#1 SMP Sat Apr 19 16:54:57 UTC 2025', machine='x86_64')
JAX_PLATFORMS=cpu
XLA_FLAGS=--xla_force_host_platform_device_count=4