Skip to content

Modules got silently "reused" with hk.vmap #740

@jjyyxx

Description

@jjyyxx

I have to admit that I do not fully understand the necessity of hk.vmap instead of jax.vmap. Nevertheless, when I need to vmap something, I would use hk.vmap whenever the inner function contains calls to haiku modules. This works OK, until I debug the bad performance of a transformer model. Things boils down to the following snippet

import jax, haiku as hk
jax.config.update("jax_platforms", "cpu")

def f1(x):
    def g(x):
        return hk.Linear(2)(x)
    x = g(x)
    x = g(x)
    return x

def f2(x):
    def g(x):
        return hk.Linear(2)(x)
    x = jax.vmap(g)(x)
    x = jax.vmap(g)(x)
    return x

key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (1, 2))
w1 = hk.transform(f1).init(key, x)
w2 = hk.transform(f2).init(key, x)
print("w1:", w1.keys())
print("w2:", w2.keys())
# w1: dict_keys(['linear', 'linear_1'])
# w2: dict_keys(['linear'])

It turns out that when g is vmapped, modules created inside g would reuse a previously created module. In some cases, errors would happen immediately due to incompatible shape, but in other cases (for me, transformer layers have quite consistent shapes), things went wrong silently.

My question: Is this behavior intended? Could the documentation be improved on this topic? Or am I missing something?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions