-
Notifications
You must be signed in to change notification settings - Fork 276
Add tree_bits function. #1341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add tree_bits function. #1341
Conversation
6b8a952 to
6d0d300
Compare
|
I like the idea! Can we check for the presence of |
6d0d300 to
a24fe51
Compare
|
@rdyro Are you suggesting adding a flag |
e206e87 to
7956e45
Compare
|
I was thinking of using Python's I'm not sure about the name of this function now, |
4d48682 to
4bc029b
Compare
|
Hmmm, Can you test this on the following case: fn = lambda: tree_bytes(jnp.ones((1024,), dtype=jnp.int32))
fn()
jax.jit(fn)()It should report 512 in both cases |
|
@rdyro Shouldn't it be 4096? An int32 has 32 bits / 8 bits per byte = 4 bytes. And 1024 * 4 = 4096. I get 4096 in both cases. |
4bc029b to
780739d
Compare
|
Should have been fn = lambda: tree_bytes(jnp.ones((1024,), dtype=jnp.int4))
fn()
jax.jit(fn)() |
780739d to
85ef02b
Compare
|
For |
|
96697d4 to
53e365f
Compare
|
This still works for the most common dtypes, so let's go with the I added a warning to the function's docstring. Feel free to reword this warning. |
|
Let’s wait until we have a solution on the JAX side, I’ll keep the PR open for now. |
|
Will you be opening an issue for that? |
|
I'll follow up on this internally. For the JAX issue you opened, can you explicitly ask about the use case of getting int4 byte size under jit? |
53e365f to
44714c9
Compare
|
@rdyro In your opinion, what would be the ideal output of Perhaps this suggests that, more generally, we ought to be counting in bits rather than bytes? |
Currently on CPU, GPU and TPU the byte size of However, it's possible that a platform doesn't guarantee packing for int4, I don't think it's possible to have a jit-compatible function counting bytes currently. I believe users interested in RAM/VRAM size should use a custom lambda with I believe fp4 will suffer from the same problem as int4, I'm not sure there's a difference between integer or floating point representations. We'd typically make an assumption that 1 byte is 8 bits, so it shouldn't change the calculation and it doesn't solve the packing representation problem of fp4/int4. Perhaps this function could be I'd prefer not to merge this function into optax. I find the haiku version actively confusing when working with int4 quantized models. |
|
What does @vroulet think? A potential alternative is to say that we're interested only in how much information there is in a pytree (not how it will be packed or laid out on devices, which is hardware-dependent). We can do this by counting bits. A |
89d298f to
f3944d4
Compare
f3944d4 to
cbbd7ca
Compare
|
@rdyro I fixed the issue by switching to a |
cbbd7ca to
ad15aa0
Compare
ad15aa0 to
ebe94c1
Compare
Add a
tree_bytesfunction to the tree utilities, analogous to Haiku's.For context, see #1321 (comment).