Skip to content

Conversation

@josephdviviano
Copy link
Collaborator

@josephdviviano josephdviviano commented Dec 11, 2025

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

  • Added debug flag to all envs, which get passed internally to states / actions.
  • This replaces validate_actions.
  • Moves all sorts of torch.compile() breaking ops outside of the hot path when debug=False (e.g., asserts).
  • In theory should slightly speed up non-compiled code as well.
  • Added benchmark to test current configuration.
  • Current changes are not sufficient to actually speed up torch.compile performance in these small benchmarks, but they do speed up non-compiled performance marginally, and I think this is a good first step towards general optimization of the speed of the library (without introducing confusing overhead).
  • See the benchmark in examples/notebooks.
  • I also made the testing of the various example scripts complete, we were missing a few, and as a consequence fixed a few bugs we missed.

I know the diffs are large in this one but a lot of it is simple changes to function calls.

The basic idea is a user would test their code with debug=True, then flip it off for large scale training runs.

@josephdviviano josephdviviano self-assigned this Dec 11, 2025
Copy link
Collaborator

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

All asserts are already disabled with -O, so we don't need if self.debug in most places

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a notebook tho :D
Maybe move it in examples?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's used by the notebook, i don't think it's really an example either - maybe we can do a PR where we reorganize these folders.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i ended up moving these to their own misc folder so keep them distinct from the tutorial notebooks

Comment on lines +28 to +44
def _assert_factory_accepts_debug(factory: Callable, factory_name: str) -> None:
"""Ensure the factory can accept a debug kwarg (explicit or via **kwargs)."""
try:
sig = inspect.signature(factory)
except (TypeError, ValueError):
return

params = sig.parameters
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()):
return
debug_param = params.get("debug")
if debug_param is not None:
return
raise TypeError(
f"{factory_name} must accept a `debug` keyword argument (or **kwargs) "
"to support debug-gated States construction."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this? The method will anyway throw an error if we pass debug while we can't (as we always use kwargs)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is meant to help env designers remember to define the debug wiring. It's true it's not needed but i think it's useful for the users of the library.

Comment on lines +774 to +779
exit_mask = torch.zeros(
self.batch_shape + (1,), device=cond.device, dtype=cond.dtype
)

if not allow_exit:
exit_mask.fill_(True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might simply be: torch.full(..., not allow_exit, ...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure why I went this way in retrospect, but it had to do with torch.compile compatibility.

@josephdviviano
Copy link
Collaborator Author

We need to explicitly move those paths under a conditional if you want to use torch.compile

@josephdviviano josephdviviano merged commit b1bb28c into master Dec 11, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants