Skip to content

Commit 53e365f

Browse files
committed
Add tree_bytes function.
1 parent c6d8eba commit 53e365f

File tree

6 files changed

+32
-2
lines changed

6 files changed

+32
-2
lines changed

docs/api/utilities.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ Tree
119119
tree_scale
120120
tree_set
121121
tree_size
122+
tree_bytes
122123
tree_sub
123124
tree_sum
124125
tree_vdot
@@ -221,6 +222,10 @@ Tree size
221222
~~~~~~~~~
222223
.. autofunction:: tree_size
223224

225+
Tree bytes
226+
~~~~~~~~~
227+
.. autofunction:: tree_bytes
228+
224229
Tree subtract
225230
~~~~~~~~~~~~~
226231
.. autofunction:: tree_sub

optax/contrib/_sophia.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,8 @@ def update_fn(updates, state: SophiaState, params=None, **hess_fn_kwargs):
161161
lambda x, y: x + y,
162162
jax.tree.map(lambda u: jnp.sum(jnp.abs(u) < clip_threshold), updates),
163163
)
164-
total_tree_size = sum(x.size for x in jax.tree.leaves(updates))
165164
if verbose:
166-
win_rate = sum_not_clipped / total_tree_size
165+
win_rate = sum_not_clipped / optax.tree.size(updates)
167166
jax.lax.cond(
168167
count_inc % print_win_rate_every_n_steps == 0,
169168
lambda: jax.debug.print("Sophia optimizer win rate: {}", win_rate),

optax/tree/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
real = _tree_math.tree_real
4848
scale = _tree_math.tree_scale
4949
size = _tree_math.tree_size
50+
bytes = _tree_math.tree_bytes # pylint: disable=redefined-builtin
5051
sub = _tree_math.tree_sub
5152
sum = _tree_math.tree_sum # pylint: disable=redefined-builtin
5253
update_infinity_moment = _tree_math.tree_update_infinity_moment

optax/tree_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from optax.tree_utils._tree_math import tree_real
4747
from optax.tree_utils._tree_math import tree_scale
4848
from optax.tree_utils._tree_math import tree_size
49+
from optax.tree_utils._tree_math import tree_bytes
4950
from optax.tree_utils._tree_math import tree_sub
5051
from optax.tree_utils._tree_math import tree_sum
5152
from optax.tree_utils._tree_math import tree_update_infinity_moment

optax/tree_utils/_tree_math.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,23 @@ def tree_size(tree: Any) -> int:
209209
return sum(jnp.size(leaf) for leaf in jax.tree.leaves(tree))
210210

211211

212+
def tree_bytes(tree: Any) -> int:
213+
r"""Total number of bytes in a pytree.
214+
215+
Args:
216+
tree: pytree
217+
218+
Returns:
219+
the total size of the pytree in bytes.
220+
221+
.. warning::
222+
It is assumed that every leaf's dtype has an integer byte size.
223+
Fractional byte sizes may yield an incorrect result.
224+
For example, ``int4`` might be only half a byte on device.
225+
"""
226+
return sum(jnp.asarray(leaf).nbytes for leaf in jax.tree.leaves(tree))
227+
228+
212229
def tree_conj(tree: Any) -> Any:
213230
"""Compute the conjugate of a pytree.
214231

optax/tree_utils/_tree_math_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,13 @@ def test_tree_size(self, key):
172172
got = tu.tree_size(tree)
173173
np.testing.assert_allclose(expected, got)
174174

175+
@parameterized.product(
176+
n=[1, 10, 100, 1000],
177+
dtype=[jnp.int16, jnp.int32, jnp.float16, jnp.float32],
178+
)
179+
def test_tree_bytes(self, n, dtype):
180+
assert tu.tree_bytes(jnp.ones(n, dtype)) == n * dtype.dtype.itemsize
181+
175182
def test_tree_conj(self):
176183
expected = jnp.conj(self.array_a)
177184
got = tu.tree_conj(self.array_a)

0 commit comments

Comments
 (0)