Skip to content

Commit b9cc05f

Browse files
author
Flax Authors
committed
Merge pull request #4968 from google:get-set-metadata
PiperOrigin-RevId: 811527483
2 parents 609a0ab + 0e2b0f3 commit b9cc05f

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

flax/nnx/variablelib.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,51 @@ def type(self):
347347
def has_ref(self) -> bool:
348348
return is_array_ref(self.raw_value)
349349

350-
def get_metadata(self):
351-
return self._var_metadata
350+
@tp.overload
351+
def get_metadata(self) -> dict[str, tp.Any]: ...
352+
@tp.overload
353+
def get_metadata(self, name: str) -> tp.Any: ...
354+
def get_metadata(self, name: str | None = None):
355+
"""Get metadata for the Variable.
356+
357+
Args:
358+
name: The key of the metadata element to get. If not provided, returns
359+
the full metadata dictionary.
360+
"""
361+
if name is None:
362+
return self._var_metadata
363+
return self._var_metadata[name]
364+
365+
@tp.overload
366+
def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ...
367+
@tp.overload
368+
def set_metadata(self, **metadata: tp.Any) -> None: ...
369+
def set_metadata(self, *args, **kwargs) -> None:
370+
"""Set metadata for the Variable.
371+
372+
`set_metadata` can be called in two ways:
373+
374+
1. By passing a dictionary of metadata as the first argument, this will replace
375+
the entire Variable's metadata.
376+
2. By using keyword arguments, these will be merged into the existing Variable's
377+
metadata.
378+
"""
379+
if not self._trace_state.is_valid():
380+
raise errors.TraceContextError(
381+
f'Cannot mutate {type(self).__name__} from a different trace level'
382+
)
383+
if not (bool(args) ^ bool(kwargs)):
384+
raise TypeError(
385+
'set_metadata takes either a single dict argument or keyword arguments'
386+
)
387+
if len(args) == 1:
388+
self._var_metadata = args[0]
389+
elif kwargs:
390+
self._var_metadata.update(kwargs)
391+
else:
392+
raise TypeError(
393+
f'set_metadata takes either 1 argument or 1 or more keyword arguments, got args={args}, kwargs={kwargs}'
394+
)
352395

353396
def copy_from(self, other: Variable[A]) -> None:
354397
if type(self) is not type(other):

tests/nnx/variable_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,5 +118,17 @@ def test_mutable_array_context(self):
118118
self.assertFalse(nnx.using_refs())
119119
self.assertFalse(nnx.is_array_ref(v.raw_value))
120120

121+
def test_get_set_metadata(self):
122+
v = nnx.Variable(jnp.array(1.0))
123+
self.assertEqual(v.get_metadata(), {})
124+
v.set_metadata(a=1, b=2)
125+
self.assertEqual(v.get_metadata('a'), 1)
126+
self.assertEqual(v.get_metadata('b'), 2)
127+
v.set_metadata({'b': 3, 'c': 4})
128+
self.assertEqual(v.get_metadata(), {'b': 3, 'c': 4})
129+
self.assertEqual(v.get_metadata('b'), 3)
130+
self.assertEqual(v.get_metadata('c'), 4)
131+
132+
121133
if __name__ == '__main__':
122134
absltest.main()

0 commit comments

Comments
 (0)