Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using pretrained weights with Ninjax #142

Open
truncs opened this issue Jun 26, 2024 · 0 comments
Open

Using pretrained weights with Ninjax #142

truncs opened this issue Jun 26, 2024 · 0 comments

Comments

@truncs
Copy link

truncs commented Jun 26, 2024

Lets say you want to pass image obs with pretrained image encoders. How would that work with Ninjax. I have a class that loads a ViT model below that I had don't add to self.modules in Agent so that the parameters are not optimized. But it errors out at self.get('flax') saying 'ValueError: not enough values to unpack (expected at least 1, got 0)'

class ViT(nj.Module):
  num_heads: int = 6
  embed_dim: int = 384
  mlp_ratio: int = 4
  img_size: int = 70

  def __call__(self, x, training=False):
    if nj.creating():
      from .vit import ViT
      from .dino_weights import load_vit_params
      vit_cls = functools.partial(
        ViT,
        num_heads=self.num_heads,
        embed_dim=self.embed_dim,
        mlp_ratio=self.mlp_ratio,
        depth=12,
        img_size=self.img_size,
      )
      vit_def = vit_cls()
      state = vit_def.init(nj.seed(), x)

      vit = torch.hub.load('pytorch/vision:v0.14.0', 'vit_b_16', pretrained=True)

      state['params'] = load_vit_params(state['params'], vit)
      self.module = vit_def
      self.put('flax', state)
    state = self.get('flax')
    return self.module.apply(state, x, training)
/zfs/aditya/workspace/dreamerv3/dreamerv3/nets.py:309 in __call__                                  │
│                                                                                                  │
│    306 │     state['params'] = load_vit_params(state['params'], vit)                             │
│    307 │     self.module = vit_def                                                               │
│    308 │     self.put('flax', state)                                                             │
│ ❱  309 │   state = self.get('flax')                                                              │
│    310 │   return self.module.apply(state, x, training)                                          │
│    311                                                                                           │
│    312 class ViT(nj.Module):                                                                     │
│                                                                                                  │
│ /zfs/aditya/workspace/dreamerv3/dreamerv3/ninjax.py:467 in wrapper                               │
│                                                                                                  │
│   464   def wrapper(self, *args, **kwargs):                                                      │
│   465 │   with scope(self._path, absolute=True):                                                 │
│   466 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 467 │   │   return method(self, *args, **kwargs)                                               │
│   468   return wrapper                                                                           │
│   469                                                                                            │
│   470                                                                                            │
│                                                                                                  │
│ /zfs/aditya/workspace/dreamerv3/dreamerv3/ninjax.py:496 in get                                   │
│                                                                                                  │
│   493 │     return self._submodules[name]                                                        │
│   494 │   if path in context():                                                                  │
│   495 │     return context()[path]                                                               │
│ ❱ 496 │   ctor, *args = args                                                                     │
│   497 │   if 'name' in inspect.signature(ctor).parameters:                                       │
│   498 │     kwargs['name'] = name                                                                │
│   499 │   value = ctor(*args, **kwargs)                                                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: not enough values to unpack (expected at least 1, got 0)

It seems that the tracing works fine but after the initial checkpoint is saved and then executed with the policy this just breaks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant