-
I'm unable to call Code: import jax
import jax.numpy as jnp
from flax import linen as nn
model = nn.WeightNorm(nn.Conv(features=64, kernel_size=3, strides=2))
# model = nn.Conv(features=64, kernel_size=3, strides=2)
x = jnp.ones((8, 44100, 1))
key = jax.random.PRNGKey(0)
params = model.init(key, x)
print(model.tabulate(key, x, depth=4,
console_kwargs={'width': 180},
column_kwargs={'width': 180},
compute_flops=True,
compute_vjp_flops=True,
))
y = model.apply(params, x)
print(y.shape) Output:
Maybe the answer involves specifying For context, the model that I want is based off the following PyTorch code: import torch.nn as nn
from torch.nn.utils import weight_norm
model = weight_norm(nn.Conv1d(1, 64, 3)) |
Beta Was this translation helpful? Give feedback.
Answered by
DBraun
Mar 25, 2024
Replies: 1 comment
-
Resolved in #3735 |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
DBraun
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Resolved in #3735