-
Notifications
You must be signed in to change notification settings - Fork 253
Open
Description
when I use jax==0.6.1 and jaxlib==0.6.1 and haiku==0.0.14. When I lanuch my script, I found the model's init() method will be call when I use apply() method and get a strange output. My script is:
# minimal_test.py
import jax
import jax.numpy as jnp
import haiku as hk
# Simplified for testing
class SimpleMessage(hk.Module):
def __init__(self, name=None):
print(f"SimpleMessage __init__ called, name: {name}")
super().__init__(name=name)
self.w = hk.get_parameter("w", shape=[1], init=jnp.ones)
def __call__(self, x):
return x * self.w
class SimpleModel(hk.Module):
def __init__(self, num_layers, name=None):
print(f"SimpleModel __init__ called, name: {name}")
super().__init__(name=name)
self.layers = [SimpleMessage(name=f"layer_{i}") for i in range(num_layers)]
def __call__(self, x_batch): # Assume x_batch is a single JAX array
x = x_batch
for layer in self.layers:
x = layer(x)
return x
# In minimal_test.py, adapt forward_test_fn and model_config
model_config_simple = {"num_layers": 2}
def forward_test_fn_simple(batch_input_array): # Assume batch is just a JAX array now
model = SimpleModel(**model_config_simple)
return model(batch_input_array)
transformed_simple = hk.transform(forward_test_fn_simple)
# Adapt sample_batch_data to be a simple JAX array
sample_input_array = jnp.ones((3,1)) # Example: batch of 3, feature dim 1
rng = jax.random.PRNGKey(0)
print("--- Before transformed_simple.init ---")
params_simple = transformed_simple.init(rng, sample_input_array)
print("--- After transformed_simple.init / Before first apply ---")
outputs1_simple = transformed_simple.apply(params_simple, rng, sample_input_array)
print("--- After first apply / Before second apply ---")
outputs2_simple = transformed_simple.apply(params_simple, rng, sample_input_array)
print("--- After second apply ---")
my output is
--- Before transformed_simple.init ---
SimpleModel __init__ called, name: None
SimpleMessage __init__ called, name: layer_0
SimpleMessage __init__ called, name: layer_1
--- After transformed_simple.init / Before first apply ---
SimpleModel __init__ called, name: None
SimpleMessage __init__ called, name: layer_0
SimpleMessage __init__ called, name: layer_1
--- After first apply / Before second apply ---
SimpleModel __init__ called, name: None
SimpleMessage __init__ called, name: layer_0
SimpleMessage __init__ called, name: layer_1
--- After second apply ---
Metadata
Metadata
Assignees
Labels
No labels