diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 476c9d71654..ef43829ef32 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -333,6 +333,47 @@ def reset(cls): cls.erase() +def _maybe_timeit(name): + """Return timeit context if not compiling, nullcontext otherwise. + + torch.compiler.is_compiling() returns True when inside a compiled region, + and timeit uses time.time() which dynamo cannot trace. + """ + if is_compiling(): + return nullcontext() + return timeit(name) + + +def _maybe_record_function(name): + """Return record_function context if not compiling, nullcontext otherwise. + + torch.autograd.profiler.record_function cannot be used inside compiled regions. + """ + from torch.autograd.profiler import record_function + + if is_compiling(): + return nullcontext() + return record_function(name) + + +def _maybe_record_function_decorator(name: str) -> Callable[[Callable], Callable]: + """Decorator version of :func:`_maybe_record_function`. + + This is preferred over sprinkling many context managers in hot code paths, + as it reduces Python overhead while keeping a useful profiler structure. + """ + + def decorator(fn: Callable) -> Callable: + @wraps(fn) + def wrapped(*args, **kwargs): + with _maybe_record_function(name): + return fn(*args, **kwargs) + + return wrapped + + return decorator + + def _check_for_faulty_process(processes): terminate = False for p in processes: