Skip to content

Commit

Permalink
Updated jax.config import
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574931286
  • Loading branch information
superbobry authored and tensorflower-gardener committed Oct 19, 2023
1 parent e16c0b7 commit 73a4a75
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion spinoffs/fun_mc/fun_mc/fun_mc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from absl.testing import parameterized
import jax
from jax.config import config as jax_config
from jax import config as jax_config
import numpy as np
import scipy.stats
import tensorflow.compat.v2 as real_tf
Expand Down
2 changes: 1 addition & 1 deletion spinoffs/fun_mc/fun_mc/malt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Dependency imports

import jax
from jax.config import config as jax_config
from jax import config as jax_config
import numpy as np
import tensorflow.compat.v2 as real_tf

Expand Down
2 changes: 1 addition & 1 deletion spinoffs/fun_mc/fun_mc/prefab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Dependency imports

import jax
from jax.config import config as jax_config
from jax import config as jax_config
import numpy as np
import tensorflow.compat.v2 as real_tf

Expand Down
2 changes: 1 addition & 1 deletion spinoffs/fun_mc/fun_mc/sga_hmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from absl.testing import parameterized
import jax
from jax.config import config as jax_config
from jax import config as jax_config
import tensorflow.compat.v2 as real_tf

from tensorflow_probability.python.internal import test_util as tfp_test_util
Expand Down
2 changes: 1 addition & 1 deletion spinoffs/fun_mc/fun_mc/util_tfp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Dependency imports

from absl.testing import parameterized
from jax.config import config as jax_config
from jax import config as jax_config
import numpy as np
import tensorflow.compat.v2 as real_tf

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"\n",
"import numpy as np\n",
"import jax\n",
"from jax.config import config\n",
"from jax import config\n",
"config.update('jax_enable_x64', True)\n",
"\n",
"from tensorflow_probability.substrates import jax as tfp\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
},
"source": [
"import jax\n",
"from jax.config import config\n",
"from jax import config\n",
"config.update('jax_enable_x64', True)\n",
"\n",
"def demo_jax():\n",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/python/internal/samplers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setUp(self):
super().setUp()

if JAX_MODE and FLAGS.test_tfp_jax_prng != 'default':
from jax.config import config # pylint: disable=g-import-not-at-top
from jax import config # pylint: disable=g-import-not-at-top
config.update('jax_default_prng_impl', FLAGS.test_tfp_jax_prng)

@test_util.substrate_disable_stateful_random_test
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/python/internal/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2021,7 +2021,7 @@ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name
def main(jax_mode=JAX_MODE, jax_enable_x64=True):
"""Test main function that injects a custom loader."""
if jax_mode and jax_enable_x64:
from jax.config import config # pylint: disable=g-import-not-at-top
from jax import config # pylint: disable=g-import-not-at-top
config.update('jax_enable_x64', True)

# This logic is borrowed from TensorFlow.
Expand Down

0 comments on commit 73a4a75

Please sign in to comment.