Skip to content

maybe a wrong loop init bug #825

@SyntaxSmith

Description

@SyntaxSmith

when I use jax==0.6.1 and jaxlib==0.6.1 and haiku==0.0.14. When I lanuch my script, I found the model's init() method will be call when I use apply() method and get a strange output. My script is:

# minimal_test.py
import jax
import jax.numpy as jnp
import haiku as hk


# Simplified for testing
class SimpleMessage(hk.Module):
    def __init__(self, name=None):
        print(f"SimpleMessage __init__ called, name: {name}")
        super().__init__(name=name)
        self.w = hk.get_parameter("w", shape=[1], init=jnp.ones)

    def __call__(self, x):
        return x * self.w

class SimpleModel(hk.Module):
    def __init__(self, num_layers, name=None):
        print(f"SimpleModel __init__ called, name: {name}")
        super().__init__(name=name)
        self.layers = [SimpleMessage(name=f"layer_{i}") for i in range(num_layers)]

    def __call__(self, x_batch): # Assume x_batch is a single JAX array
        x = x_batch
        for layer in self.layers:
            x = layer(x)
        return x

# In minimal_test.py, adapt forward_test_fn and model_config
model_config_simple = {"num_layers": 2}
def forward_test_fn_simple(batch_input_array): # Assume batch is just a JAX array now
    model = SimpleModel(**model_config_simple)
    return model(batch_input_array)

transformed_simple = hk.transform(forward_test_fn_simple)

# Adapt sample_batch_data to be a simple JAX array
sample_input_array = jnp.ones((3,1)) # Example: batch of 3, feature dim 1

rng = jax.random.PRNGKey(0)
print("--- Before transformed_simple.init ---")
params_simple = transformed_simple.init(rng, sample_input_array)
print("--- After transformed_simple.init / Before first apply ---")
outputs1_simple = transformed_simple.apply(params_simple, rng, sample_input_array)
print("--- After first apply / Before second apply ---")
outputs2_simple = transformed_simple.apply(params_simple, rng, sample_input_array)
print("--- After second apply ---")

my output is

--- Before transformed_simple.init ---
SimpleModel __init__ called, name: None
SimpleMessage __init__ called, name: layer_0
SimpleMessage __init__ called, name: layer_1
--- After transformed_simple.init / Before first apply ---
SimpleModel __init__ called, name: None
SimpleMessage __init__ called, name: layer_0
SimpleMessage __init__ called, name: layer_1
--- After first apply / Before second apply ---
SimpleModel __init__ called, name: None
SimpleMessage __init__ called, name: layer_0
SimpleMessage __init__ called, name: layer_1
--- After second apply ---

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions