Skip to content

Commit

Permalink
Added ability to set directory in pytorch lightning
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Fishman <[email protected]>
  • Loading branch information
fishbotics committed Mar 21, 2021
1 parent a8032b4 commit 169ba3f
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions python/keepsake/pl_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
from typing import Optional, Dict, Tuple, Any
from pathlib import Path

import keepsake
from pytorch_lightning.callbacks.base import Callback
Expand Down Expand Up @@ -56,15 +57,18 @@ def __init__(
"""

super().__init__()
self.filepath = filepath
self.filepath = Path(filepath).resolve()
self.params = params
self.primary_metric = primary_metric
self.period = period
self.save_weights_only = save_weights_only
self.last_global_step_saved = -1

def on_pretrain_routine_start(self, trainer, pl_module):
self.experiment = keepsake.init(path=".", params=self.params)
self.experiment = keepsake.init(
path=str(self.filepath.parent),
params=self.params,
)

def on_epoch_end(self, trainer, pl_module):
self._save_model(trainer, pl_module)
Expand All @@ -89,7 +93,7 @@ def _save_model(self, trainer, pl_module):
return

if self.filepath != None:
trainer.save_checkpoint(self.filepath, self.save_weights_only)
trainer.save_checkpoint(self.filepath.name, self.save_weights_only)

self.last_global_step_saved = global_step

Expand All @@ -99,7 +103,7 @@ def _save_model(self, trainer, pl_module):
metrics.update({"global_step": trainer.global_step})

self.experiment.checkpoint(
path=self.filepath,
path=self.filepath.name,
step=trainer.current_epoch,
metrics=metrics,
primary_metric=self.primary_metric,
Expand Down

0 comments on commit 169ba3f

Please sign in to comment.