Skip to content

Commit a24fe51

Browse files
committed
Add tree_bytes function.
1 parent c6d8eba commit a24fe51

File tree

5 files changed

+20
-2
lines changed

5 files changed

+20
-2
lines changed

docs/api/utilities.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ Tree
119119
tree_scale
120120
tree_set
121121
tree_size
122+
tree_bytes
122123
tree_sub
123124
tree_sum
124125
tree_vdot
@@ -221,6 +222,10 @@ Tree size
221222
~~~~~~~~~
222223
.. autofunction:: tree_size
223224

225+
Tree bytes
226+
~~~~~~~~~
227+
.. autofunction:: tree_bytes
228+
224229
Tree subtract
225230
~~~~~~~~~~~~~
226231
.. autofunction:: tree_sub

optax/contrib/_sophia.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,8 @@ def update_fn(updates, state: SophiaState, params=None, **hess_fn_kwargs):
161161
lambda x, y: x + y,
162162
jax.tree.map(lambda u: jnp.sum(jnp.abs(u) < clip_threshold), updates),
163163
)
164-
total_tree_size = sum(x.size for x in jax.tree.leaves(updates))
165164
if verbose:
166-
win_rate = sum_not_clipped / total_tree_size
165+
win_rate = sum_not_clipped / optax.tree.size(updates)
167166
jax.lax.cond(
168167
count_inc % print_win_rate_every_n_steps == 0,
169168
lambda: jax.debug.print("Sophia optimizer win rate: {}", win_rate),

optax/tree/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
real = _tree_math.tree_real
4848
scale = _tree_math.tree_scale
4949
size = _tree_math.tree_size
50+
bytes = _tree_math.tree_bytes # pylint: disable=redefined-builtin
5051
sub = _tree_math.tree_sub
5152
sum = _tree_math.tree_sum # pylint: disable=redefined-builtin
5253
update_infinity_moment = _tree_math.tree_update_infinity_moment

optax/tree_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from optax.tree_utils._tree_math import tree_real
4747
from optax.tree_utils._tree_math import tree_scale
4848
from optax.tree_utils._tree_math import tree_size
49+
from optax.tree_utils._tree_math import tree_bytes
4950
from optax.tree_utils._tree_math import tree_sub
5051
from optax.tree_utils._tree_math import tree_sum
5152
from optax.tree_utils._tree_math import tree_update_infinity_moment

optax/tree_utils/_tree_math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,18 @@ def tree_size(tree: Any) -> int:
209209
return sum(jnp.size(leaf) for leaf in jax.tree.leaves(tree))
210210

211211

212+
def tree_bytes(tree: Any) -> int:
213+
r"""Total number of bytes in a pytree.
214+
215+
Args:
216+
tree: pytree
217+
218+
Returns:
219+
the total number of bytes in the pytree.
220+
"""
221+
return sum(jnp.asarray(leaf).nbytes for leaf in jax.tree.leaves(tree))
222+
223+
212224
def tree_conj(tree: Any) -> Any:
213225
"""Compute the conjugate of a pytree.
214226

0 commit comments

Comments
 (0)