diff --git a/ivy/functional/backends/jax/module.py b/ivy/functional/backends/jax/module.py index ba078d0b02594..cde2b841d7ad1 100644 --- a/ivy/functional/backends/jax/module.py +++ b/ivy/functional/backends/jax/module.py @@ -698,8 +698,36 @@ def _addindent(s_, numSpaces): s = first + "\n" + s return s + @staticmethod + def _contains_array(value: Any) -> bool: + try: + for leaf in tree.tree_leaves(value): + if isinstance(leaf, (jax.Array, getattr(nnx, "Variable", object), nnx.Param)): + return True + return False + except Exception: + return False + + @staticmethod + def _contains_module(value: Any) -> bool: + # recursively inspect common python containers for modules + try: + if isinstance(value, (Module, nnx.Module)): + return True + if isinstance(value, (list, tuple)): + return any(ModelHelpers._contains_module(v) for v in value) + if isinstance(value, dict): + return any(ModelHelpers._contains_module(v) for v in value.values()) + except Exception: + return False + return False + class Module(nnx.Module, ModelHelpers, TorchModuleHelpers): + # mark private containers that hold arrays as data for nnx 0.12 strict pytree + _v: nnx.Data[dict] + _buffers: nnx.Data[dict] + _module_dict: nnx.Data[dict] _build_mode = None _with_partial_v = None _store_vars = True @@ -739,7 +767,8 @@ def __init__( self._store_vars = store_vars self._built = False self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() + # keep internal containers as plain dicts; type annotations mark them as data + self._v = v if isinstance(v, dict) else (v if v is not None else {}) self._buffers = dict(buffers or {}) self._module_dict = module_dict if module_dict is not None else dict() self._args = args @@ -768,11 +797,17 @@ def build( return def register_buffer(self, name: str, value: jax.Array, persistent: bool = False): - self._buffers.update({name: value}) + if self._buffers is None: + # initialize buffers container on first use + self.__dict__["_buffers"] = {} + self._buffers[name] = value return value def register_parameter(self, name: str, value: jax.Array): - self._v.update({name: value}) + if self._v is None: + # initialize parameters container on first use + self.__dict__["_v"] = {} + self._v[name] = value def train(self, mode: bool = True): for _, module in self.named_modules(): @@ -945,19 +980,20 @@ def training(self, value): @property def v(self): - return self._v + return self._v if self._v is not None else {} @property def buffers(self): - return self._buffers + return self._buffers if self._buffers is not None else {} @property def state_dict(self): - return {**self.v, **self.buffers} + # ensure we return a plain mapping + return {**dict(self.v), **dict(self.buffers)} @property def module_dict(self): - return self._module_dict + return self._module_dict if self._module_dict is not None else {} # Dunder Methods # # ---------------# @@ -982,8 +1018,18 @@ def __getattr__(self, name): if name in _dict: return _dict[name] - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] + elif "_v" in _dict and _dict["_v"]: + container = _dict["_v"] + try: + # support nnx.Dict which exposes keys as attributes + sentinel = object() + val = getattr(container, name, sentinel) + if val is not sentinel: + return val + except Exception: + pass + if isinstance(container, dict) and name in container: + return container[name] return super().__getattribute__(name) @@ -1009,6 +1055,8 @@ def __setattr__(self, name, value): _dict[name] = value # compute the module dict + if "_module_dict" not in self.__dict__ or self.__dict__.get("_module_dict") is None: + object.__setattr__(self, "_module_dict", {}) self._compute_module_dict() obj_to_search = ( @@ -1066,19 +1114,51 @@ def __setattr__(self, name, value): return elif isinstance(value, jax.Array): _dict = getattr(self, "__dict__", None) - if _dict and name in _dict: - orig_value = _dict[name] - if isinstance(orig_value, nnx.Param): - new_value = nnx.Param(value) - _dict[name] = new_value - self.register_parameter(name, new_value) - object.__setattr__(self, name, new_value) - return - + # always wrap Arrays as nnx.Param to satisfy strict pytree rules + new_value = nnx.Param(value) if _dict: - _dict[name] = value - object.__setattr__(self, name, value) + _dict[name] = new_value + self.register_parameter(name, new_value) + object.__setattr__(self, name, new_value) return + elif value is None: + # keep private/internal attributes static None to avoid data tags inside Pytrees + if name.startswith("_"): + return object.__setattr__(self, name, None) + # for public attrs, explicitly mark as data + _dict = getattr(self, "__dict__", None) + data_value = nnx.data(None) + if _dict: + _dict[name] = data_value + object.__setattr__(self, name, data_value) + return + elif isinstance(value, list): + # wrap lists only if they contain arrays or modules + if ModelHelpers._contains_array(value) or ModelHelpers._contains_module(value): + _dict = getattr(self, "__dict__", None) + list_value = nnx.List(value) + if _dict: + _dict[name] = list_value + object.__setattr__(self, name, list_value) + return + elif isinstance(value, tuple): + # tuples remain static unless they contain arrays or modules + if ModelHelpers._contains_array(value) or ModelHelpers._contains_module(value): + _dict = getattr(self, "__dict__", None) + list_value = nnx.List(list(value)) + if _dict: + _dict[name] = list_value + object.__setattr__(self, name, list_value) + return + elif isinstance(value, dict): + # wrap dicts only if they contain arrays or modules + if ModelHelpers._contains_array(value) or ModelHelpers._contains_module(value): + _dict = getattr(self, "__dict__", None) + dict_value = nnx.Dict(value) + if _dict: + _dict[name] = dict_value + object.__setattr__(self, name, dict_value) + return else: try: obj_to_search = getattr(self, name) @@ -1107,6 +1187,8 @@ def __setattr__(self, name, value): submod.register_buffer(b_key, value) # finally update the module dict + if "_module_dict" not in self.__dict__ or self.__dict__.get("_module_dict") is None: + object.__setattr__(self, "_module_dict", {}) self._module_dict[name] = value # TODO: super().__setattr__ leads to an error during jax.jit @@ -1131,14 +1213,57 @@ def _find_variables( if isinstance(obj, (Module)) and obj is not self: fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" if not obj._built and without_initialisation: + obj_kwargs = obj._kwargs if isinstance(obj._kwargs, dict) else {} return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs + *obj._args, dynamic_backend=self._dynamic_backend, **obj_kwargs ) + obj_kwargs = obj._kwargs if isinstance(obj._kwargs, dict) else {} return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs + *obj._args, dynamic_backend=obj._dynamic_backend, **obj_kwargs ) elif isinstance(obj, nnx.Module) and obj is not self: - return obj.v if trainable else obj.buffers + # Some nnx containers (e.g., nnx.Dict/nnx.List) are Pytree Modules but + # do not expose `.v`/`.buffers`. Treat them as plain containers. + try: + return obj.v if trainable else obj.buffers + except AttributeError: + # handle container-like nnx modules here + # nnx.Dict or dict-like + is_nnx_dict = hasattr(nnx, "Dict") and isinstance(obj, nnx.Dict) + is_nnx_list = hasattr(nnx, "List") and isinstance(obj, nnx.List) + if is_nnx_dict or isinstance(obj, dict): + try: + items_iter = obj.items() if hasattr(obj, "items") else dict(obj).items() + except Exception: + return {} + for k, v_child in items_iter: + ret = self._find_variables( + obj=v_child, + without_initialisation=without_initialisation, + _visited=_visited, + trainable=trainable, + ) + if ret: + vs[k[1:] if isinstance(k, str) and k and k[0] == "_" else k] = ret + return vs + # nnx.List/list/tuple-like + if is_nnx_list or isinstance(obj, (list, tuple)): + try: + seq = list(obj) + except Exception: + return {} + for i, v_child in enumerate(seq): + ret = self._find_variables( + obj=v_child, + without_initialisation=without_initialisation, + _visited=_visited, + trainable=trainable, + ) + if ret: + vs[f"v{str(i)}"] = ret + return vs + # unknown nnx.Module without v/buffers + return {} elif isinstance(obj, (list, tuple)): for i, v in enumerate(obj): diff --git a/ivy/functional/frontends/torch/nn/modules/module.py b/ivy/functional/frontends/torch/nn/modules/module.py index 5c0f3d9de8723..97bc37b50be76 100644 --- a/ivy/functional/frontends/torch/nn/modules/module.py +++ b/ivy/functional/frontends/torch/nn/modules/module.py @@ -449,16 +449,43 @@ def __getattribute__(self, name: str) -> Any: return super().__getattribute__(name) if "_module_dict" in self.__dict__: modules = self.__dict__["_module_dict"] - if name in modules: - return modules[name] + try: + if isinstance(modules, dict): + if name in modules: + return modules[name] + else: + sentinel = object() + val = getattr(modules, name, sentinel) + if val is not sentinel: + return val + except Exception: + pass if "_buffers" in self.__dict__: buffers = self.__dict__["_buffers"] - if name in buffers: - return buffers[name] + try: + if isinstance(buffers, dict): + if name in buffers: + return buffers[name] + else: + sentinel = object() + val = getattr(buffers, name, sentinel) + if val is not sentinel: + return val + except Exception: + pass if "_v" in self.__dict__: v = self.__dict__["_v"] - if name in v: - return v[name] + try: + if isinstance(v, dict): + if name in v: + return v[name] + else: + sentinel = object() + val = getattr(v, name, sentinel) + if val is not sentinel: + return val + except Exception: + pass # Adding this attribute mapping s.t if someone tries # to retrieve self._modules/self._parameters, we # can handle that here @@ -478,7 +505,15 @@ def remove_from(*dicts_or_sets): d.discard(name) params = self.__dict__.get("_v") - if params is not None and name in params and isinstance(value, Parameter): + def _has_key_like(container, key): + if isinstance(container, dict): + return key in container + try: + return hasattr(container, key) + except Exception: + return False + + if params is not None and _has_key_like(params, name) and isinstance(value, Parameter): remove_from(self.__dict__, self._buffers, self._module_dict) self.register_parameter(name, value) super().__setattr__(name, value) @@ -513,9 +548,18 @@ def __repr__(self): def __dir__(self): module_attrs = dir(self.__class__) attrs = list(self.__dict__.keys()) - parameters = list(self._v.keys()) - modules = list(self._module_dict.keys()) - buffers = list(self._buffers.keys()) + def _keys(container): + try: + return list(container.keys()) + except Exception: + try: + return [k for k in dir(container) if not k.startswith("_")] + except Exception: + return [] + + parameters = _keys(self._v) + modules = _keys(self._module_dict) + buffers = _keys(self._buffers) keys = module_attrs + attrs + parameters + modules + buffers # Eliminate attrs that are not legal Python variable names