Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax.tree module docs #24925

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,15 @@ def grad(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = ()) -> Callable:
"""Creates a function that evaluates the gradient of ``fun``.

"""A JAX transformation that creates a function that evaluates the gradient of ``fun``.

Learn more in the Automatic differentiation `[1]`_ and Advanced automatic differentiation `[2]`_
tutorials, and the Quickstart `[3]`_ documentation.

.. _[1] https://jax.readthedocs.io/en/latest/automatic-differentiation.html
.. _[2] https://jax.readthedocs.io/en/latest/advanced-autodiff.html
.. _[3] https://jax.readthedocs.io/en/latest/quickstart.html#taking-derivatives-with-jax-grad

Args:
fun: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, or standard Python containers.
Expand All @@ -352,12 +359,12 @@ def grad(fun: Callable, argnums: int | Sequence[int] = 0,
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
differentiated and the second element is auxiliary data. Default ``False``.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
holomorphic. If ``True``, inputs and outputs must be complex. Default ``False``.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
have a trivial vector-space dtype (float0). Default ``False``.

Returns:
A function with the same arguments as ``fun``, that evaluates the gradient
Expand Down
4 changes: 4 additions & 0 deletions jax/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
"""Utilities for working with tree-like container data structures.
The :mod:`jax.tree` namespace contains aliases of utilities from :mod:`jax.tree_util`.
Refer to the Working with pytrees `[1]`_ tutorial for examples.
.. _[1] https://jax.readthedocs.io/en/latest/working-with-pytrees.html
"""

from jax._src.tree import (
Expand Down
7 changes: 4 additions & 3 deletions jax/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
functions in this file.
The primary purpose of this module is to enable the interoperability between
user defined data structures and JAX transformations (e.g. `jit`). This is not
user defined data structures and JAX transformations (e.g. ``jax.jit``). This is not
meant to be a general purpose tree-like data structure handling library.
See the `JAX pytrees note <pytrees.html>`_
for examples.
Refer to the Working with pytrees `[1]`_ tutorial for examples.
.. _[1] https://jax.readthedocs.io/en/latest/working-with-pytrees.html
"""

# Note: import <name> as <name> is required for names to be exported.
Expand Down