diff --git a/python/keepsake/pl_callback.py b/python/keepsake/pl_callback.py index 7bbfedae..6a379686 100644 --- a/python/keepsake/pl_callback.py +++ b/python/keepsake/pl_callback.py @@ -1,5 +1,6 @@ from copy import deepcopy from typing import Optional, Dict, Tuple, Any +from pathlib import Path import keepsake from pytorch_lightning.callbacks.base import Callback @@ -29,6 +30,7 @@ class KeepsakeCallback(Callback): def __init__( self, filepath="model.pth", + experiment_path=".", params: Optional[Dict[str, Any]] = None, primary_metric: Optional[Tuple[str, str]] = None, period: Optional[int] = 1, @@ -56,7 +58,8 @@ def __init__( """ super().__init__() - self.filepath = filepath + self.filepath = Path(filepath).resolve() + self.experiment_path = Path(experiment_path).resolve() self.params = params self.primary_metric = primary_metric self.period = period @@ -64,7 +67,10 @@ def __init__( self.last_global_step_saved = -1 def on_pretrain_routine_start(self, trainer, pl_module): - self.experiment = keepsake.init(path=".", params=self.params) + self.experiment = keepsake.init( + path=str(self.experiment_path), + params=self.params, + ) def on_epoch_end(self, trainer, pl_module): self._save_model(trainer, pl_module) @@ -89,7 +95,7 @@ def _save_model(self, trainer, pl_module): return if self.filepath != None: - trainer.save_checkpoint(self.filepath, self.save_weights_only) + trainer.save_checkpoint(self.filepath.name, self.save_weights_only) self.last_global_step_saved = global_step @@ -99,7 +105,7 @@ def _save_model(self, trainer, pl_module): metrics.update({"global_step": trainer.global_step}) self.experiment.checkpoint( - path=self.filepath, + path=self.filepath.name, step=trainer.current_epoch, metrics=metrics, primary_metric=self.primary_metric,