You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Lets say you want to pass image obs with pretrained image encoders. How would that work with Ninjax. I have a class that loads a ViT model below that I had don't add to self.modules in Agent so that the parameters are not optimized. But it errors out at self.get('flax') saying 'ValueError: not enough values to unpack (expected at least 1, got 0)'
class ViT(nj.Module):
num_heads: int = 6
embed_dim: int = 384
mlp_ratio: int = 4
img_size: int = 70
def __call__(self, x, training=False):
if nj.creating():
from .vit import ViT
from .dino_weights import load_vit_params
vit_cls = functools.partial(
ViT,
num_heads=self.num_heads,
embed_dim=self.embed_dim,
mlp_ratio=self.mlp_ratio,
depth=12,
img_size=self.img_size,
)
vit_def = vit_cls()
state = vit_def.init(nj.seed(), x)
vit = torch.hub.load('pytorch/vision:v0.14.0', 'vit_b_16', pretrained=True)
state['params'] = load_vit_params(state['params'], vit)
self.module = vit_def
self.put('flax', state)
state = self.get('flax')
return self.module.apply(state, x, training)
/zfs/aditya/workspace/dreamerv3/dreamerv3/nets.py:309 in __call__ │
│ │
│ 306 │ state['params'] = load_vit_params(state['params'], vit) │
│ 307 │ self.module = vit_def │
│ 308 │ self.put('flax', state) │
│ ❱ 309 │ state = self.get('flax') │
│ 310 │ return self.module.apply(state, x, training) │
│ 311 │
│ 312 class ViT(nj.Module): │
│ │
│ /zfs/aditya/workspace/dreamerv3/dreamerv3/ninjax.py:467 in wrapper │
│ │
│ 464 def wrapper(self, *args, **kwargs): │
│ 465 │ with scope(self._path, absolute=True): │
│ 466 │ with jax.named_scope(self._path.split('/')[-1]): │
│ ❱ 467 │ │ return method(self, *args, **kwargs) │
│ 468 return wrapper │
│ 469 │
│ 470 │
│ │
│ /zfs/aditya/workspace/dreamerv3/dreamerv3/ninjax.py:496 in get │
│ │
│ 493 │ return self._submodules[name] │
│ 494 │ if path in context(): │
│ 495 │ return context()[path] │
│ ❱ 496 │ ctor, *args = args │
│ 497 │ if 'name' in inspect.signature(ctor).parameters: │
│ 498 │ kwargs['name'] = name │
│ 499 │ value = ctor(*args, **kwargs) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: not enough values to unpack (expected at least 1, got 0)
It seems that the tracing works fine but after the initial checkpoint is saved and then executed with the policy this just breaks.
The text was updated successfully, but these errors were encountered:
Lets say you want to pass image obs with pretrained image encoders. How would that work with Ninjax. I have a class that loads a ViT model below that I had don't add to self.modules in Agent so that the parameters are not optimized. But it errors out at self.get('flax') saying 'ValueError: not enough values to unpack (expected at least 1, got 0)'
It seems that the tracing works fine but after the initial checkpoint is saved and then executed with the policy this just breaks.
The text was updated successfully, but these errors were encountered: