diff --git a/tensorflow_probability/substrates/BUILD b/tensorflow_probability/substrates/BUILD index ae3da5ba44..3f99a48001 100644 --- a/tensorflow_probability/substrates/BUILD +++ b/tensorflow_probability/substrates/BUILD @@ -47,6 +47,7 @@ py_library( tags = ["alt_dep=//tensorflow_probability:jax"], deps = [ # "//tensorflow_probability/google:google.jax", # DisableOnExport +# "//tensorflow_probability/google/autosts:autosts.jax", # DisableOnExport # "//tensorflow_probability/google/staging:staging.jax", # DisableOnExport # "//tensorflow_probability/google/tfp_google:tfp_google.jax", # DisableOnExport "//tensorflow_probability/python:version", diff --git a/tensorflow_probability/substrates/jax/__init__.py b/tensorflow_probability/substrates/jax/__init__.py index a02ed1cd7b..aeae242357 100644 --- a/tensorflow_probability/substrates/jax/__init__.py +++ b/tensorflow_probability/substrates/jax/__init__.py @@ -37,6 +37,7 @@ def _ensure_jax_install(): # pylint: disable=g-statement-before-imports del _ensure_jax_install # Cleanup symbol to avoid polluting namespace. from tensorflow_probability.python.version import __version__ +# from tensorflow_probability.substrates.jax.google import autosts # DisableOnExport # pylint:disable=line-too-long # from tensorflow_probability.substrates.jax.google import staging # DisableOnExport # pylint:disable=line-too-long from tensorflow_probability.substrates.jax import bijectors from tensorflow_probability.substrates.jax import distributions diff --git a/tensorflow_probability/substrates/meta/rewrite.py b/tensorflow_probability/substrates/meta/rewrite.py index 3f05c71017..4fe8edc349 100644 --- a/tensorflow_probability/substrates/meta/rewrite.py +++ b/tensorflow_probability/substrates/meta/rewrite.py @@ -97,8 +97,8 @@ ('auto_batching', 'composite_tensor', 'linalg', 'marginalize', 'nn', 'sequential', 'substrates'), } -LIBS = ('bijectors', 'distributions', 'experimental', 'glm', 'math', 'mcmc', - 'monte_carlo', 'optimizer', 'random', 'staging', 'stats', 'sts', +LIBS = ('autosts', 'bijectors', 'distributions', 'experimental', 'glm', 'math', + 'mcmc', 'monte_carlo', 'optimizer', 'random', 'staging', 'stats', 'sts', 'tfp_google', 'util', 'vi') DISTRIBUTION_INTERNALS = ('stochastic_process_util',) INTERNALS = ('assert_util', 'auto_composite_tensor',