-
Notifications
You must be signed in to change notification settings - Fork 53
Optimize states and actions compute paths with debug=False mode in Env
#449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…t new env developers, augmented documentation
…ts to pass debug to the env properly (and to handle this properly in the tests) - fixed one failing test.
younik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
All asserts are already disabled with -O, so we don't need if self.debug in most places
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not a notebook tho :D
Maybe move it in examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's used by the notebook, i don't think it's really an example either - maybe we can do a PR where we reorganize these folders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i ended up moving these to their own misc folder so keep them distinct from the tutorial notebooks
| def _assert_factory_accepts_debug(factory: Callable, factory_name: str) -> None: | ||
| """Ensure the factory can accept a debug kwarg (explicit or via **kwargs).""" | ||
| try: | ||
| sig = inspect.signature(factory) | ||
| except (TypeError, ValueError): | ||
| return | ||
|
|
||
| params = sig.parameters | ||
| if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()): | ||
| return | ||
| debug_param = params.get("debug") | ||
| if debug_param is not None: | ||
| return | ||
| raise TypeError( | ||
| f"{factory_name} must accept a `debug` keyword argument (or **kwargs) " | ||
| "to support debug-gated States construction." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this? The method will anyway throw an error if we pass debug while we can't (as we always use kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is meant to help env designers remember to define the debug wiring. It's true it's not needed but i think it's useful for the users of the library.
| exit_mask = torch.zeros( | ||
| self.batch_shape + (1,), device=cond.device, dtype=cond.dtype | ||
| ) | ||
|
|
||
| if not allow_exit: | ||
| exit_mask.fill_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might simply be: torch.full(..., not allow_exit, ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not sure why I went this way in retrospect, but it had to do with torch.compile compatibility.
|
We need to explicitly move those paths under a conditional if you want to use torch.compile |
Description
debugflag to all envs, which get passed internally to states / actions.validate_actions.torch.compile()breaking ops outside of the hot path whendebug=False(e.g., asserts).torch.compileperformance in these small benchmarks, but they do speed up non-compiled performance marginally, and I think this is a good first step towards general optimization of the speed of the library (without introducing confusing overhead).examples/notebooks.I know the diffs are large in this one but a lot of it is simple changes to function calls.
The basic idea is a user would test their code with debug=True, then flip it off for large scale training runs.