diff --git a/jax/_src/tree.py b/jax/_src/tree.py index 870f74aeb0ed..30699867fc48 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -191,6 +191,11 @@ def reduce(function: Callable[[T, Any], T], >>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]]) 21 + Notes: + **Tip**: You can exclude leaves from the reduction by first mapping them to + ``None`` using :func:`jax.tree.map`. This causes them to not be counted as + leaves after that. + See Also: - :func:`jax.tree.reduce_associative` - :func:`jax.tree.leaves` @@ -230,6 +235,11 @@ def reduce_associative( >>> jax.tree.reduce_associative(operator.add, [1, (2, 3), [4, 5, 6]]) 21 + Notes: + **Tip**: You can exclude leaves from the reduction by first mapping them to + ``None`` using :func:`jax.tree.map`. This causes them to not be counted as + leaves after that. + See Also: - :func:`jax.tree.reduce` """