From 169ba3fbab53c7dbe9af6489fba501e631e7dc96 Mon Sep 17 00:00:00 2001 From: Adam Fishman Date: Sat, 20 Mar 2021 21:45:18 -0700 Subject: [PATCH 1/2] Added ability to set directory in pytorch lightning Signed-off-by: Adam Fishman --- python/keepsake/pl_callback.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/keepsake/pl_callback.py b/python/keepsake/pl_callback.py index 7bbfedae..56089de3 100644 --- a/python/keepsake/pl_callback.py +++ b/python/keepsake/pl_callback.py @@ -1,5 +1,6 @@ from copy import deepcopy from typing import Optional, Dict, Tuple, Any +from pathlib import Path import keepsake from pytorch_lightning.callbacks.base import Callback @@ -56,7 +57,7 @@ def __init__( """ super().__init__() - self.filepath = filepath + self.filepath = Path(filepath).resolve() self.params = params self.primary_metric = primary_metric self.period = period @@ -64,7 +65,10 @@ def __init__( self.last_global_step_saved = -1 def on_pretrain_routine_start(self, trainer, pl_module): - self.experiment = keepsake.init(path=".", params=self.params) + self.experiment = keepsake.init( + path=str(self.filepath.parent), + params=self.params, + ) def on_epoch_end(self, trainer, pl_module): self._save_model(trainer, pl_module) @@ -89,7 +93,7 @@ def _save_model(self, trainer, pl_module): return if self.filepath != None: - trainer.save_checkpoint(self.filepath, self.save_weights_only) + trainer.save_checkpoint(self.filepath.name, self.save_weights_only) self.last_global_step_saved = global_step @@ -99,7 +103,7 @@ def _save_model(self, trainer, pl_module): metrics.update({"global_step": trainer.global_step}) self.experiment.checkpoint( - path=self.filepath, + path=self.filepath.name, step=trainer.current_epoch, metrics=metrics, primary_metric=self.primary_metric, From 455df288729266051473b8a636a223e05245014e Mon Sep 17 00:00:00 2001 From: Adam Fishman Date: Sun, 4 Apr 2021 16:01:16 -0700 Subject: [PATCH 2/2] Updated per @andreasjansson's suggestion --- python/keepsake/pl_callback.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/keepsake/pl_callback.py b/python/keepsake/pl_callback.py index 56089de3..6a379686 100644 --- a/python/keepsake/pl_callback.py +++ b/python/keepsake/pl_callback.py @@ -30,6 +30,7 @@ class KeepsakeCallback(Callback): def __init__( self, filepath="model.pth", + experiment_path=".", params: Optional[Dict[str, Any]] = None, primary_metric: Optional[Tuple[str, str]] = None, period: Optional[int] = 1, @@ -58,6 +59,7 @@ def __init__( super().__init__() self.filepath = Path(filepath).resolve() + self.experiment_path = Path(experiment_path).resolve() self.params = params self.primary_metric = primary_metric self.period = period @@ -66,7 +68,7 @@ def __init__( def on_pretrain_routine_start(self, trainer, pl_module): self.experiment = keepsake.init( - path=str(self.filepath.parent), + path=str(self.experiment_path), params=self.params, )