@@ -74,18 +74,33 @@ def _identity(value: Any) -> Any:
7474 return value
7575
7676
77- def reshape_batch_axis (pytree : Any , microbatch_size : int ):
78- """Reshape pytree leaves to shape (num_microbatches, microbatch_size, ...)."""
79- # If data is sharded along the 0th axis, using column-major order is important
80- # to ensure that each microbatch is sharded in the same manner.
81- # For example, if the data was sharded across 2 devices, each device would
82- # handle one of the examples in each microbatch.
83- # [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] --> [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]
84-
85- return jax .tree .map (
86- lambda x : x .reshape (- 1 , microbatch_size , * x .shape [1 :], order = 'F' ),
87- pytree ,
88- )
77+ def reshape_batch_axis (tree : Any , microbatch_size : int , axis : int = 0 ):
78+ """Reshape batch axis of pytree leaves for use with microbatching.
79+
80+ This function reshapes the batch axis of each leaf into a shape
81+ (num_microbatches, microbatch_size) appearing at the same axis as the original
82+ batch axis. The reshape is done using a column-major order, so any sharding
83+ along the batch axis should be preserved in the new `microbatch_size` axis,
84+ while the new `num_microbatches` axis will generally be replicated.
85+
86+ Args:
87+ tree: A pytree of jax.Arrays, each having a batch axis.
88+ microbatch_size: The size of sub-batches used for each microbatch.
89+ axis: The axis to reshape.
90+
91+ Returns:
92+ A pytree of reshaped jax.Arrays.
93+ """
94+
95+ def leaf_fn (x ):
96+ shape = x .shape
97+ batch_size = shape [axis ]
98+ if batch_size % microbatch_size != 0 :
99+ raise ValueError (f'{ batch_size = } not divisible by { microbatch_size = } ' )
100+ new_shape = shape [:axis ] + (- 1 , microbatch_size ) + shape [axis + 1 :]
101+ return x .reshape (new_shape , order = 'F' )
102+
103+ return jax .tree .map (leaf_fn , tree )
89104
90105
91106def _lift (accumulator : Accumulator ) -> Accumulator :
@@ -127,13 +142,14 @@ def finalize(carry):
127142
128143 def aggregate (values ):
129144 return jax .tree .map (
130- lambda acc , val : acc .accumulate (val ), accumulators , values
145+ lambda acc , val : acc .aggregate (val ), accumulators , values
131146 )
132147
133148 return Accumulator (init , update , finalize , aggregate )
134149
135150
136151def _sum () -> Accumulator :
152+ """An Accumulator that computes the sum of microbatched outputs."""
137153 return _lift (
138154 Accumulator (
139155 init = _identity ,
@@ -145,6 +161,7 @@ def _sum() -> Accumulator:
145161
146162
147163def _mean (num_microbatches : int ) -> Accumulator :
164+ """An Accumulator that computes the mean of microbatched outputs."""
148165 return _lift (
149166 Accumulator (
150167 init = _with_floating_check (_identity ),
@@ -156,6 +173,7 @@ def _mean(num_microbatches: int) -> Accumulator:
156173
157174
158175def _running_mean () -> Accumulator :
176+ """An Accumulator that computes the running mean of microbatched outputs."""
159177 def update (carry , value , index ):
160178 p = index / (index + 1 )
161179 new_state = carry * p + value * (1 - p )
@@ -172,8 +190,11 @@ def update(carry, value, index):
172190
173191
174192def _concat (num_microbatches : int ) -> Accumulator :
193+ """An Accumulator that concatenates microbatched outputs along the axis 0."""
175194 def init (value ):
176- return jnp .broadcast_to (value , (num_microbatches ,) + value .shape )
195+ shape = (num_microbatches ,) + value .shape
196+ zeros = jnp .broadcast_to (jnp .zeros_like (value ), shape )
197+ return zeros .at [0 ].set (value )
177198
178199 def update (carry , value , index ):
179200 return carry .at [index ].set (value )
@@ -217,14 +238,49 @@ def fun(acc):
217238 return _compose (jax .tree .map (fun , tree ))
218239
219240
220- _DEFAULT = AccumulationType .SUM
241+ def _reshape_all_args (
242+ microbatch_size : int ,
243+ argnums : Sequence [int ],
244+ argnames : Sequence [str ],
245+ in_axes : Sequence [int ],
246+ args : tuple [Any , ...],
247+ kwargs : dict [str , Any ]
248+ ) -> tuple [tuple [Any , ...], dict [str , Any ], int ]:
249+ """Reshapes all batch arguments to have a microbatch axis."""
250+ new_args = list (args )
251+ new_kwargs = dict (kwargs )
252+ batch_args = [args [i ] for i in argnums ] + [kwargs [i ] for i in argnames ]
253+
254+ batch_sizes = jax .tree .flatten (jax .tree .map (
255+ lambda ax , subtree : jax .tree .map (lambda x : x .shape [ax ], subtree ),
256+ tuple (in_axes ), tuple (batch_args )
257+ ))[0 ]
258+
259+ if len (set (batch_sizes )) > 1 :
260+ raise ValueError (
261+ 'Batch Arguments must have the same shape along the batch axis, found'
262+ f' multiple batch sizes: { batch_sizes } '
263+ )
264+
265+ for i , ax in zip (argnums , in_axes ):
266+ new_args [i ] = reshape_batch_axis (args [i ], microbatch_size , ax )
267+
268+ for name , ax in zip (argnames , in_axes [len (argnums ) :]):
269+ new_kwargs [name ] = reshape_batch_axis (kwargs [name ], microbatch_size , ax )
270+
271+ return tuple (new_args ), new_kwargs , tuple (batch_sizes )[0 ]
221272
222273
223274def microbatch (
224275 fun : Callable [..., Any ],
225276 argnums : int | Sequence [int ],
226277 microbatch_size : int | None ,
227- accumulator : Accumulator | AccumulationType | AccumulatorTree = _DEFAULT ,
278+ accumulator : (
279+ Accumulator | AccumulationType | AccumulatorTree
280+ ) = AccumulationType .SUM ,
281+ * ,
282+ argnames : str | Sequence [str ] = (),
283+ in_axes : int | Sequence [int ] = 0 ,
228284 num_real_microbatches : int | None = None ,
229285) -> Callable [..., Any ]:
230286 """A general microbatching transformation.
@@ -269,17 +325,21 @@ def microbatched_fun(full_batch):
269325 (Array([2, 3, 4, 5], dtype=int32), Array(30, dtype=int32))
270326
271327 Args:
272- fun: An arbitrary function. All kwargs are assumed to have a batch axis.
273- argnums: A sequence of argument indices that have a batch axis. All
274- kwargs are assumed to have a batch axis, similar to ``jax.vmap``.
328+ fun: An arbitrary function.
329+ argnums: A sequence of argument indices that have a batch axis.
275330 microbatch_size: The number of rows in the overall batch used in each
276331 microbatch. Smaller values reduce memory overhead, but require more
277332 sequential computation. This must evenly divide the batch axis size of
278333 the batch arguments.
279334 accumulator: Specifies how to combine results from each microbatch; can be
280- a single ``Accumulator``, a pytree matching the structure of ``fun``'s
281- output, with ``Accumulator`` values at the leaves, or anything in
282- between (i.e., a PyTree prefix of ``fun``'s output`).
335+ a single `Accumulator`, a pytree matching the structure of `fun`'s
336+ output, with `Accumulator` values at the leaves, or anything in between
337+ (i.e., a PyTree prefix of `fun`'s output`).
338+ argnames: A sequence of keyword argument names that have a batch axis.
339+ in_axes: An integer or sequence of integers indicating the batch axis
340+ index for each argument in `argnums` and `argnames` should be aligned
341+ with the list `argnums + argnames`. The default value of 0 assumes
342+ that all arguments have a batch axis on the 0th dimension of the array.
283343 num_real_microbatches: Optional number of microbatches that are actually
284344 executed. If specified, microbatching will terminate early after this
285345 many steps. Can be helpful to handle variable batch sizes without
@@ -295,31 +355,38 @@ def microbatched_fun(full_batch):
295355 if isinstance (argnums , int ):
296356 argnums = (argnums ,)
297357
358+ if isinstance (argnames , str ):
359+ argnames = (argnames ,)
360+
361+ if isinstance (in_axes , int ):
362+ in_axes = (in_axes ,) * (len (argnums ) + len (argnames ))
363+
298364 def microbatched_fun (* args , ** kwargs ):
299- batch_args = [args [i ] for i in argnums ]
300- batch_size = jax .tree .leaves (batch_args )[0 ].shape [0 ]
301- if batch_size % microbatch_size != 0 :
302- raise ValueError (f'{ batch_size = } not divisible by { microbatch_size = } ' )
365+ reshaped_args , reshaped_kwargs , batch_size = _reshape_all_args (
366+ microbatch_size , argnums , argnames , in_axes , args , kwargs
367+ )
303368 num_microbatches = batch_size // microbatch_size
304369 accumulator_ = _canonicalize (accumulator , num_microbatches )
305370
306- reshaped_batch_args = reshape_batch_axis (batch_args , microbatch_size )
307- reshaped_kwargs = reshape_batch_axis (kwargs , microbatch_size )
308-
309371 def f (index ):
310- fetch = lambda arg : jax .tree .map (lambda x : x [index ], arg )
311- inputs = list (args )
312- for i , arg in zip (argnums , reshaped_batch_args ):
313- inputs [i ] = fetch (arg )
314- input_kwargs = {k : fetch (kwarg ) for k , kwarg in reshaped_kwargs .items ()}
315- return fun (* inputs , ** input_kwargs )
372+ input_args = list (reshaped_args )
373+ input_kwargs = dict (reshaped_kwargs )
374+ for i , ax in zip (argnums , in_axes ):
375+ input_args [i ] = jax .tree .map (
376+ functools .partial (jnp .take , indices = index , axis = ax ), input_args [i ]
377+ )
378+ for i , ax in zip (argnames , in_axes [len (argnums ) :]):
379+ input_kwargs [i ] = jax .tree .map (
380+ functools .partial (jnp .take , indices = index , axis = ax ), input_kwargs [i ]
381+ )
382+ return fun (* input_args , ** input_kwargs )
316383
317384 def body_fun (index , carry ):
318385 return accumulator_ .update (carry , f (index ), index )
319386
320387 loop_bound = num_real_microbatches or num_microbatches
321388 answer = jax .lax .fori_loop (
322- 1 , loop_bound , body_fun , accumulator_ .init (f (0 ))
389+ 1 , loop_bound , body_fun , accumulator_ .init (f (0 )),
323390 )
324391
325392 return accumulator_ .finalize (answer )
0 commit comments