From 73a4a75b3c48f1d0b56886d30a7f75e7d7369198 Mon Sep 17 00:00:00 2001 From: slebedev Date: Thu, 19 Oct 2023 10:51:51 -0700 Subject: [PATCH] Updated jax.config import PiperOrigin-RevId: 574931286 --- spinoffs/fun_mc/fun_mc/fun_mc_test.py | 2 +- spinoffs/fun_mc/fun_mc/malt_test.py | 2 +- spinoffs/fun_mc/fun_mc/prefab_test.py | 2 +- spinoffs/fun_mc/fun_mc/sga_hmc_test.py | 2 +- spinoffs/fun_mc/fun_mc/util_tfp_test.py | 2 +- ...ase_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb | 2 +- .../jupyter_notebooks/TFP_Release_Notebook_0_11_0.ipynb | 2 +- tensorflow_probability/python/internal/samplers_test.py | 2 +- tensorflow_probability/python/internal/test_util.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_test.py b/spinoffs/fun_mc/fun_mc/fun_mc_test.py index 82ad4a20f0..4941891633 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_test.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_test.py @@ -22,7 +22,7 @@ from absl.testing import parameterized import jax -from jax.config import config as jax_config +from jax import config as jax_config import numpy as np import scipy.stats import tensorflow.compat.v2 as real_tf diff --git a/spinoffs/fun_mc/fun_mc/malt_test.py b/spinoffs/fun_mc/fun_mc/malt_test.py index 54db9965c6..beb927a192 100644 --- a/spinoffs/fun_mc/fun_mc/malt_test.py +++ b/spinoffs/fun_mc/fun_mc/malt_test.py @@ -20,7 +20,7 @@ # Dependency imports import jax -from jax.config import config as jax_config +from jax import config as jax_config import numpy as np import tensorflow.compat.v2 as real_tf diff --git a/spinoffs/fun_mc/fun_mc/prefab_test.py b/spinoffs/fun_mc/fun_mc/prefab_test.py index dc8f88ecf8..5b7b85be3a 100644 --- a/spinoffs/fun_mc/fun_mc/prefab_test.py +++ b/spinoffs/fun_mc/fun_mc/prefab_test.py @@ -20,7 +20,7 @@ # Dependency imports import jax -from jax.config import config as jax_config +from jax import config as jax_config import numpy as np import tensorflow.compat.v2 as real_tf diff --git a/spinoffs/fun_mc/fun_mc/sga_hmc_test.py b/spinoffs/fun_mc/fun_mc/sga_hmc_test.py index a26036def7..4cdee429ce 100644 --- a/spinoffs/fun_mc/fun_mc/sga_hmc_test.py +++ b/spinoffs/fun_mc/fun_mc/sga_hmc_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized import jax -from jax.config import config as jax_config +from jax import config as jax_config import tensorflow.compat.v2 as real_tf from tensorflow_probability.python.internal import test_util as tfp_test_util diff --git a/spinoffs/fun_mc/fun_mc/util_tfp_test.py b/spinoffs/fun_mc/fun_mc/util_tfp_test.py index b52503820f..6315f8e6e0 100644 --- a/spinoffs/fun_mc/fun_mc/util_tfp_test.py +++ b/spinoffs/fun_mc/fun_mc/util_tfp_test.py @@ -17,7 +17,7 @@ # Dependency imports from absl.testing import parameterized -from jax.config import config as jax_config +from jax import config as jax_config import numpy as np import tensorflow.compat.v2 as real_tf diff --git a/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb index f076e1efd2..34f0d4c5de 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb @@ -95,7 +95,7 @@ "\n", "import numpy as np\n", "import jax\n", - "from jax.config import config\n", + "from jax import config\n", "config.update('jax_enable_x64', True)\n", "\n", "from tensorflow_probability.substrates import jax as tfp\n", diff --git a/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_11_0.ipynb b/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_11_0.ipynb index ee40cea633..28c7a447fe 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_11_0.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_11_0.ipynb @@ -143,7 +143,7 @@ }, "source": [ "import jax\n", - "from jax.config import config\n", + "from jax import config\n", "config.update('jax_enable_x64', True)\n", "\n", "def demo_jax():\n", diff --git a/tensorflow_probability/python/internal/samplers_test.py b/tensorflow_probability/python/internal/samplers_test.py index 2b860b93f9..3ae5fdfd0e 100644 --- a/tensorflow_probability/python/internal/samplers_test.py +++ b/tensorflow_probability/python/internal/samplers_test.py @@ -37,7 +37,7 @@ def setUp(self): super().setUp() if JAX_MODE and FLAGS.test_tfp_jax_prng != 'default': - from jax.config import config # pylint: disable=g-import-not-at-top + from jax import config # pylint: disable=g-import-not-at-top config.update('jax_default_prng_impl', FLAGS.test_tfp_jax_prng) @test_util.substrate_disable_stateful_random_test diff --git a/tensorflow_probability/python/internal/test_util.py b/tensorflow_probability/python/internal/test_util.py index 1058376766..0da39d05b5 100644 --- a/tensorflow_probability/python/internal/test_util.py +++ b/tensorflow_probability/python/internal/test_util.py @@ -2021,7 +2021,7 @@ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name def main(jax_mode=JAX_MODE, jax_enable_x64=True): """Test main function that injects a custom loader.""" if jax_mode and jax_enable_x64: - from jax.config import config # pylint: disable=g-import-not-at-top + from jax import config # pylint: disable=g-import-not-at-top config.update('jax_enable_x64', True) # This logic is borrowed from TensorFlow.