diff --git a/policy_models/training/trainers.py b/policy_models/training/trainers.py index 3d27794..73bae1f 100644 --- a/policy_models/training/trainers.py +++ b/policy_models/training/trainers.py @@ -294,6 +294,20 @@ def log(self, obj: Dict[str, Any]): def close(self): self.f.close() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def __del__(self): + try: + if not self.f.closed: + self.f.close() + except Exception: + pass + class Checkpointer: def __init__(self, out_dir: str, pv_dir: str):