[Question] How to tie/share weights in a flax neural network #1264
Replies: 4 comments 13 replies
-
@jheek -- do you have a example handy of how a variant of |
Beta Was this translation helpful? Give feedback.
-
For now I would use something like this:
|
Beta Was this translation helpful? Give feedback.
-
you can use |
Beta Was this translation helpful? Give feedback.
-
@patrickvonplaten @jheek Check out my solution. It works pretty well and looks quite idiomatic. def tie(target, mappings, collections='params', transpose=True):
"""Tie weights of `target` module` enumerated in `mappings` from
`collections`.
Example::
>>> class Model(nn.Module):
... @nn.compact
... def __call__(self, xs):
... ys = nn.Embed(10, 8)(xs)
... zs = nn.Dense(10)(ys)
... return zs
...
>>> rules = {('params', 'Embed_0', 'embedding'):
... ('params', 'Dense_0', 'kernel')}
>>> TiedModel = tie(Model, rules)
>>> model = TiedModel()
>>> variables = model.init(jax.random.PRNGKey(42),
... jnp.arange(6).reshape(2, 3))
Args:
target: the module or function to be transformed.
mappings: weight sharing rules.
collections: the collection(s) to be transformed.
transpose: transpose tied weights or not.
Returns:
a wrapped version of ``target`` with shared weights.
"""
if isinstance(mappings, dict):
mappings = [*mappings.items()]
def tie_in(variables):
variables = flatten_dict(variables)
for src, dst in mappings:
if transpose:
variables[dst] = variables[src].T
else:
variables[dst] = variables[src]
return unflatten_dict(variables)
def tie_out(variables):
variables = flatten_dict(variables)
for _, dst in mappings:
variables.pop(dst, None)
return unflatten_dict(variables)
return nn.map_variables(target, collections, tie_in, tie_out, init=True) |
Beta Was this translation helpful? Give feedback.
-
Sorry this is not really a bug report but more a question. Lots of language models tie their word embedding matrix to the output logits matrix (This both saves quite some memory since vocab_sizes can be huge & can lead to better results). In PyTorch it's quite easy to do so - one can simply do the following:
This will simply make sure that both weights point to the same weight node in the graph and consequently the weights are tied in the network.
Is there a way to do this in Flax?
I tried to simple set the weights of
state
to each other, e.g.:but this doesn't seem to actually tie the weights (during gradient descent the weights are updated independently). Do you know if there is an elegant way to share weights in Flax?
Thanks a lot!
Beta Was this translation helpful? Give feedback.
All reactions