Skip to content

Commit

Permalink
Merge two (18, 3, 512) into a single (18, 8, 512), further reducing l…
Browse files Browse the repository at this point in the history
…atency from 7643 us to 7452 us

PiperOrigin-RevId: 705112984
  • Loading branch information
The Jaxite Team authored and copybara-github committed Dec 11, 2024
1 parent 758c432 commit d9a91a6
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions jaxite/jaxite_lib/polymul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@
def _i32_matmul_unreduced(lhs, rhs):
lax = jax.lax
m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[1]
lhs_i8 = jnp.broadcast_to(lhs, (4, *lhs.shape))
lhs_i8 = jnp.broadcast_to(lhs, (2, *lhs.shape)).reshape((4, m//2, k))
lhs_shift = lax.broadcasted_iota(jnp.int32, lhs_i8.shape, dimension=0) * 8
lhs_i8 = lax.shift_right_logical(lhs_i8, lhs_shift)
lhs_i8 = lax.bitwise_and(lhs_i8, jnp.broadcast_to(0xFF, lhs_i8.shape))
lhs_i8 = lhs_i8.reshape((4 * m, k))
lhs_i8 = lhs_i8.reshape((2 * m, k))

acc = jnp.zeros((4 * m, n), dtype=jnp.int32)
out_shift_base = lax.mul(
lax.div(lax.broadcasted_iota(jnp.int32, (4 * m, n), dimension=0), m), 8
lax.broadcasted_iota(jnp.int32, (4, m//2, n), dimension=0), 8
)
acc = jnp.zeros((m//2, n), dtype=jnp.int32)
for rhs_shift in range(0, 32, 8):
# TODO(b/201562458): Don't multiply lhs rows with large shift.
rhs_i8 = lax.shift_right_logical(
Expand All @@ -53,8 +53,9 @@ def _i32_matmul_unreduced(lhs, rhs):
lhs_i8.astype(jnp.bfloat16),
rhs_i8.astype(jnp.bfloat16),
preferred_element_type=jnp.float32,
).astype(jnp.int32)
acc += jnp.left_shift(raw_out, out_shift_base + rhs_shift)
).astype(jnp.int32).reshape((4, m//2, n))
raw_out = jnp.left_shift(raw_out, out_shift_base + rhs_shift)
acc += raw_out[0] + raw_out[1] + raw_out[2] + raw_out[3]
return acc


Expand All @@ -77,10 +78,11 @@ def _vector_matrix_polymul(poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray):
m = 8
poly_mat2 = jnp.pad(
poly_mat2,
((0, 0), (0, m - real_m), (0, 0)),
((0, 0), (0, (m // 2) - real_m), (0, 0)),
mode="constant",
constant_values=(0,),
)
poly_mat2 = jnp.concatenate((poly_mat2, poly_mat2), axis=(1))

if n % 128 != 0:
raise ValueError(f"Input size {n} is not a multiple of 128")
Expand Down Expand Up @@ -110,7 +112,7 @@ def vec_mat_polymul_kernel_single_batch(vec_ref, mat_ref, out_ref):

assert vec_toeplitz.shape == (n, n)
result = _i32_matmul_unreduced(mat_ref[...], vec_toeplitz)
assert result.shape == (4 * m, n), result.shape
assert result.shape == (m//2, n), result.shape
out_ref[...] = result

def vec_mat_polymul_kernel(vec_ref, mat_ref, out_ref):
Expand All @@ -131,15 +133,15 @@ def vec_mat_polymul_kernel(vec_ref, mat_ref, out_ref):
pl.BlockSpec((block_b, 1, n), lambda b: (b, 0, 0)),
pl.BlockSpec((block_b, m, n), lambda b: (b, 0, 0)),
),
out_specs=pl.BlockSpec((block_b, 4 * m, n), lambda b: (b, 0, 0)),
out_shape=jax.ShapeDtypeStruct((b, 4 * m, n), jnp.int32),
out_specs=pl.BlockSpec((block_b, m//2, n), lambda b: (b, 0, 0)),
out_shape=jax.ShapeDtypeStruct((b, m//2, n), jnp.int32),
grid=(steps_b,),
)(
poly_vec1[:, None].astype(jnp.int32), poly_mat2.astype(jnp.int32)
).reshape(
b, 4, m, n
b, m//2, n
),
axis=(0, 1),
axis=(0, ),
).astype(jnp.uint32)[:real_m]


Expand Down

0 comments on commit d9a91a6

Please sign in to comment.