Skip to content

Conversation

@carlosgmartin
Copy link
Contributor

Add a tree_bytes function to the tree utilities, analogous to Haiku's.

For context, see #1321 (comment).

@rdyro
Copy link
Collaborator

rdyro commented Jun 20, 2025

I like the idea!

Can we check for the presence of jnp.ones(10).on_device_size_in_bytes() on the object, for int4 arrays itemsize=1, but on device size is half a byte per element.

@carlosgmartin
Copy link
Contributor Author

@rdyro Are you suggesting adding a flag on_device to tree_bytes that replaces jnp.asarray(leaf).nbytes with jnp.asarray(leaf).on_device_size_in_bytes()?

@carlosgmartin carlosgmartin force-pushed the tree_bytes branch 2 times, most recently from e206e87 to 7956e45 Compare June 21, 2025 23:20
@rdyro
Copy link
Collaborator

rdyro commented Jun 22, 2025

I was thinking of using Python's hasattr, however, this will fail for tracers that are backend independent and will report they don't have on_device_size_in_bytes.

I'm not sure about the name of this function now, size * itemsize does not correctly report the byte size and on_device_size_in_bytes is not available on tracers.

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Jun 22, 2025

I'm now using nbytes. That seems like the simplest, most straightforward approach. Let's go with that for now.

In the meantime, I left a question about on_device_size_in_bytes here.

@carlosgmartin carlosgmartin force-pushed the tree_bytes branch 2 times, most recently from 4d48682 to 4bc029b Compare June 22, 2025 18:34
@rdyro
Copy link
Collaborator

rdyro commented Jun 22, 2025

Hmmm, nbytes appears to just use size * itemsize.

Can you test this on the following case:

fn = lambda: tree_bytes(jnp.ones((1024,), dtype=jnp.int32))
fn()
jax.jit(fn)()

It should report 512 in both cases

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Jun 22, 2025

@rdyro Shouldn't it be 4096?

An int32 has 32 bits / 8 bits per byte = 4 bytes. And 1024 * 4 = 4096.

I get 4096 in both cases.

@rdyro
Copy link
Collaborator

rdyro commented Jun 22, 2025

Should have been int4, that's the whole problem here :/

fn = lambda: tree_bytes(jnp.ones((1024,), dtype=jnp.int4))
fn()
jax.jit(fn)()

@carlosgmartin
Copy link
Contributor Author

For int4, I get 1024 in both cases.

@rdyro
Copy link
Collaborator

rdyro commented Jun 22, 2025

For int4, I get 1024 in both cases.

.nbytes and .size * .itemsize both give the wrong number. Currently the only way to get the true byte size for all dtypes is to call .on_device_size_in_bytes() in eager mode.

@carlosgmartin carlosgmartin force-pushed the tree_bytes branch 2 times, most recently from 96697d4 to 53e365f Compare June 23, 2025 00:31
@carlosgmartin
Copy link
Contributor Author

This still works for the most common dtypes, so let's go with the nbytes-based implementation for now.

I added a warning to the function's docstring. Feel free to reword this warning.

@rdyro
Copy link
Collaborator

rdyro commented Jun 23, 2025

Let’s wait until we have a solution on the JAX side, I’ll keep the PR open for now.

@carlosgmartin
Copy link
Contributor Author

Will you be opening an issue for that?

@rdyro
Copy link
Collaborator

rdyro commented Jun 23, 2025

I'll follow up on this internally. For the JAX issue you opened, can you explicitly ask about the use case of getting int4 byte size under jit?

@carlosgmartin
Copy link
Contributor Author

@rdyro In your opinion, what would be the ideal output of jnp.ones(1, dtype=jnp.int4).nbytes, or whatever the equivalent to nbytes is? Seems ugly to mix float and int values.

Perhaps this suggests that, more generally, we ought to be counting in bits rather than bytes?

@rdyro
Copy link
Collaborator

rdyro commented Jun 25, 2025

@rdyro In your opinion, what would be the ideal output of jnp.ones(1, dtype=jnp.int4).nbytes, or whatever the equivalent to nbytes is? Seems ugly to mix float and int values.

Perhaps this suggests that, more generally, we ought to be counting in bits rather than bytes?

Currently on CPU, GPU and TPU the byte size of jnp.ones(1, dtype=jnp.int4) is 1, but the byte size jnp.ones(2, dtype=jnp.int4) is also 1 since it's packed.

However, it's possible that a platform doesn't guarantee packing for int4, I don't think it's possible to have a jit-compatible function counting bytes currently. I believe users interested in RAM/VRAM size should use a custom lambda with jax.tree.map to explicitly state their own assumptions.

I believe fp4 will suffer from the same problem as int4, I'm not sure there's a difference between integer or floating point representations.

We'd typically make an assumption that 1 byte is 8 bits, so it shouldn't change the calculation and it doesn't solve the packing representation problem of fp4/int4.

Perhaps this function could be tree_nbytes, but I'm really wary of exposing bytes counting in optax because it can be a source of issues and can be potentially confusing for users working with large models.

I'd prefer not to merge this function into optax. I find the haiku version actively confusing when working with int4 quantized models.

@carlosgmartin
Copy link
Contributor Author

What does @vroulet think?

A potential alternative is to say that we're interested only in how much information there is in a pytree (not how it will be packed or laid out on devices, which is hardware-dependent). We can do this by counting bits. A tree_bits function would be able to do so exactly, even for fractional-byte types like int4, int2, float4, bool, etc.

@carlosgmartin carlosgmartin force-pushed the tree_bytes branch 2 times, most recently from 89d298f to f3944d4 Compare July 11, 2025 18:03
@carlosgmartin carlosgmartin changed the title Add tree_bytes function. Add tree_bits function. Jul 11, 2025
@carlosgmartin
Copy link
Contributor Author

@rdyro I fixed the issue by switching to a tree_bits function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants