Skip to content

Conversation

copybara-service[bot]
Copy link

[Mosaic] Canonicalize integer tpu::UnpackSubelementsOp with sign_extended=false if it's used by pack that reduces the bitwidth.

This happens in retiling (8, 128) <-> (8 * packing, 128) and (1, 128) <-> (packing, 128), where it unpacks x -> 32 and packs 32 -> x.

// (8,128) <-> (8 * packing,128) tiling change for packed type.
if (ctx.hardware_generation >= 4 && bitwidth < 32 && 32 % bitwidth == 0 &&
((src.tiling() == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
ctx.target_shape[1]}) ||
(dst_tiling == ctx.target_shape &&
src.tiling() == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
ctx.target_shape[1]}))) {
// Note: for int4, retiling with scratch is always faster.
if (bitwidth != 4 || !has_enough_scratch) {
FAILUREOR_ASSIGN_OR_RETURN(std::tie(src, vregs),
unpack_vregs(src, vregs, ctx.target_shape));
return pack_vregs(src, vregs, dst_tiling, dst_offsets_hint);
}
}
// Handle retiling from/to (1, 128 * packing) to/from (packing, 128) for
// packed data.
// TODO(tlongeri): Interleaved unpacking followed by interleaved
// packing (but with different pairings) might also be
// interesting if the next step is a retile, since we can also
// match corresponding elements without shifting. It's just that
// the tiles are not adjacent (no contiguous vreg slice).
if (bitwidth < 32 && 32 % bitwidth == 0 &&
((src.tiling() ==
std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) ||
(src.tiling() == std::array<int64_t, 2>{packing, ctx.target_shape[1]} &&
dst_tiling ==
std::array<int64_t, 2>{1, ctx.target_shape[1] * packing}))) {
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, vregs),
unpack_vregs(src, vregs, {1, ctx.target_shape[1]}));
return pack_vregs(src, vregs, dst_tiling, dst_offsets_hint);
}

…xtended=false` if it's used by pack that reduces the bitwidth.

This happens in retiling (8, 128) <-> (8 * packing, 128) and (1, 128) <-> (packing, 128), where it unpacks x -> 32 and packs 32 -> x.

https://github.com/jax-ml/jax/blob/44b50826f40fecd8d8d426fba9d4e7b2c9335ab3/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc#L8362-L8396

PiperOrigin-RevId: 820691897
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant