Skip to content

Commit

Permalink
Add orbax-checkpoint as dependancy
Browse files Browse the repository at this point in the history
  • Loading branch information
tttc3 committed Sep 25, 2023
1 parent 458d880 commit 827f686
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 2 deletions.
3 changes: 1 addition & 2 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import OrderedDict

import numpy as np
import orbax.checkpoint as ocp

from . import config
from . import display
Expand Down Expand Up @@ -1012,9 +1013,7 @@ def save(self, save_path, protocol="backend", verbose=0):
save_path += ".ckpt"
self.net.save_weights(save_path)
elif backend_name == "jax":
# Lazy load Orbax to avoid a hard dependancy when using JAX
# TODO: identify a better solution that complies with PEP8
import orbax.checkpoint as ocp
from flax.training import orbax_utils
save_path += ".ckpt"
checkpoint = {
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
matplotlib
numpy
orbax-checkpoint
scikit-learn
scikit-optimize>=0.9.0
scipy
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ classifiers = [
dependencies = [
"matplotlib",
"numpy",
"orbax-checkpoint",
"scikit-learn",
"scikit-optimize>=0.9.0",
"scipy",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
matplotlib
numpy
orbax-checkpoint
scikit-learn
scikit-optimize>=0.9.0
scipy

0 comments on commit 827f686

Please sign in to comment.