Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 148 additions & 23 deletions ivy/functional/backends/jax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,36 @@ def _addindent(s_, numSpaces):
s = first + "\n" + s
return s

@staticmethod
def _contains_array(value: Any) -> bool:
try:
for leaf in tree.tree_leaves(value):
if isinstance(leaf, (jax.Array, getattr(nnx, "Variable", object), nnx.Param)):
return True
return False
except Exception:
return False

@staticmethod
def _contains_module(value: Any) -> bool:
# recursively inspect common python containers for modules
try:
if isinstance(value, (Module, nnx.Module)):
return True
if isinstance(value, (list, tuple)):
return any(ModelHelpers._contains_module(v) for v in value)
if isinstance(value, dict):
return any(ModelHelpers._contains_module(v) for v in value.values())
except Exception:
return False
return False


class Module(nnx.Module, ModelHelpers, TorchModuleHelpers):
# mark private containers that hold arrays as data for nnx 0.12 strict pytree
_v: nnx.Data[dict]
_buffers: nnx.Data[dict]
_module_dict: nnx.Data[dict]
_build_mode = None
_with_partial_v = None
_store_vars = True
Expand Down Expand Up @@ -739,7 +767,8 @@ def __init__(
self._store_vars = store_vars
self._built = False
self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v)
self._v = v if v is not None else dict()
# keep internal containers as plain dicts; type annotations mark them as data
self._v = v if isinstance(v, dict) else (v if v is not None else {})
self._buffers = dict(buffers or {})
self._module_dict = module_dict if module_dict is not None else dict()
self._args = args
Expand Down Expand Up @@ -768,11 +797,17 @@ def build(
return

def register_buffer(self, name: str, value: jax.Array, persistent: bool = False):
self._buffers.update({name: value})
if self._buffers is None:
# initialize buffers container on first use
self.__dict__["_buffers"] = {}
self._buffers[name] = value
return value

def register_parameter(self, name: str, value: jax.Array):
self._v.update({name: value})
if self._v is None:
# initialize parameters container on first use
self.__dict__["_v"] = {}
self._v[name] = value

def train(self, mode: bool = True):
for _, module in self.named_modules():
Expand Down Expand Up @@ -945,19 +980,20 @@ def training(self, value):

@property
def v(self):
return self._v
return self._v if self._v is not None else {}

@property
def buffers(self):
return self._buffers
return self._buffers if self._buffers is not None else {}

@property
def state_dict(self):
return {**self.v, **self.buffers}
# ensure we return a plain mapping
return {**dict(self.v), **dict(self.buffers)}

@property
def module_dict(self):
return self._module_dict
return self._module_dict if self._module_dict is not None else {}

# Dunder Methods #
# ---------------#
Expand All @@ -982,8 +1018,18 @@ def __getattr__(self, name):
if name in _dict:
return _dict[name]

elif "_v" in _dict and name in _dict["_v"]:
return _dict["_v"][name]
elif "_v" in _dict and _dict["_v"]:
container = _dict["_v"]
try:
# support nnx.Dict which exposes keys as attributes
sentinel = object()
val = getattr(container, name, sentinel)
if val is not sentinel:
return val
except Exception:
pass
if isinstance(container, dict) and name in container:
return container[name]

return super().__getattribute__(name)

Expand All @@ -1009,6 +1055,8 @@ def __setattr__(self, name, value):
_dict[name] = value

# compute the module dict
if "_module_dict" not in self.__dict__ or self.__dict__.get("_module_dict") is None:
object.__setattr__(self, "_module_dict", {})
self._compute_module_dict()

obj_to_search = (
Expand Down Expand Up @@ -1066,19 +1114,51 @@ def __setattr__(self, name, value):
return
elif isinstance(value, jax.Array):
_dict = getattr(self, "__dict__", None)
if _dict and name in _dict:
orig_value = _dict[name]
if isinstance(orig_value, nnx.Param):
new_value = nnx.Param(value)
_dict[name] = new_value
self.register_parameter(name, new_value)
object.__setattr__(self, name, new_value)
return

# always wrap Arrays as nnx.Param to satisfy strict pytree rules
new_value = nnx.Param(value)
if _dict:
_dict[name] = value
object.__setattr__(self, name, value)
_dict[name] = new_value
self.register_parameter(name, new_value)
object.__setattr__(self, name, new_value)
return
elif value is None:
# keep private/internal attributes static None to avoid data tags inside Pytrees
if name.startswith("_"):
return object.__setattr__(self, name, None)
# for public attrs, explicitly mark as data
_dict = getattr(self, "__dict__", None)
data_value = nnx.data(None)
if _dict:
_dict[name] = data_value
object.__setattr__(self, name, data_value)
return
elif isinstance(value, list):
# wrap lists only if they contain arrays or modules
if ModelHelpers._contains_array(value) or ModelHelpers._contains_module(value):
_dict = getattr(self, "__dict__", None)
list_value = nnx.List(value)
if _dict:
_dict[name] = list_value
object.__setattr__(self, name, list_value)
return
elif isinstance(value, tuple):
# tuples remain static unless they contain arrays or modules
if ModelHelpers._contains_array(value) or ModelHelpers._contains_module(value):
_dict = getattr(self, "__dict__", None)
list_value = nnx.List(list(value))
if _dict:
_dict[name] = list_value
object.__setattr__(self, name, list_value)
return
elif isinstance(value, dict):
# wrap dicts only if they contain arrays or modules
if ModelHelpers._contains_array(value) or ModelHelpers._contains_module(value):
_dict = getattr(self, "__dict__", None)
dict_value = nnx.Dict(value)
if _dict:
_dict[name] = dict_value
object.__setattr__(self, name, dict_value)
return
else:
try:
obj_to_search = getattr(self, name)
Expand Down Expand Up @@ -1107,6 +1187,8 @@ def __setattr__(self, name, value):
submod.register_buffer(b_key, value)

# finally update the module dict
if "_module_dict" not in self.__dict__ or self.__dict__.get("_module_dict") is None:
object.__setattr__(self, "_module_dict", {})
self._module_dict[name] = value

# TODO: super().__setattr__ leads to an error during jax.jit
Expand All @@ -1131,14 +1213,57 @@ def _find_variables(
if isinstance(obj, (Module)) and obj is not self:
fn = "_build_and_return_v" if trainable else "_build_and_return_buffers"
if not obj._built and without_initialisation:
obj_kwargs = obj._kwargs if isinstance(obj._kwargs, dict) else {}
return lambda: getattr(obj, fn)(
*obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs
*obj._args, dynamic_backend=self._dynamic_backend, **obj_kwargs
)
obj_kwargs = obj._kwargs if isinstance(obj._kwargs, dict) else {}
return getattr(obj, fn)(
*obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs
*obj._args, dynamic_backend=obj._dynamic_backend, **obj_kwargs
)
elif isinstance(obj, nnx.Module) and obj is not self:
return obj.v if trainable else obj.buffers
# Some nnx containers (e.g., nnx.Dict/nnx.List) are Pytree Modules but
# do not expose `.v`/`.buffers`. Treat them as plain containers.
try:
return obj.v if trainable else obj.buffers
except AttributeError:
# handle container-like nnx modules here
# nnx.Dict or dict-like
is_nnx_dict = hasattr(nnx, "Dict") and isinstance(obj, nnx.Dict)
is_nnx_list = hasattr(nnx, "List") and isinstance(obj, nnx.List)
if is_nnx_dict or isinstance(obj, dict):
try:
items_iter = obj.items() if hasattr(obj, "items") else dict(obj).items()
except Exception:
return {}
for k, v_child in items_iter:
ret = self._find_variables(
obj=v_child,
without_initialisation=without_initialisation,
_visited=_visited,
trainable=trainable,
)
if ret:
vs[k[1:] if isinstance(k, str) and k and k[0] == "_" else k] = ret
return vs
# nnx.List/list/tuple-like
if is_nnx_list or isinstance(obj, (list, tuple)):
try:
seq = list(obj)
except Exception:
return {}
for i, v_child in enumerate(seq):
ret = self._find_variables(
obj=v_child,
without_initialisation=without_initialisation,
_visited=_visited,
trainable=trainable,
)
if ret:
vs[f"v{str(i)}"] = ret
return vs
# unknown nnx.Module without v/buffers
return {}

elif isinstance(obj, (list, tuple)):
for i, v in enumerate(obj):
Expand Down
64 changes: 54 additions & 10 deletions ivy/functional/frontends/torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,16 +449,43 @@ def __getattribute__(self, name: str) -> Any:
return super().__getattribute__(name)
if "_module_dict" in self.__dict__:
modules = self.__dict__["_module_dict"]
if name in modules:
return modules[name]
try:
if isinstance(modules, dict):
if name in modules:
return modules[name]
else:
sentinel = object()
val = getattr(modules, name, sentinel)
if val is not sentinel:
return val
except Exception:
pass
if "_buffers" in self.__dict__:
buffers = self.__dict__["_buffers"]
if name in buffers:
return buffers[name]
try:
if isinstance(buffers, dict):
if name in buffers:
return buffers[name]
else:
sentinel = object()
val = getattr(buffers, name, sentinel)
if val is not sentinel:
return val
except Exception:
pass
if "_v" in self.__dict__:
v = self.__dict__["_v"]
if name in v:
return v[name]
try:
if isinstance(v, dict):
if name in v:
return v[name]
else:
sentinel = object()
val = getattr(v, name, sentinel)
if val is not sentinel:
return val
except Exception:
pass
# Adding this attribute mapping s.t if someone tries
# to retrieve self._modules/self._parameters, we
# can handle that here
Expand All @@ -478,7 +505,15 @@ def remove_from(*dicts_or_sets):
d.discard(name)

params = self.__dict__.get("_v")
if params is not None and name in params and isinstance(value, Parameter):
def _has_key_like(container, key):
if isinstance(container, dict):
return key in container
try:
return hasattr(container, key)
except Exception:
return False

if params is not None and _has_key_like(params, name) and isinstance(value, Parameter):
remove_from(self.__dict__, self._buffers, self._module_dict)
self.register_parameter(name, value)
super().__setattr__(name, value)
Expand Down Expand Up @@ -513,9 +548,18 @@ def __repr__(self):
def __dir__(self):
module_attrs = dir(self.__class__)
attrs = list(self.__dict__.keys())
parameters = list(self._v.keys())
modules = list(self._module_dict.keys())
buffers = list(self._buffers.keys())
def _keys(container):
try:
return list(container.keys())
except Exception:
try:
return [k for k in dir(container) if not k.startswith("_")]
except Exception:
return []

parameters = _keys(self._v)
modules = _keys(self._module_dict)
buffers = _keys(self._buffers)
keys = module_attrs + attrs + parameters + modules + buffers

# Eliminate attrs that are not legal Python variable names
Expand Down
Loading