Skip to content

Commit

Permalink
Added JAX checkpointing via Orbax
Browse files Browse the repository at this point in the history
  • Loading branch information
tttc3 committed Sep 25, 2023
1 parent 0b518c6 commit 458d880
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,23 @@ def save(self, save_path, protocol="backend", verbose=0):
elif backend_name == "tensorflow":
save_path += ".ckpt"
self.net.save_weights(save_path)
elif backend_name == "jax":
# Lazy load Orbax to avoid a hard dependancy when using JAX
# TODO: identify a better solution that complies with PEP8
import orbax.checkpoint as ocp
from flax.training import orbax_utils
save_path += ".ckpt"
checkpoint = {
"params": self.params,
"state": self.opt_state
}
self.checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(checkpoint)
# `Force=True` option causes existing checkpoints to be
# overwritten, matching the PyTorch checkpointer behaviour.
self.checkpointer.save(
save_path, checkpoint, force=True, save_args=save_args
)
elif backend_name == "pytorch":
save_path += ".pt"
checkpoint = {
Expand Down Expand Up @@ -1055,6 +1072,10 @@ def restore(self, save_path, device=None, verbose=0):
self.saver.restore(self.sess, save_path)
elif backend_name == "tensorflow":
self.net.load_weights(save_path)
elif backend_name == "jax":
checkpoint = self.checkpointer.restore(save_path)
self.params, self.opt_state = checkpoint["params"], checkpoint["state"]
self.net.params, self.external_trainable_variables = self.params
elif backend_name == "pytorch":
if device is not None:
checkpoint = torch.load(save_path, map_location=torch.device(device))
Expand Down

0 comments on commit 458d880

Please sign in to comment.