diff --git a/docs/gallery.rst b/docs/gallery.rst index caeb20812..c11aa4a0f 100644 --- a/docs/gallery.rst +++ b/docs/gallery.rst @@ -300,6 +300,21 @@ Examples that make use of the :doc:`api/contrib` module.
AdEMAMix.
+.. raw:: html + +
+ +.. only:: html + + .. image:: /images/favicon.svg + :alt: Aggregators. + + :doc:`_collections/examples/aggregators` + +.. raw:: html + +
Aggregators.
+
.. raw:: html diff --git a/examples/aggregate.ipynb b/examples/aggregate.ipynb new file mode 100644 index 000000000..07c09e3f1 --- /dev/null +++ b/examples/aggregate.ipynb @@ -0,0 +1,1067 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Aggregating and processing gradients\n", + "\n", + "⚠️ **Warning**: *This is an experimental feature subject to code changes.*\n", + "\n", + "Optax implements GradientTransformation that operates on average of gradients computed as the gradient of an average loss on a mini-batch. The optimizers can be seen as\n", + "$$\n", + "w_{\\text{next}} = w - t_k \\circ \\ldots \\circ t_1 \\circ \\text{avg} (grads)\n", + "$$\n", + "where grads here are a collection of gradients and avg is implemented implicitly by having computed the gradient on a mini-batch loss average.\n", + "\n", + "The class {py:func}:`optax.experimental.aggregate.Aggregators` and the function {py:func}:`optax.experimental.aggregate.process` extend this paradigm to allow for optimizers of the form\n", + "$$\n", + "w_{\\text{next}} = w -\n", + "t_k \\circ \\ldots \\circ t_1 \\circ\n", + "\\text{agg} \\circ \n", + "s_j \\circ \\ldots \\circ s_1 (grads)\n", + "$$\n", + "where the transformations $t_i$ and $s_i$ preserve the shape of their inputs (as usual gradient transformations) and agg (stands for aggregator) can for example average its inputs (so change the shape of its inputs). This paradigm enables for example a simple implementation of differential privacy setups (where individual gradients are clipped before being averaged).\n", + "\n", + "In this notebook, we explain how this paradigm is implemented and present several instanciations.\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "Dz6mpfWCPU6H" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Imports\n", + "import functools\n", + "from typing import Iterator, NamedTuple, Tuple\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jrd\n", + "\n", + "import optax\n", + "from optax import tree\n", + "from optax._src import base\n", + "from optax._src import transform\n", + "from optax._src import utils\n", + "from optax.transforms import _clipping\n", + "from optax.transforms import _combining\n", + "\n", + "from optax.experimental import aggregate" + ], + "metadata": { + "id": "vY4D-01UkyKO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Aggregators class\n", + "\n", + "We introduce a new class to isolate gradient transformations that operates on per-element gradients. Its signature mimics GradientTransformationExtraArgs with `init` and `update` functions.\n", + "\n", + "Optax base GradientTransformation expect input and output updates to be of the same shape as the parameters. The aggregators are supposed to take as inputs per-example gradients of shape `[*batch_shape, *params_shape]` and return an update direction of shape `[*params_shape]`. Note that we are not enforcing shape constraints during definition of instances of this class (nor we do for\n", + "GradientTransformation). This class serves as a guide to know that parts of the\n", + "rest of the optimization pipeline may need to change.\n", + "\n", + "\n", + "```python\n", + "PerElementUpdates = chex.ArrayTree\n", + "AggregatedUpdates = chex.ArrayTree\n", + "\n", + "\n", + "class AggregatorUpdateFn(Protocol):\n", + " \"\"\"Update function for aggregators.\"\"\"\n", + "\n", + " def __call__(\n", + " self,\n", + " per_elt_updates: PerElementUpdates,\n", + " state: base.OptState,\n", + " params: base.Params | None = None,\n", + " **extra_args: Any,\n", + " ) -\u003e tuple[AggregatedUpdates, base.OptState]:\n", + " \"\"\"Transforms per-element updates into aggregated update direction.\"\"\"\n", + "\n", + "\n", + "class Aggregator(base.GradientTransformationExtraArgs):\n", + " \"\"\"A pair of pure functions that implement stateful aggregation of gradients.\n", + "\n", + " Attributes:\n", + " init: Initialization function that takes params and returns state.\n", + " update: Update function that takes per-example gradients, state and params\n", + " (optionally) and returns updates and updated state.\n", + " \"\"\"\n", + "\n", + " init: base.TransformInitFn\n", + " update: AggregatorUpdateFn\n", + "```\n", + "\n", + "The main benefit is to let the user know what type of gradient oracle they should get.\n", + "\n", + "The usual training pipeline takes the form\n", + "```python\n", + "grads = jax.grad(loss)(params, batch)\n", + "updates, opt_state = opt.update(grads, opt_state)\n", + "```\n", + "To accomodate for aggregators a simple class check suffices. Namely, one can replace the line above by\n", + "```python\n", + "if isinstance(opt, aggregators.Aggregator):\n", + " grads = jax.vmap(jax.grad(loss), in_axes=(None, 0))(params, batch)\n", + "else:\n", + " grads = jax.grad(loss)(params, batch)\n", + "updates, opt_state = opt.update(grads, opt_state)\n", + "\n", + "```\n", + "\n", + "Let's see a first basic instance with the basic `average_per_element_udpates` (defined in the `aggregate` module).\n", + "\n", + "```python\n", + "def average_per_element_udpates(\n", + " per_elt_axis: int | list[int] = 0\n", + " ) -\u003e aggregate.Aggregator:\n", + " \"\"\"Average per-element updates.\"\"\"\n", + "\n", + " def update_fn(per_elt_updates, state, params=None):\n", + " del params\n", + " avg_updates = jax.tree.map(\n", + " lambda x: jnp.mean(x, axis=per_elt_axis), per_elt_updates\n", + " )\n", + " return avg_updates, state\n", + "\n", + " return aggregate.Aggregator(base.init_empty_state, update_fn)\n", + "```" + ], + "metadata": { + "id": "t3LkWAdhke00" + } + }, + { + "cell_type": "code", + "source": [ + "def data_iterator(\n", + " key: jrd.PRNGKey,\n", + " num_samples: int,\n", + " dim: int,\n", + " num_classes: int,\n", + " batch_size: int,\n", + ") -\u003e Iterator[Tuple[jnp.ndarray, jnp.ndarray]]:\n", + " \"\"\"Generates a synthetic set of inputs and targets.\"\"\"\n", + " inputs_key, targets_key = jrd.split(key)\n", + " inputs = jrd.normal(inputs_key, (num_samples, dim))\n", + " targets = jrd.normal(targets_key, (num_samples, num_classes))\n", + "\n", + " for i in range(0, num_samples, batch_size):\n", + " yield inputs[i : i + batch_size], targets[i : i + batch_size]\n", + "\n", + "\n", + "def loss_fun(params, batch):\n", + " inputs, targets = batch\n", + " return jnp.mean(jnp.sum((inputs.dot(params) - targets) ** 2, -1))\n", + "\n", + "\n", + "def basic_train(opt):\n", + " num_samples, batch_size, dim, num_classes = 16, 4, 4, 2\n", + "\n", + " data = data_iterator(jrd.key(0), num_samples, dim, num_classes, batch_size)\n", + " params = jrd.normal(jrd.key(1), (dim, num_classes))\n", + "\n", + " @jax.jit\n", + " def train_step(params, state, batch):\n", + " if isinstance(opt, aggregate.Aggregator):\n", + " losses, grads = jax.vmap(jax.value_and_grad(loss_fun), (None, 0))(\n", + " params, batch\n", + " )\n", + " loss = jnp.mean(losses)\n", + " else:\n", + " loss, grads = jax.value_and_grad(loss_fun)(params, batch)\n", + " updates, state = opt.update(grads, state)\n", + " params = optax.apply_updates(params, updates)\n", + " return params, state, loss\n", + "\n", + " state = opt.init(params)\n", + " for i, batch in enumerate(data):\n", + " params, state, loss = train_step(params, state, batch)\n", + " print(f'Step: {i} | Batch loss: {loss:.2e}')" + ], + "metadata": { + "id": "3yc25FeULwul" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print('Standard training')\n", + "opt = optax.sgd(learning_rate=0.01)\n", + "\n", + "basic_train(opt)\n", + "\n", + "print('\\nWith explicit aggregation')\n", + "opt = aggregate.chain(\n", + " aggregate.average_per_element_udpates(), optax.sgd(learning_rate=0.01)\n", + ")\n", + "basic_train(opt)" + ], + "metadata": { + "id": "mTNfRXyeNyPa" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Processing gradients\n", + "\n", + "Optimizers of the form\n", + "$$\n", + "w_{\\text{next}} = w -\n", + "t_k \\circ \\ldots \\circ t_1 \\circ\n", + "\\text{agg} \\circ \n", + "s_j \\circ \\ldots \\circ s_1 (grads)\n", + "$$\n", + "may be defined just a chain of gradient transformations.\n", + "\n", + "However,\n", + "1. we may aggregate gradients by doing simple gradient accumulation rather than computing all gradients at once,\n", + "2. we may want to do more than just average the gradients (for example we may want to access to some variance).\n", + "\n", + "\n", + "For this reason, we provide a `process` function.\n", + "```python\n", + "def process(\n", + " preprocessor: base.GradientTransformation,\n", + " aggregator: base.GradientTransformation | Aggregator,\n", + " postprocessor: base.GradientTransformation,\n", + " aggregator_has_aux: bool = False,\n", + "):\n", + " \"\"\"Process gradients through a sequence of transformations.\n", + "\n", + " Args:\n", + " preprocessor: A transformation that maps per-example gradients to\n", + " per-example updates.\n", + " aggregator: A transformation that aggregates per-example updates into a\n", + " single update.\n", + " postprocessor: A transformation that maps aggregated updates to the final\n", + " updates.\n", + " aggregator_has_aux: Whether the aggregator returns more than just the\n", + " average updates.\n", + "\n", + " Returns:\n", + " A :class:`optax.GradientTransformation`.\n", + " \"\"\"\n", + "\n", + " def init_fn(params) -\u003e tuple[base.OptState, base.OptState, base.OptState]:\n", + " preprocess_state = preprocessor.init(params)\n", + " aggregate_state = aggregator.init(params)\n", + " postprocess_state = postprocessor.init(params)\n", + " return preprocess_state, aggregate_state, postprocess_state\n", + "\n", + " def update_fn(indiv_grads, states, params=None, **extra_args):\n", + " preprocess_state, aggregate_state, postprocess_state = states\n", + "\n", + " indiv_updates, new_preprocess_state = preprocessor.update(\n", + " indiv_grads, preprocess_state, params, **extra_args\n", + " )\n", + "\n", + " aggregated, new_aggregate_state = aggregator.update(\n", + " indiv_updates, aggregate_state, params, **extra_args\n", + " )\n", + "\n", + " if aggregator_has_aux:\n", + " avg_updates, agg_aux = aggregated\n", + " extra_args = extra_args | agg_aux\n", + " else:\n", + " avg_updates = aggregated\n", + "\n", + " ready_to_post_process = tree.get(new_aggregate_state, 'ready', True)\n", + "\n", + " updates, new_postprocess_state = jax.lax.cond(\n", + " ready_to_post_process,\n", + " lambda g, s, p, kw: postprocessor.update(g, s, p, **kw),\n", + " lambda g, s, *_: (tree.zeros_like(avg_updates), s),\n", + " avg_updates,\n", + " postprocess_state,\n", + " params,\n", + " extra_args,\n", + " )\n", + " return updates, (\n", + " new_preprocess_state,\n", + " new_aggregate_state,\n", + " new_postprocess_state,\n", + " )\n", + "\n", + " if isinstance(aggregator, Aggregator):\n", + " return Aggregator(init_fn, update_fn)\n", + " else:\n", + " return base.GradientTransformationExtraArgs(init_fn, update_fn)\n", + "```\n", + "\n", + "This function lets the user define an `aggregate` transform that can aggregate gradients it receives chunk by chunk until is ready to post-process them.\n", + "It also provides the possibility to pass along more than the average updates to\n", + "the post-processing stage.\n", + "\n", + "As a simple example, we can extend the basic `average_per_element_updates` to\n", + "work with micro-batches using the following tools (in the `aggregate` module).\n", + "\n", + "\n", + "```python\n", + "class AccumulateAvgUpdatesState(NamedTuple):\n", + " \"\"\"State for the average gradient accumulator.\"\"\"\n", + "\n", + " micro_step: int\n", + " ready: bool\n", + " avg_grad: base.Updates\n", + "\n", + "\n", + "def accumulate_avg_udpates(\n", + " num_microbatches: int,\n", + ") -\u003e base.GradientTransformation:\n", + " \"\"\"Accumulate average gradients.\"\"\"\n", + "\n", + " if num_microbatches \u003c 1:\n", + " raise ValueError('num_microbatches must be larger than or equal to than 0.')\n", + "\n", + " if num_microbatches == 1:\n", + " # If there is only one microbatch, we don't need accumulation.\n", + " # We return identity to save unnecessary state tracking.\n", + " return base.identity()\n", + "\n", + " def init_fn(params):\n", + " return AccumulateAvgUpdatesState(\n", + " micro_step=0, ready=False, avg_grad=tree.zeros_like(params)\n", + " )\n", + "\n", + " def update_fn(updates, state, params=None):\n", + " del params\n", + " new_micro_step = state.micro_step + 1\n", + " new_avg_grad = jax.tree.map(\n", + " lambda u, a: a + (u - a) / new_micro_step,\n", + " updates,\n", + " state.avg_grad,\n", + " )\n", + " ready_state = AccumulateAvgUpdatesState(\n", + " micro_step=0, ready=True, avg_grad=tree.zeros_like(new_avg_grad)\n", + " )\n", + " not_ready_state = AccumulateAvgUpdatesState(\n", + " micro_step=new_micro_step, ready=False, avg_grad=new_avg_grad\n", + " )\n", + " updates, new_state = tree.where(\n", + " new_micro_step == num_microbatches,\n", + " (new_avg_grad, ready_state),\n", + " (tree.zeros_like(new_avg_grad), not_ready_state),\n", + " )\n", + " return updates, new_state\n", + "\n", + " return base.GradientTransformation(init_fn, update_fn)\n", + "\n", + "\n", + "def average_incrementally_updates(\n", + " per_elt_axis: int | list[int] | None, num_microbatches: int\n", + ") -\u003e Aggregator | base.GradientTransformation:\n", + " \"\"\"Average and accumulate per-element updates.\"\"\"\n", + " if per_elt_axis is None:\n", + " return accumulate_avg_udpates(num_microbatches)\n", + " else:\n", + " return chain(\n", + " average_per_element_udpates(per_elt_axis),\n", + " accumulate_avg_udpates(num_microbatches),\n", + " )\n", + "```\n", + "\n", + "We can revise our basic example with this." + ], + "metadata": { + "id": "-ZBm1jhfJdm9" + } + }, + { + "cell_type": "code", + "source": [ + "def train(\n", + " opt,\n", + " num_microbatches: int = 1,\n", + " num_samples: int = 16,\n", + " batch_size: int = 4,\n", + " dim: int = 4,\n", + " num_classes: int = 2,\n", + "):\n", + "\n", + " data_iter = lambda: data_iterator(\n", + " jrd.key(0), num_samples, dim, num_classes, batch_size // num_microbatches\n", + " )\n", + " full_data = [jnp.concatenate(a, axis=0) for a in zip(*data_iter())]\n", + " params = jrd.normal(jrd.key(1), (dim, num_classes))\n", + "\n", + " @jax.jit\n", + " def train_step(params, state, batch):\n", + " if isinstance(opt, aggregate.Aggregator):\n", + " losses, grads = jax.vmap(jax.value_and_grad(loss_fun), (None, 0))(\n", + " params, batch\n", + " )\n", + " loss = jnp.mean(losses)\n", + " else:\n", + " loss, grads = jax.value_and_grad(loss_fun)(params, batch)\n", + " updates, state = opt.update(grads, state)\n", + " params = optax.apply_updates(params, updates)\n", + " return params, state, loss\n", + "\n", + " state = opt.init(params)\n", + " for i, batch in enumerate(data_iter()):\n", + " full_batch_loss = loss_fun(params, full_data)\n", + " params, state, loss = train_step(params, state, batch)\n", + " print(\n", + " f'Step: {i} |'\n", + " f'Mini-batch Loss: {loss:.2e} |'\n", + " f'Full batch loss: {full_batch_loss:.2e}'\n", + " )" + ], + "metadata": { + "id": "4swmJOuVAawj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print('Standard training')\n", + "opt = optax.sgd(learning_rate=0.01)\n", + "train(opt)\n", + "\n", + "print('\\nWith accumulation')\n", + "num_microbatches = 2\n", + "# The optimizer below does not use per-example average as we use\n", + "# per_elt_axis=None. It returns a standard GradientTransform and uses\n", + "# jax.grad in the train pipeline above\n", + "opt = aggregate.process(\n", + " base.identity(),\n", + " aggregate.average_incrementally_updates(\n", + " per_elt_axis=None, num_microbatches=num_microbatches\n", + " ),\n", + " optax.sgd(learning_rate=0.01),\n", + ")\n", + "train(opt, num_microbatches)\n", + "\n", + "print('\\nWith explicit aggregation and accumulation')\n", + "# This optimizer is an Aggregator and will sue the vmap grads\n", + "opt = aggregate.process(\n", + " base.identity(),\n", + " aggregate.average_incrementally_updates(\n", + " per_elt_axis=0, num_microbatches=num_microbatches),\n", + " optax.sgd(learning_rate=0.01),\n", + ")\n", + "train(opt, num_microbatches)\n" + ], + "metadata": { + "id": "t1YlQEsPQxW3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The resulting mini-batch losses between with or without accumulation\n", + "do not match since the mini-batches are not the same and we are not\n", + "accumulating losses as we are accumulating gradients.\n", + "The full losses naturally match: the full batch loss at step `i` without accumulation matches the full batch loss with accumulation at step `num_microbatches x i`." + ], + "metadata": { + "id": "4mUdjNTe5qXY" + } + }, + { + "cell_type": "markdown", + "source": [ + "Note that the proposed `accumulate_avg_udpates` combined with `process` can also replace `optax.MultiSteps` as\n", + "\n", + "```python\n", + "def accumulate_grads(\n", + " opt: base.GradientTransformation,\n", + " num_microbatches: int,\n", + ") -\u003e base.GradientTransformation:\n", + " \"\"\"Accumulate gradients.\"\"\"\n", + " return process(\n", + " preprocessor=base.identity(),\n", + " aggregator=average_incrementally_updates(\n", + " per_elt_axis=None,\n", + " num_microbatches=num_microbatches,\n", + " ),\n", + " postprocessor=opt,\n", + " )\n", + "```\n", + "\n", + "In the rest of the notebook, we present how this paradigm can be applied (i) in differential privacy, (ii) to record variance of gradients per coordinate." + ], + "metadata": { + "id": "NAFYZW-5YZzX" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Differentially private SGD\n", + "\n", + "Differentially private sgd is an algorithm that clips per-example gradients, then add noise to preserve privacy (see [Deep Learning with Differential Privacy (Abadi et al., 2016)](https://arxiv.org/abs/1607.00133) for more details). With the help of the `process` function such an algorithm can be implemented very easily." + ], + "metadata": { + "id": "jP9XRrSBZER6" + } + }, + { + "cell_type": "code", + "source": [ + "def per_example_clip(\n", + " l2_norm_clip: float,\n", + " per_elt_axis: int | list[int] | None = 0,\n", + ") -\u003e base.GradientTransformation:\n", + " \"\"\"Clip per-example gradients with their individual norm.\"\"\"\n", + "\n", + " if per_elt_axis is None:\n", + " return _clipping.clip_by_global_norm(l2_norm_clip)\n", + "\n", + " def update_fn(per_elt_grads, state, params=None):\n", + " del params\n", + " clip = functools.partial(\n", + " optax.projections.projection_l2_ball, scale=l2_norm_clip\n", + " )\n", + " clipped_updates = jax.vmap(clip, in_axes=per_elt_axis)(per_elt_grads)\n", + " return clipped_updates, state\n", + "\n", + " return base.GradientTransformation(base.init_empty_state, update_fn)\n", + "\n", + "\n", + "def differentially_private_aggregate(\n", + " l2_norm_clip: float,\n", + " noise_multiplier: float,\n", + " key: jax.Array | int,\n", + " per_elt_axis: int | list[int] | None = 0,\n", + " num_microbatches: int = 1,\n", + ") -\u003e base.GradientTransformation | aggregate.Aggregator:\n", + " \"\"\"Processes gradients based on the DPSGD algorithm.\"\"\"\n", + " noise_std = l2_norm_clip * noise_multiplier\n", + "\n", + " return aggregate.process(\n", + " preprocessor=per_example_clip(l2_norm_clip, per_elt_axis),\n", + " aggregator=aggregate.average_incrementally_updates(\n", + " per_elt_axis=per_elt_axis,\n", + " num_microbatches=num_microbatches,\n", + " ),\n", + " postprocessor=transform.add_noise(\n", + " eta=noise_std,\n", + " gamma=1.0,\n", + " key=key,\n", + " ),\n", + " )" + ], + "metadata": { + "id": "nnoILZL3Yogs" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We can instantiate it the pipeline defined before." + ], + "metadata": { + "id": "-OXsFc7NZv-B" + } + }, + { + "cell_type": "code", + "source": [ + "print('Without DP')\n", + "train(optax.sgd(learning_rate=0.01))\n", + "\n", + "print('\\nWith DP')\n", + "opt = aggregate.chain(\n", + " differentially_private_aggregate(\n", + " l2_norm_clip=1.0, noise_multiplier=1.0, key=jrd.key(2)\n", + " ),\n", + " optax.sgd(learning_rate=0.01),\n", + ")\n", + "train(opt)" + ], + "metadata": { + "id": "WLZu7zfvtCEY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Micro-Adam\n", + "\n", + "We can instantiate a variant of Adam that uses average of square rather than square of average gradients (see [Batch size invariant Adam (Wang \u0026 Aitchison, 2024)](https://arxiv.org/abs/2402.18824)).\n", + "\n", + "This instance requires us to aggregate both average gradient and average of square gradients. So it illustrates the need of a `has_aux` argument for the `process` function.\n" + ], + "metadata": { + "id": "Ym2Au8qgu5sb" + } + }, + { + "cell_type": "code", + "source": [ + "def scale_by_micro_adam(\n", + " b1: float,\n", + " b2: float,\n", + " eps: float,\n", + ") -\u003e base.GradientTransformationExtraArgs:\n", + " \"\"\"Micro-Adam optimizer.\"\"\"\n", + "\n", + " def init_fn(params):\n", + " return transform.scale_by_adam(b1=b1, b2=b2, eps=eps).init(params)\n", + "\n", + " def update_fn(updates, state, params=None, *, avg_sq_updates, **extra_args):\n", + " del params, extra_args\n", + " mu = tree.update_moment(updates, state.mu, b1, 1)\n", + " nu = tree.update_moment(avg_sq_updates, state.nu, b2, 1)\n", + " count_inc = utils.safe_int32_increment(state.count)\n", + "\n", + " mu_hat = tree.bias_correction(mu, b1, count_inc)\n", + " nu_hat = tree.bias_correction(nu, b2, count_inc)\n", + " updates = jax.tree.map(lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat)\n", + " new_state = transform.ScaleByAdamState(count=count_inc, mu=mu, nu=nu)\n", + " return updates, new_state\n", + "\n", + " return base.GradientTransformationExtraArgs(init_fn, update_fn)\n", + "\n", + "\n", + "def get_avg_and_avg_sq_updates(\n", + " per_elt_axis: int | list[int] | None = 0,\n", + " num_microbatches: int = 1,\n", + ") -\u003e base.GradientTransformationExtraArgs:\n", + " \"\"\"Collect average and average of squares of gradients.\"\"\"\n", + "\n", + " def incremental_update_fn(updates, state, params=None):\n", + " # With this update function, we use the accumulation below to progressively\n", + " # compute the average and average of squares of updates.\n", + " del params\n", + " sq_updates = jax.tree.map(jnp.square, updates)\n", + " return (updates, {'avg_sq_updates': sq_updates}), state\n", + "\n", + " def per_elt_update_fn(per_elt_updates, state, params=None):\n", + " # With this update function, we consider `per_elt_updates` to be per-element\n", + " # updates that we want to average on. The accumulator below may even let us\n", + " # get reach larger batches.\n", + " del params\n", + " avg_updates = jax.tree.map(\n", + " lambda x: jnp.mean(x, axis=per_elt_axis), per_elt_updates\n", + " )\n", + " avg_sq_updates = jax.tree.map(\n", + " lambda x: jnp.mean(jnp.square(x), axis=per_elt_axis), per_elt_updates\n", + " )\n", + " return (avg_updates, {'avg_sq_updates': avg_sq_updates}), state\n", + "\n", + " if per_elt_axis is None:\n", + " opt = aggregate.chain(\n", + " base.GradientTransformation(\n", + " base.init_empty_state, incremental_update_fn\n", + " ),\n", + " aggregate.accumulate_avg_udpates(num_microbatches),\n", + " )\n", + " else:\n", + " opt = aggregate.chain(\n", + " aggregate.Aggregator(base.init_empty_state, per_elt_update_fn),\n", + " aggregate.accumulate_avg_udpates(num_microbatches),\n", + " )\n", + " return opt\n", + "\n", + "\n", + "def micro_adam(\n", + " learning_rate: base.ScalarOrSchedule,\n", + " b1: float = 0.9,\n", + " b2: float = 0.999,\n", + " eps: float = 1e-8,\n", + " per_elt_axis: int | list[int] | None = 0,\n", + " num_microbatches: int = 1,\n", + ") -\u003e base.GradientTransformation:\n", + " \"\"\"Micro-Adam optimizer.\"\"\"\n", + " return aggregate.process(\n", + " preprocessor=base.identity(),\n", + " aggregator=get_avg_and_avg_sq_updates(per_elt_axis, num_microbatches),\n", + " postprocessor=aggregate.chain(\n", + " scale_by_micro_adam(b1, b2, eps),\n", + " transform.scale_by_learning_rate(learning_rate),\n", + " ),\n", + " aggregator_has_aux=True,\n", + " )\n" + ], + "metadata": { + "id": "E8QCx7PZvUHn" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print('Classical Adam')\n", + "opt = optax.adam(learning_rate=0.01)\n", + "train(opt)\n", + "\n", + "print('Micro-Adam')\n", + "opt = micro_adam(learning_rate=0.01)\n", + "train(opt)\n" + ], + "metadata": { + "id": "8VLg-RzCvzN6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Adam with variance computations\n", + "\n", + "Finally, we present how to track per element gradient variance during optimization. The variance is accumulated in a numerically stable manner with Welford's algorithm." + ], + "metadata": { + "id": "4VUEUX428ZZN" + } + }, + { + "cell_type": "code", + "source": [ + "def get_batch_size_from_per_elt_updates(\n", + " per_elt_updates: base.Updates, per_elt_axis: int | list[int]\n", + ") -\u003e int:\n", + " \"\"\"Get batch size from per-element updates.\"\"\"\n", + "\n", + " def get_batch_size(u):\n", + " if isinstance(per_elt_axis, int):\n", + " return u.shape[per_elt_axis]\n", + " else:\n", + " return functools.reduce(\n", + " lambda a, b: a * b, [u.shape[i] for i in per_elt_axis]\n", + " )\n", + "\n", + " batch_sizes = jax.tree.map(get_batch_size, per_elt_updates)\n", + " batch_sizes = jax.tree.leaves(batch_sizes)\n", + " if not all(b == batch_sizes[0] for b in batch_sizes):\n", + " raise ValueError(\n", + " f'Per-element updates must have the same batch size. Got: {batch_sizes}'\n", + " )\n", + " return batch_sizes[0]\n", + "\n", + "\n", + "class PerElementMeanAndSumSqDiffGradsState(NamedTuple):\n", + " \"\"\"State for the per-element mean and variance accumulator.\"\"\"\n", + "\n", + " micro_step: int\n", + " ready: bool\n", + " mean_grads: base.Updates\n", + " sum_sq_diff_grads: base.Updates\n", + "\n", + "\n", + "def get_per_element_mean_and_sum_sq_diff_grads(\n", + " per_elt_axis: int | list[int] = 0,\n", + " num_microbatches: int = 1,\n", + ") -\u003e aggregate.Aggregator:\n", + " \"\"\"Collect per-element mean and variance gradient metrics.\"\"\"\n", + "\n", + " if per_elt_axis is None:\n", + " raise NotImplementedError(\n", + " 'Per-element mean and sum square diff need a per_elt_axis.'\n", + " )\n", + "\n", + " def compute_avg_and_sum_sq_diff(\n", + " per_elt_udpates: base.Updates,\n", + " state: base.OptState,\n", + " params: base.Params | None,\n", + " ) -\u003e tuple[base.Updates, base.Updates]:\n", + " del params\n", + " batch_size = get_batch_size_from_per_elt_updates(\n", + " per_elt_udpates, per_elt_axis\n", + " )\n", + " mean_grads = jax.tree.map(\n", + " lambda x: jnp.mean(x, axis=per_elt_axis), per_elt_udpates\n", + " )\n", + " sum_sq_diff_grads = jax.tree.map(\n", + " lambda x: jnp.sum(jnp.square(x), axis=per_elt_axis), per_elt_udpates\n", + " )\n", + " return (\n", + " mean_grads,\n", + " {'sum_sq_diff_grads': sum_sq_diff_grads, 'sample_size': batch_size},\n", + " ), state\n", + "\n", + " if num_microbatches == 1:\n", + " return aggregate.Aggregator(\n", + " base.init_empty_state, compute_avg_and_sum_sq_diff\n", + " )\n", + "\n", + " def init_fn(params):\n", + " return PerElementMeanAndSumSqDiffGradsState(\n", + " micro_step=0,\n", + " ready=False,\n", + " mean_grads=tree.zeros_like(params),\n", + " sum_sq_diff_grads=tree.zeros_like(params),\n", + " )\n", + "\n", + " def update_fn(per_elt_udpates, state, params=None):\n", + " del params\n", + " batch_size = get_batch_size_from_per_elt_updates(\n", + " per_elt_udpates, per_elt_axis\n", + " )\n", + " new_micro_step = state.micro_step + 1\n", + "\n", + " # Compute batch averages.\n", + " batch_mean_grads = jax.tree.map(\n", + " lambda x: jnp.mean(x, axis=per_elt_axis, keepdims=True), per_elt_udpates\n", + " )\n", + " batch_sum_sq_diff_grads = jax.tree.map(\n", + " lambda x, a: jnp.sum(jnp.square(x - a), axis=per_elt_axis),\n", + " per_elt_udpates,\n", + " batch_mean_grads,\n", + " )\n", + " batch_mean_grads = jax.tree.map(\n", + " lambda x: x.squeeze(axis=per_elt_axis), batch_mean_grads\n", + " )\n", + "\n", + " # Update accumulated averages.\n", + " delta = jax.tree.map(lambda u, a: u - a, batch_mean_grads, state.mean_grads)\n", + " new_mean_grads = jax.tree.map(\n", + " lambda a, d: a + d / new_micro_step,\n", + " state.mean_grads,\n", + " delta,\n", + " )\n", + " size_factor = state.micro_step * batch_size / new_micro_step\n", + " new_sum_sq_diff_grads = jax.tree.map(\n", + " lambda a, s, d: a + s + d**2 * size_factor,\n", + " state.sum_sq_diff_grads,\n", + " batch_sum_sq_diff_grads,\n", + " delta,\n", + " )\n", + " maybe_outputs = (\n", + " new_mean_grads,\n", + " {\n", + " 'sum_sq_diff_grads': new_sum_sq_diff_grads,\n", + " 'sample_size': batch_size * new_micro_step,\n", + " },\n", + " )\n", + "\n", + " # Output or not the accumulated averages.\n", + " ready_state = PerElementMeanAndSumSqDiffGradsState(\n", + " micro_step=0,\n", + " ready=True,\n", + " mean_grads=tree.zeros_like(new_mean_grads),\n", + " sum_sq_diff_grads=tree.zeros_like(new_sum_sq_diff_grads),\n", + " )\n", + " not_ready_state = PerElementMeanAndSumSqDiffGradsState(\n", + " micro_step=new_micro_step,\n", + " ready=False,\n", + " mean_grads=new_mean_grads,\n", + " sum_sq_diff_grads=new_sum_sq_diff_grads,\n", + " )\n", + " updates, new_state = tree.where(\n", + " new_micro_step == num_microbatches,\n", + " (maybe_outputs, ready_state),\n", + " (tree.zeros_like(maybe_outputs), not_ready_state),\n", + " )\n", + " return updates, new_state\n", + "\n", + " return aggregate.Aggregator(init_fn, update_fn)\n", + "\n", + "\n", + "class PerElementMeanAndVarianceEMAState(NamedTuple):\n", + " \"\"\"State for the per-element mean and variance accumulator.\"\"\"\n", + "\n", + " count: jax.Array\n", + " ema_decay: jax.Array\n", + " mean_grads_ema: base.Updates\n", + " variance_grads_ema: base.Updates\n", + "\n", + "\n", + "def track_per_element_mean_and_variance_with_ema(\n", + " ema_decay: float = 0.9,\n", + ") -\u003e base.GradientTransformation:\n", + " \"\"\"Track variance metrics with an EMA over time.\"\"\"\n", + "\n", + " def init_fn(params):\n", + " return PerElementMeanAndVarianceEMAState(\n", + " count=jnp.zeros([], jnp.int32),\n", + " ema_decay=jnp.asarray(ema_decay),\n", + " mean_grads_ema=tree.zeros_like(params),\n", + " variance_grads_ema=tree.zeros_like(params),\n", + " )\n", + "\n", + " def update_fn(updates, state, params=None, *, sum_sq_diff_grads, sample_size):\n", + " del params\n", + " mean_grads_ema = jax.tree.map(\n", + " lambda x, y: (1.0 - ema_decay) * x + ema_decay * y,\n", + " updates,\n", + " state.mean_grads_ema,\n", + " )\n", + " variance_step = tree.scale(1 / (sample_size - 1), sum_sq_diff_grads)\n", + " variance_grads_ema = jax.tree.map(\n", + " lambda x, y: (1.0 - ema_decay) * x + ema_decay * y,\n", + " variance_step,\n", + " state.variance_grads_ema,\n", + " )\n", + " new_count = utils.safe_int32_increment(state.count)\n", + " new_state = state._replace(\n", + " count=new_count,\n", + " mean_grads_ema=mean_grads_ema,\n", + " variance_grads_ema=variance_grads_ema,\n", + " )\n", + " return updates, new_state\n", + "\n", + " return base.GradientTransformationExtraArgs(init_fn, update_fn)\n", + "\n", + "\n", + "def get_unbiased_mean_and_variance_ema(\n", + " state: base.OptState,\n", + ") -\u003e tuple[base.Updates, base.Updates]:\n", + " \"\"\"Track unbiased mean and variance with an EMA over time.\"\"\"\n", + " per_elt_mean_and_variance_ema_state = tree.get(\n", + " state, 'PerElementMeanAndVarianceEMAState', None\n", + " )\n", + " if per_elt_mean_and_variance_ema_state is None:\n", + " raise ValueError(\n", + " 'State must have PerElementMeanAndVarianceEMAState to compute unbiased'\n", + " ' mean and variance EMA.'\n", + " )\n", + " count = per_elt_mean_and_variance_ema_state.count\n", + " ema_decay = per_elt_mean_and_variance_ema_state.ema_decay\n", + " mean_grads_ema = per_elt_mean_and_variance_ema_state.mean_grads_ema\n", + " variance_grads_ema = per_elt_mean_and_variance_ema_state.variance_grads_ema\n", + " unbiased_mean_grads_ema = jax.tree.map(\n", + " lambda x: x / (1 - ema_decay**count), mean_grads_ema\n", + " )\n", + " unbiased_variance_grads_ema = jax.tree.map(\n", + " lambda x: x / (1 - ema_decay**count), variance_grads_ema\n", + " )\n", + " return unbiased_mean_grads_ema, unbiased_variance_grads_ema\n", + "\n", + "\n", + "def add_mean_variance_to_opt(\n", + " opt: base.GradientTransformation,\n", + " ema_decay: float = 0.9,\n", + " per_elt_axis: int | list[int] | None = 0,\n", + " num_microbatches: int = 1,\n", + "):\n", + " \"\"\"Add mean and variance to an optimizer.\"\"\"\n", + " return aggregate.process(\n", + " preprocessor=base.identity(),\n", + " aggregator=get_per_element_mean_and_sum_sq_diff_grads(\n", + " per_elt_axis, num_microbatches\n", + " ),\n", + " postprocessor=aggregate.chain(\n", + " track_per_element_mean_and_variance_with_ema(ema_decay),\n", + " opt,\n", + " ),\n", + " aggregator_has_aux=True,\n", + " )" + ], + "metadata": { + "id": "niI5AGiQ8Yv3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def train_and_track(\n", + " opt,\n", + " num_microbatches: int = 1,\n", + " num_samples: int = 16,\n", + " batch_size: int = 4,\n", + " dim: int = 4,\n", + " num_classes: int = 2,\n", + "):\n", + "\n", + " data_iter = lambda: data_iterator(\n", + " jrd.key(0), num_samples, dim, num_classes, batch_size // num_microbatches\n", + " )\n", + " full_data = [jnp.concatenate(a, axis=0) for a in zip(*data_iter())]\n", + " params = jrd.normal(jrd.key(1), (dim, num_classes))\n", + "\n", + " opt = add_mean_variance_to_opt(opt)\n", + "\n", + " @jax.jit\n", + " def train_step(params, state, batch):\n", + " if isinstance(opt, aggregate.Aggregator):\n", + " losses, grads = jax.vmap(jax.value_and_grad(loss_fun), (None, 0))(\n", + " params, batch\n", + " )\n", + " loss = jnp.mean(losses)\n", + " else:\n", + " loss, grads = jax.value_and_grad(loss_fun)(params, batch)\n", + " updates, state = opt.update(grads, state)\n", + " params = optax.apply_updates(params, updates)\n", + " return params, state, loss\n", + "\n", + " state = opt.init(params)\n", + " for i, batch in enumerate(data_iter()):\n", + " full_batch_loss = loss_fun(params, full_data)\n", + " params, state, loss = train_step(params, state, batch)\n", + " mean, var = get_unbiased_mean_and_variance_ema(state)\n", + " print(\n", + " f'Step: {i} |'\n", + " f'Mini-batch Loss: {loss:.2e} |'\n", + " f'Full batch loss: {full_batch_loss:.2e}\\n'\n", + " f'Mean EMA:\\n {mean}\\n'\n", + " f'Variance EMA:\\n {var}:'\n", + " )" + ], + "metadata": { + "id": "miAFLNf4_GQ_" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_and_track(optax.adam(learning_rate=1e-1))" + ], + "metadata": { + "id": "3FpwgB5C_wee" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Sharp edges\n", + "\n", + "Currently, the axis along which both the aggregator is done is not accessible where the train step is defined. Ideally GradientTransforms would be DataClasses so Aggregators could store meta-fields like the axis along which a vmap can be done." + ], + "metadata": { + "id": "sefOimx5RxRQ" + } + } + ], + "metadata": { + "colab": { + "private_outputs": true, + "provenance": [], + "last_runtime": { + "build_target": "//learning/grp/tools/ml_python:ml_python_notebook", + "kind": "private" + } + }, + "language_info": { + "name": "python" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/optax/experimental/__init__.py b/optax/experimental/__init__.py new file mode 100644 index 000000000..ac5d30c23 --- /dev/null +++ b/optax/experimental/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental features for Optax.""" + +from optax.experimental.aggregate import add_mean_variance_to_opt +from optax.experimental.aggregate import Aggregator +from optax.experimental.aggregate import average_incrementally_updates +from optax.experimental.aggregate import get_unbiased_mean_and_variance_ema +from optax.experimental.aggregate import process diff --git a/optax/experimental/_aggregate_test.py b/optax/experimental/_aggregate_test.py new file mode 100644 index 000000000..eb1c32fbb --- /dev/null +++ b/optax/experimental/_aggregate_test.py @@ -0,0 +1,199 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from absl.testing import absltest +from absl.testing import parameterized +import chex +import jax +import jax.numpy as jnp +import jax.random as jrd +from optax._src import alias +from optax._src import base +from optax._src import update +from optax.experimental import aggregate + + +def _train( + opt, + num_microbatches: int = 1, + num_samples: int = 16, + batch_size: int = 4, + dim: int = 4, + num_classes: int = 2, +): + """Synthetic training with the given optimizer.""" + microbatch_size = batch_size // num_microbatches + + def data_iterator(key): + inputs_key, targets_key = jrd.split(key) + inputs = jrd.normal(inputs_key, (num_samples, dim)) + targets = jrd.normal(targets_key, (num_samples, num_classes)) + + for i in range(0, num_samples, microbatch_size): + yield inputs[i : i + microbatch_size], targets[i : i + microbatch_size] + + def loss_fun(params, batch): + inputs, targets = batch + return jnp.mean(jnp.sum((inputs.dot(params) - targets) ** 2, -1)) + + data_key, param_key = jrd.split(jrd.key(0)) + full_data = [ + jnp.concatenate(a, axis=0) for a in zip(*data_iterator(data_key)) + ] + params = jrd.normal(param_key, (dim, num_classes)) + + @jax.jit + def train_step(params, state, batch): + var_grads = None + if isinstance(opt, aggregate.Aggregator): + losses, grads = jax.vmap(jax.value_and_grad(loss_fun), (None, 0))( + params, batch + ) + loss = jnp.mean(losses) + if num_microbatches == 1: + var_grads = jax.tree.map(lambda g: jnp.var(g, axis=0, ddof=1), grads) + else: + loss, grads = jax.value_and_grad(loss_fun)(params, batch) + updates, state = opt.update(grads, state) + params = update.apply_updates(params, updates) + return params, state, loss, var_grads + + state = opt.init(params) + metrics = {} + for batch in data_iterator(data_key): + full_batch_loss = loss_fun(params, full_data) + params, state, loss, var_grads = train_step(params, state, batch) + step_metrics = {'loss': loss, 'full_batch_loss': full_batch_loss} + if var_grads is not None: + step_metrics['var_grads'] = var_grads + try: + mean_grads_ema, var_grads_ema = ( + aggregate.get_unbiased_mean_and_variance_ema(state) + ) + step_metrics['mean_grads_ema'] = mean_grads_ema + step_metrics['var_grads_ema'] = var_grads_ema + except ValueError: + pass + if not metrics: + for key in step_metrics: + metrics[key] = [] + for key, value in step_metrics.items(): + metrics[key].append(value) + return params, metrics + + +class AggregatorsTest(parameterized.TestCase): + + def test_aggregation_and_accumulation_match_standard(self): + base_opt = alias.sgd(learning_rate=0.1) + std_params, std_metrics = _train(base_opt) + + opt = aggregate.process( + base.identity(), + aggregate.average_incrementally_updates( + per_elt_axis=0, num_microbatches=1 + ), + base_opt, + ) + agg_params, agg_metrics = _train(opt, 1) + device_type = jax.devices()[0].platform + rtol = 5*1e-3 if device_type == 'tpu' else 1e-5 + with self.subTest('aggregation matches standard'): + chex.assert_trees_all_close(std_params, agg_params, rtol=rtol) + chex.assert_trees_all_close( + std_metrics['full_batch_loss'], + agg_metrics['full_batch_loss'], + rtol=rtol, + ) + + opt = aggregate.process( + base.identity(), + aggregate.average_incrementally_updates( + per_elt_axis=None, num_microbatches=2 + ), + base_opt, + ) + acc_params, acc_metrics = _train(opt, num_microbatches=2) + + with self.subTest('accumulation matches standard'): + chex.assert_trees_all_close(std_params, acc_params) + chex.assert_trees_all_close( + std_metrics['full_batch_loss'], acc_metrics['full_batch_loss'][::2] + ) + + opt = aggregate.process( + base.identity(), + aggregate.average_incrementally_updates( + per_elt_axis=0, num_microbatches=2 + ), + base_opt, + ) + agg_acc_params, agg_acc_metrics = _train(opt, num_microbatches=2) + + with self.subTest('aggregation and accumulation match standard'): + chex.assert_trees_all_close(std_params, agg_acc_params) + chex.assert_trees_all_close( + std_metrics['full_batch_loss'], + agg_acc_metrics['full_batch_loss'][::2], + ) + + @parameterized.product(ema_decay=[0.0, 0.9]) + def test_mean_variance_ema_match_standard(self, ema_decay: float = 0.99): + base_opt = alias.sgd(learning_rate=0.1) + std_params, std_metrics = _train(base_opt) + + opt = aggregate.add_mean_variance_to_opt(base_opt, ema_decay) + mean_var_agg_params, mean_var_agg_metrics = _train(opt) + + with self.subTest( + 'mean variance ema with aggregation training matches standard' + ): + chex.assert_trees_all_close(std_params, mean_var_agg_params, rtol=1e-2) + chex.assert_trees_all_close( + std_metrics['full_batch_loss'], + mean_var_agg_metrics['full_batch_loss'], + rtol=1e-2, + ) + with self.subTest('var grads ema matches var grads'): + if ema_decay == 0.0: + chex.assert_trees_all_close( + mean_var_agg_metrics['var_grads_ema'], + mean_var_agg_metrics['var_grads'], + ) + + opt = aggregate.add_mean_variance_to_opt( + base_opt, ema_decay, num_microbatches=2 + ) + mean_var_acc_params, mean_var_acc_metrics = _train(opt, num_microbatches=2) + with self.subTest( + 'mean variance ema with accumulation training matches standard' + ): + chex.assert_trees_all_close(std_params, mean_var_acc_params, rtol=1e-2) + chex.assert_trees_all_close( + std_metrics['full_batch_loss'], + mean_var_acc_metrics['full_batch_loss'][::2], + rtol=1e-2, + ) + + with self.subTest( + 'var grads ema with accumulation matches var grads ema with' + ' aggregation' + ): + chex.assert_trees_all_close( + mean_var_agg_metrics['var_grads_ema'], + mean_var_acc_metrics['var_grads_ema'][1::2], + ) + +if __name__ == '__main__': + absltest.main() diff --git a/optax/experimental/aggregate.py b/optax/experimental/aggregate.py new file mode 100644 index 000000000..483c3529a --- /dev/null +++ b/optax/experimental/aggregate.py @@ -0,0 +1,480 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient transformations that aggregate gradients.""" + +import functools +from typing import Any, NamedTuple, Protocol +import chex +import jax +import jax.numpy as jnp +from optax import tree +from optax._src import base +from optax._src import utils +from optax.transforms import _combining + + +############################################################################### +# Aggregators + + +PerElementUpdates = chex.ArrayTree +AggregatedUpdates = chex.ArrayTree +MaybeAxis = int | list[int] | None + + +class AggregatorUpdateFn(Protocol): + """Update function for aggregators.""" + + def __call__( + self, + per_elt_updates: PerElementUpdates, + state: base.OptState, + params: base.Params | None = None, + **extra_args: Any, + ) -> tuple[AggregatedUpdates, base.OptState]: + """Transforms per-element updates into aggregated update direction.""" + + +class Aggregator(base.GradientTransformationExtraArgs): + """A pair of pure functions that implement stateful aggregation of gradients. + + This class differs from a standard optax GradientTransformation as it is + defined to operate on a set of invidividual gradients, rather than on + aggregated gradients -- like the mini-batch average of gradients. + + Optax base GradientTransformation expect input and output updates to be of the + same shape as the parameters. The aggregators take as inputs per-example + gradients of shape [*batch_shape, *params_shape] and return update direction + of shape [*params_shape]. + + While usual optax transformations are used in an api of the form + grads = jax.grad(loss)(params, batch) + updates, opt_state = transformation.update(grads, opt_state) + The aggregators are used in an api of the form + grads = jax.vmap(jax.grad(loss), in_axes=(None, 0))(params, batch) + updates, opt_state = aggregator.update(grads, opt_state) + + The signatures of AggregatorUpdateFn and GradientTransformationUpdateFn are + identical, but the distinction is necessary for the user to adapt the gradient + oracles to such specific transformations. + + Attributes: + init: Initialization function that takes params and returns state. + update: Update function that takes per-example gradients, state and params + (optionally) and returns updates and updated state. + """ + + init: base.TransformInitFn + update: AggregatorUpdateFn + + +def chain(*transforms) -> base.GradientTransformationExtraArgs: + """Combines transforms, returning an Aggregator if one is present.""" + opt = _combining.chain(*transforms) + if any(isinstance(t, Aggregator) for t in transforms): + return Aggregator(opt.init, opt.update) + return opt + + +################################################################################# + + +def process( + preprocessor: base.GradientTransformation, + aggregator: base.GradientTransformation | Aggregator, + postprocessor: base.GradientTransformation, + aggregator_has_aux: bool = False, +): + """Process gradients through a sequence of transformations. + + Args: + preprocessor: A transformation that maps per-example gradients to + per-example updates. + aggregator: A transformation that aggregates per-example updates into a + single update. + postprocessor: A transformation that maps aggregated updates to the final + updates. + aggregator_has_aux: Whether the aggregator returns more than just the + average updates. + + Returns: + A :class:`optax.GradientTransformation`. + """ + + def init_fn(params) -> tuple[base.OptState, base.OptState, base.OptState]: + preprocess_state = preprocessor.init(params) + aggregate_state = aggregator.init(params) + postprocess_state = postprocessor.init(params) + return preprocess_state, aggregate_state, postprocess_state + + def update_fn(indiv_grads, states, params=None, **extra_args): + preprocess_state, aggregate_state, postprocess_state = states + + indiv_updates, new_preprocess_state = preprocessor.update( + indiv_grads, preprocess_state, params, **extra_args + ) + + aggregated, new_aggregate_state = aggregator.update( + indiv_updates, aggregate_state, params, **extra_args + ) + + if aggregator_has_aux: + avg_updates, agg_aux = aggregated + extra_args = extra_args | agg_aux + else: + avg_updates = aggregated + + ready_to_post_process = tree.get(new_aggregate_state, 'ready', True) + + updates, new_postprocess_state = jax.lax.cond( + ready_to_post_process, + lambda g, s, p, kw: postprocessor.update(g, s, p, **kw), + lambda g, s, *_: (tree.zeros_like(avg_updates), s), + avg_updates, + postprocess_state, + params, + extra_args, + ) + return updates, ( + new_preprocess_state, + new_aggregate_state, + new_postprocess_state, + ) + + if isinstance(aggregator, Aggregator): + return Aggregator(init_fn, update_fn) + else: + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +################################################################################ +# Base aggregator/accumulator + + +def average_per_element_udpates( + per_elt_axis: int | list[int] = 0 +) -> Aggregator: + """Average per-element updates.""" + + def update_fn(per_elt_updates, state, params=None): + del params + avg_updates = jax.tree.map( + lambda x: jnp.mean(x, axis=per_elt_axis), per_elt_updates + ) + return avg_updates, state + + return Aggregator(base.init_empty_state, update_fn) + + +class AccumulateAvgUpdatesState(NamedTuple): + """State for the average gradient accumulator.""" + + micro_step: int + ready: bool + avg_grad: base.Updates + + +def accumulate_avg_udpates( + num_microbatches: int, +) -> base.GradientTransformation: + """Accumulate average gradients.""" + + if num_microbatches < 1: + raise ValueError('num_microbatches must be larger than or equal to than 0.') + + if num_microbatches == 1: + # If there is only one microbatch, we don't need accumulation. + # We return identity to save unnecessary state tracking. + return base.identity() + + def init_fn(params): + return AccumulateAvgUpdatesState( + micro_step=0, ready=False, avg_grad=tree.zeros_like(params) + ) + + def update_fn(updates, state, params=None): + del params + new_micro_step = state.micro_step + 1 + new_avg_grad = jax.tree.map( + lambda u, a: a + (u - a) / new_micro_step, + updates, + state.avg_grad, + ) + ready_state = AccumulateAvgUpdatesState( + micro_step=0, ready=True, avg_grad=tree.zeros_like(new_avg_grad) + ) + not_ready_state = AccumulateAvgUpdatesState( + micro_step=new_micro_step, ready=False, avg_grad=new_avg_grad + ) + updates, new_state = tree.where( + new_micro_step == num_microbatches, + (new_avg_grad, ready_state), + (tree.zeros_like(new_avg_grad), not_ready_state), + ) + return updates, new_state + + return base.GradientTransformation(init_fn, update_fn) + + +def average_incrementally_updates( + per_elt_axis: MaybeAxis, num_microbatches: int +) -> Aggregator | base.GradientTransformation: + """Average and accumulate per-element updates.""" + if per_elt_axis is None: + return accumulate_avg_udpates(num_microbatches) + else: + return chain( + average_per_element_udpates(per_elt_axis), + accumulate_avg_udpates(num_microbatches), + ) + + +################################################################################ +# Adding mean and variance gradient metrics + + +def get_batch_size_from_per_elt_updates( + per_elt_updates: base.Updates, per_elt_axis: MaybeAxis +) -> int: + """Get batch size from per-element updates.""" + + def get_batch_size(u): + if isinstance(per_elt_axis, int): + return u.shape[per_elt_axis] + else: + return functools.reduce( + lambda a, b: a * b, [u.shape[i] for i in per_elt_axis] + ) + + batch_sizes = jax.tree.map(get_batch_size, per_elt_updates) + batch_sizes = jax.tree.leaves(batch_sizes) + if not all(b == batch_sizes[0] for b in batch_sizes): + raise ValueError( + f'Per-element updates must have the same batch size. Got: {batch_sizes}' + ) + return batch_sizes[0] + + +class PerElementMeanAndSumSqDiffGradsState(NamedTuple): + """State for the per-element mean and variance accumulator.""" + + micro_step: int + ready: bool + mean_grads: base.Updates + sum_sq_diff_grads: base.Updates + + +def get_per_element_mean_and_sum_sq_diff_grads( + per_elt_axis: int | list[int] = 0, + num_microbatches: int = 1, +) -> Aggregator: + """Collect per-element variance metrics.""" + + if per_elt_axis is None: + raise NotImplementedError( + 'Per-element mean and sum square diff need a per_elt_axis.' + ) + + def compute_avg_and_sum_sq_diff( + per_elt_udpates: base.Updates, + state: base.OptState, + params: base.Params | None, + ) -> tuple[base.Updates, base.Updates]: + del params + batch_size = get_batch_size_from_per_elt_updates( + per_elt_udpates, per_elt_axis + ) + mean_grads = jax.tree.map( + lambda x: jnp.mean(x, axis=per_elt_axis, keepdims=True), + per_elt_udpates, + ) + sum_sq_diff_grads = jax.tree.map( + lambda x, a: jnp.sum(jnp.square(x - a), axis=per_elt_axis), + per_elt_udpates, + mean_grads, + ) + mean_grads = jax.tree.map( + lambda x: x.squeeze(axis=per_elt_axis), mean_grads + ) + return ( + mean_grads, + {'sum_sq_diff_grads': sum_sq_diff_grads, 'sample_size': batch_size}, + ), state + + if num_microbatches == 1: + return Aggregator(base.init_empty_state, compute_avg_and_sum_sq_diff) + + def init_fn(params): + return PerElementMeanAndSumSqDiffGradsState( + micro_step=0, + ready=False, + mean_grads=tree.zeros_like(params), + sum_sq_diff_grads=tree.zeros_like(params), + ) + + def update_fn(per_elt_udpates, state, params=None): + del params + batch_size = get_batch_size_from_per_elt_updates( + per_elt_udpates, per_elt_axis + ) + new_micro_step = state.micro_step + 1 + + # Compute batch averages. + batch_mean_grads = jax.tree.map( + lambda x: jnp.mean(x, axis=per_elt_axis, keepdims=True), per_elt_udpates + ) + batch_sum_sq_diff_grads = jax.tree.map( + lambda x, a: jnp.sum(jnp.square(x - a), axis=per_elt_axis), + per_elt_udpates, + batch_mean_grads, + ) + batch_mean_grads = jax.tree.map( + lambda x: x.squeeze(axis=per_elt_axis), batch_mean_grads + ) + + # Update accumulated averages. + delta = jax.tree.map(lambda u, a: u - a, batch_mean_grads, state.mean_grads) + new_mean_grads = jax.tree.map( + lambda a, d: a + d / new_micro_step, + state.mean_grads, + delta, + ) + size_factor = state.micro_step * batch_size / new_micro_step + new_sum_sq_diff_grads = jax.tree.map( + lambda a, s, d: a + s + d**2 * size_factor, + state.sum_sq_diff_grads, + batch_sum_sq_diff_grads, + delta, + ) + maybe_outputs = ( + new_mean_grads, + { + 'sum_sq_diff_grads': new_sum_sq_diff_grads, + 'sample_size': batch_size * new_micro_step, + }, + ) + + # Output or not the accumulated averages. + ready_state = PerElementMeanAndSumSqDiffGradsState( + micro_step=0, + ready=True, + mean_grads=tree.zeros_like(new_mean_grads), + sum_sq_diff_grads=tree.zeros_like(new_sum_sq_diff_grads), + ) + not_ready_state = PerElementMeanAndSumSqDiffGradsState( + micro_step=new_micro_step, + ready=False, + mean_grads=new_mean_grads, + sum_sq_diff_grads=new_sum_sq_diff_grads, + ) + updates, new_state = tree.where( + new_micro_step == num_microbatches, + (maybe_outputs, ready_state), + (tree.zeros_like(maybe_outputs), not_ready_state), + ) + return updates, new_state + + return Aggregator(init_fn, update_fn) + + +class PerElementMeanAndVarianceEMAState(NamedTuple): + """State for the per-element mean and variance accumulator.""" + + count: jax.Array + ema_decay: jax.Array + mean_grads_ema: base.Updates + variance_grads_ema: base.Updates + + +def track_per_element_mean_and_variance_with_ema( + ema_decay: float = 0.9, +) -> base.GradientTransformation: + """Track variance metrics with an EMA over time.""" + + def init_fn(params): + return PerElementMeanAndVarianceEMAState( + count=jnp.zeros([], jnp.int32), + ema_decay=jnp.asarray(ema_decay), + mean_grads_ema=tree.zeros_like(params), + variance_grads_ema=tree.zeros_like(params), + ) + + def update_fn(updates, state, params=None, *, sum_sq_diff_grads, sample_size): + del params + mean_grads_ema = jax.tree.map( + lambda x, y: (1.0 - ema_decay) * x + ema_decay * y, + updates, + state.mean_grads_ema, + ) + variance_step = tree.scale(1 / (sample_size - 1), sum_sq_diff_grads) + variance_grads_ema = jax.tree.map( + lambda x, y: (1.0 - ema_decay) * x + ema_decay * y, + variance_step, + state.variance_grads_ema, + ) + new_count = utils.safe_int32_increment(state.count) + new_state = state._replace( + count=new_count, + mean_grads_ema=mean_grads_ema, + variance_grads_ema=variance_grads_ema, + ) + return updates, new_state + + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +def get_unbiased_mean_and_variance_ema( + state: base.OptState, +) -> tuple[base.Updates, base.Updates]: + """Track unbiased mean and variance with an EMA over time.""" + per_elt_mean_and_variance_ema_state = tree.get( + state, 'PerElementMeanAndVarianceEMAState', None + ) + if per_elt_mean_and_variance_ema_state is None: + raise ValueError( + 'State must have PerElementMeanAndVarianceEMAState to compute unbiased' + ' mean and variance EMA.' + ) + count = per_elt_mean_and_variance_ema_state.count + ema_decay = per_elt_mean_and_variance_ema_state.ema_decay + mean_grads_ema = per_elt_mean_and_variance_ema_state.mean_grads_ema + variance_grads_ema = per_elt_mean_and_variance_ema_state.variance_grads_ema + unbiased_mean_grads_ema = jax.tree.map( + lambda x: x / (1 - ema_decay**count), mean_grads_ema + ) + unbiased_variance_grads_ema = jax.tree.map( + lambda x: x / (1 - ema_decay**count), variance_grads_ema + ) + return unbiased_mean_grads_ema, unbiased_variance_grads_ema + + +def add_mean_variance_to_opt( + opt: base.GradientTransformation, + ema_decay: float = 0.9, + per_elt_axis: MaybeAxis = 0, + num_microbatches: int = 1, +): + """Add mean and variance to an optimizer.""" + return process( + preprocessor=base.identity(), + aggregator=get_per_element_mean_and_sum_sq_diff_grads( + per_elt_axis, num_microbatches + ), + postprocessor=chain( + track_per_element_mean_and_variance_with_ema(ema_decay), + opt, + ), + aggregator_has_aux=True, + )