diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 136fbf8..f658ca6 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -18,7 +18,6 @@ def make_transform_adapter( import flowjax import flowjax.train import flowjax.flows - from flowjax.bijections import mvscale import optax import traceback from paramax import Parameterize, unwrap @@ -164,6 +163,8 @@ def make_layer(key, is_last=False): flow = flowjax.flows._add_default_permute(coupling, n_dim, key_permute) if scale_layer: + from flowjax.bijections import mvscale + bijections = list(flow.bijections) bijections.append(mvscale.MvScale4(jnp.ones(n_dim) * 1e-5)) # bijections.append(mvscale.MvScale3(jnp.ones(n_dim) * 1e-5))