Skip to content

Commit 4105764

Browse files
Ryan McKennaOptaxDev
authored andcommitted
Add optional "in_axes" and "argnames" kwargs to microbatch, and define a simple vmap wrapper around it.
PiperOrigin-RevId: 823690163
1 parent d75a5ec commit 4105764

File tree

2 files changed

+140
-37
lines changed

2 files changed

+140
-37
lines changed

optax/experimental/_microbatching.py

Lines changed: 103 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,33 @@ def _identity(value: Any) -> Any:
7474
return value
7575

7676

77-
def reshape_batch_axis(pytree: Any, microbatch_size: int):
78-
"""Reshape pytree leaves to shape (num_microbatches, microbatch_size, ...)."""
79-
# If data is sharded along the 0th axis, using column-major order is important
80-
# to ensure that each microbatch is sharded in the same manner.
81-
# For example, if the data was sharded across 2 devices, each device would
82-
# handle one of the examples in each microbatch.
83-
# [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] --> [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]
84-
85-
return jax.tree.map(
86-
lambda x: x.reshape(-1, microbatch_size, *x.shape[1:], order='F'),
87-
pytree,
88-
)
77+
def reshape_batch_axis(tree: Any, microbatch_size: int, axis: int = 0):
78+
"""Reshape batch axis of pytree leaves for use with microbatching.
79+
80+
This function reshapes the batch axis of each leaf into a shape
81+
(num_microbatches, microbatch_size) appearing at the same axis as the original
82+
batch axis. The reshape is done using a column-major order, so any sharding
83+
along the batch axis should be preserved in the new `microbatch_size` axis,
84+
while the new `num_microbatches` axis will generally be replicated.
85+
86+
Args:
87+
tree: A pytree of jax.Arrays, each having a batch axis.
88+
microbatch_size: The size of sub-batches used for each microbatch.
89+
axis: The axis to reshape.
90+
91+
Returns:
92+
A pytree of reshaped jax.Arrays.
93+
"""
94+
95+
def leaf_fn(x):
96+
shape = x.shape
97+
batch_size = shape[axis]
98+
if batch_size % microbatch_size != 0:
99+
raise ValueError(f'{batch_size=} not divisible by {microbatch_size=}')
100+
new_shape = shape[:axis] + (-1, microbatch_size) + shape[axis + 1:]
101+
return x.reshape(new_shape, order='F')
102+
103+
return jax.tree.map(leaf_fn, tree)
89104

90105

91106
def _lift(accumulator: Accumulator) -> Accumulator:
@@ -127,13 +142,14 @@ def finalize(carry):
127142

128143
def aggregate(values):
129144
return jax.tree.map(
130-
lambda acc, val: acc.accumulate(val), accumulators, values
145+
lambda acc, val: acc.aggregate(val), accumulators, values
131146
)
132147

133148
return Accumulator(init, update, finalize, aggregate)
134149

135150

136151
def _sum() -> Accumulator:
152+
"""An Accumulator that computes the sum of microbatched outputs."""
137153
return _lift(
138154
Accumulator(
139155
init=_identity,
@@ -145,6 +161,7 @@ def _sum() -> Accumulator:
145161

146162

147163
def _mean(num_microbatches: int) -> Accumulator:
164+
"""An Accumulator that computes the mean of microbatched outputs."""
148165
return _lift(
149166
Accumulator(
150167
init=_with_floating_check(_identity),
@@ -156,6 +173,7 @@ def _mean(num_microbatches: int) -> Accumulator:
156173

157174

158175
def _running_mean() -> Accumulator:
176+
"""An Accumulator that computes the running mean of microbatched outputs."""
159177
def update(carry, value, index):
160178
p = index / (index + 1)
161179
new_state = carry * p + value * (1 - p)
@@ -172,8 +190,11 @@ def update(carry, value, index):
172190

173191

174192
def _concat(num_microbatches: int) -> Accumulator:
193+
"""An Accumulator that concatenates microbatched outputs along the axis 0."""
175194
def init(value):
176-
return jnp.broadcast_to(value, (num_microbatches,) + value.shape)
195+
shape = (num_microbatches,) + value.shape
196+
zeros = jnp.broadcast_to(jnp.zeros_like(value), shape)
197+
return zeros.at[0].set(value)
177198

178199
def update(carry, value, index):
179200
return carry.at[index].set(value)
@@ -217,14 +238,49 @@ def fun(acc):
217238
return _compose(jax.tree.map(fun, tree))
218239

219240

220-
_DEFAULT = AccumulationType.SUM
241+
def _reshape_all_args(
242+
microbatch_size: int,
243+
argnums: Sequence[int],
244+
argnames: Sequence[str],
245+
in_axes: Sequence[int],
246+
args: tuple[Any, ...],
247+
kwargs: dict[str, Any]
248+
) -> tuple[tuple[Any, ...], dict[str, Any], int]:
249+
"""Reshapes all batch arguments to have a microbatch axis."""
250+
new_args = list(args)
251+
new_kwargs = dict(kwargs)
252+
batch_args = [args[i] for i in argnums] + [kwargs[i] for i in argnames]
253+
254+
batch_sizes = jax.tree.flatten(jax.tree.map(
255+
lambda ax, subtree: jax.tree.map(lambda x: x.shape[ax], subtree),
256+
tuple(in_axes), tuple(batch_args)
257+
))[0]
258+
259+
if len(set(batch_sizes)) > 1:
260+
raise ValueError(
261+
'Batch Arguments must have the same shape along the batch axis, found'
262+
f' multiple batch sizes: {batch_sizes}'
263+
)
264+
265+
for i, ax in zip(argnums, in_axes):
266+
new_args[i] = reshape_batch_axis(args[i], microbatch_size, ax)
267+
268+
for name, ax in zip(argnames, in_axes[len(argnums) :]):
269+
new_kwargs[name] = reshape_batch_axis(kwargs[name], microbatch_size, ax)
270+
271+
return tuple(new_args), new_kwargs, tuple(batch_sizes)[0]
221272

222273

223274
def microbatch(
224275
fun: Callable[..., Any],
225276
argnums: int | Sequence[int],
226277
microbatch_size: int | None,
227-
accumulator: Accumulator | AccumulationType | AccumulatorTree = _DEFAULT,
278+
accumulator: (
279+
Accumulator | AccumulationType | AccumulatorTree
280+
) = AccumulationType.SUM,
281+
*,
282+
argnames: str | Sequence[str] = (),
283+
in_axes: int | Sequence[int] = 0,
228284
num_real_microbatches: int | None = None,
229285
) -> Callable[..., Any]:
230286
"""A general microbatching transformation.
@@ -269,17 +325,21 @@ def microbatched_fun(full_batch):
269325
(Array([2, 3, 4, 5], dtype=int32), Array(30, dtype=int32))
270326
271327
Args:
272-
fun: An arbitrary function. All kwargs are assumed to have a batch axis.
273-
argnums: A sequence of argument indices that have a batch axis. All
274-
kwargs are assumed to have a batch axis, similar to ``jax.vmap``.
328+
fun: An arbitrary function.
329+
argnums: A sequence of argument indices that have a batch axis.
275330
microbatch_size: The number of rows in the overall batch used in each
276331
microbatch. Smaller values reduce memory overhead, but require more
277332
sequential computation. This must evenly divide the batch axis size of
278333
the batch arguments.
279334
accumulator: Specifies how to combine results from each microbatch; can be
280-
a single ``Accumulator``, a pytree matching the structure of ``fun``'s
281-
output, with ``Accumulator`` values at the leaves, or anything in
282-
between (i.e., a PyTree prefix of ``fun``'s output`).
335+
a single `Accumulator`, a pytree matching the structure of `fun`'s
336+
output, with `Accumulator` values at the leaves, or anything in between
337+
(i.e., a PyTree prefix of `fun`'s output`).
338+
argnames: A sequence of keyword argument names that have a batch axis.
339+
in_axes: An integer or sequence of integers indicating the batch axis
340+
index for each argument in `argnums` and `argnames` should be aligned
341+
with the list `argnums + argnames`. The default value of 0 assumes
342+
that all arguments have a batch axis on the 0th dimension of the array.
283343
num_real_microbatches: Optional number of microbatches that are actually
284344
executed. If specified, microbatching will terminate early after this
285345
many steps. Can be helpful to handle variable batch sizes without
@@ -295,31 +355,38 @@ def microbatched_fun(full_batch):
295355
if isinstance(argnums, int):
296356
argnums = (argnums,)
297357

358+
if isinstance(argnames, str):
359+
argnames = (argnames,)
360+
361+
if isinstance(in_axes, int):
362+
in_axes = (in_axes,) * (len(argnums) + len(argnames))
363+
298364
def microbatched_fun(*args, **kwargs):
299-
batch_args = [args[i] for i in argnums]
300-
batch_size = jax.tree.leaves(batch_args)[0].shape[0]
301-
if batch_size % microbatch_size != 0:
302-
raise ValueError(f'{batch_size=} not divisible by {microbatch_size=}')
365+
reshaped_args, reshaped_kwargs, batch_size = _reshape_all_args(
366+
microbatch_size, argnums, argnames, in_axes, args, kwargs
367+
)
303368
num_microbatches = batch_size // microbatch_size
304369
accumulator_ = _canonicalize(accumulator, num_microbatches)
305370

306-
reshaped_batch_args = reshape_batch_axis(batch_args, microbatch_size)
307-
reshaped_kwargs = reshape_batch_axis(kwargs, microbatch_size)
308-
309371
def f(index):
310-
fetch = lambda arg: jax.tree.map(lambda x: x[index], arg)
311-
inputs = list(args)
312-
for i, arg in zip(argnums, reshaped_batch_args):
313-
inputs[i] = fetch(arg)
314-
input_kwargs = {k: fetch(kwarg) for k, kwarg in reshaped_kwargs.items()}
315-
return fun(*inputs, **input_kwargs)
372+
input_args = list(reshaped_args)
373+
input_kwargs = dict(reshaped_kwargs)
374+
for i, ax in zip(argnums, in_axes):
375+
input_args[i] = jax.tree.map(
376+
functools.partial(jnp.take, indices=index, axis=ax), input_args[i]
377+
)
378+
for i, ax in zip(argnames, in_axes[len(argnums) :]):
379+
input_kwargs[i] = jax.tree.map(
380+
functools.partial(jnp.take, indices=index, axis=ax), input_kwargs[i]
381+
)
382+
return fun(*input_args, **input_kwargs)
316383

317384
def body_fun(index, carry):
318385
return accumulator_.update(carry, f(index), index)
319386

320387
loop_bound = num_real_microbatches or num_microbatches
321388
answer = jax.lax.fori_loop(
322-
1, loop_bound, body_fun, accumulator_.init(f(0))
389+
1, loop_bound, body_fun, accumulator_.init(f(0)),
323390
)
324391

325392
return accumulator_.finalize(answer)

optax/experimental/_microbatching_test.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
import functools
1617

1718
from absl.testing import absltest
1819
from absl.testing import parameterized
19-
2020
import chex
2121
import jax.numpy as jnp
2222
import numpy as np
@@ -101,6 +101,7 @@ def test_microbatch_with_kwargs(self):
101101
argnums=(1,),
102102
microbatch_size=2,
103103
accumulator=accumulator,
104+
argnames=('batch_kwarg2',),
104105
)
105106
expected_answer = fun(nonbatch_arg, batch_arg1, batch_kwarg2=batch_kwarg2)
106107
actual_answer = microbatched_fun(
@@ -158,6 +159,41 @@ def test_correct_dtype_returned(self, arg_dtype):
158159
answer = microbatched_fun(nonbatch_arg, batch_arg1, batch_arg2)
159160
self.assertEqual(answer.dtype, arg_dtype)
160161

162+
def test_early_stopping_concat(self):
163+
x = jnp.arange(12).astype(jnp.float32) + 1
164+
165+
output = microbatching.microbatch(
166+
lambda x: x,
167+
argnums=0,
168+
accumulator=microbatching.AccumulationType.CONCAT,
169+
microbatch_size=3,
170+
num_real_microbatches=2,
171+
)(x)
172+
173+
chex.assert_trees_all_close(jnp.sum(output != 0), 6)
174+
175+
@parameterized.parameters(
176+
microbatching.AccumulationType.SUM,
177+
microbatching.AccumulationType.MEAN,
178+
microbatching.AccumulationType.RUNNING_MEAN,
179+
microbatching.AccumulationType.CONCAT,
180+
)
181+
def test_in_axes_invariant(self, acc):
182+
183+
arg_axis0 = jnp.array(np.random.normal(size=(10, 4, 5)))
184+
arg_axis1 = jnp.transpose(arg_axis0, axes=(1, 0, 2))
185+
self.assertEqual(arg_axis1.shape, (4, 10, 5))
186+
fun_axis0 = functools.partial(jnp.einsum, 'bij,bkj->ik')
187+
fun_axis1 = functools.partial(jnp.einsum, 'ibj,kbj->ik')
188+
189+
result0 = microbatching.microbatch(
190+
fun_axis0, argnums=(0, 1), microbatch_size=2, in_axes=0, accumulator=acc
191+
)(arg_axis0, arg_axis0)
192+
result1 = microbatching.microbatch(
193+
fun_axis1, argnums=(0, 1), microbatch_size=2, in_axes=1, accumulator=acc
194+
)(arg_axis1, arg_axis1)
195+
chex.assert_trees_all_close(result0, result1)
196+
161197

162198
if __name__ == '__main__':
163199
absltest.main()

0 commit comments

Comments
 (0)