Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Tree
tree_scale
tree_set
tree_size
tree_bits
tree_sub
tree_sum
tree_vdot
Expand Down Expand Up @@ -226,6 +227,10 @@ Tree size
~~~~~~~~~
.. autofunction:: tree_size

Tree bits
~~~~~~~~~
.. autofunction:: tree_bits

Tree subtract
~~~~~~~~~~~~~
.. autofunction:: tree_sub
Expand Down
1 change: 1 addition & 0 deletions optax/tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
real = _tree_math.tree_real
scale = _tree_math.tree_scale
size = _tree_math.tree_size
bits = _tree_math.tree_bits
sub = _tree_math.tree_sub
sum = _tree_math.tree_sum # pylint: disable=redefined-builtin
update_infinity_moment = _tree_math.tree_update_infinity_moment
Expand Down
1 change: 1 addition & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from optax.tree_utils._tree_math import tree_real
from optax.tree_utils._tree_math import tree_scale
from optax.tree_utils._tree_math import tree_size
from optax.tree_utils._tree_math import tree_bits
from optax.tree_utils._tree_math import tree_sub
from optax.tree_utils._tree_math import tree_sum
from optax.tree_utils._tree_math import tree_update_infinity_moment
Expand Down
31 changes: 31 additions & 0 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,37 @@ def tree_size(tree: Any) -> int:
return sum(jnp.size(leaf) for leaf in jax.tree.leaves(tree))


def _get_bits(dtype):
if jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).bits
elif jnp.issubdtype(dtype, jnp.floating):
return jnp.finfo(dtype).bits
elif dtype is bool:
return 1
else:
raise NotImplementedError(f"_get_bits not implemented for {dtype=}")


def tree_bits(tree: Any) -> int:
r"""Total number of bits in a pytree.

Args:
tree: pytree

Returns:
the total size of the pytree in bits.

.. warning::
It is assumed that every leaf's dtype has an integer byte size.
Fractional byte sizes may yield an incorrect result.
For example, ``int4`` might be only half a byte on device.
"""
return sum(
_get_bits(jnp.asarray(leaf).dtype) * jnp.size(leaf)
for leaf in jax.tree.leaves(tree)
)


def tree_conj(tree: Any) -> Any:
"""Compute the conjugate of a pytree.

Expand Down
33 changes: 33 additions & 0 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,39 @@ def test_tree_allclose(self):
assert tu.tree_allclose(1, 1 + 1e-7)
assert not tu.tree_allclose(1, 2)

@parameterized.product(
size=[1, 10, 100, 1000],
dtype=[
jnp.int4,
jnp.int8,
jnp.int16,
jnp.int32,
jnp.uint4,
jnp.uint8,
jnp.uint16,
jnp.uint32,
jnp.float16,
jnp.float32,
jnp.bfloat16,
],
)
def test_tree_bits(self, size, dtype):
tree = jnp.zeros(size, dtype)
bits = {
jnp.int4: 4,
jnp.int8: 8,
jnp.int16: 16,
jnp.int32: 32,
jnp.uint4: 4,
jnp.uint8: 8,
jnp.uint16: 16,
jnp.uint32: 32,
jnp.float16: 16,
jnp.float32: 32,
jnp.bfloat16: 16,
}[dtype]
assert tu.tree_bits(tree) == bits * size

def test_tree_conj(self):
expected = jnp.conj(self.array_a)
got = tu.tree_conj(self.array_a)
Expand Down
Loading