diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index c0e2a54540..045ef6fbd8 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -187,8 +187,8 @@ they supersede all previous conventions. 1. Submodule names should be singular, except where they overlap to TF. Justification: Having plural looks strange in user code, ie, - tf.optimizer.Foo reads nicer than tf.optimizers.Foo since submodules are - only used to access a single, specific thing (at a time). + tf.optimizer.Foo reads nicer than tf_keras.optimizers.Foo since submodules + are only used to access a single, specific thing (at a time). 1. Use `tf.newaxis` rather than `None` to `tf.expand_dims`. diff --git a/SUBSTRATES.md b/SUBSTRATES.md index e926007a4f..99edbf1ee0 100644 --- a/SUBSTRATES.md +++ b/SUBSTRATES.md @@ -75,11 +75,11 @@ vmap, etc.), we will special-case using an `if JAX_MODE:` block. tests, TFP impl, etc), with `tfp.math.value_and_gradient` or similar. Then, we can special-case `JAX_MODE` inside the body of `value_and_gradient`. -* __`tf.Variable`, `tf.optimizers.Optimizer`__ +* __`tf.Variable`, `tf_keras.optimizers.Optimizer`__ TF provides a `Variable` abstraction so that graph functions may modify - state, including using the TF `Optimizer` subclasses like `Adam`. JAX, in - contrast, operates only on pure functions. In general, TFP is fairly + state, including using the Keras `Optimizer` subclasses like `Adam`. JAX, + in contrast, operates only on pure functions. In general, TFP is fairly functional (e.g. `tfp.optimizer.lbfgs_minimize`), but in some cases (e.g. `tfp.vi.fit_surrogate_posterior`, `tfp.optimizer.StochasticGradientLangevinDynamics`) we have felt the diff --git a/discussion/adaptive_malt/adaptive_malt.py b/discussion/adaptive_malt/adaptive_malt.py index a6edb0b0e9..3952b04d09 100644 --- a/discussion/adaptive_malt/adaptive_malt.py +++ b/discussion/adaptive_malt/adaptive_malt.py @@ -350,7 +350,7 @@ def adaptive_mcmc_step( target_log_prob_fn: fun_mc.PotentialFn, num_mala_steps: int, num_adaptation_steps: int, - seed: jax.random.KeyArray, + seed: jax.Array, method: str = 'hmc', damping: Optional[jnp.ndarray] = None, scalar_step_size: Optional[jnp.ndarray] = None, @@ -778,7 +778,7 @@ def adaptive_nuts_step( target_log_prob_fn: fun_mc.PotentialFn, num_mala_steps: int, num_adaptation_steps: int, - seed: jax.random.KeyArray, + seed: jax.Array, scalar_step_size: Optional[jnp.ndarray] = None, vector_step_size: Optional[jnp.ndarray] = None, rvar_factor: int = 8, @@ -1040,7 +1040,7 @@ class MeadsExtra(NamedTuple): def meads_init(state: jnp.ndarray, target_log_prob_fn: fun_mc.PotentialFn, - num_folds: int, seed: jax.random.KeyArray): + num_folds: int, seed: jax.Array): """Initializes MEADS.""" num_dimensions = state.shape[-1] num_chains = state.shape[0] @@ -1062,7 +1062,7 @@ def meads_init(state: jnp.ndarray, target_log_prob_fn: fun_mc.PotentialFn, def meads_step(meads_state: MeadsState, target_log_prob_fn: fun_mc.PotentialFn, - seed: jax.random.KeyArray, + seed: jax.Array, vector_step_size: Optional[jnp.ndarray] = None, damping: Optional[jnp.ndarray] = None, step_size_multiplier: float = 0.5, @@ -1221,7 +1221,7 @@ def run_adaptive_mcmc_on_target( init_step_size: jnp.ndarray, num_adaptation_steps: int, num_results: int, - seed: jax.random.KeyArray, + seed: jax.Array, num_mala_steps: int = 100, rvar_smoothing: int = 0, trajectory_opt_kwargs: Mapping[str, Any] = immutabledict.immutabledict({ @@ -1358,7 +1358,7 @@ def run_adaptive_nuts_on_target( init_step_size: jnp.ndarray, num_adaptation_steps: int, num_results: int, - seed: jax.random.KeyArray, + seed: jax.Array, num_mala_steps: int = 100, rvar_smoothing: int = 0, num_chains: Optional[int] = None, @@ -1478,7 +1478,7 @@ def run_meads_on_target( num_adaptation_steps: int, num_results: int, thinning: int, - seed: jax.random.KeyArray, + seed: jax.Array, num_folds: int, num_chains: Optional[int] = None, init_x: Optional[jnp.ndarray] = None, @@ -1596,7 +1596,7 @@ def run_fixed_mcmc_on_target( target: gym.targets.Model, init_x: jnp.ndarray, method: str, - seed: jax.random.KeyArray, + seed: jax.Array, num_warmup_steps: int, num_results: int, scalar_step_size: jnp.ndarray, @@ -1706,7 +1706,7 @@ def run_vi_on_target( init_x: jnp.ndarray, num_steps: int, learning_rate: float, - seed: jax.random.KeyArray, + seed: jax.Array, ): """Run VI on a target. diff --git a/required_packages.py b/required_packages.py index cefd122969..fbb6305291 100644 --- a/required_packages.py +++ b/required_packages.py @@ -26,7 +26,6 @@ 'cloudpickle>=1.3', 'gast>=0.3.2', # For autobatching 'dm-tree', # For NumPy/JAX backends (hence, also for prefer_static) - 'typing-extensions<4.6.0', # TODO(b/284106340): Remove this pin ] if __name__ == '__main__': diff --git a/setup.py b/setup.py index 7d5107064d..6834b256d1 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ def has_ext_modules(self): url='http://github.com/tensorflow/probability', license='Apache 2.0', packages=find_packages(), - python_requires='>=3.8', + python_requires='>=3.9', install_requires=REQUIRED_PACKAGES, # Add in any packaged data. include_package_data=True, @@ -88,7 +88,6 @@ def has_ext_modules(self): 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py index f017b19f86..bb0805c4d5 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py @@ -97,7 +97,9 @@ def make_tensor_seed(seed): """Converts a seed to a `Tensor` seed.""" if seed is None: raise ValueError('seed must not be None when using JAX') - if isinstance(seed, jax.random.PRNGKeyArray): + if hasattr(seed, 'dtype') and jax.dtypes.issubdtype( + seed.dtype, jax.dtypes.prng_key + ): return seed return jnp.asarray(seed, jnp.uint32) diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py index 01e0fbead9..498a909a0b 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py @@ -732,9 +732,9 @@ def maybe_broadcast_structure(from_structure: Any, to_structure: Any) -> Any: """Maybe broadcasts `from_structure` to `to_structure`. - If `from_structure` is a singleton, it is tiled to match the structure of - `to_structure`. Note that the elements in `from_structure` are not copied if - this tiling occurs. + This assumes that `from_structure` is a shallow version of `to_structure`. + Subtrees of `to_structure` are set to the leaf values of `from_structure` that + those subtrees correspond to. Args: from_structure: A structure. @@ -743,11 +743,12 @@ def maybe_broadcast_structure(from_structure: Any, Returns: new_from_structure: Same structure as `to_structure`. """ - flat_from = util.flatten_tree(from_structure) - flat_to = util.flatten_tree(to_structure) - if len(flat_from) == 1: - flat_from *= len(flat_to) - return util.unflatten_tree(to_structure, flat_from) + def _broadcast_leaf(from_val, to_subtree): + return util.map_tree(lambda _: from_val, to_subtree) + + return util.map_tree_up_to( + from_structure, _broadcast_leaf, from_structure, to_structure + ) def reparameterize_potential_fn( @@ -3420,8 +3421,11 @@ def _default_log_weight_fn(old_state, new_state, stage, transition_extra): @util.named_call def systematic_resample( - particles: State, log_weights: FloatTensor, - seed: Any) -> (tuple[tuple[State, FloatTensor], IntTensor]): + particles: State, + log_weights: FloatTensor, + seed: Any, + do_resample: Optional[BooleanTensor] = None, +) -> tuple[tuple[State, FloatTensor], IntTensor]: """Systematically resamples particles in proportion to their weights. This uses the algorithm from [1]. @@ -3430,6 +3434,8 @@ def systematic_resample( particles: The particles. log_weights: Un-normalized weights. seed: PRNG seed. + do_resample: Whether to perform the resample. If None, resampling is + performed unconditionally. Returns: particles_and_weights: tuple of resampled particles and weights. @@ -3453,9 +3459,13 @@ def systematic_resample( repeats = tf.cast(util.diff(tf.floor(pie), prepend=0), tf.int32) parent_idxs = util.repeat( tf.range(num_particles), repeats, total_repeat_length=num_particles) + if do_resample is not None: + parent_idxs = tf.where(do_resample, parent_idxs, tf.range(num_particles)) new_particles = util.map_tree(lambda x: tf.gather(x, parent_idxs), particles) new_log_weights = tf.fill(log_weights.shape, tfp.math.reduce_logmeanexp(log_weights)) + if do_resample is not None: + new_log_weights = tf.where(do_resample, new_log_weights, log_weights) return (new_particles, new_log_weights), parent_idxs @@ -3463,20 +3473,22 @@ def systematic_resample( def annealed_importance_sampling_resample( ais_state: AnnealedImportanceSamplingState, resample_fn: Callable[ - [State, FloatTensor, Any], tuple[tuple[State, tf.Tensor], ResampleExtra] + [State, FloatTensor, Any, BooleanTensor], + tuple[tuple[State, tf.Tensor], ResampleExtra], ] = systematic_resample, min_ess_threshold: FloatTensor = 0.5, seed: Any = None, ) -> tuple[AnnealedImportanceSamplingState, ResampleExtra]: """Resamples the particles in AnnealedImportanceSamplingState.""" - (state, log_weight), extra = resample_fn(ais_state.state, - ais_state.log_weight, seed) - state, log_weight = choose( - ais_state.ess() < - tf.cast(log_weight.shape[0], log_weight.dtype) * min_ess_threshold, - (state, log_weight), - (ais_state.state, ais_state.log_weight), + log_weight = tf.convert_to_tensor(ais_state.log_weight) + do_resample = ( + ais_state.ess() + < tf.cast(log_weight.shape[0], log_weight.dtype) + * min_ess_threshold + ) + (state, log_weight), extra = resample_fn( + ais_state.state, ais_state.log_weight, seed, do_resample ) return ais_state._replace(state=state, log_weight=log_weight), extra @@ -3500,7 +3512,7 @@ def geometric_annealing_path( initial_target_log_prob_fn: PotentialFn, final_target_log_prob_fn: PotentialFn, fraction_fn: Optional[Callable[[FloatTensor], tf.Tensor]] = None, -) -> Callable[[Stage], PotentialFn]: +) -> PotentialFn: """Returns a geometrically interpolated target density function. This interpolates between `initial_target_log_prob_fn` and diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_test.py b/spinoffs/fun_mc/fun_mc/fun_mc_test.py index 527836d037..4941891633 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_test.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_test.py @@ -22,7 +22,7 @@ from absl.testing import parameterized import jax -from jax.config import config as jax_config +from jax import config as jax_config import numpy as np import scipy.stats import tensorflow.compat.v2 as real_tf @@ -361,6 +361,9 @@ def testBroadcastStructure(self): struct = fun_mc.maybe_broadcast_structure([3, 4], [1, 2]) self.assertEqual([3, 4], struct) + struct = fun_mc.maybe_broadcast_structure([1, 2], [[0, 0], [0, 0, 0]]) + self.assertEqual([[1, 1], [2, 2, 2]], struct) + def testCallPotentialFn(self): def potential(x): @@ -1885,6 +1888,36 @@ def body(seed): new_log_weights, tf.fill(probs.shape, tfp.math.reduce_logmeanexp(log_weights))) + def testSystematicResampleAncestors(self): + log_weights = self._constant([-float('inf'), 0.]) + particles = tf.range(log_weights.shape[0]) + seed = self._make_seed(_test_seed()) + + (new_particles, new_log_weights), ancestors = fun_mc.systematic_resample( + particles, log_weights, seed=seed + ) + self.assertAllEqual(new_particles, tf.ones_like(particles)) + self.assertAllEqual( + new_log_weights, tf.math.log(self._constant([0.5, 0.5])) + ) + self.assertAllEqual(ancestors, tf.ones_like(particles)) + + (new_particles, new_log_weights), ancestors = fun_mc.systematic_resample( + particles, log_weights, do_resample=True, seed=seed + ) + self.assertAllEqual(new_particles, tf.ones_like(particles)) + self.assertAllEqual( + new_log_weights, tf.math.log(self._constant([0.5, 0.5])) + ) + self.assertAllEqual(ancestors, tf.ones_like(particles)) + + (new_particles, new_log_weights), ancestors = fun_mc.systematic_resample( + particles, log_weights, do_resample=False, seed=seed + ) + self.assertAllEqual(new_particles, particles) + self.assertAllEqual(new_log_weights, log_weights) + self.assertAllEqual(ancestors, particles) + def testAIS(self): def tlp_1(x): diff --git a/spinoffs/fun_mc/fun_mc/malt_test.py b/spinoffs/fun_mc/fun_mc/malt_test.py index 54db9965c6..beb927a192 100644 --- a/spinoffs/fun_mc/fun_mc/malt_test.py +++ b/spinoffs/fun_mc/fun_mc/malt_test.py @@ -20,7 +20,7 @@ # Dependency imports import jax -from jax.config import config as jax_config +from jax import config as jax_config import numpy as np import tensorflow.compat.v2 as real_tf diff --git a/spinoffs/fun_mc/fun_mc/prefab_test.py b/spinoffs/fun_mc/fun_mc/prefab_test.py index dc8f88ecf8..5b7b85be3a 100644 --- a/spinoffs/fun_mc/fun_mc/prefab_test.py +++ b/spinoffs/fun_mc/fun_mc/prefab_test.py @@ -20,7 +20,7 @@ # Dependency imports import jax -from jax.config import config as jax_config +from jax import config as jax_config import numpy as np import tensorflow.compat.v2 as real_tf diff --git a/spinoffs/fun_mc/fun_mc/sga_hmc_test.py b/spinoffs/fun_mc/fun_mc/sga_hmc_test.py index a26036def7..4cdee429ce 100644 --- a/spinoffs/fun_mc/fun_mc/sga_hmc_test.py +++ b/spinoffs/fun_mc/fun_mc/sga_hmc_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized import jax -from jax.config import config as jax_config +from jax import config as jax_config import tensorflow.compat.v2 as real_tf from tensorflow_probability.python.internal import test_util as tfp_test_util diff --git a/spinoffs/fun_mc/fun_mc/util_tfp_test.py b/spinoffs/fun_mc/fun_mc/util_tfp_test.py index b52503820f..6315f8e6e0 100644 --- a/spinoffs/fun_mc/fun_mc/util_tfp_test.py +++ b/spinoffs/fun_mc/fun_mc/util_tfp_test.py @@ -17,7 +17,7 @@ # Dependency imports from absl.testing import parameterized -from jax.config import config as jax_config +from jax import config as jax_config import numpy as np import tensorflow.compat.v2 as real_tf diff --git a/spinoffs/inference_gym/inference_gym/BUILD b/spinoffs/inference_gym/inference_gym/BUILD index 339318ae8b..7aa8632343 100644 --- a/spinoffs/inference_gym/inference_gym/BUILD +++ b/spinoffs/inference_gym/inference_gym/BUILD @@ -18,7 +18,6 @@ # Placeholder: py_library # [internal] load pytype.bzl (pytype_strict_library) -# [internal] load dummy dependency package( # default_applicable_licenses @@ -98,5 +97,3 @@ py_library( name = "backend_tensorflow", srcs = ["dynamic/backend_tensorflow/__init__.py"], ) - -# third_party_dependency(package = "py/inference_gym") # DisableOnExport diff --git a/tensorflow_probability/examples/BUILD b/tensorflow_probability/examples/BUILD index 2193b49f76..58c89b9220 100644 --- a/tensorflow_probability/examples/BUILD +++ b/tensorflow_probability/examples/BUILD @@ -84,6 +84,7 @@ py_library( # six dep, # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/internal:tf_keras", ], ) diff --git a/tensorflow_probability/examples/bayesian_neural_network.py b/tensorflow_probability/examples/bayesian_neural_network.py index fe1f08cd4a..976a99de69 100644 --- a/tensorflow_probability/examples/bayesian_neural_network.py +++ b/tensorflow_probability/examples/bayesian_neural_network.py @@ -37,6 +37,7 @@ import numpy as np import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.internal import tf_keras tf.enable_v2_behavior() @@ -174,26 +175,26 @@ def create_model(): # and two fully connected dense layers. We use the Flipout # Monte Carlo estimator for these layers, which enables lower variance # stochastic gradients than naive reparameterization. - model = tf.keras.models.Sequential([ + model = tf_keras.models.Sequential([ tfp.layers.Convolution2DFlipout( 6, kernel_size=5, padding='SAME', kernel_divergence_fn=kl_divergence_function, activation=tf.nn.relu), - tf.keras.layers.MaxPooling2D( + tf_keras.layers.MaxPooling2D( pool_size=[2, 2], strides=[2, 2], padding='SAME'), tfp.layers.Convolution2DFlipout( 16, kernel_size=5, padding='SAME', kernel_divergence_fn=kl_divergence_function, activation=tf.nn.relu), - tf.keras.layers.MaxPooling2D( + tf_keras.layers.MaxPooling2D( pool_size=[2, 2], strides=[2, 2], padding='SAME'), tfp.layers.Convolution2DFlipout( 120, kernel_size=5, padding='SAME', kernel_divergence_fn=kl_divergence_function, activation=tf.nn.relu), - tf.keras.layers.Flatten(), + tf_keras.layers.Flatten(), tfp.layers.DenseFlipout( 84, kernel_divergence_fn=kl_divergence_function, activation=tf.nn.relu), @@ -203,7 +204,7 @@ def create_model(): ]) # Model compilation. - optimizer = tf.keras.optimizers.Adam(lr=FLAGS.learning_rate) + optimizer = tf_keras.optimizers.Adam(lr=FLAGS.learning_rate) # We use the categorical_crossentropy loss since the MNIST dataset contains # ten labels. The Keras API will then automatically add the # Kullback-Leibler divergence (contained on the individual layers of @@ -214,7 +215,7 @@ def create_model(): return model -class MNISTSequence(tf.keras.utils.Sequence): +class MNISTSequence(tf_keras.utils.Sequence): """Produces a sequence of MNIST digits with labels.""" def __init__(self, data=None, batch_size=128, fake_data_size=None): @@ -272,7 +273,7 @@ def __preprocessing(images, labels): images = 2 * (images / 255.) - 1. images = images[..., tf.newaxis] - labels = tf.keras.utils.to_categorical(labels) + labels = tf_keras.utils.to_categorical(labels) return images, labels def __len__(self): @@ -298,7 +299,7 @@ def main(argv): heldout_seq = MNISTSequence(batch_size=FLAGS.batch_size, fake_data_size=NUM_HELDOUT_EXAMPLES) else: - train_set, heldout_set = tf.keras.datasets.mnist.load_data() + train_set, heldout_set = tf_keras.datasets.mnist.load_data() train_seq = MNISTSequence(data=train_set, batch_size=FLAGS.batch_size) heldout_seq = MNISTSequence(data=heldout_set, batch_size=FLAGS.batch_size) diff --git a/tensorflow_probability/examples/cifar10_bnn.py b/tensorflow_probability/examples/cifar10_bnn.py index 666f5aca1b..4504667bd7 100644 --- a/tensorflow_probability/examples/cifar10_bnn.py +++ b/tensorflow_probability/examples/cifar10_bnn.py @@ -47,6 +47,8 @@ from tensorflow_probability.examples.models.bayesian_resnet import bayesian_resnet from tensorflow_probability.examples.models.bayesian_vgg import bayesian_vgg +from tensorflow_probability.python.internal import tf_keras + matplotlib.use("Agg") warnings.simplefilter(action="ignore") tfd = tfp.distributions @@ -169,7 +171,7 @@ def main(argv): if FLAGS.fake_data: (x_train, y_train), (x_test, y_test) = build_fake_data() else: - (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() + (x_train, y_train), (x_test, y_test) = tf_keras.datasets.cifar10.load_data() (images, labels, handle, training_iterator, diff --git a/tensorflow_probability/examples/disentangled_vae.py b/tensorflow_probability/examples/disentangled_vae.py index 483153adb9..de8f823ff9 100644 --- a/tensorflow_probability/examples/disentangled_vae.py +++ b/tensorflow_probability/examples/disentangled_vae.py @@ -102,10 +102,12 @@ from absl import app from absl import flags -import tensorflow.compat.v1 as tf +import tensorflow.compat.v1 as tf1 +import tensorflow.compat.v2 as tf import tensorflow_probability as tfp from tensorflow_probability.examples import sprites_dataset +from tensorflow_probability.python.internal import tf_keras tfd = tfp.distributions @@ -178,7 +180,7 @@ FLAGS = flags.FLAGS -class LearnableMultivariateNormalDiag(tf.keras.Model): +class LearnableMultivariateNormalDiag(tf_keras.v1.Model): """Learnable multivariate diagonal normal distribution. The model is a multivariate normal distribution with learnable @@ -193,19 +195,19 @@ def __init__(self, dimensions): distribution. """ super(LearnableMultivariateNormalDiag, self).__init__() - with tf.compat.v1.name_scope(self._name): + with tf1.name_scope(self._name): self.dimensions = dimensions - self._mean = tf.compat.v2.Variable( - tf.random.normal([dimensions], stddev=0.1), name="mean") + self._mean = tf.Variable( + tf1.random.normal([dimensions], stddev=0.1), name="mean") # Initialize the std dev such that it will be close to 1 after a softplus # function. - self._untransformed_stddev = tf.compat.v2.Variable( - tf.random.normal([dimensions], mean=0.55, stddev=0.1), + self._untransformed_stddev = tf.Variable( + tf1.random.normal([dimensions], mean=0.55, stddev=0.1), name="untransformed_stddev") def __call__(self, *args, **kwargs): # Allow this Model to be called without inputs. - dummy = tf.zeros(self.dimensions) + dummy = tf1.zeros(self.dimensions) return super(LearnableMultivariateNormalDiag, self).__call__( dummy, *args, **kwargs) @@ -221,7 +223,7 @@ def call(self, inputs): dimensions]. """ del inputs # unused - with tf.compat.v1.name_scope(self._name): + with tf1.name_scope(self._name): return tfd.MultivariateNormalDiag(self.loc, self.scale_diag) @property @@ -232,10 +234,10 @@ def loc(self): @property def scale_diag(self): """The diagonal standard deviation of the normal distribution.""" - return tf.nn.softplus(self._untransformed_stddev) + 1e-5 # keep > 0 + return tf1.nn.softplus(self._untransformed_stddev) + 1e-5 # keep > 0 -class LearnableMultivariateNormalDiagCell(tf.keras.Model): +class LearnableMultivariateNormalDiagCell(tf_keras.v1.Model): """Multivariate diagonal normal distribution RNN cell. The model is an LSTM-based recurrent function that computes the @@ -254,8 +256,8 @@ def __init__(self, dimensions, hidden_size): super(LearnableMultivariateNormalDiagCell, self).__init__() self.dimensions = dimensions self.hidden_size = hidden_size - self.lstm_cell = tf.keras.layers.LSTMCell(hidden_size) - self.output_layer = tf.keras.layers.Dense(2*dimensions) + self.lstm_cell = tf_keras.v1.layers.LSTMCell(hidden_size) + self.output_layer = tf_keras.v1.layers.Dense(2*dimensions) def zero_state(self, sample_batch_shape=()): """Returns an initial state for the LSTM cell. @@ -268,12 +270,11 @@ def zero_state(self, sample_batch_shape=()): A tuple of the initial previous output at timestep 0 of shape [sample_batch_shape, dimensions], and the cell state. """ - h0 = tf.zeros([1, self.hidden_size]) - c0 = tf.zeros([1, self.hidden_size]) - combined_shape = tf.concat((tf.convert_to_tensor( - value=sample_batch_shape, dtype=tf.int32), [self.dimensions]), - axis=-1) - previous_output = tf.zeros(combined_shape) + h0 = tf1.zeros([1, self.hidden_size]) + c0 = tf1.zeros([1, self.hidden_size]) + combined_shape = tf1.concat((tf1.convert_to_tensor( + value=sample_batch_shape, dtype=tf1.int32), [self.dimensions]), axis=-1) + previous_output = tf1.zeros(combined_shape) return previous_output, (h0, c0) def call(self, inputs, state): @@ -298,20 +299,20 @@ def call(self, inputs, state): # In order to allow the user to pass in a single example without a batch # dimension, we always expand the input to at least two dimensions, then # fix the output shape to remove the batch dimension if necessary. - original_shape = inputs.shape - if len(original_shape) < 2: - inputs = tf.reshape(inputs, [1, -1]) + # original_shape = inputs.shape + # if len(original_shape) < 2: + # inputs = tf1.reshape(inputs, [1, -1]) out, state = self.lstm_cell(inputs, state) out = self.output_layer(out) - correct_shape = tf.concat((original_shape[:-1], tf.shape(input=out)[-1:]), - 0) - out = tf.reshape(out, correct_shape) + # correct_shape = tf1.concat( + # (original_shape[:-1], tf1.shape(input=out)[-1:]), 0) + # out = tf1.reshape(out, correct_shape) loc = out[..., :self.dimensions] - scale_diag = tf.nn.softplus(out[..., self.dimensions:]) + 1e-5 # keep > 0 + scale_diag = tf1.nn.softplus(out[..., self.dimensions:]) + 1e-5 # keep > 0 return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag), state -class Decoder(tf.keras.Model): +class Decoder(tf_keras.v1.Model): """Probabilistic decoder for `p(x_t | z_t, f)`. The decoder generates a sequence of image frames `x_{1:T}` from @@ -341,11 +342,11 @@ def __init__(self, hidden_size, channels=3): """ super(Decoder, self).__init__() self.hidden_size = hidden_size - activation = tf.nn.leaky_relu - self.dense = tf.keras.layers.Dense(hidden_size, activation=activation) + activation = tf1.nn.leaky_relu + self.dense = tf_keras.v1.layers.Dense(hidden_size, activation=activation) # Spatial sizes: (1,1) -> (8,8) -> (16,16) -> (32,32) -> (64,64). - conv_transpose = functools.partial( - tf.keras.layers.Conv2DTranspose, padding="SAME", activation=activation) + conv_transpose = functools.partial(tf_keras.v1.layers.Conv2DTranspose, + padding="SAME", activation=activation) self.conv_transpose1 = conv_transpose(256, 8, 1, padding="VALID") self.conv_transpose2 = conv_transpose(256, 3, 2) self.conv_transpose3 = conv_transpose(256, 3, 2) @@ -367,27 +368,27 @@ def call(self, inputs): batch_size, timesteps, height, width, channels]. """ # We explicitly broadcast f to the same shape as z other than the final - # dimension, because `tf.concat` can't automatically do this. + # dimension, because `tf1.concat` can't automatically do this. dynamic, static = inputs - timesteps = tf.shape(input=dynamic)[-2] - static = static[..., tf.newaxis, :] + tf.zeros([timesteps, 1]) - latents = tf.concat([dynamic, static], axis=-1) # (sample, N, T, latents) + timesteps = tf1.shape(input=dynamic)[-2] + static = static[..., tf1.newaxis, :] + tf1.zeros([timesteps, 1]) + latents = tf1.concat([dynamic, static], axis=-1) # (sample, N, T, latents) out = self.dense(latents) - out = tf.reshape(out, (-1, 1, 1, self.hidden_size)) + out = tf1.reshape(out, (-1, 1, 1, self.hidden_size)) out = self.conv_transpose1(out) out = self.conv_transpose2(out) out = self.conv_transpose3(out) out = self.conv_transpose4(out) # (sample*N*T, h, w, c) - expanded_shape = tf.concat( - (tf.shape(input=latents)[:-1], tf.shape(input=out)[1:]), axis=0) - out = tf.reshape(out, expanded_shape) # (sample, N, T, h, w, c) + expanded_shape = tf1.concat( + (tf1.shape(input=latents)[:-1], tf1.shape(input=out)[1:]), axis=0) + out = tf1.reshape(out, expanded_shape) # (sample, N, T, h, w, c) return tfd.Independent( distribution=tfd.Normal(loc=out, scale=1.), reinterpreted_batch_ndims=3, # wrap (h, w, c) name="decoded_image") -class Compressor(tf.keras.Model): +class Compressor(tf_keras.v1.Model): """Feature extractor. This convolutional model aims to extract features corresponding to a @@ -408,7 +409,7 @@ def __init__(self, hidden_size): self.hidden_size = hidden_size # Spatial sizes: (64,64) -> (32,32) -> (16,16) -> (8,8) -> (1,1). conv = functools.partial( - tf.keras.layers.Conv2D, padding="SAME", activation=tf.nn.leaky_relu) + tf_keras.v1.layers.Conv2D, padding="SAME", activation=tf1.nn.leaky_relu) self.conv1 = conv(256, 3, 2) self.conv2 = conv(256, 3, 2) self.conv3 = conv(256, 3, 2) @@ -426,18 +427,18 @@ def call(self, inputs): A batch of intermediate representations of shape [sample_shape, batch_size, timesteps, hidden_size]. """ - image_shape = tf.shape(input=inputs)[-3:] - collapsed_shape = tf.concat(([-1], image_shape), axis=0) - out = tf.reshape(inputs, collapsed_shape) # (sample*batch*T, h, w, c) + image_shape = tf1.shape(input=inputs)[-3:] + collapsed_shape = tf1.concat(([-1], image_shape), axis=0) + out = tf1.reshape(inputs, collapsed_shape) # (sample*batch*T, h, w, c) out = self.conv1(out) out = self.conv2(out) out = self.conv3(out) out = self.conv4(out) - expanded_shape = tf.concat((tf.shape(input=inputs)[:-3], [-1]), axis=0) - return tf.reshape(out, expanded_shape) # (sample, batch, T, hidden) + expanded_shape = tf1.concat((tf1.shape(input=inputs)[:-3], [-1]), axis=0) + return tf1.reshape(out, expanded_shape) # (sample, batch, T, hidden) -class EncoderStatic(tf.keras.Model): +class EncoderStatic(tf_keras.v1.Model): """Probabilistic encoder for the time-invariant latent variable `f`. The conditional distribution `q(f | x_{1:T})` is a multivariate @@ -476,10 +477,10 @@ def __init__(self, latent_size, hidden_size): super(EncoderStatic, self).__init__() self.latent_size = latent_size self.hidden_size = hidden_size - self.bilstm = tf.keras.layers.Bidirectional( - tf.keras.layers.LSTM(hidden_size), + self.bilstm = tf_keras.v1.layers.Bidirectional( + tf_keras.v1.layers.LSTM(hidden_size), merge_mode="sum") - self.output_layer = tf.keras.layers.Dense(2*latent_size) + self.output_layer = tf_keras.v1.layers.Dense(2*latent_size) def call(self, inputs): """Runs the model to generate a distribution `q(f | x_{1:T})`. @@ -500,18 +501,18 @@ def call(self, inputs): """ # TODO(dusenberrymw): Remove these reshaping commands after b/113126249 is # fixed. - collapsed_shape = tf.concat(([-1], tf.shape(input=inputs)[-2:]), axis=0) - out = tf.reshape(inputs, collapsed_shape) # (sample*batch_size, T, hidden) + collapsed_shape = tf1.concat(([-1], tf1.shape(input=inputs)[-2:]), axis=0) + out = tf1.reshape(inputs, collapsed_shape) # (sample*batch_size, T, hidden) out = self.bilstm(out) # (sample*batch_size, hidden) - expanded_shape = tf.concat((tf.shape(input=inputs)[:-2], [-1]), axis=0) - out = tf.reshape(out, expanded_shape) # (sample, batch_size, hidden) + expanded_shape = tf1.concat((tf1.shape(input=inputs)[:-2], [-1]), axis=0) + out = tf1.reshape(out, expanded_shape) # (sample, batch_size, hidden) out = self.output_layer(out) # (sample, batch_size, 2*latent_size) loc = out[..., :self.latent_size] - scale_diag = tf.nn.softplus(out[..., self.latent_size:]) + 1e-5 # keep > 0 + scale_diag = tf1.nn.softplus(out[..., self.latent_size:]) + 1e-5 # keep > 0 return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag) -class EncoderDynamicFactorized(tf.keras.Model): +class EncoderDynamicFactorized(tf_keras.v1.Model): """Probabilistic encoder for the time-variant latent variable `z_t`. The conditional distribution `q(z_t | x_t)` is a multivariate normal @@ -542,8 +543,9 @@ def __init__(self, latent_size, hidden_size): super(EncoderDynamicFactorized, self).__init__() self.latent_size = latent_size self.hidden_size = hidden_size - self.dense = tf.keras.layers.Dense(hidden_size, activation=tf.nn.leaky_relu) - self.output_layer = tf.keras.layers.Dense(2*latent_size) + self.dense = tf_keras.v1.layers.Dense(hidden_size, + activation=tf1.nn.leaky_relu) + self.output_layer = tf_keras.v1.layers.Dense(2*latent_size) def call(self, inputs): """Runs the model to generate a distribution `q(z_{1:T} | x_{1:T})`. @@ -562,11 +564,11 @@ def call(self, inputs): out = self.dense(inputs) # (..., batch, time, hidden) out = self.output_layer(out) # (..., batch, time, 2*latent) loc = out[..., :self.latent_size] - scale_diag = tf.nn.softplus(out[..., self.latent_size:]) + 1e-5 # keep > 0 + scale_diag = tf1.nn.softplus(out[..., self.latent_size:]) + 1e-5 # keep > 0 return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag) -class EncoderDynamicFull(tf.keras.Model): +class EncoderDynamicFull(tf_keras.v1.Model): """Probabilistic encoder for the time-variant latent variable `z_t`. The conditional distribution `q(z_{1:T} | x_{1:T}, f)` is a @@ -601,11 +603,11 @@ def __init__(self, latent_size, hidden_size): super(EncoderDynamicFull, self).__init__() self.latent_size = latent_size self.hidden_size = hidden_size - self.bilstm = tf.keras.layers.Bidirectional( - tf.keras.layers.LSTM(hidden_size, return_sequences=True), + self.bilstm = tf_keras.v1.layers.Bidirectional( + tf_keras.v1.layers.LSTM(hidden_size, return_sequences=True), merge_mode="sum") - self.rnn = tf.keras.layers.SimpleRNN(hidden_size, return_sequences=True) - self.output_layer = tf.keras.layers.Dense(2*latent_size) + self.rnn = tf_keras.v1.layers.SimpleRNN(hidden_size, return_sequences=True) + self.output_layer = tf_keras.v1.layers.Dense(2*latent_size) def call(self, inputs): """Runs the model to generate a distribution `q(z_{1:T} | x_{1:T}, f)`. @@ -629,37 +631,37 @@ def call(self, inputs): sample. """ # We explicitly broadcast `x` and `f` to the same shape other than the final - # dimension, because `tf.concat` can't automatically do this. This will + # dimension, because `tf1.concat` can't automatically do this. This will # entail adding a `timesteps` dimension to `f` to give the shape `(..., # batch, timesteps, latent)`, and then broadcasting the sample shapes of # both tensors to the same shape. features, static_sample = inputs - length = tf.shape(input=features)[-2] - static_sample = static_sample[..., tf.newaxis, :] + tf.zeros([length, 1]) - sample_shape_static = tf.shape(input=static_sample)[:-3] - sample_shape_inputs = tf.shape(input=features)[:-3] - broadcast_shape_inputs = tf.concat((sample_shape_static, [1, 1, 1]), 0) - broadcast_shape_static = tf.concat((sample_shape_inputs, [1, 1, 1]), 0) - features = features + tf.zeros(broadcast_shape_inputs) - static_sample = static_sample + tf.zeros(broadcast_shape_static) + length = tf1.shape(input=features)[-2] + static_sample = static_sample[..., tf1.newaxis, :] + tf1.zeros([length, 1]) + sample_shape_static = tf1.shape(input=static_sample)[:-3] + sample_shape_inputs = tf1.shape(input=features)[:-3] + broadcast_shape_inputs = tf1.concat((sample_shape_static, [1, 1, 1]), 0) + broadcast_shape_static = tf1.concat((sample_shape_inputs, [1, 1, 1]), 0) + features = features + tf1.zeros(broadcast_shape_inputs) + static_sample = static_sample + tf1.zeros(broadcast_shape_static) # `combined` will have shape (..., batch, T, hidden+latent). - combined = tf.concat((features, static_sample), axis=-1) + combined = tf1.concat((features, static_sample), axis=-1) # TODO(dusenberrymw): Remove these reshaping commands after b/113126249 is # fixed. - collapsed_shape = tf.concat(([-1], tf.shape(input=combined)[-2:]), axis=0) - out = tf.reshape(combined, collapsed_shape) + collapsed_shape = tf1.concat(([-1], tf1.shape(input=combined)[-2:]), axis=0) + out = tf1.reshape(combined, collapsed_shape) out = self.bilstm(out) # (sample*batch, T, hidden_size) out = self.rnn(out) # (sample*batch, T, hidden_size) - expanded_shape = tf.concat( - (tf.shape(input=combined)[:-2], tf.shape(input=out)[1:]), axis=0) - out = tf.reshape(out, expanded_shape) # (sample, batch, T, hidden_size) + expanded_shape = tf1.concat( + (tf1.shape(input=combined)[:-2], tf1.shape(input=out)[1:]), axis=0) + out = tf1.reshape(out, expanded_shape) # (sample, batch, T, hidden_size) out = self.output_layer(out) # (sample, batch, T, 2*latent_size) loc = out[..., :self.latent_size] - scale_diag = tf.nn.softplus(out[..., self.latent_size:]) + 1e-5 # keep > 0 + scale_diag = tf1.nn.softplus(out[..., self.latent_size:]) + 1e-5 # keep > 0 return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag) -class DisentangledSequentialVAE(tf.keras.Model): +class DisentangledSequentialVAE(tf_keras.v1.Model): """Disentangled Sequential Variational Autoencoder. The disentangled sequential variational autoencoder posits a generative @@ -812,8 +814,8 @@ def reconstruct(self, inputs, samples=1, sample_static=False, sample shape [sample_shape, samples, batch_size, timesteps, height, width, channels]. """ - batch_size = tf.shape(input=inputs)[-5] - length = len(tf.unstack(inputs, axis=-4)) # hack for graph mode + batch_size = tf1.shape(input=inputs)[-5] + length = len(tf1.unstack(inputs, axis=-4)) # hack for graph mode features = self.compressor(inputs) # (..., batch, timesteps, hidden) @@ -824,7 +826,7 @@ def reconstruct(self, inputs, samples=1, sample_static=False, static_sample, _ = self.sample_static_posterior(features, samples) if swap_static: - static_sample = tf.reverse(static_sample, axis=[1]) + static_sample = tf1.reverse(static_sample, axis=[1]) if sample_dynamic: dynamic_sample, _ = self.sample_dynamic_prior( @@ -834,7 +836,7 @@ def reconstruct(self, inputs, samples=1, sample_static=False, features, samples, static_sample) if swap_dynamic: - dynamic_sample = tf.reverse(dynamic_sample, axis=[1]) + dynamic_sample = tf1.reverse(dynamic_sample, axis=[1]) likelihood = self.decoder((dynamic_sample, static_sample)) return likelihood @@ -856,7 +858,7 @@ def sample_static_prior(self, samples, batch_size, fixed=False): """ dist = self.static_prior() if fixed: # in either case, shape is (samples, batch, latent) - sample = dist.sample((samples, 1)) + tf.zeros([batch_size, 1]) + sample = dist.sample((samples, 1)) + tf1.zeros([batch_size, 1]) else: sample = dist.sample((samples, batch_size)) return sample, dist @@ -913,12 +915,12 @@ def sample_dynamic_prior(self, samples, batch_size, length, fixed=False): scale_diags.append(dist.parameters["scale_diag"]) sample_list.append(sample) - sample = tf.stack(sample_list, axis=2) - loc = tf.stack(locs, axis=2) - scale_diag = tf.stack(scale_diags, axis=2) + sample = tf1.stack(sample_list, axis=2) + loc = tf1.stack(locs, axis=2) + scale_diag = tf1.stack(scale_diags, axis=2) if fixed: # tile along the batch axis - sample = sample + tf.zeros([batch_size, 1, 1]) + sample = sample + tf1.zeros([batch_size, 1, 1]) return sample, tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag) @@ -967,15 +969,15 @@ def image_summary(seqs, name, num=None): num: Integer for the number of examples to visualize. Defaults to all examples. """ - seqs = tf.clip_by_value(seqs, 0., 1.) - seqs = tf.unstack(seqs[:num]) - joined_seqs = [tf.concat(tf.unstack(seq), 1) for seq in seqs] - joined_seqs = tf.expand_dims(tf.concat(joined_seqs, 0), 0) - tf.compat.v2.summary.image( + seqs = tf1.clip_by_value(seqs, 0., 1.) + seqs = tf1.unstack(seqs[:num]) + joined_seqs = [tf1.concat(tf1.unstack(seq), 1) for seq in seqs] + joined_seqs = tf1.expand_dims(tf1.concat(joined_seqs, 0), 0) + tf.summary.image( name, joined_seqs, max_outputs=1, - step=tf.compat.v1.train.get_or_create_global_step()) + step=tf1.train.get_or_create_global_step()) def visualize_reconstruction(inputs, reconstruct, num=3, name="reconstruction"): @@ -989,8 +991,8 @@ def visualize_reconstruction(inputs, reconstruct, num=3, name="reconstruction"): num: Integer for the number of examples to visualize. name: String name of this summary. """ - reconstruct = tf.clip_by_value(reconstruct, 0., 1.) - inputs_and_reconstruct = tf.concat((inputs[:num], reconstruct[:num]), axis=0) + reconstruct = tf1.clip_by_value(reconstruct, 0., 1.) + inputs_and_reconstruct = tf1.concat((inputs[:num], reconstruct[:num]), axis=0) image_summary(inputs_and_reconstruct, name) @@ -1006,9 +1008,9 @@ def visualize_qualitative_analysis(inputs, model, samples=1, batch_size=3, batch_size: Number of sequences to generate. length: Number of timesteps to generate for each sequence. """ - average = lambda dist: tf.reduce_mean( + average = lambda dist: tf1.reduce_mean( input_tensor=dist.mean(), axis=0) # avg over samples - with tf.compat.v1.name_scope("val_reconstruction"): + with tf1.name_scope("val_reconstruction"): reconstruct = functools.partial(model.reconstruct, inputs=inputs, samples=samples) visualize_reconstruction(inputs, average(reconstruct())) @@ -1021,7 +1023,7 @@ def visualize_qualitative_analysis(inputs, model, samples=1, batch_size=3, visualize_reconstruction(inputs, average(reconstruct(swap_dynamic=True)), name="swap_dynamic") - with tf.compat.v1.name_scope("generation"): + with tf1.name_scope("generation"): generate = functools.partial(model.generate, batch_size=batch_size, length=length, samples=samples) image_summary(average(generate(fix_static=True)), "fix_static") @@ -1037,15 +1039,15 @@ def summarize_dist_params(dist, name, name_scope="dist_params"): name: The name of the distribution. name_scope: The name scope of this summary. """ - with tf.compat.v1.name_scope(name_scope): - tf.compat.v2.summary.histogram( + with tf1.name_scope(name_scope): + tf.summary.histogram( name="{}/{}".format(name, "mean"), data=dist.mean(), - step=tf.compat.v1.train.get_or_create_global_step()) - tf.compat.v2.summary.histogram( + step=tf1.train.get_or_create_global_step()) + tf.summary.histogram( name="{}/{}".format(name, "stddev"), data=dist.stddev(), - step=tf.compat.v1.train.get_or_create_global_step()) + step=tf1.train.get_or_create_global_step()) def summarize_mean_in_nats_and_bits(inputs, units, name, @@ -1061,29 +1063,29 @@ def summarize_mean_in_nats_and_bits(inputs, units, name, nats_name_scope: The name scope of the nats summary. bits_name_scope: The name scope of the bits summary. """ - mean = tf.reduce_mean(input_tensor=inputs) - with tf.compat.v1.name_scope(nats_name_scope): - tf.compat.v2.summary.scalar( + mean = tf1.reduce_mean(input_tensor=inputs) + with tf1.name_scope(nats_name_scope): + tf.summary.scalar( name, mean, - step=tf.compat.v1.train.get_or_create_global_step()) - with tf.compat.v1.name_scope(bits_name_scope): - tf.compat.v2.summary.scalar( + step=tf1.train.get_or_create_global_step()) + with tf1.name_scope(bits_name_scope): + tf.summary.scalar( name, - mean / units / tf.math.log(2.), - step=tf.compat.v1.train.get_or_create_global_step()) + mean / units / tf1.math.log(2.), + step=tf1.train.get_or_create_global_step()) def main(argv): del argv # unused - tf.compat.v1.enable_eager_execution() - tf.compat.v1.set_random_seed(FLAGS.seed) + tf1.enable_eager_execution() + tf1.set_random_seed(FLAGS.seed) timestamp = datetime.strftime(datetime.today(), "%y%m%d_%H%M%S") FLAGS.logdir = FLAGS.logdir.format(timestamp=timestamp) FLAGS.model_dir = FLAGS.model_dir.format(timestamp=timestamp) - if not tf.io.gfile.exists(FLAGS.model_dir): - tf.io.gfile.makedirs(FLAGS.model_dir) + if not tf1.io.gfile.exists(FLAGS.model_dir): + tf1.io.gfile.makedirs(FLAGS.model_dir) sprites_data = sprites_dataset.SpritesDataset(fake_data=FLAGS.fake_data) @@ -1093,18 +1095,17 @@ def main(argv): hidden_size=FLAGS.hidden_size, channels=sprites_data.channels, latent_posterior=FLAGS.latent_posterior) - global_step = tf.compat.v1.train.get_or_create_global_step() - optimizer = tf.compat.v1.train.AdamOptimizer( - tf.compat.v1.train.cosine_decay(FLAGS.learning_rate, global_step, - FLAGS.max_steps)) + global_step = tf1.train.get_or_create_global_step() + optimizer = tf1.train.AdamOptimizer( + tf1.train.cosine_decay(FLAGS.learning_rate, global_step, FLAGS.max_steps)) - checkpoint = tf.train.Checkpoint(model=model, global_step=global_step, - optimizer=optimizer) - checkpoint_manager = tf.train.CheckpointManager( + checkpoint = tf1.train.Checkpoint(model=model, global_step=global_step, + optimizer=optimizer) + checkpoint_manager = tf1.train.CheckpointManager( checkpoint, directory=FLAGS.model_dir, max_to_keep=5) checkpoint.restore(checkpoint_manager.latest_checkpoint) - writer = tf.compat.v2.summary.create_file_writer(FLAGS.logdir) + writer = tf.summary.create_file_writer(FLAGS.logdir) writer.set_as_default() dataset = sprites_data.train.map(lambda *x: x[0]).shuffle(1000).repeat() @@ -1112,14 +1113,14 @@ def main(argv): if FLAGS.enable_debug_logging: for inputs in dataset.prefetch(buffer_size=None): - with tf.compat.v2.summary.record_if( - lambda: tf.math.equal(0, global_step % FLAGS.log_steps)): - tf.compat.v2.summary.histogram( + with tf.summary.record_if( + lambda: tf1.math.equal(0, global_step % FLAGS.log_steps)): + tf.summary.histogram( "image", data=inputs, - step=tf.compat.v1.train.get_or_create_global_step()) + step=tf1.train.get_or_create_global_step()) - with tf.GradientTape() as tape: + with tf1.GradientTape() as tape: features = model.compressor(inputs) # (batch, timesteps, hidden) static_sample, static_posterior = model.sample_static_posterior( features, FLAGS.num_samples) # (samples, batch, latent) @@ -1127,7 +1128,7 @@ def main(argv): features, FLAGS.num_samples, static_sample) # (sampl, N, T, latent) likelihood = model.decoder((dynamic_sample, static_sample)) - reconstruction = tf.reduce_mean( # integrate samples + reconstruction = tf1.reduce_mean( # integrate samples input_tensor=likelihood.mean()[:FLAGS.num_reconstruction_samples], axis=0) visualize_reconstruction(inputs, reconstruction, @@ -1146,17 +1147,17 @@ def main(argv): static_prior_log_prob = static_prior.log_prob(static_sample) static_posterior_log_prob = static_posterior.log_prob(static_sample) - dynamic_prior_log_prob = tf.reduce_sum( + dynamic_prior_log_prob = tf1.reduce_sum( input_tensor=dynamic_prior.log_prob(dynamic_sample), axis=-1) # sum time - dynamic_posterior_log_prob = tf.reduce_sum( + dynamic_posterior_log_prob = tf1.reduce_sum( input_tensor=dynamic_posterior.log_prob(dynamic_sample), axis=-1) # sum time - likelihood_log_prob = tf.reduce_sum( + likelihood_log_prob = tf1.reduce_sum( input_tensor=likelihood.log_prob(inputs), axis=-1) # sum time if FLAGS.enable_debug_logging: - with tf.compat.v1.name_scope("log_probs"): + with tf1.name_scope("log_probs"): summarize_mean_in_nats_and_bits( static_prior_log_prob, FLAGS.latent_size_static, "static_prior") summarize_mean_in_nats_and_bits( @@ -1172,40 +1173,40 @@ def main(argv): likelihood_log_prob, sprites_data.frame_size ** 2 * sprites_data.channels * sprites_data.length, "likelihood") - elbo = tf.reduce_mean(input_tensor=static_prior_log_prob - - static_posterior_log_prob + - dynamic_prior_log_prob - - dynamic_posterior_log_prob + likelihood_log_prob) + elbo = tf1.reduce_mean(input_tensor=static_prior_log_prob - + static_posterior_log_prob + + dynamic_prior_log_prob - + dynamic_posterior_log_prob + likelihood_log_prob) loss = -elbo - tf.compat.v2.summary.scalar( + tf.summary.scalar( "elbo", elbo, - step=tf.compat.v1.train.get_or_create_global_step()) + step=tf1.train.get_or_create_global_step()) grads = tape.gradient(loss, model.variables) - grads, global_norm = tf.clip_by_global_norm(grads, FLAGS.clip_norm) + grads, global_norm = tf1.clip_by_global_norm(grads, FLAGS.clip_norm) grads_and_vars = list(zip(grads, model.variables)) # allow reuse in py3 if FLAGS.enable_debug_logging: - with tf.compat.v1.name_scope("grads"): - tf.compat.v2.summary.scalar( + with tf1.name_scope("grads"): + tf.summary.scalar( "global_norm_grads", global_norm, - step=tf.compat.v1.train.get_or_create_global_step()) - tf.compat.v2.summary.scalar( + step=tf1.train.get_or_create_global_step()) + tf.summary.scalar( "global_norm_grads_clipped", - tf.linalg.global_norm(grads), - step=tf.compat.v1.train.get_or_create_global_step()) + tf1.linalg.global_norm(grads), + step=tf1.train.get_or_create_global_step()) for grad, var in grads_and_vars: - with tf.compat.v1.name_scope("grads"): - tf.compat.v2.summary.histogram( + with tf1.name_scope("grads"): + tf.summary.histogram( "{}/grad".format(var.name), data=grad, - step=tf.compat.v1.train.get_or_create_global_step()) - with tf.compat.v1.name_scope("vars"): - tf.compat.v2.summary.histogram( + step=tf1.train.get_or_create_global_step()) + with tf1.name_scope("vars"): + tf.summary.histogram( var.name, data=var, - step=tf.compat.v1.train.get_or_create_global_step()) + step=tf1.train.get_or_create_global_step()) optimizer.apply_gradients(grads_and_vars, global_step) is_log_step = global_step.numpy() % FLAGS.log_steps == 0 @@ -1214,7 +1215,7 @@ def main(argv): checkpoint_manager.save() print("ELBO ({}/{}): {}".format(global_step.numpy(), FLAGS.max_steps, elbo.numpy())) - with tf.compat.v2.summary.record_if(True): + with tf.summary.record_if(True): val_data = sprites_data.test.take(20) inputs = next(iter(val_data.shuffle(20).batch(3)))[0] visualize_qualitative_analysis(inputs, model, diff --git a/tensorflow_probability/examples/jupyter_notebooks/Fitting_DPMM_Using_pSGLD.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Fitting_DPMM_Using_pSGLD.ipynb index b2a9ade5c7..536d04540d 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Fitting_DPMM_Using_pSGLD.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Fitting_DPMM_Using_pSGLD.ipynb @@ -411,8 +411,8 @@ "To update parameters $\\boldsymbol{\\theta}\\equiv\\{\\boldsymbol{\\pi},\\,\\alpha,\\, \\boldsymbol{\\mu_j},\\,\\boldsymbol{\\sigma_j}\\}$ in $t\\,$th iteration with mini-batch size $M$, the update is sampled as:\n", "\n", "$$\\begin{align*}\n", - "\\Delta \\boldsymbol { \\theta } _ { t } \u0026 \\sim \\frac { \\epsilon _ { t } } { 2 } \\bigl[ G \\left( \\boldsymbol { \\theta } _ { t } \\right) \\bigl( \\nabla _ { \\boldsymbol { \\theta } } \\log p \\left( \\boldsymbol { \\theta } _ { t } \\right) \n", - " + \\frac { N } { M } \\sum _ { k = 1 } ^ { M } \\nabla _ \\boldsymbol { \\theta } \\log \\text{GMM}(x_{t_k})\\bigr) + \\sum_\\boldsymbol{\\theta}\\nabla_\\theta G \\left( \\boldsymbol { \\theta } _ { t } \\right) \\bigr]\\\\\n", + "\\Delta \\boldsymbol { \\theta } _ { t } \u0026 \\sim \\frac { \\epsilon _ { t } } { 2 } \\bigl[ G \\left( \\boldsymbol { \\theta } _ { t } \\right) \\bigl( \\nabla _ { \\boldsymbol { \\theta } } \\log p \\left( \\boldsymbol { \\theta } _ { t } \\right) +\n", + " \\frac { N } { M } \\sum _ { k = 1 } ^ { M } \\nabla _ \\boldsymbol { \\theta } \\log \\text{GMM}(x_{t_k})\\bigr) + \\sum_\\boldsymbol{\\theta}\\nabla_\\theta G \\left( \\boldsymbol { \\theta } _ { t } \\right) \\bigr]\\\\\n", "\u0026+ G ^ { \\frac { 1 } { 2 } } \\left( \\boldsymbol { \\theta } _ { t } \\right) \\text { Normal } \\left( \\text{loc}=\\boldsymbol{0} ,\\, \\text{scale}=\\epsilon _ { t }\\boldsymbol{1} \\right)\\\\\n", "\\end{align*}$$\n", "\n", diff --git a/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Latent_Variable_Model.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Latent_Variable_Model.ipynb index 352461a31c..8ae554c36d 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Latent_Variable_Model.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Latent_Variable_Model.ipynb @@ -345,7 +345,7 @@ " unconstrained_observation_noise,\n", " latent_index_points]\n", "\n", - "optimizer = tf.optimizers.Adam(learning_rate=1.0)\n", + "optimizer = tf.keras.optimizers.Adam(learning_rate=1.0)\n", "\n", "@tf.function(autograph=False, jit_compile=True)\n", "def train_model():\n", diff --git a/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Regression_In_TFP.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Regression_In_TFP.ipynb index 2a86903c1e..af1b67a7ec 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Regression_In_TFP.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Regression_In_TFP.ipynb @@ -541,7 +541,7 @@ "source": [ "# Now we optimize the model parameters.\n", "num_iters = 1000\n", - "optimizer = tf.optimizers.Adam(learning_rate=.01)\n", + "optimizer = tf.keras.optimizers.Adam(learning_rate=.01)\n", "\n", "# Use `tf.function` to trace the loss for more efficient evaluation.\n", "@tf.function(autograph=False, jit_compile=False)\n", diff --git a/tensorflow_probability/examples/jupyter_notebooks/Linear_Mixed_Effects_Model_Variational_Inference.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Linear_Mixed_Effects_Model_Variational_Inference.ipynb index b60c89bfe6..874d6fcb97 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Linear_Mixed_Effects_Model_Variational_Inference.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Linear_Mixed_Effects_Model_Variational_Inference.ipynb @@ -800,7 +800,7 @@ }, "outputs": [], "source": [ - "optimizer = tf.optimizers.Adam(learning_rate=1e-2)\n", + "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2)\n", "\n", "losses = tfp.vi.fit_surrogate_posterior(\n", " target_log_prob_fn, \n", diff --git a/tensorflow_probability/examples/jupyter_notebooks/Linear_Mixed_Effects_Models.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Linear_Mixed_Effects_Models.ipynb index 81a7bd6c27..d9fb7b6b5e 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Linear_Mixed_Effects_Models.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Linear_Mixed_Effects_Models.ipynb @@ -743,7 +743,7 @@ " previous_kernel_results=kernel_results)\n", " return next_state, next_kernel_results\n", "\n", - "optimizer = tf.optimizers.Adam(learning_rate=.01)\n", + "optimizer = tf.keras.optimizers.Adam(learning_rate=.01)\n", "\n", "# Set up M-step (gradient descent).\n", "@tf.function(autograph=False, jit_compile=True)\n", diff --git a/tensorflow_probability/examples/jupyter_notebooks/Multiple_changepoint_detection_and_Bayesian_model_selection.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Multiple_changepoint_detection_and_Bayesian_model_selection.ipynb index 6c2139f913..e41f6fe90a 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Multiple_changepoint_detection_and_Bayesian_model_selection.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Multiple_changepoint_detection_and_Bayesian_model_selection.ipynb @@ -317,7 +317,7 @@ "\n", "losses = tfp.math.minimize(\n", " lambda: -log_prob(),\n", - " optimizer=tf.optimizers.Adam(learning_rate=0.1),\n", + " optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),\n", " num_steps=100)\n", "plt.plot(losses)\n", "plt.ylabel('Negative log marginal likelihood')" @@ -740,7 +740,7 @@ "source": [ "losses = tfp.math.minimize(\n", " lambda: -log_prob(),\n", - " optimizer=tf.optimizers.Adam(0.1),\n", + " optimizer=tf.keras.optimizers.Adam(0.1),\n", " num_steps=100)\n", "plt.plot(losses)\n", "plt.ylabel('Negative log marginal likelihood')" diff --git a/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_Regression.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_Regression.ipynb index 0fea808da2..f90231691d 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_Regression.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_Regression.ipynb @@ -289,7 +289,7 @@ "])\n", "\n", "# Do inference.\n", - "model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01), loss=negloglik)\n", + "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=negloglik)\n", "model.fit(x, y, epochs=1000, verbose=False);\n", "\n", "# Profit.\n", @@ -391,7 +391,7 @@ "])\n", "\n", "# Do inference.\n", - "model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01), loss=negloglik)\n", + "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=negloglik)\n", "model.fit(x, y, epochs=1000, verbose=False);\n", "\n", "# Profit.\n", @@ -540,7 +540,7 @@ "])\n", "\n", "# Do inference.\n", - "model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01), loss=negloglik)\n", + "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=negloglik)\n", "model.fit(x, y, epochs=1000, verbose=False);\n", "\n", "# Profit.\n", @@ -650,7 +650,7 @@ "])\n", "\n", "# Do inference.\n", - "model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01), loss=negloglik)\n", + "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=negloglik)\n", "model.fit(x, y, epochs=1000, verbose=False);\n", "\n", "# Profit.\n", @@ -806,7 +806,7 @@ "batch_size = 32\n", "loss = lambda y, rv_y: rv_y.variational_loss(\n", " y, kl_weight=np.array(batch_size, x.dtype) / x.shape[0])\n", - "model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01), loss=loss)\n", + "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=loss)\n", "model.fit(x, y, batch_size=batch_size, epochs=1000, verbose=False)\n", "\n", "# Profit.\n", diff --git a/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_VAE.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_VAE.ipynb index 063a7041d7..71cd8347ed 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_VAE.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_VAE.ipynb @@ -434,7 +434,7 @@ "source": [ "negloglik = lambda x, rv_x: -rv_x.log_prob(x)\n", "\n", - "vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),\n", + "vae.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),\n", " loss=negloglik)\n", "\n", "_ = vae.fit(train_dataset,\n", diff --git a/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_PCA.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_PCA.ipynb index f3c38dc8a5..0de23fb122 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_PCA.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_PCA.ipynb @@ -337,7 +337,7 @@ "target_log_prob_fn = lambda w, z: model.log_prob((w, z, x_train))\n", "losses = tfp.math.minimize(\n", " lambda: -target_log_prob_fn(w, z),\n", - " optimizer=tf.optimizers.Adam(learning_rate=0.05),\n", + " optimizer=tf.keras.optimizers.Adam(learning_rate=0.05),\n", " num_steps=200)" ] }, @@ -479,7 +479,7 @@ "losses = tfp.vi.fit_surrogate_posterior(\n", " target_log_prob_fn,\n", " surrogate_posterior=surrogate_posterior,\n", - " optimizer=tf.optimizers.Adam(learning_rate=0.05),\n", + " optimizer=tf.keras.optimizers.Adam(learning_rate=0.05),\n", " num_steps=200)" ] }, diff --git a/tensorflow_probability/examples/jupyter_notebooks/STS_approximate_inference_for_models_with_non_Gaussian_observations.ipynb b/tensorflow_probability/examples/jupyter_notebooks/STS_approximate_inference_for_models_with_non_Gaussian_observations.ipynb index 7316016f68..6c86b1969b 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/STS_approximate_inference_for_models_with_non_Gaussian_observations.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/STS_approximate_inference_for_models_with_non_Gaussian_observations.ipynb @@ -660,7 +660,7 @@ "t0 = time.time()\n", "losses = tfp.vi.fit_surrogate_posterior(pinned_model.unnormalized_log_prob,\n", " surrogate_posterior,\n", - " optimizer=tf.optimizers.Adam(0.1),\n", + " optimizer=tf.keras.optimizers.Adam(0.1),\n", " num_steps=num_variational_steps)\n", "t1 = time.time()\n", "print(\"Inference ran in {:.2f}s.\".format(t1-t0))" diff --git a/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb index f076e1efd2..34f0d4c5de 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb @@ -95,7 +95,7 @@ "\n", "import numpy as np\n", "import jax\n", - "from jax.config import config\n", + "from jax import config\n", "config.update('jax_enable_x64', True)\n", "\n", "from tensorflow_probability.substrates import jax as tfp\n", diff --git a/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_11_0.ipynb b/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_11_0.ipynb index ee40cea633..28c7a447fe 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_11_0.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_11_0.ipynb @@ -143,7 +143,7 @@ }, "source": [ "import jax\n", - "from jax.config import config\n", + "from jax import config\n", "config.update('jax_enable_x64', True)\n", "\n", "def demo_jax():\n", diff --git a/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_12_1.ipynb b/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_12_1.ipynb index 85728d1589..8bbd6eb75e 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_12_1.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_12_1.ipynb @@ -1237,7 +1237,7 @@ "\r\n", "asvi_losses = tfp.vi.fit_surrogate_posterior(target_log_prob,\r\n", " asvi_surrogate_posterior,\r\n", - " optimizer=tf.optimizers.Adam(learning_rate=0.1),\r\n", + " optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),\r\n", " num_steps=500)\r\n", "logging.getLogger('tensorflow').setLevel(logging.NOTSET)" ] @@ -1255,7 +1255,7 @@ "\r\n", "factored_losses = tfp.vi.fit_surrogate_posterior(target_log_prob,\r\n", " factored_surrogate_posterior,\r\n", - " optimizer=tf.optimizers.Adam(learning_rate=0.1),\r\n", + " optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),\r\n", " num_steps=500)" ] }, diff --git a/tensorflow_probability/examples/jupyter_notebooks/Variational_Inference_and_Joint_Distributions.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Variational_Inference_and_Joint_Distributions.ipynb index 74a15b0a62..604d7c8663 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Variational_Inference_and_Joint_Distributions.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Variational_Inference_and_Joint_Distributions.ipynb @@ -512,7 +512,7 @@ } ], "source": [ - "optimizer = tf.optimizers.Adam(learning_rate=1e-2)\n", + "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2)\n", "mvn_loss = tfp.vi.fit_surrogate_posterior(\n", " target_model.unnormalized_log_prob,\n", " surrogate_posterior,\n", @@ -706,7 +706,7 @@ } ], "source": [ - "optimizer=tf.optimizers.Adam(learning_rate=1e-2)\n", + "optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2)\n", "iaf_loss = tfp.vi.fit_surrogate_posterior(\n", " target_model.unnormalized_log_prob,\n", " iaf_surrogate_posterior,\n", @@ -830,7 +830,7 @@ " mean_field_scale # apply the block matrix transformation to the standard Normal distribution\n", " ]))\n", "\n", - "optimizer=tf.optimizers.Adam(learning_rate=1e-2)\n", + "optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2)\n", "mean_field_loss = tfp.vi.fit_surrogate_posterior(\n", " target_model.unnormalized_log_prob,\n", " mean_field_surrogate_posterior,\n", diff --git a/tensorflow_probability/examples/logistic_regression.py b/tensorflow_probability/examples/logistic_regression.py index 095d362d34..c2171a3e8e 100644 --- a/tensorflow_probability/examples/logistic_regression.py +++ b/tensorflow_probability/examples/logistic_regression.py @@ -25,6 +25,7 @@ import numpy as np import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.internal import tf_keras tf.enable_v2_behavior() @@ -132,7 +133,7 @@ def toy_logistic_data(num_examples, input_size=2, weights_prior_stddev=5.0): return random_weights, random_bias, np.float32(design_matrix), labels -class ToyDataSequence(tf.keras.utils.Sequence): +class ToyDataSequence(tf_keras.utils.Sequence): """Creates a sequence of labeled points from provided numpy arrays.""" def __init__(self, features, labels, batch_size): @@ -177,7 +178,7 @@ def create_model(num_samples, num_dimensions): # parameterized by logits from a single linear layer. We use the Flipout # Monte Carlo estimator for the layer: this enables lower variance # stochastic gradients than naive reparameterization. - input_layer = tf.keras.layers.Input(shape=num_dimensions) + input_layer = tf_keras.layers.Input(shape=num_dimensions) dense_layer = tfp.layers.DenseFlipout( units=1, activation='sigmoid', @@ -186,8 +187,8 @@ def create_model(num_samples, num_dimensions): kernel_divergence_fn=kl_divergence_function)(input_layer) # Model compilation. - model = tf.keras.Model(inputs=input_layer, outputs=dense_layer) - optimizer = tf.keras.optimizers.Adam(lr=FLAGS.learning_rate) + model = tf_keras.Model(inputs=input_layer, outputs=dense_layer) + optimizer = tf_keras.optimizers.Adam(lr=FLAGS.learning_rate) # We use the binary_crossentropy loss since this toy example contains # two labels. The Keras API will then automatically add the # Kullback-Leibler divergence (contained on the individual layers of diff --git a/tensorflow_probability/examples/models/bayesian_resnet.py b/tensorflow_probability/examples/models/bayesian_resnet.py index 1ad4f9be24..8a2c16e824 100644 --- a/tensorflow_probability/examples/models/bayesian_resnet.py +++ b/tensorflow_probability/examples/models/bayesian_resnet.py @@ -16,6 +16,7 @@ import tensorflow.compat.v1 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.internal import tf_keras def bayesian_resnet(input_shape, @@ -42,7 +43,7 @@ def bayesian_resnet(input_shape, i.e. log_var <= log(kernel_posterior_scale_constraint). Returns: - tf.keras.Model. + tf_keras.Model. """ filters = [64, 128, 256, 512] @@ -59,7 +60,7 @@ def _untransformed_scale_constraint(t): stddev=kernel_posterior_scale_stddev), untransformed_scale_constraint=_untransformed_scale_constraint) - image = tf.keras.layers.Input(shape=input_shape, dtype='float32') + image = tf_keras.layers.Input(shape=input_shape, dtype='float32') x = tfp.layers.Convolution2DFlipout( 64, 3, @@ -75,23 +76,23 @@ def _untransformed_scale_constraint(t): strides[i], kernel_posterior_fn) - x = tf.keras.layers.BatchNormalization()(x) - x = tf.keras.layers.Activation('relu')(x) - x = tf.keras.layers.AveragePooling2D(4, 1)(x) - x = tf.keras.layers.Flatten()(x) + x = tf_keras.layers.BatchNormalization()(x) + x = tf_keras.layers.Activation('relu')(x) + x = tf_keras.layers.AveragePooling2D(4, 1)(x) + x = tf_keras.layers.Flatten()(x) x = tfp.layers.DenseFlipout( num_classes, kernel_posterior_fn=kernel_posterior_fn)(x) - model = tf.keras.Model(inputs=image, outputs=x, name='resnet18') + model = tf_keras.Model(inputs=image, outputs=x, name='resnet18') return model def _resnet_block(x, filters, kernel, stride, kernel_posterior_fn): """Network block for ResNet.""" - x = tf.keras.layers.BatchNormalization()(x) - x = tf.keras.layers.Activation('relu')(x) + x = tf_keras.layers.BatchNormalization()(x) + x = tf_keras.layers.Activation('relu')(x) if stride != 1 or filters != x.shape[1]: shortcut = _projection_shortcut(x, filters, stride, kernel_posterior_fn) @@ -104,8 +105,8 @@ def _resnet_block(x, filters, kernel, stride, kernel_posterior_fn): strides=stride, padding='same', kernel_posterior_fn=kernel_posterior_fn)(x) - x = tf.keras.layers.BatchNormalization()(x) - x = tf.keras.layers.Activation('relu')(x) + x = tf_keras.layers.BatchNormalization()(x) + x = tf_keras.layers.Activation('relu')(x) x = tfp.layers.Convolution2DFlipout( filters, @@ -113,7 +114,7 @@ def _resnet_block(x, filters, kernel, stride, kernel_posterior_fn): strides=1, padding='same', kernel_posterior_fn=kernel_posterior_fn)(x) - x = tf.keras.layers.add([x, shortcut]) + x = tf_keras.layers.add([x, shortcut]) return x diff --git a/tensorflow_probability/examples/models/bayesian_vgg.py b/tensorflow_probability/examples/models/bayesian_vgg.py index 339e4e6015..f3a8826e9e 100644 --- a/tensorflow_probability/examples/models/bayesian_vgg.py +++ b/tensorflow_probability/examples/models/bayesian_vgg.py @@ -16,6 +16,7 @@ import tensorflow.compat.v1 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.internal import tf_keras def bayesian_vgg(input_shape, @@ -42,7 +43,7 @@ def bayesian_vgg(input_shape, i.e. log_var <= log(kernel_posterior_scale_constraint). Returns: - tf.keras.Model. + tf_keras.Model. """ filters = [64, 128, 256, 512, 512] @@ -59,7 +60,7 @@ def _untransformed_scale_constraint(t): stddev=kernel_posterior_scale_stddev), untransformed_scale_constraint=_untransformed_scale_constraint) - image = tf.keras.layers.Input(shape=input_shape, dtype='float32') + image = tf_keras.layers.Input(shape=input_shape, dtype='float32') x = image for i in range(len(kernels)): @@ -70,11 +71,11 @@ def _untransformed_scale_constraint(t): strides[i], kernel_posterior_fn) - x = tf.keras.layers.Flatten()(x) + x = tf_keras.layers.Flatten()(x) x = tfp.layers.DenseFlipout( num_classes, kernel_posterior_fn=kernel_posterior_fn)(x) - model = tf.keras.Model(inputs=image, outputs=x, name='vgg16') + model = tf_keras.Model(inputs=image, outputs=x, name='vgg16') return model @@ -85,17 +86,17 @@ def _vggconv_block(x, filters, kernel, stride, kernel_posterior_fn): kernel, padding='same', kernel_posterior_fn=kernel_posterior_fn)(x) - out = tf.keras.layers.BatchNormalization()(out) - out = tf.keras.layers.Activation('relu')(out) + out = tf_keras.layers.BatchNormalization()(out) + out = tf_keras.layers.Activation('relu')(out) out = tfp.layers.Convolution2DFlipout( filters, kernel, padding='same', kernel_posterior_fn=kernel_posterior_fn)(out) - out = tf.keras.layers.BatchNormalization()(out) - out = tf.keras.layers.Activation('relu')(out) + out = tf_keras.layers.BatchNormalization()(out) + out = tf_keras.layers.Activation('relu')(out) - out = tf.keras.layers.MaxPooling2D( + out = tf_keras.layers.MaxPooling2D( pool_size=(2, 2), strides=stride)(out) return out diff --git a/tensorflow_probability/examples/vq_vae.py b/tensorflow_probability/examples/vq_vae.py index 2bb73e6bb6..d2b4e08f35 100644 --- a/tensorflow_probability/examples/vq_vae.py +++ b/tensorflow_probability/examples/vq_vae.py @@ -43,6 +43,7 @@ import tensorflow.compat.v1 as tf from tensorflow_probability import distributions as tfd +from tensorflow_probability.python.internal import tf_keras from tensorflow.contrib.learn.python.learn.datasets import mnist from tensorflow.python.training import moving_averages @@ -174,17 +175,17 @@ def make_encoder(base_depth, activation, latent_size, code_size): `[..., latent_size, code_size]`. """ conv = functools.partial( - tf.keras.layers.Conv2D, padding="SAME", activation=activation) + tf_keras.layers.Conv2D, padding="SAME", activation=activation) - encoder_net = tf.keras.Sequential([ + encoder_net = tf_keras.Sequential([ conv(base_depth, 5, 1), conv(base_depth, 5, 2), conv(2 * base_depth, 5, 1), conv(2 * base_depth, 5, 2), conv(4 * latent_size, 7, padding="VALID"), - tf.keras.layers.Flatten(), - tf.keras.layers.Dense(latent_size * code_size, activation=None), - tf.keras.layers.Reshape([latent_size, code_size]) + tf_keras.layers.Flatten(), + tf_keras.layers.Dense(latent_size * code_size, activation=None), + tf_keras.layers.Reshape([latent_size, code_size]) ]) def encoder(images): @@ -219,11 +220,11 @@ def make_decoder(base_depth, activation, input_size, output_shape): `tfd.Distribution` instance over images. """ deconv = functools.partial( - tf.keras.layers.Conv2DTranspose, padding="SAME", activation=activation) + tf_keras.layers.Conv2DTranspose, padding="SAME", activation=activation) conv = functools.partial( - tf.keras.layers.Conv2D, padding="SAME", activation=activation) - decoder_net = tf.keras.Sequential([ - tf.keras.layers.Reshape((1, 1, input_size)), + tf_keras.layers.Conv2D, padding="SAME", activation=activation) + decoder_net = tf_keras.Sequential([ + tf_keras.layers.Reshape((1, 1, input_size)), deconv(2 * base_depth, 7, padding="VALID"), deconv(2 * base_depth, 5), deconv(2 * base_depth, 5, 2), @@ -231,7 +232,7 @@ def make_decoder(base_depth, activation, input_size, output_shape): deconv(base_depth, 5, 2), deconv(base_depth, 5), conv(output_shape[-1], 5, activation=None), - tf.keras.layers.Reshape(output_shape), + tf_keras.layers.Reshape(output_shape), ]) def decoder(codes): diff --git a/tensorflow_probability/python/__init__.py b/tensorflow_probability/python/__init__.py index a313019fe0..f4742348ad 100644 --- a/tensorflow_probability/python/__init__.py +++ b/tensorflow_probability/python/__init__.py @@ -51,7 +51,7 @@ def _validate_tf_environment(package): # # Update this whenever we need to depend on a newer TensorFlow release. # - required_tensorflow_version = '2.11' + required_tensorflow_version = '2.14' # required_tensorflow_version = '1.15' # Needed internally -- DisableOnExport if (distutils.version.LooseVersion(tf.__version__) < diff --git a/tensorflow_probability/python/bijectors/BUILD b/tensorflow_probability/python/bijectors/BUILD index f049ceca71..14b1dc619f 100644 --- a/tensorflow_probability/python/bijectors/BUILD +++ b/tensorflow_probability/python/bijectors/BUILD @@ -260,6 +260,7 @@ multi_substrate_py_library( deps = [ ":bijector", # tensorflow dep, + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -275,6 +276,7 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensorshape_util", ], @@ -300,6 +302,7 @@ multi_substrate_py_library( deps = [ ":bijector", # tensorflow dep, + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:nest_util", @@ -314,6 +317,7 @@ multi_substrate_py_library( deps = [ ":bijector", # tensorflow dep, + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:nest_util", ], ) @@ -474,6 +478,7 @@ py_library( ":tanh", ":transpose", # tensorflow dep, + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util", ], ) @@ -557,6 +562,8 @@ multi_substrate_py_library( srcs = ["invert.py"], deps = [ ":bijector", + "//tensorflow_probability/python/internal:auto_composite_tensor", + "//tensorflow_probability/python/internal:parameter_properties", ], ) @@ -607,6 +614,7 @@ multi_substrate_py_library( # tensorflow dep, "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:numeric", ], ) @@ -740,6 +748,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util", ], ) @@ -765,6 +774,7 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -820,6 +830,7 @@ multi_substrate_py_library( ":softplus", ":transform_diagonal", # tensorflow dep, + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:tensor_util", ], @@ -1102,6 +1113,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -1201,6 +1213,7 @@ py_test( "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/distributions:transformed_distribution", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -1265,6 +1278,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:sample", "//tensorflow_probability/python/distributions:transformed_distribution", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -1693,6 +1707,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:transformed_distribution", "//tensorflow_probability/python/internal:tensorshape_util", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:gradient", ], ) @@ -1790,6 +1805,7 @@ multi_substrate_py_test( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -1853,6 +1869,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:transformed_distribution", "//tensorflow_probability/python/internal:tensorshape_util", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -1873,6 +1890,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/internal:hypothesis_testlib", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) diff --git a/tensorflow_probability/python/bijectors/batch_normalization.py b/tensorflow_probability/python/bijectors/batch_normalization.py index 1c7619880c..74537c9c6c 100644 --- a/tensorflow_probability/python/bijectors/batch_normalization.py +++ b/tensorflow_probability/python/bijectors/batch_normalization.py @@ -16,10 +16,10 @@ # Dependency imports -import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector +from tensorflow_probability.python.internal import tf_keras __all__ = [ @@ -128,7 +128,7 @@ def __init__(self, Args: batchnorm_layer: `tf.layers.BatchNormalization` layer object. If `None`, - defaults to a `tf.keras.layers.BatchNormalization` with + defaults to a `tf_keras.layers.BatchNormalization` with `gamma_constraint=tf.nn.relu(x) + 1e-6)`. This ensures positivity of the scale variable. @@ -146,7 +146,7 @@ def __init__(self, with tf.name_scope(name) as name: # Scale must be positive. g_constraint = lambda x: tf.nn.relu(x) + 1e-6 - self.batchnorm = batchnorm_layer or tf.keras.layers.BatchNormalization( + self.batchnorm = batchnorm_layer or tf_keras.layers.BatchNormalization( gamma_constraint=g_constraint) self._validate_bn_layer(self.batchnorm) self._training = training @@ -174,11 +174,11 @@ def _validate_bn_layer(self, layer): `tf.layers.BatchNormalization`, or if `batchnorm_layer.renorm=True` or if `batchnorm_layer.virtual_batch_size` is specified. """ - if (not isinstance(layer, tf.keras.layers.BatchNormalization) and - not isinstance(layer, tf1.layers.BatchNormalization)): + if (not isinstance(layer, tf_keras.layers.BatchNormalization) and + not isinstance(layer, tf_keras.tf1_layers.BatchNormalization)): raise ValueError( 'batchnorm_layer must be an instance of ' - '`tf.keras.layers.BatchNormalization` or ' + '`tf_keras.layers.BatchNormalization` or ' '`tf.compat.v1.layers.BatchNormalization`. Got {}'.format( type(layer))) if layer.renorm: diff --git a/tensorflow_probability/python/bijectors/batch_normalization_test.py b/tensorflow_probability/python/bijectors/batch_normalization_test.py index f5b3a50788..7a1604380d 100644 --- a/tensorflow_probability/python/bijectors/batch_normalization_test.py +++ b/tensorflow_probability/python/bijectors/batch_normalization_test.py @@ -29,6 +29,7 @@ from tensorflow_probability.python.distributions import sample from tensorflow_probability.python.distributions import transformed_distribution from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras @test_util.test_all_tf_execution_regimes @@ -68,7 +69,7 @@ def testForwardInverse(self, input_shape, event_dims, training): x_, input_shape if 0 in event_dims else (None,) + input_shape[1:]) # When training, memorize the exact mean of the last # minibatch that it normalized (instead of moving average assignment). - layer = tf.keras.layers.BatchNormalization( + layer = tf_keras.layers.BatchNormalization( axis=event_dims, momentum=0., epsilon=0.) batch_norm = batch_normalization.BatchNormalization( batchnorm_layer=layer, training=training) @@ -140,13 +141,13 @@ def testForwardInverse(self, input_shape, event_dims, training): @parameterized.named_parameters( ("2d_event_ndims_v1", - (10, 4), [-1], False, tf1.layers.BatchNormalization), + (10, 4), [-1], False, tf_keras.tf1_layers.BatchNormalization), ("1d_event_ndims_v1", - 2, [-1], False, tf1.layers.BatchNormalization), + 2, [-1], False, tf_keras.tf1_layers.BatchNormalization), ("2d_event_ndims_keras", - (10, 4), [-1], False, tf.keras.layers.BatchNormalization), + (10, 4), [-1], False, tf_keras.layers.BatchNormalization), ("1d_event_ndims_keras", - 2, [-1], False, tf.keras.layers.BatchNormalization)) + 2, [-1], False, tf_keras.layers.BatchNormalization)) def testLogProb(self, event_shape, event_dims, training, layer_cls): training = tf1.placeholder_with_default(training, (), "training") layer = layer_cls(axis=event_dims, epsilon=0.) @@ -173,8 +174,8 @@ def testLogProb(self, event_shape, event_dims, training, layer_cls): self.assertAllClose(base_log_prob_, dist_log_prob_) @parameterized.named_parameters( - ("v1", tf1.layers.BatchNormalization), - ("keras", tf.keras.layers.BatchNormalization)) + ("v1", tf_keras.tf1_layers.BatchNormalization), + ("keras", tf_keras.layers.BatchNormalization)) def testMutuallyConsistent(self, layer_cls): # BatchNorm bijector is only mutually consistent when training=False. dims = 4 @@ -195,8 +196,8 @@ def testMutuallyConsistent(self, layer_cls): rtol=0.02) @parameterized.named_parameters( - ("v1", tf1.layers.BatchNormalization), - ("keras", tf.keras.layers.BatchNormalization)) + ("v1", tf_keras.tf1_layers.BatchNormalization), + ("keras", tf_keras.layers.BatchNormalization)) def testInvertMutuallyConsistent(self, layer_cls): # BatchNorm bijector is only mutually consistent when training=False. dims = 4 @@ -219,7 +220,7 @@ def testInvertMutuallyConsistent(self, layer_cls): def testWithKeras(self): # NOTE: Keras throws an error below if we use - # tf1.layers.BatchNormalization() here. + # tf_keras.tf1_layers.BatchNormalization() here. layer = None dist = transformed_distribution.TransformedDistribution( @@ -227,9 +228,9 @@ def testWithKeras(self): bijector=batch_normalization.BatchNormalization(batchnorm_layer=layer), validate_args=True) - x_ = tf.keras.Input(shape=(1,)) + x_ = tf_keras.Input(shape=(1,)) log_prob_ = dist.log_prob(x_) - model = tf.keras.Model(x_, log_prob_) + model = tf_keras.Model(x_, log_prob_) model.compile(optimizer="adam", loss=lambda _, log_prob: -log_prob) diff --git a/tensorflow_probability/python/bijectors/bijector_test.py b/tensorflow_probability/python/bijectors/bijector_test.py index 92f7e3d1fb..2d87e67daa 100644 --- a/tensorflow_probability/python/bijectors/bijector_test.py +++ b/tensorflow_probability/python/bijectors/bijector_test.py @@ -46,6 +46,7 @@ from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras JAX_MODE = False @@ -978,7 +979,7 @@ def testJacobianRespectsCache(self, keras): bijector = InverseOnlyBijector(scale=2.) y = tf.constant(10.) if keras: - y = tf.keras.layers.Input(shape=(), dtype=tf.float32, tensor=y) + y = tf_keras.layers.Input(shape=(), dtype=tf.float32, tensor=y) x = bijector.inverse(y) # Forward computation should work here because it should look up # `y` in the cache and call `inverse_log_det_jacobian`. diff --git a/tensorflow_probability/python/bijectors/bijector_test_util.py b/tensorflow_probability/python/bijectors/bijector_test_util.py index 5052f0733a..7243d64c62 100644 --- a/tensorflow_probability/python/bijectors/bijector_test_util.py +++ b/tensorflow_probability/python/bijectors/bijector_test_util.py @@ -28,6 +28,8 @@ from tensorflow_probability.python.internal import test_util as tfp_test_util from tensorflow_probability.python.math.gradient import batch_jacobian +JAX_MODE = False + def assert_finite(array): if not np.isfinite(array).all(): @@ -368,3 +370,28 @@ def _inverse(self, y): def _parameter_properties(cls, dtype): return dict() + +class PytreeShift(bijector_lib.Bijector): + """Mimics a user-defined bijector that is registered as a Pytree.""" + + def __init__(self, shift): + parameters = dict(locals()) + self.shift = shift + super(PytreeShift, self).__init__( + validate_args=True, + forward_min_event_ndims=0, + parameters=parameters, + name='pytree_shift') + + def _forward(self, x): + return x + self.shift + + def _inverse(self, y): + return y - self.shift + +if JAX_MODE: + from jax import tree_util # pylint: disable=g-import-not-at-top, g-bad-import-order + tree_util.register_pytree_node( + PytreeShift, + flatten_func=lambda v: (v.shift, None), + unflatten_func=lambda _, c: PytreeShift(c)) diff --git a/tensorflow_probability/python/bijectors/blockwise.py b/tensorflow_probability/python/bijectors/blockwise.py index 6870acf0a7..cfd81b635d 100644 --- a/tensorflow_probability/python/bijectors/blockwise.py +++ b/tensorflow_probability/python/bijectors/blockwise.py @@ -24,6 +24,7 @@ from tensorflow_probability.python.bijectors import joint_map from tensorflow_probability.python.bijectors import split from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util @@ -278,7 +279,7 @@ def __new__(cls, *args, **kwargs): raise TypeError( '`Blockwise.__new__()` is missing argument `bijectors`.') - if not all(isinstance(b, tf.__internal__.CompositeTensor) + if not all(auto_composite_tensor.is_composite_tensor(b) for b in bijectors): return _Blockwise(*args, **kwargs) return super(Blockwise, cls).__new__(cls) diff --git a/tensorflow_probability/python/bijectors/chain.py b/tensorflow_probability/python/bijectors/chain.py index 8f9e32a810..3158d978b7 100644 --- a/tensorflow_probability/python/bijectors/chain.py +++ b/tensorflow_probability/python/bijectors/chain.py @@ -18,6 +18,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.bijectors import composition +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import prefer_static as ps @@ -161,7 +162,7 @@ def __new__(cls, *args, **kwargs): bijectors = kwargs.get('bijectors') if bijectors is not None: - if not all(isinstance(b, tf.__internal__.CompositeTensor) + if not all(auto_composite_tensor.is_composite_tensor(b) for b in bijectors): return _Chain(*args, **kwargs) return super(Chain, cls).__new__(cls) diff --git a/tensorflow_probability/python/bijectors/cumsum.py b/tensorflow_probability/python/bijectors/cumsum.py index ae7f886a70..8736dc3779 100644 --- a/tensorflow_probability/python/bijectors/cumsum.py +++ b/tensorflow_probability/python/bijectors/cumsum.py @@ -117,5 +117,5 @@ def _forward_log_det_jacobian(self, x): return tf.constant(0., x.dtype) @property - def _compposite_tensor_shape_params(self): + def _composite_tensor_shape_params(self): return ('axis',) diff --git a/tensorflow_probability/python/bijectors/fill_scale_tril.py b/tensorflow_probability/python/bijectors/fill_scale_tril.py index daf5f51527..205254a1e1 100644 --- a/tensorflow_probability/python/bijectors/fill_scale_tril.py +++ b/tensorflow_probability/python/bijectors/fill_scale_tril.py @@ -20,6 +20,7 @@ from tensorflow_probability.python.bijectors import shift from tensorflow_probability.python.bijectors import softplus from tensorflow_probability.python.bijectors import transform_diagonal +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import tensor_util @@ -28,6 +29,8 @@ 'FillScaleTriL', ] +JAX_MODE = False + class FillScaleTriL(chain.Chain): """Transforms unconstrained vectors to TriL matrices with positive diagonal. @@ -103,15 +106,18 @@ def __init__(self, Raises: TypeError, if `diag_bijector` is not an instance of - `tf.__internal__.CompositeTensor`. + `tf.__internal__.CompositeTensor` (or a pytree in JAX mode). """ parameters = dict(locals()) with tf.name_scope(name) as name: if diag_bijector is None: diag_bijector = softplus.Softplus(validate_args=validate_args) - if not isinstance(diag_bijector, tf.__internal__.CompositeTensor): - raise TypeError('`diag_bijector` must be an instance of ' - '`tf.__internal__.CompositeTensor`.') + if not auto_composite_tensor.is_composite_tensor(diag_bijector): + if JAX_MODE: + raise TypeError('`diag_bijector` must be a pytree.') + else: + raise TypeError('`diag_bijector` must be an instance of ' + '`tf.__internal__.CompositeTensor`.') if diag_shift is not None: dtype = dtype_util.common_dtype([diag_bijector, diag_shift], tf.float32) diff --git a/tensorflow_probability/python/bijectors/generalized_pareto.py b/tensorflow_probability/python/bijectors/generalized_pareto.py index eafaf01ca2..822645104e 100644 --- a/tensorflow_probability/python/bijectors/generalized_pareto.py +++ b/tensorflow_probability/python/bijectors/generalized_pareto.py @@ -101,38 +101,64 @@ def scale(self): def concentration(self): return self._concentration - def _negative_concentration_bijector(self): + def _classify_conc(self): + scale_div_conc = self.scale / self.concentration + # Guard against overflow when scale >> concentration + use_negative = (self._concentration < 0.) & tf.math.is_finite( + scale_div_conc + ) + return use_negative, tf.where( + use_negative, scale_div_conc, tf.ones_like(scale_div_conc) + ) + + def _negative_concentration_bijector(self, scale_div_conc=None): # Constructed dynamically so that `loc + scale / concentration` is # tape-safe. + if scale_div_conc is None: + scale_div_conc = self.scale / self.concentration loc = tf.convert_to_tensor(self.loc) - high = loc + tf.math.abs(self.scale / self.concentration) + high = loc + tf.math.abs(scale_div_conc) return sigmoid_bijector.Sigmoid( low=loc, high=high, validate_args=self.validate_args) def _forward(self, x): - return tf.where(self._concentration < 0., - self._negative_concentration_bijector().forward(x), - self._non_negative_concentration_bijector.forward(x)) + use_negative, scale_div_conc = self._classify_conc() + return tf.where( + use_negative, + self._negative_concentration_bijector(scale_div_conc).forward(x), + self._non_negative_concentration_bijector.forward(x), + ) def _inverse(self, y): - return tf.where(self._concentration < 0., - self._negative_concentration_bijector().inverse(y), - self._non_negative_concentration_bijector.inverse(y)) + use_negative, scale_div_conc = self._classify_conc() + return tf.where( + use_negative, + self._negative_concentration_bijector(scale_div_conc).inverse(y), + self._non_negative_concentration_bijector.inverse(y), + ) def _forward_log_det_jacobian(self, x): event_ndims = self.forward_min_event_ndims + use_negative, scale_div_conc = self._classify_conc() return tf.where( - self._concentration < 0., - self._negative_concentration_bijector().forward_log_det_jacobian( - x, event_ndims=event_ndims), + use_negative, + self._negative_concentration_bijector( + scale_div_conc + ).forward_log_det_jacobian(x, event_ndims=event_ndims), self._non_negative_concentration_bijector.forward_log_det_jacobian( - x, event_ndims=event_ndims)) + x, event_ndims=event_ndims + ), + ) def _inverse_log_det_jacobian(self, y): event_ndims = self.inverse_min_event_ndims + use_negative, scale_div_conc = self._classify_conc() return tf.where( - self._concentration < 0., - self._negative_concentration_bijector().inverse_log_det_jacobian( - y, event_ndims=event_ndims), + use_negative, + self._negative_concentration_bijector( + scale_div_conc + ).inverse_log_det_jacobian(y, event_ndims=event_ndims), self._non_negative_concentration_bijector.inverse_log_det_jacobian( - y, event_ndims=event_ndims)) + y, event_ndims=event_ndims + ), + ) diff --git a/tensorflow_probability/python/bijectors/generalized_pareto_test.py b/tensorflow_probability/python/bijectors/generalized_pareto_test.py index 936a272f20..44c23c49a3 100644 --- a/tensorflow_probability/python/bijectors/generalized_pareto_test.py +++ b/tensorflow_probability/python/bijectors/generalized_pareto_test.py @@ -42,6 +42,15 @@ def testScalarCongruencyNegativeConcentration(self): eval_func=self.evaluate, rtol=.1) + def testScalarCongruencyTinyNegativeConcentration(self): + bijector_test_util.assert_scalar_congruency( + generalized_pareto.GeneralizedPareto( + loc=1., scale=8., concentration=-2e-38, validate_args=True), + lower_x=-7., + upper_x=7., + eval_func=self.evaluate, + rtol=.2) + def testBijectiveAndFinitePositiveConcentration(self): loc = 5. x = np.linspace(-10., 20., 20).astype(np.float32) diff --git a/tensorflow_probability/python/bijectors/glow.py b/tensorflow_probability/python/bijectors/glow.py index a2b4d9727e..bdcd5cde42 100644 --- a/tensorflow_probability/python/bijectors/glow.py +++ b/tensorflow_probability/python/bijectors/glow.py @@ -34,10 +34,11 @@ from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static from tensorflow_probability.python.internal import tensorshape_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.util.deferred_tensor import TransformedVariable from tensorflow_probability.python.util.seed_stream import SeedStream -tfk = tf.keras +tfk = tf_keras tfkl = tfk.layers __all__ = [ @@ -859,15 +860,15 @@ def __init__(self, input_shape, num_hidden=400, kernel_shape=3): conv_last = functools.partial( tfkl.Conv2D, padding='same', - kernel_initializer=tf.initializers.zeros(), - bias_initializer=tf.initializers.zeros()) + kernel_initializer=tf_keras.initializers.zeros(), + bias_initializer=tf_keras.initializers.zeros()) super(GlowDefaultNetwork, self).__init__([ tfkl.Input(shape=input_shape), tfkl.Conv2D(num_hidden, kernel_shape, padding='same', - kernel_initializer=tf.initializers.he_normal(), + kernel_initializer=tf_keras.initializers.he_normal(), activation='relu'), tfkl.Conv2D(num_hidden, 1, padding='same', - kernel_initializer=tf.initializers.he_normal(), + kernel_initializer=tf_keras.initializers.he_normal(), activation='relu'), conv_last(this_nchan, kernel_shape) ]) @@ -886,8 +887,8 @@ def __init__(self, input_shape, output_chan, kernel_shape=3): conv = functools.partial( tfkl.Conv2D, padding='same', - kernel_initializer=tf.initializers.zeros(), - bias_initializer=tf.initializers.zeros()) + kernel_initializer=tf_keras.initializers.zeros(), + bias_initializer=tf_keras.initializers.zeros()) super(GlowDefaultExitNetwork, self).__init__([ tfkl.Input(input_shape), diff --git a/tensorflow_probability/python/bijectors/glow_test.py b/tensorflow_probability/python/bijectors/glow_test.py index 735d365ce7..37903ea362 100644 --- a/tensorflow_probability/python/bijectors/glow_test.py +++ b/tensorflow_probability/python/bijectors/glow_test.py @@ -29,6 +29,7 @@ from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math.gradient import batch_jacobian @@ -331,14 +332,14 @@ def testDtypes(self): def float64_net(input_shape): input_nchan = input_shape[-1] - return tf.keras.Sequential([ - tf.keras.layers.Input(input_shape, dtype=tf.float64), - tf.keras.layers.Conv2D( + return tf_keras.Sequential([ + tf_keras.layers.Input(input_shape, dtype=tf.float64), + tf_keras.layers.Conv2D( 2 * input_nchan, 3, padding='same', dtype=tf.float64)]) def float64_exit(input_shape, output_chan): - return tf.keras.Sequential([ - tf.keras.layers.Input(input_shape, dtype=tf.float64), - tf.keras.layers.Conv2D( + return tf_keras.Sequential([ + tf_keras.layers.Input(input_shape, dtype=tf.float64), + tf_keras.layers.Conv2D( 2*output_chan, 3, padding='same', dtype=tf.float64)]) float64_bijection = glow.Glow( @@ -359,15 +360,15 @@ def testBijectorFn(self): ims = self._make_images() def shiftfn(input_shape): input_nchan = input_shape[-1] - return tf.keras.Sequential([ - tf.keras.layers.Input(input_shape), - tf.keras.layers.Conv2D( + return tf_keras.Sequential([ + tf_keras.layers.Input(input_shape), + tf_keras.layers.Conv2D( input_nchan, 3, padding='same')]) def shiftexitfn(input_shape, output_chan): - return tf.keras.Sequential([ - tf.keras.layers.Input(input_shape), - tf.keras.layers.Conv2D( + return tf_keras.Sequential([ + tf_keras.layers.Input(input_shape), + tf_keras.layers.Conv2D( output_chan, 3, padding='same')]) shiftonlyglow = glow.Glow( diff --git a/tensorflow_probability/python/bijectors/hypothesis_testlib.py b/tensorflow_probability/python/bijectors/hypothesis_testlib.py index 7cd644c43c..d0b947f06e 100644 --- a/tensorflow_probability/python/bijectors/hypothesis_testlib.py +++ b/tensorflow_probability/python/bijectors/hypothesis_testlib.py @@ -565,7 +565,9 @@ def generalized_pareto_constraint(loc, scale, conc): def constrain(x): conc_ = tf.convert_to_tensor(conc) loc_ = tf.convert_to_tensor(loc) - return tf.where(conc_ >= 0., + # When conc is very small but negative, the maximum of the support is + # infinite, so we treat it as if it were non-negative. + return tf.where((conc_ >= 0.) | ~tf.math.is_finite(scale / conc_), tf.math.softplus(x) + loc_, loc_ - tf.math.sigmoid(x) * scale / conc_) return constrain diff --git a/tensorflow_probability/python/bijectors/invert.py b/tensorflow_probability/python/bijectors/invert.py index 353742ad70..f061a66d2d 100644 --- a/tensorflow_probability/python/bijectors/invert.py +++ b/tensorflow_probability/python/bijectors/invert.py @@ -17,6 +17,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector as bijector_lib +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import parameter_properties __all__ = [ @@ -160,7 +161,7 @@ def __new__(cls, *args, **kwargs): else: raise TypeError('`Invert.__new__()` is missing argument `bijector`.') - if not isinstance(bijector, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(bijector): return _Invert(*args, **kwargs) return super(Invert, cls).__new__(cls) diff --git a/tensorflow_probability/python/bijectors/joint_map.py b/tensorflow_probability/python/bijectors/joint_map.py index 8b1d80b00a..a54ef5156b 100644 --- a/tensorflow_probability/python/bijectors/joint_map.py +++ b/tensorflow_probability/python/bijectors/joint_map.py @@ -17,6 +17,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.bijectors import composition +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import @@ -124,7 +125,7 @@ def __new__(cls, *args, **kwargs): else: bijectors = kwargs.get('bijectors') if bijectors is not None: - if not all(isinstance(b, tf.__internal__.CompositeTensor) + if not all(auto_composite_tensor.is_composite_tensor(b) for b in tf.nest.flatten(bijectors)): return _JointMap(*args, **kwargs) return super(JointMap, cls).__new__(cls) diff --git a/tensorflow_probability/python/bijectors/masked_autoregressive.py b/tensorflow_probability/python/bijectors/masked_autoregressive.py index c83cacb48b..7c1fb5b60d 100644 --- a/tensorflow_probability/python/bijectors/masked_autoregressive.py +++ b/tensorflow_probability/python/bijectors/masked_autoregressive.py @@ -27,6 +27,7 @@ from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math.numeric import clip_by_value_preserve_gradient from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import @@ -87,7 +88,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): is possible that this architecture is suboptimal for your task. To build alternative networks, either change the arguments to `tfp.bijectors.AutoregressiveNetwork` or use some other architecture, e.g., - using `tf.keras.layers`. + using `tf_keras.layers`. Warning: no attempt is made to validate that the `shift_and_log_scale_fn` enforces the 'autoregressive property'. @@ -215,7 +216,7 @@ def inverse(y): track variables used inside `shift_and_log_scale_fn` or `bijector_fn`. To get `tfb.MaskedAutoregressiveFlow` to track such variables, either: - 1. Replace the Python function with a `tf.Module`, `tf.keras.Layer`, + 1. Replace the Python function with a `tf.Module`, `tf_keras.Layer`, or other callable object through which `tf.Module` can find variables. 2. Or, add a reference to the variables to the `tfb.MaskedAutoregressiveFlow` @@ -482,7 +483,7 @@ def masked_initializer(shape, dtype=None, partition_info=None): return mask * kernel_initializer(shape, dtype, partition_info) with tf.name_scope(name or 'masked_dense'): - layer = tf1.layers.Dense( + layer = tf_keras.tf1_layers.Dense( units, kernel_initializer=masked_initializer, kernel_constraint=lambda x: mask * x, @@ -621,7 +622,7 @@ def _fn(x): return tf1.make_template(name, _fn) -class AutoregressiveNetwork(tf.keras.layers.Layer): +class AutoregressiveNetwork(tf_keras.layers.Layer): r"""Masked Autoencoder for Distribution Estimation [Germain et al. (2015)][1]. A `AutoregressiveNetwork` takes as input a Tensor of shape `[..., event_size]` @@ -664,7 +665,7 @@ class AutoregressiveNetwork(tf.keras.layers.Layer): log_prob_ = distribution.log_prob(x_) model = tfk.Model(x_, log_prob_) - model.compile(optimizer=tf.optimizers.Adam(), + model.compile(optimizer=tf_keras.optimizers.Adam(), loss=lambda _, log_prob: -log_prob) batch_size = 25 @@ -718,7 +719,7 @@ class AutoregressiveNetwork(tf.keras.layers.Layer): x_, bijector_kwargs={'conditional_input': c_}) model = tfk.Model([x_, c_], log_prob_) - model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.1), + model.compile(optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), loss=lambda _, log_prob: -log_prob) batch_size = 25 @@ -780,7 +781,7 @@ class AutoregressiveNetwork(tf.keras.layers.Layer): log_prob_ = distribution.log_prob(x_) model = tfk.Model(x_, log_prob_) - model.compile(optimizer=tf.optimizers.Adam(), + model.compile(optimizer=tf_keras.optimizers.Adam(), loss=lambda _, log_prob: -log_prob) batch_size = 10 @@ -838,7 +839,7 @@ class AutoregressiveNetwork(tf.keras.layers.Layer): log_prob_ = distribution.log_prob(x_) model = tfk.Model(x_, log_prob_) - model.compile(optimizer=tf.optimizers.Adam(), + model.compile(optimizer=tf_keras.optimizers.Adam(), loss=lambda _, log_prob: -log_prob) batch_size = 10 @@ -923,10 +924,10 @@ def __init__(self, hidden_degrees: Method for assigning degrees to the hidden units: 'equal', 'random'. If 'equal', hidden units in each layer are allocated equally (up to a remainder term) to each degree. Default: 'equal'. - activation: An activation function. See `tf.keras.layers.Dense`. Default: + activation: An activation function. See `tf_keras.layers.Dense`. Default: `None`. use_bias: Whether or not the dense layers constructed in this layer - should have a bias term. See `tf.keras.layers.Dense`. Default: `True`. + should have a bias term. See `tf_keras.layers.Dense`. Default: `True`. kernel_initializer: Initializer for the `Dense` kernel weight matrices. Default: 'glorot_uniform'. bias_initializer: Initializer for the `Dense` bias vectors. Default: @@ -944,7 +945,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. **kwargs: Additional keyword arguments passed to this layer (but not to - the `tf.keras.layer.Dense` layers constructed by this layer). + the `tf_keras.layer.Dense` layers constructed by this layer). """ super().__init__(**kwargs) @@ -964,7 +965,7 @@ def __init__(self, self._bias_initializer = bias_initializer self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer - self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) + self._kernel_constraint = tf_keras.constraints.get(kernel_constraint) self._bias_constraint = bias_constraint self._validate_args = validate_args self._kwargs = kwargs @@ -1030,10 +1031,10 @@ def build(self, input_shape): hidden_degrees=self._hidden_degrees, ) - outputs = [tf.keras.Input((self._event_size,), dtype=self.dtype)] + outputs = [tf_keras.Input((self._event_size,), dtype=self.dtype)] inputs = outputs[0] if self._conditional: - conditional_input = tf.keras.Input((self._conditional_size,), + conditional_input = tf_keras.Input((self._conditional_size,), dtype=self.dtype) inputs = [inputs, conditional_input] @@ -1043,7 +1044,7 @@ def build(self, input_shape): # [..., self._hidden_units[-1]] -> [..., event_size * self._params]. layer_output_sizes = self._hidden_units + [self._event_size * self._params] for k in range(len(self._masks)): - autoregressive_output = tf.keras.layers.Dense( + autoregressive_output = tf_keras.layers.Dense( layer_output_sizes[k], activation=None, use_bias=self._use_bias, @@ -1059,7 +1060,7 @@ def build(self, input_shape): if (self._conditional and ((self._conditional_layers == 'all_layers') or ((self._conditional_layers == 'first_layer') and (k == 0)))): - conditional_output = tf.keras.layers.Dense( + conditional_output = tf_keras.layers.Dense( layer_output_sizes[k], activation=None, use_bias=False, @@ -1070,16 +1071,16 @@ def build(self, input_shape): kernel_constraint=self._kernel_constraint, bias_constraint=None, dtype=self.dtype)(conditional_input) - outputs.append(tf.keras.layers.Add()([ + outputs.append(tf_keras.layers.Add()([ autoregressive_output, conditional_output])) else: outputs.append(autoregressive_output) if k + 1 < len(self._masks): outputs.append( - tf.keras.layers.Activation(self._activation) + tf_keras.layers.Activation(self._activation) (outputs[-1])) - self._network = tf.keras.models.Model( + self._network = tf_keras.models.Model( inputs=inputs, outputs=outputs[-1]) # Allow network to be called with inputs of shapes that don't match @@ -1352,11 +1353,11 @@ def _create_masks(degrees): def _make_masked_initializer(mask, initializer): """Returns a masked version of the given initializer.""" - initializer = tf.keras.initializers.get(initializer) + initializer = tf_keras.initializers.get(initializer) def masked_initializer(shape, dtype=None, partition_info=None): # If no `partition_info` is given, then don't pass it to `initializer`, as - # `initializer` may be a `tf.initializers.Initializer` (which don't accept a - # `partition_info` argument). + # `initializer` may be a `tf_keras.initializers.Initializer` (which don't + # accept a `partition_info` argument). if partition_info is None: x = initializer(shape, dtype) else: @@ -1366,7 +1367,7 @@ def masked_initializer(shape, dtype=None, partition_info=None): def _make_masked_constraint(mask, constraint=None): - constraint = tf.keras.constraints.get(constraint) + constraint = tf_keras.constraints.get(constraint) def masked_constraint(x): x = tf.convert_to_tensor(x, dtype_hint=tf.float32, name='x') if constraint is not None: diff --git a/tensorflow_probability/python/bijectors/masked_autoregressive_test.py b/tensorflow_probability/python/bijectors/masked_autoregressive_test.py index 4c4dad6152..11e126fff6 100644 --- a/tensorflow_probability/python/bijectors/masked_autoregressive_test.py +++ b/tensorflow_probability/python/bijectors/masked_autoregressive_test.py @@ -39,10 +39,11 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math import gradient -tfk = tf.keras -tfkl = tf.keras.layers +tfk = tf_keras +tfkl = tf_keras.layers def _funnel_bijector_fn(x): @@ -711,7 +712,7 @@ def test_layer_no_hidden_units(self): self.assertIsAutoregressive(made, event_size=3, order="left-to-right") def test_layer_v2_kernel_initializer(self): - init = tf.keras.initializers.GlorotNormal() + init = tf_keras.initializers.GlorotNormal() made = masked_autoregressive.AutoregressiveNetwork( params=2, event_shape=4, @@ -798,9 +799,9 @@ def test_doc_string_2(self): model = tfk.Model([x_, c_], log_prob_) if tf.__internal__.tf2.enabled() and tf.executing_eagerly(): - optimizer = tf.keras.optimizers.Adam(learning_rate=0.1) + optimizer = tf_keras.optimizers.Adam(learning_rate=0.1) else: - optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=0.1) + optimizer = tf_keras.optimizers.legacy.Adam(learning_rate=0.1) model.compile( optimizer=optimizer, loss=lambda _, log_prob: -log_prob) diff --git a/tensorflow_probability/python/bijectors/permute_test.py b/tensorflow_probability/python/bijectors/permute_test.py index eef1994567..cce4e5b439 100644 --- a/tensorflow_probability/python/bijectors/permute_test.py +++ b/tensorflow_probability/python/bijectors/permute_test.py @@ -22,6 +22,7 @@ from tensorflow_probability.python.bijectors import bijector_test_util from tensorflow_probability.python.bijectors import permute from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras @test_util.test_all_tf_execution_regimes @@ -88,7 +89,7 @@ def testPreservesShape(self): # TODO(b/131157549, b/131124359): Test should not be needed. Consider # deleting when underlying issue with constant eager tensors is fixed. permutation = [2, 1, 0] - x = tf.keras.Input((3,), batch_size=None) + x = tf_keras.Input((3,), batch_size=None) bijector = permute.Permute( permutation=permutation, axis=-1, validate_args=True) diff --git a/tensorflow_probability/python/bijectors/rational_quadratic_spline.py b/tensorflow_probability/python/bijectors/rational_quadratic_spline.py index 2b3f12e785..8c2e2e13ca 100644 --- a/tensorflow_probability/python/bijectors/rational_quadratic_spline.py +++ b/tensorflow_probability/python/bijectors/rational_quadratic_spline.py @@ -100,11 +100,11 @@ def _slopes(x): x = tf.reshape(x, out_shape) return tf.math.softplus(x) + self._min_slope - self._bin_widths = tf.keras.layers.Dense( + self._bin_widths = tf_keras.layers.Dense( nunits * self._nbins, activation=_bin_positions, name='w') - self._bin_heights = tf.keras.layers.Dense( + self._bin_heights = tf_keras.layers.Dense( nunits * self._nbins, activation=_bin_positions, name='h') - self._knot_slopes = tf.keras.layers.Dense( + self._knot_slopes = tf_keras.layers.Dense( nunits * (self._nbins - 1), activation=_slopes, name='s') self._built = True diff --git a/tensorflow_probability/python/bijectors/rational_quadratic_spline_test.py b/tensorflow_probability/python/bijectors/rational_quadratic_spline_test.py index 9e6210bdb7..dbad3a4ed6 100644 --- a/tensorflow_probability/python/bijectors/rational_quadratic_spline_test.py +++ b/tensorflow_probability/python/bijectors/rational_quadratic_spline_test.py @@ -31,6 +31,8 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras + JAX_MODE = False @@ -96,11 +98,11 @@ def _slopes(x): x = tf.reshape(x, out_shape) return tf.math.softplus(x) + 1e-2 - self._bin_widths = tf.keras.layers.Dense( + self._bin_widths = tf_keras.layers.Dense( nunits * self._nbins, activation=_bin_positions, name='w') - self._bin_heights = tf.keras.layers.Dense( + self._bin_heights = tf_keras.layers.Dense( nunits * self._nbins, activation=_bin_positions, name='h') - self._knot_slopes = tf.keras.layers.Dense( + self._knot_slopes = tf_keras.layers.Dense( nunits * (self._nbins - 1), activation=_slopes, name='s') self._built = True diff --git a/tensorflow_probability/python/bijectors/real_nvp.py b/tensorflow_probability/python/bijectors/real_nvp.py index d9b7f5deb1..c51e857a9f 100644 --- a/tensorflow_probability/python/bijectors/real_nvp.py +++ b/tensorflow_probability/python/bijectors/real_nvp.py @@ -23,6 +23,7 @@ from tensorflow_probability.python.bijectors import scale as scale_lib from tensorflow_probability.python.bijectors import shift as shift_lib from tensorflow_probability.python.internal import tensorshape_util +from tensorflow_probability.python.internal import tf_keras __all__ = [ @@ -389,13 +390,13 @@ def _fn(x, output_units, **condition_kwargs): else: reshape_output = lambda x: x for units in hidden_layers: - x = tf1.layers.dense( + x = tf_keras.tf1_layers.dense( inputs=x, units=units, activation=activation, *args, # pylint: disable=keyword-arg-before-vararg **kwargs) - x = tf1.layers.dense( + x = tf_keras.tf1_layers.dense( inputs=x, units=(1 if shift_only else 2) * output_units, activation=None, diff --git a/tensorflow_probability/python/bijectors/real_nvp_test.py b/tensorflow_probability/python/bijectors/real_nvp_test.py index 1af9299353..43dce97222 100644 --- a/tensorflow_probability/python/bijectors/real_nvp_test.py +++ b/tensorflow_probability/python/bijectors/real_nvp_test.py @@ -30,6 +30,7 @@ from tensorflow_probability.python.distributions import transformed_distribution from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras @test_util.test_all_tf_execution_regimes @@ -226,7 +227,7 @@ def _bijector_fn(x, output_units): else: reshape_output = lambda x: x - out = tf1.layers.dense(inputs=x, units=2 * output_units) + out = tf_keras.tf1_layers.dense(inputs=x, units=2 * output_units) shift, logit_gate = tf.split(out, 2, axis=-1) shift = reshape_output(shift) logit_gate = reshape_output(logit_gate) diff --git a/tensorflow_probability/python/build_defs.bzl b/tensorflow_probability/python/build_defs.bzl index 51bff587ca..47de202039 100644 --- a/tensorflow_probability/python/build_defs.bzl +++ b/tensorflow_probability/python/build_defs.bzl @@ -212,6 +212,7 @@ def multi_substrate_py_library( remove_deps = [ "//third_party/py/tensorflow", "//third_party/py/tensorflow:tensorflow", + "//tensorflow_probability/python/internal:tf_keras", ] trimmed_deps = [dep for dep in deps if (dep not in substrates_omit_deps and @@ -337,6 +338,7 @@ def multi_substrate_py_test( remove_deps = [ "//third_party/py/tensorflow", "//third_party/py/tensorflow:tensorflow", + "//tensorflow_probability/python/internal:tf_keras", ] trimmed_deps = [dep for dep in deps if dep not in remove_deps] diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index bed669970a..228f0a6532 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -189,6 +189,8 @@ multi_substrate_py_library( ":distribution", # numpy dep, # tensorflow dep, + "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensorshape_util", @@ -204,6 +206,7 @@ multi_substrate_py_library( # tensorflow dep, "//tensorflow_probability/python/bijectors:bijector", "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensorshape_util", ], @@ -341,6 +344,7 @@ multi_substrate_py_library( ":kullback_leibler", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:reparameterization", @@ -758,6 +762,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:reparameterization", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:linalg", "//tensorflow_probability/python/math/psd_kernels/internal:util", ], @@ -776,6 +781,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:nest_util", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math/psd_kernels:schur_complement", "//tensorflow_probability/python/util", ], @@ -982,6 +988,7 @@ multi_substrate_py_library( ":log_prob_ratio", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", @@ -1151,6 +1158,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/distributions:joint_distribution_coroutine", "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util:seed_stream", ], ) @@ -1433,6 +1441,7 @@ multi_substrate_py_library( # tensorflow dep, "//tensorflow_probability/python/bijectors:identity", "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", @@ -1450,6 +1459,7 @@ multi_substrate_py_library( ":distribution", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", @@ -1470,6 +1480,7 @@ multi_substrate_py_library( ":independent", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", @@ -1849,6 +1860,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:reparameterization", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/layers:weight_norm", ], ) @@ -1970,6 +1982,7 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", @@ -2030,6 +2043,7 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:prefer_static", @@ -2173,7 +2187,6 @@ multi_substrate_py_library( ":cholesky_util", ":distribution", ":multivariate_student_t", - ":student_t", # tensorflow dep, "//tensorflow_probability/python/bijectors:identity", "//tensorflow_probability/python/bijectors:softplus", @@ -2186,6 +2199,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:reparameterization", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:linalg", "//tensorflow_probability/python/math:special", ], @@ -2397,6 +2411,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:linalg", "//tensorflow_probability/python/math/psd_kernels:positive_semidefinite_kernel", "//tensorflow_probability/python/math/psd_kernels/internal:util", @@ -3141,7 +3156,7 @@ multi_substrate_py_test( name = "gaussian_process_regression_model_test", srcs = ["gaussian_process_regression_model_test.py"], jax_size = "medium", - shard_count = 2, + shard_count = 4, deps = [ ":gaussian_process", ":gaussian_process_regression_model", @@ -3614,6 +3629,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/internal:reparameterization", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/layers:distribution_layer", ], ) @@ -4206,6 +4222,7 @@ multi_substrate_py_test( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:gradient", ], ) @@ -4493,7 +4510,6 @@ multi_substrate_py_test( shard_count = 2, deps = [ ":multivariate_student_t", - ":student_t", ":student_t_process", # absl/testing:parameterized dep, # numpy dep, @@ -4532,6 +4548,7 @@ multi_substrate_py_test( tags = ["colab-smoke"], deps = [ ":beta", + ":dirichlet", ":exponential", ":independent", ":joint_distribution_auto_batched", @@ -4544,6 +4561,7 @@ multi_substrate_py_test( ":normal", ":sample", ":transformed_distribution", + ":uniform", # numpy dep, # scipy dep, # tensorflow dep, @@ -4858,6 +4876,9 @@ py_library( # hypothesis dep, # jax dep, # numpy dep, + "//tensorflow_probability/python/bijectors:bijector_test_util.jax", + "//tensorflow_probability/python/distributions:normal.jax", + "//tensorflow_probability/python/distributions:transformed_distribution.jax", "//tensorflow_probability/python/internal:hypothesis_testlib.jax", "//tensorflow_probability/python/internal:reparameterization", "//tensorflow_probability/python/internal:tensor_util.jax", @@ -4981,6 +5002,7 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:prefer_static", diff --git a/tensorflow_probability/python/distributions/batch_broadcast.py b/tensorflow_probability/python/distributions/batch_broadcast.py index dbee015334..6c1bbe835b 100644 --- a/tensorflow_probability/python/distributions/batch_broadcast.py +++ b/tensorflow_probability/python/distributions/batch_broadcast.py @@ -20,6 +20,7 @@ from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.distributions import distribution as distribution_lib from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensor_util @@ -385,7 +386,7 @@ def __new__(cls, *args, **kwargs): else: distribution = kwargs.get('distribution') - if not isinstance(distribution, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(distribution): return _BatchBroadcast(*args, **kwargs) return super(BatchBroadcast, cls).__new__(cls) @@ -473,7 +474,7 @@ def __new__(cls, *args, **kwargs): else: bijector = kwargs.get('bijector') - if not (isinstance(bcast_dist, tf.__internal__.CompositeTensor) - and isinstance(bijector, tf.__internal__.CompositeTensor)): + if not (auto_composite_tensor.is_composite_tensor(bcast_dist) + and auto_composite_tensor.is_composite_tensor(bijector)): return _NonCompositeTensorBroadcastingBijector(*args, **kwargs) return super(_BroadcastingBijector, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/batch_concat.py b/tensorflow_probability/python/distributions/batch_concat.py index 2e4f169a76..0487410cb1 100644 --- a/tensorflow_probability/python/distributions/batch_concat.py +++ b/tensorflow_probability/python/distributions/batch_concat.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.distributions import distribution as distribution_lib from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps @@ -499,7 +500,7 @@ def __new__(cls, *args, **kwargs): else: distributions = kwargs.get('distributions') - if not all(isinstance(d, tf.__internal__.CompositeTensor) + if not all(auto_composite_tensor.is_composite_tensor(d) for d in distributions): return _BatchConcat(*args, **kwargs) return super(BatchConcat, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/batch_reshape.py b/tensorflow_probability/python/distributions/batch_reshape.py index 421d4c4c2d..c39880ca7d 100644 --- a/tensorflow_probability/python/distributions/batch_reshape.py +++ b/tensorflow_probability/python/distributions/batch_reshape.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.distributions import distribution as distribution_lib from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps @@ -486,7 +487,7 @@ def __new__(cls, *args, **kwargs): else: distribution = kwargs.get('distribution') - if not isinstance(distribution, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(distribution): return _BatchReshape(*args, **kwargs) return super(BatchReshape, cls).__new__(cls) @@ -625,6 +626,6 @@ def __new__(cls, *args, **kwargs): else: base_bijector = kwargs.get('base_bijector') - if not isinstance(base_bijector, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(base_bijector): return _NonCompositeTensorBatchReshapeBijector(*args, **kwargs) return super(_BatchReshapeBijector, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/blockwise.py b/tensorflow_probability/python/distributions/blockwise.py index e30306d186..0ab15b4eb6 100644 --- a/tensorflow_probability/python/distributions/blockwise.py +++ b/tensorflow_probability/python/distributions/blockwise.py @@ -22,6 +22,7 @@ from tensorflow_probability.python.distributions import joint_distribution_sequential from tensorflow_probability.python.distributions import kullback_leibler from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps @@ -95,7 +96,7 @@ def __new__(cls, *args, **kwargs): else: distribution = kwargs.get('distribution') - if not isinstance(distribution, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(distribution): return _NonCompositeTensorCast(*args, **kwargs) return super(_Cast, cls).__new__(cls) @@ -430,7 +431,7 @@ def __new__(cls, *args, **kwargs): else: distributions = kwargs.get('distributions') - if not all(isinstance(d, tf.__internal__.CompositeTensor) + if not all(auto_composite_tensor.is_composite_tensor(d) for d in tf.nest.flatten(distributions)): return _Blockwise(*args, **kwargs) return super(Blockwise, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/gaussian_process.py b/tensorflow_probability/python/distributions/gaussian_process.py index 30479f4894..6131b688e4 100644 --- a/tensorflow_probability/python/distributions/gaussian_process.py +++ b/tensorflow_probability/python/distributions/gaussian_process.py @@ -15,7 +15,6 @@ """The GaussianProcess distribution class.""" import functools -import warnings # Dependency imports import numpy as np @@ -50,18 +49,6 @@ JAX_MODE = False -_ALWAYS_YIELD_MVN_DEPRECATION_WARNING = ( - '`always_yield_multivariate_normal` is deprecated. This arg is now ignored' - 'and will be removed after 2023-07-01. A `GaussianProcess` evaluated at a' - 'single index point now always has event shape `[1]` (the previous behavior' - 'for `always_yield_multivariate_normal=True`). To reproduce the previous ' - 'behavior of `always_yield_multivariate_normal=False`, squeeze the ' - 'rightmost singleton dimension from the output of `mean`, `sample`, etc.') - - -_GET_MARGINAL_DISTRIBUTION_ALREADY_WARNED = False - - def make_cholesky_factored_marginal_fn(cholesky_fn): """Construct a `marginal_fn` for use with `tfd.GaussianProcess`. @@ -234,7 +221,7 @@ class GaussianProcess( gp = tfd.GaussianProcess(kernel, observed_index_points) - optimizer = tf.optimizers.Adam() + optimizer = tf_keras.optimizers.Adam() @tf.function def optimize(): @@ -258,10 +245,6 @@ def optimize(): '2021-05-10', '`jitter` is deprecated; please use `marginal_fn` directly.', 'jitter') - @deprecation.deprecated_args( - '2023-07-01', - _ALWAYS_YIELD_MVN_DEPRECATION_WARNING, - 'always_yield_multivariate_normal') def __init__(self, kernel, index_points=None, @@ -270,7 +253,6 @@ def __init__(self, marginal_fn=None, cholesky_fn=None, jitter=1e-6, - always_yield_multivariate_normal=None, validate_args=False, allow_nan_stats=False, parameters=None, @@ -317,7 +299,6 @@ def __init__(self, `marginal_fn` and `cholesky_fn` is None. This argument is ignored if `cholesky_fn` is set. Default value: `1e-6`. - always_yield_multivariate_normal: Deprecated and ignored. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -338,28 +319,40 @@ def __init__(self, """ parameters = dict(locals()) if parameters is None else parameters with tf.name_scope(name) as name: - if tf.nest.is_nested(kernel.feature_ndims): - input_dtype = dtype_util.common_dtype( + input_dtype = dtype_util.common_dtype( + dict( + kernel=kernel, + index_points=index_points, + ), + dtype_hint=nest_util.broadcast_structure( + kernel.feature_ndims, tf.float32 + ), + ) + + # If the input dtype is non-nested float, we infer a single dtype for the + # input and the float parameters, which is also the dtype of the GP's + # samples, log_prob, etc. If the input dtype is nested (or not float), we + # do not use it to infer the GP's float dtype. + if (not tf.nest.is_nested(input_dtype) and + dtype_util.is_floating(input_dtype)): + dtype = dtype_util.common_dtype( dict( kernel=kernel, index_points=index_points, + observation_noise_variance=observation_noise_variance, + jitter=jitter, ), - dtype_hint=nest_util.broadcast_structure( - kernel.feature_ndims, tf.float32 - ), + dtype_hint=tf.float32, ) - dtype = dtype_util.common_dtype( - [observation_noise_variance, jitter], tf.float32) + input_dtype = dtype else: - # If the index points are not nested, we assume they are of the same - # float dtype as the GP. dtype = dtype_util.common_dtype( - { - 'index_points': index_points, - 'observation_noise_variance': observation_noise_variance, - 'jitter': jitter - }, tf.float32) - input_dtype = dtype + dict( + observation_noise_variance=observation_noise_variance, + jitter=jitter, + ), + dtype_hint=tf.float32, + ) if index_points is not None: index_points = nest_util.convert_to_nested_tensor( @@ -395,7 +388,6 @@ def __init__(self, else: self._marginal_fn = marginal_fn - self._always_yield_multivariate_normal = always_yield_multivariate_normal with tf.name_scope('init'): super(GaussianProcess, self).__init__( dtype=dtype, @@ -424,24 +416,6 @@ def get_marginal_distribution(self, index_points=None): marginal: a Normal distribution with vector event shape. """ with self._name_and_control_scope('get_marginal_distribution'): - global _GET_MARGINAL_DISTRIBUTION_ALREADY_WARNED - if (not _GET_MARGINAL_DISTRIBUTION_ALREADY_WARNED and # pylint: disable=protected-access - self._always_yield_multivariate_normal is not None): # pylint: disable=protected-access - warnings.warn( - 'The `always_yield_multivariate_normal` arg to ' - '`GaussianProcess.__init__` is now ignored and ' - '`get_marginal_distribution` always returns a Normal distribution' - 'with vector event shape. This was the previous behavior of' - '`always_yield_multivariate_normal=True`. To recover the behavior' - 'of `always_yield_multivariate_normal=False` when `index_points`' - 'contains a single index point, build a scalar `Normal`' - 'distribution as follows: ' - '`mvn = get_marginal_distribution(index_points); `' - '`norm = tfd.Normal(mvn.loc[..., 0], scale=mvn.stddev()[..., 0])`' - '. To suppress these warnings, build the `GaussianProcess` with ' - '`always_yield_multivariate_normal=True`.', - FutureWarning) - _GET_MARGINAL_DISTRIBUTION_ALREADY_WARNED = True # pylint: disable=protected-access return self._get_marginal_distribution(index_points=index_points) def _get_marginal_distribution(self, index_points=None, is_missing=None): @@ -770,8 +744,6 @@ def posterior_predictive( 'cholesky_fn': self.cholesky_fn, 'mean_fn': self.mean_fn, 'jitter': self.jitter, - 'always_yield_multivariate_normal': - self._always_yield_multivariate_normal, 'validate_args': self.validate_args, 'allow_nan_stats': self.allow_nan_stats } diff --git a/tensorflow_probability/python/distributions/gaussian_process_regression_model.py b/tensorflow_probability/python/distributions/gaussian_process_regression_model.py index d5dd4814df..1df5cb596e 100644 --- a/tensorflow_probability/python/distributions/gaussian_process_regression_model.py +++ b/tensorflow_probability/python/distributions/gaussian_process_regression_model.py @@ -28,7 +28,6 @@ from tensorflow_probability.python.internal import slicing from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.math.psd_kernels import schur_complement -from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import __all__ = [ @@ -36,16 +35,6 @@ ] -_ALWAYS_YIELD_MVN_DEPRECATION_WARNING = ( - '`always_yield_multivariate_normal` is deprecated. This arg is now ignored' - 'and will be removed after 2023-07-01. A `GaussianProcessRegressionModel`' - 'evaluated at a single index point now always has event shape `[1]` (the' - 'previous behavior for `always_yield_multivariate_normal=True`). To' - 'reproduce the previous behavior of' - '`always_yield_multivariate_normal=False`, squeeze the rightmost singleton' - 'dimension from the output of `mean`, `sample`, etc.') - - class GaussianProcessRegressionModel( gaussian_process.GaussianProcess, distribution.AutoCompositeTensorDistribution): @@ -201,7 +190,7 @@ class GaussianProcessRegressionModel( index_points=observation_index_points, observation_noise_variance=observation_noise_variance) - optimizer = tf.optimizers.Adam(learning_rate=.05, beta_1=.5, beta_2=.99) + optimizer = tf_keras.optimizers.Adam(learning_rate=.05, beta_1=.5, beta_2=.99) @tf.function def optimize(): @@ -326,10 +315,6 @@ def run_mcmc(): """ # pylint:disable=invalid-name - @deprecation.deprecated_args( - '2023-07-01', - _ALWAYS_YIELD_MVN_DEPRECATION_WARNING, - 'always_yield_multivariate_normal') def __init__(self, kernel, index_points=None, @@ -340,7 +325,6 @@ def __init__(self, mean_fn=None, cholesky_fn=None, jitter=1e-6, - always_yield_multivariate_normal=None, validate_args=False, allow_nan_stats=False, name='GaussianProcessRegressionModel', @@ -409,7 +393,6 @@ def __init__(self, matrix to ensure positive definiteness of the covariance matrix. This argument is ignored if `cholesky_fn` is set. Default value: `1e-6`. - always_yield_multivariate_normal: Deprecated and ignored. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -432,22 +415,42 @@ def __init__(self, """ parameters = dict(locals()) with tf.name_scope(name) as name: - if tf.nest.is_nested(kernel.feature_ndims): - input_dtype = dtype_util.common_dtype( - [kernel, index_points, observation_index_points], - dtype_hint=nest_util.broadcast_structure( - kernel.feature_ndims, tf.float32)) + input_dtype = dtype_util.common_dtype( + dict( + kernel=kernel, + index_points=index_points, + observation_index_points=observation_index_points, + ), + dtype_hint=nest_util.broadcast_structure( + kernel.feature_ndims, tf.float32)) + + # If the input dtype is non-nested float, we infer a single dtype for the + # input and the float parameters, which is also the dtype of the GP's + # samples, log_prob, etc. If the input dtype is nested (or not float), we + # do not use it to infer the GP's float dtype. + if (not tf.nest.is_nested(input_dtype) and + dtype_util.is_floating(input_dtype)): dtype = dtype_util.common_dtype( - [observations, observation_noise_variance, - predictive_noise_variance, jitter], tf.float32) - else: - # If the index points are not nested, we assume they are of the same - # dtype as the GPRM. - dtype = dtype_util.common_dtype([ - index_points, observation_index_points, observations, - observation_noise_variance, predictive_noise_variance, jitter - ], tf.float32) + dict( + kernel=kernel, + index_points=index_points, + observations=observations, + observation_index_points=observation_index_points, + observation_noise_variance=observation_noise_variance, + predictive_noise_variance=predictive_noise_variance, + jitter=jitter, + ), + dtype_hint=tf.float32, + ) input_dtype = dtype + else: + dtype = dtype_util.common_dtype( + dict( + observations=observations, + observation_noise_variance=observation_noise_variance, + predictive_noise_variance=predictive_noise_variance, + jitter=jitter, + ), dtype_hint=tf.float32) if index_points is not None: index_points = nest_util.convert_to_nested_tensor( @@ -541,7 +544,6 @@ def conditional_mean_fn(x): index_points=index_points, cholesky_fn=cholesky_fn, jitter=jitter, - always_yield_multivariate_normal=always_yield_multivariate_normal, # What the GP super class calls "observation noise variance" we call # here the "predictive noise variance". We use the observation noise # variance for the fit/solve process above, and predictive for @@ -552,10 +554,6 @@ def conditional_mean_fn(x): self._parameters = parameters @staticmethod - @deprecation.deprecated_args( - '2023-07-01', - _ALWAYS_YIELD_MVN_DEPRECATION_WARNING, - 'always_yield_multivariate_normal') def precompute_regression_model( kernel, observation_index_points, @@ -567,7 +565,6 @@ def precompute_regression_model( mean_fn=None, cholesky_fn=None, jitter=1e-6, - always_yield_multivariate_normal=None, validate_args=False, allow_nan_stats=False, name='PrecomputedGaussianProcessRegressionModel', @@ -661,7 +658,6 @@ def precompute_regression_model( jitter: `float` scalar `Tensor` added to the diagonal of the covariance matrix to ensure positive definiteness of the covariance matrix. Default value: `1e-6`. - always_yield_multivariate_normal: Deprecated and ignored. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -773,7 +769,6 @@ def conditional_mean_fn(x): predictive_noise_variance=predictive_noise_variance, cholesky_fn=cholesky_fn, jitter=jitter, - always_yield_multivariate_normal=always_yield_multivariate_normal, _conditional_kernel=conditional_kernel, _conditional_mean_fn=conditional_mean_fn, validate_args=validate_args, diff --git a/tensorflow_probability/python/distributions/generalized_pareto_test.py b/tensorflow_probability/python/distributions/generalized_pareto_test.py index fa779f3fa1..0c30b71124 100644 --- a/tensorflow_probability/python/distributions/generalized_pareto_test.py +++ b/tensorflow_probability/python/distributions/generalized_pareto_test.py @@ -141,6 +141,7 @@ def testCDF(self, dist): loc, scale, conc = self.evaluate([dist.loc, dist.scale, dist.concentration]) hp.assume(abs(loc / scale) < 1e7) + hp.assume((abs(conc) > 1e-12) or (conc == 0.)) expected_cdf = sp_stats.genpareto(conc, loc=loc, scale=scale).cdf(xs) actual_cdf = self.evaluate(cdf) msg = ('Location: {}, scale: {}, concentration: {}, xs: {} ' diff --git a/tensorflow_probability/python/distributions/independent.py b/tensorflow_probability/python/distributions/independent.py index 955a0592f2..892a170a23 100644 --- a/tensorflow_probability/python/distributions/independent.py +++ b/tensorflow_probability/python/distributions/independent.py @@ -23,6 +23,7 @@ from tensorflow_probability.python.distributions import kullback_leibler from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensor_util @@ -365,7 +366,7 @@ def __new__(cls, *args, **kwargs): else: distribution = kwargs.get('distribution') - if not isinstance(distribution, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(distribution): return _Independent(*args, **kwargs) return super(Independent, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/inflated.py b/tensorflow_probability/python/distributions/inflated.py index ca87d81ffb..0e69bee3ab 100644 --- a/tensorflow_probability/python/distributions/inflated.py +++ b/tensorflow_probability/python/distributions/inflated.py @@ -242,7 +242,7 @@ def __new__(cls, *args, **kwargs): else: distribution = kwargs.get('distribution') - if not isinstance(distribution, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(distribution): return _Inflated(*args, **kwargs) return super(Inflated, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/internal/statistical_testing.py b/tensorflow_probability/python/distributions/internal/statistical_testing.py index 75fe286711..2cf7189a3f 100644 --- a/tensorflow_probability/python/distributions/internal/statistical_testing.py +++ b/tensorflow_probability/python/distributions/internal/statistical_testing.py @@ -127,6 +127,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.util.seed_stream import SeedStream @@ -1494,7 +1495,7 @@ def _random_unit_hypersphere(sample_shape, event_shape, dtype, seed): target_shape = tf.concat([sample_shape, event_shape], axis=0) return tf.math.l2_normalize( tf.random.normal(target_shape, seed=seed, dtype=dtype), - axis=-1 - tf.range(tf.size(event_shape))) + axis=-1 - ps.range(ps.size(event_shape))) def assert_multivariate_true_cdf_equal_on_projections_two_sample( diff --git a/tensorflow_probability/python/distributions/jax_transformation_test.py b/tensorflow_probability/python/distributions/jax_transformation_test.py index a2f0da6bd1..3604b4e464 100644 --- a/tensorflow_probability/python/distributions/jax_transformation_test.py +++ b/tensorflow_probability/python/distributions/jax_transformation_test.py @@ -28,7 +28,10 @@ from tensorflow_probability.python.internal import reparameterization from tensorflow_probability.python.internal.backend import jax as tf +from tensorflow_probability.substrates.jax.bijectors import bijector_test_util from tensorflow_probability.substrates.jax.distributions import hypothesis_testlib as dhps +from tensorflow_probability.substrates.jax.distributions import normal +from tensorflow_probability.substrates.jax.distributions import transformed_distribution from tensorflow_probability.substrates.jax.internal import hypothesis_testlib as tfp_hps from tensorflow_probability.substrates.jax.internal import test_util @@ -430,6 +433,30 @@ def dist_and_sample(dist): eligibility_filter=lambda dname: dname not in PYTREE_BLOCKLIST)) dist_and_sample(dist) + def test_user_defined_pytree(self): + k = np.asarray([3]) + pytree_shift = bijector_test_util.PytreeShift(k) + td = transformed_distribution.TransformedDistribution( + normal.Normal(0., 1), bijector=pytree_shift) + leaves, treedef = jax.tree_util.tree_flatten(td) + node_data = treedef.node_data() + + # `td` and `td.bijector` are both Pytrees, but only `td` was registered as a + # Pytree via AutoCompositeTensor. + self.assertFalse(jax.tree_util.treedef_is_leaf(treedef)) + self.assertFalse( + jax.tree_util.treedef_is_leaf(jax.tree_util.tree_structure(td.bijector)) + ) + self.assertIsInstance(td, tf.__internal__.CompositeTensor) + self.assertNotIsInstance(td.bijector, tf.__internal__.CompositeTensor) + + # `"bijector"` is in the tuple of arg names for the Pytree children and not + # the auxiliary data. + self.assertIn('bijector', node_data[1][0]) + # The shift parameter (and both Normal parameters) are leaves. + self.assertLen(leaves, 3) + + if __name__ == '__main__': os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' test_util.main() diff --git a/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py b/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py index c114a7f46c..34081e44eb 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py +++ b/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py @@ -597,8 +597,8 @@ def __new__(cls, *args, **kwargs): # Return a `_JointDistributionSequentialAutoBatched` instance if `model` # contains distributions that are not CompositeTensors. - if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d) - for d in model): + if not all(auto_composite_tensor.is_composite_tensor(d) + or callable(d) for d in model): return _JointDistributionSequentialAutoBatched(*args, **kwargs) return super(JointDistributionSequentialAutoBatched, cls).__new__(cls) @@ -634,8 +634,8 @@ def __new__(cls, *args, **kwargs): # Return a `_JointDistributionNamedAutoBatched` instance if `model` # contains distributions that are not CompositeTensors. - if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d) - for d in tf.nest.flatten(model)): + if not all(auto_composite_tensor.is_composite_tensor(d) + or callable(d) for d in tf.nest.flatten(model)): return _JointDistributionNamedAutoBatched(*args, **kwargs) return super(JointDistributionNamedAutoBatched, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/joint_distribution_named.py b/tensorflow_probability/python/distributions/joint_distribution_named.py index f0bbaeaf18..1ddb49828c 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_named.py +++ b/tensorflow_probability/python/distributions/joint_distribution_named.py @@ -470,8 +470,8 @@ def __new__(cls, *args, **kwargs): else: model = kwargs.get('model') - if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d) - for d in tf.nest.flatten(model)): + if not all(auto_composite_tensor.is_composite_tensor(d) + or callable(d) for d in tf.nest.flatten(model)): return _JointDistributionNamed(*args, **kwargs) return super(JointDistributionNamed, cls).__new__(cls) @@ -509,7 +509,7 @@ def _to_components(self, obj): if self._callable_params: components = [] for d in tf.nest.flatten(obj.model): - if isinstance(d, tf.__internal__.CompositeTensor): + if auto_composite_tensor.is_composite_tensor(d): components.append(d) else: components = obj.model @@ -526,7 +526,7 @@ def _from_components(self, components): def from_instance(cls, obj): model_param_specs, callable_model_params = [], [] for d in tf.nest.flatten(obj.model): - if isinstance(d, tf.__internal__.CompositeTensor): + if auto_composite_tensor.is_composite_tensor(d): model_param_specs.append(d._type_spec) # pylint: disable=protected-access else: callable_model_params.append(d) @@ -556,8 +556,7 @@ def from_instance(cls, obj): # there are no callable elements of `model`, in which case the nested # structure of `model` is recorded in `param_specs`. structure_with_callables = tf.nest.map_structure( - lambda x: (None if isinstance(x, tf.__internal__.CompositeTensor) # pylint: disable=g-long-lambda - else x), + lambda x: None if auto_composite_tensor.is_composite_tensor(x) else x, obj.model) spec._structure_with_callables = structure_with_callables return spec diff --git a/tensorflow_probability/python/distributions/joint_distribution_sequential.py b/tensorflow_probability/python/distributions/joint_distribution_sequential.py index c936788b43..4653ba0941 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_sequential.py +++ b/tensorflow_probability/python/distributions/joint_distribution_sequential.py @@ -54,10 +54,10 @@ class _JointDistributionSequential(joint_distribution_lib.JointDistribution): a single model specification. A joint distribution is a collection of possibly interdependent distributions. - Like `tf.keras.Sequential`, the `JointDistributionSequential` can be specified + Like `tf_keras.Sequential`, the `JointDistributionSequential` can be specified via a `list` of functions (each responsible for making a `tfp.distributions.Distribution`-like instance). Unlike - `tf.keras.Sequential`, each function can depend on the output of all previous + `tf_keras.Sequential`, each function can depend on the output of all previous elements rather than only the immediately previous. #### Mathematical Details @@ -734,8 +734,8 @@ def __new__(cls, *args, **kwargs): else: model = kwargs.get('model') - if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d) - for d in model): + if not all(auto_composite_tensor.is_composite_tensor(d) + or callable(d) for d in model): return _JointDistributionSequential(*args, **kwargs) return super(JointDistributionSequential, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/lambertw_f_test.py b/tensorflow_probability/python/distributions/lambertw_f_test.py index a5e3f6b4e3..95d8da2918 100644 --- a/tensorflow_probability/python/distributions/lambertw_f_test.py +++ b/tensorflow_probability/python/distributions/lambertw_f_test.py @@ -27,6 +27,7 @@ from tensorflow_probability.python.distributions import transformed_distribution from tensorflow_probability.python.distributions import uniform from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras @test_util.test_all_tf_execution_regimes @@ -190,16 +191,16 @@ def dist_lambda(t): from tensorflow_probability.python.layers import distribution_layer # pylint:disable=g-import-not-at-top dist_layer = distribution_layer.DistributionLambda(dist_lambda) - model = tf.keras.Sequential([ - tf.keras.layers.Dense(10, "relu"), - tf.keras.layers.Dense(5, "selu"), - tf.keras.layers.Dense(1 + 1 + 1), + model = tf_keras.Sequential([ + tf_keras.layers.Dense(10, "relu"), + tf_keras.layers.Dense(5, "selu"), + tf_keras.layers.Dense(1 + 1 + 1), dist_layer]) negloglik = lambda y, p_y: -p_y.log_prob(y) if tf.__internal__.tf2.enabled() and tf.executing_eagerly(): - optimizer = tf.keras.optimizers.Adam(learning_rate=0.01) + optimizer = tf_keras.optimizers.Adam(learning_rate=0.01) else: - optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=0.01) + optimizer = tf_keras.optimizers.legacy.Adam(learning_rate=0.01) model.compile(optimizer=optimizer, loss=negloglik) diff --git a/tensorflow_probability/python/distributions/linear_gaussian_ssm.py b/tensorflow_probability/python/distributions/linear_gaussian_ssm.py index f68099f899..0938fbd242 100644 --- a/tensorflow_probability/python/distributions/linear_gaussian_ssm.py +++ b/tensorflow_probability/python/distributions/linear_gaussian_ssm.py @@ -36,7 +36,7 @@ from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.math import linalg -from tensorflow.python.ops import parallel_for # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import tfl = tf.linalg @@ -694,8 +694,8 @@ def _build_model_spec_kwargs_for_parallel_fns(self, sample_shape=(), pass_covariance=False): """Builds a dict of model parameters across all timesteps.""" - kwargs = parallel_for.pfor(self._get_time_varying_kwargs, - self.num_timesteps) + kwargs = control_flow_ops.pfor(self._get_time_varying_kwargs, + self.num_timesteps) # If given a sample shape, encode it as additional batch dimension(s). # It is sufficient to do this for one parameter (we use initial_mean), @@ -1371,7 +1371,7 @@ def pfor_body(t): t=self.initial_step + t, latent_mean=tf.gather(latent_means, t), latent_cov=tf.gather(latent_covs, t)) - observation_means, observation_covs = parallel_for.pfor( + observation_means, observation_covs = control_flow_ops.pfor( pfor_body, self._num_timesteps) observation_means = distribution_util.move_dimension( @@ -1831,7 +1831,7 @@ def linear_gaussian_update( # P* = P - K * H * P # but this is prone to numerical issues because it subtracts a # value from a PSD matrix. We choose instead to use the more - # expensive Jordan form update + # expensive Joseph form update # P* = (I - K H) * P * (I - K H)' + K R K' # which always produces a PSD result. This uses # tmp_term = (I - K * H)' diff --git a/tensorflow_probability/python/distributions/masked.py b/tensorflow_probability/python/distributions/masked.py index be64e074af..27475f2ba2 100644 --- a/tensorflow_probability/python/distributions/masked.py +++ b/tensorflow_probability/python/distributions/masked.py @@ -22,6 +22,7 @@ from tensorflow_probability.python.distributions import kullback_leibler from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers @@ -309,7 +310,7 @@ def __new__(cls, *args, **kwargs): else: distribution = kwargs.get('distribution') - if not isinstance(distribution, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(distribution): return _Masked(*args, **kwargs) return super(Masked, cls).__new__(cls) @@ -463,7 +464,7 @@ def __new__(cls, *args, **kwargs): else: bijector = kwargs.get('underlying_bijector') - if not (isinstance(masked, tf.__internal__.CompositeTensor) - and isinstance(bijector, tf.__internal__.CompositeTensor)): + if not (auto_composite_tensor.is_composite_tensor(masked) + and auto_composite_tensor.is_composite_tensor(bijector)): return _NonCompositeTensorMaskedBijector(*args, **kwargs) return super(_MaskedBijector, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/mixture.py b/tensorflow_probability/python/distributions/mixture.py index 564f576ec9..9f1d4a99fb 100644 --- a/tensorflow_probability/python/distributions/mixture.py +++ b/tensorflow_probability/python/distributions/mixture.py @@ -22,6 +22,7 @@ from tensorflow_probability.python.distributions import categorical from tensorflow_probability.python.distributions import distribution from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps @@ -450,8 +451,8 @@ def __new__(cls, *args, **kwargs): components = kwargs.get('components') _validate_cat_and_components(cat, components) - if not (isinstance(cat, tf.__internal__.CompositeTensor) - and all(isinstance(d, tf.__internal__.CompositeTensor) + if not (auto_composite_tensor.is_composite_tensor(cat) + and all(auto_composite_tensor.is_composite_tensor(d) for d in components)): return _Mixture(*args, **kwargs) return super(Mixture, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/mixture_same_family.py b/tensorflow_probability/python/distributions/mixture_same_family.py index d929769971..8f81bb1548 100644 --- a/tensorflow_probability/python/distributions/mixture_same_family.py +++ b/tensorflow_probability/python/distributions/mixture_same_family.py @@ -23,6 +23,7 @@ from tensorflow_probability.python.distributions import distribution from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties @@ -695,9 +696,9 @@ def __new__(cls, *args, **kwargs): else: components_distribution = kwargs.get('components_distribution') - if not (isinstance(mixture_distribution, tf.__internal__.CompositeTensor) - and isinstance( - components_distribution, tf.__internal__.CompositeTensor)): + if not (auto_composite_tensor.is_composite_tensor(mixture_distribution) + and auto_composite_tensor.is_composite_tensor( + components_distribution)): return _MixtureSameFamily(*args, **kwargs) return super(MixtureSameFamily, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/mvn_tril_test.py b/tensorflow_probability/python/distributions/mvn_tril_test.py index b1f30ca427..3f990b8e44 100644 --- a/tensorflow_probability/python/distributions/mvn_tril_test.py +++ b/tensorflow_probability/python/distributions/mvn_tril_test.py @@ -390,7 +390,7 @@ def testSampleLarge(self): self.assertAllClose(true_mean, sample_mean_, atol=0., rtol=0.03) self.assertAllClose(true_mean, analytical_mean_, atol=0., rtol=1e-6) - self.assertAllClose(true_covariance, sample_covariance_, atol=0., rtol=0.03) + self.assertAllClose(true_covariance, sample_covariance_, atol=0., rtol=0.04) self.assertAllClose( true_covariance, analytical_covariance_, atol=0., rtol=1e-6) diff --git a/tensorflow_probability/python/distributions/pixel_cnn.py b/tensorflow_probability/python/distributions/pixel_cnn.py index 08582f88ce..d1b325a0f8 100644 --- a/tensorflow_probability/python/distributions/pixel_cnn.py +++ b/tensorflow_probability/python/distributions/pixel_cnn.py @@ -30,6 +30,7 @@ from tensorflow_probability.python.internal import prefer_static from tensorflow_probability.python.internal import reparameterization from tensorflow_probability.python.internal import tensorshape_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import weight_norm @@ -103,8 +104,8 @@ class PixelCNN(distribution.Distribution): import tensorflow_probability as tfp tfd = tfp.distributions - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers # Load MNIST from tensorflow_datasets data = tfds.load('mnist') @@ -381,7 +382,7 @@ class labels), or `None`. May have leading batch dimension(s), which must broadcast to the leading batch dimensions of `value`. training: `bool` or `None`. If `bool`, it controls the dropout layer, where `True` implies dropout is active. If `None`, it defaults to - `tf.keras.backend.learning_phase()`. + `tf_keras.backend.learning_phase()`. Returns: log_prob_values: `Tensor`. """ @@ -618,7 +619,7 @@ def _event_shape(self): return tf.TensorShape(self.image_shape) -class _PixelCNNNetwork(tf.keras.layers.Layer): +class _PixelCNNNetwork(tf_keras.layers.Layer): """Keras `Layer` to parameterize a Pixel CNN++ distribution. This is a Keras implementation of the Pixel CNN++ network, as described in @@ -701,33 +702,33 @@ def build(self, input_shape): dtype = self.dtype if len(input_shape) == 2: batch_image_shape, batch_conditional_shape = input_shape - conditional_input = tf.keras.layers.Input( + conditional_input = tf_keras.layers.Input( shape=batch_conditional_shape[1:], dtype=dtype) else: batch_image_shape = input_shape conditional_input = None image_shape = batch_image_shape[1:] - image_input = tf.keras.layers.Input(shape=image_shape, dtype=dtype) + image_input = tf_keras.layers.Input(shape=image_shape, dtype=dtype) if self._resnet_activation == 'concat_elu': - activation = tf.keras.layers.Lambda( + activation = tf_keras.layers.Lambda( lambda x: tf.nn.elu(tf.concat([x, -x], axis=-1)), dtype=dtype) else: - activation = tf.keras.activations.get(self._resnet_activation) + activation = tf_keras.activations.get(self._resnet_activation) # Define layers with default inputs and layer wrapper applied Conv2D = functools.partial( # pylint:disable=invalid-name - self._layer_wrapper(tf.keras.layers.Convolution2D), + self._layer_wrapper(tf_keras.layers.Convolution2D), filters=self._num_filters, padding='same', dtype=dtype) Dense = functools.partial( # pylint:disable=invalid-name - self._layer_wrapper(tf.keras.layers.Dense), dtype=dtype) + self._layer_wrapper(tf_keras.layers.Dense), dtype=dtype) Conv2DTranspose = functools.partial( # pylint:disable=invalid-name - self._layer_wrapper(tf.keras.layers.Conv2DTranspose), + self._layer_wrapper(tf_keras.layers.Conv2DTranspose), filters=self._num_filters, padding='same', strides=(2, 2), @@ -773,7 +774,7 @@ def build(self, input_shape): kernel_constraint=_make_kernel_constraint( (3, cols), (0, 2), (0, cols // 2)))(image_input) - horizontal_stack_init = tf.keras.layers.add( + horizontal_stack_init = tf_keras.layers.add( [horizontal_stack_up, horizontal_stack_left], dtype=dtype) layer_stacks = { @@ -803,10 +804,10 @@ def build(self, input_shape): if stack == 'horizontal': h = activation(layer_stacks['vertical'][-1]) h = Dense(self._num_filters)(h) - x = tf.keras.layers.add([h, x], dtype=dtype) + x = tf_keras.layers.add([h, x], dtype=dtype) x = activation(x) - x = tf.keras.layers.Dropout(self._dropout_p, dtype=dtype)(x) + x = tf_keras.layers.Dropout(self._dropout_p, dtype=dtype)(x) x = Conv2D(filters=2*self._num_filters, kernel_size=kernel_sizes[stack], kernel_constraint=kernel_constraints[stack])(x) @@ -814,12 +815,12 @@ def build(self, input_shape): if conditional_input is not None: h_projection = _build_and_apply_h_projection( conditional_input, self._num_filters, dtype=dtype) - x = tf.keras.layers.add([x, h_projection], dtype=dtype) + x = tf_keras.layers.add([x, h_projection], dtype=dtype) x = _apply_sigmoid_gating(x) # Add a residual connection from the layer's input. - out = tf.keras.layers.add([input_x, x], dtype=dtype) + out = tf_keras.layers.add([input_x, x], dtype=dtype) layer_stacks[stack].append(out) if i < self._num_hierarchies - 1: @@ -872,17 +873,17 @@ def build(self, input_shape): # Include the vertical-stack layer of the upward pass in the layers # to be added to the horizontal layer. if stack == 'horizontal': - x_symmetric = tf.keras.layers.Concatenate(axis=-1, dtype=dtype)( + x_symmetric = tf_keras.layers.Concatenate(axis=-1, dtype=dtype)( [upward_pass['vertical'], x_symmetric]) # Add a skip-connection from the symmetric layer in the downward # pass to the layer `x` in the upward pass. h = activation(x_symmetric) h = Dense(self._num_filters)(h) - x = tf.keras.layers.add([h, x], dtype=dtype) + x = tf_keras.layers.add([h, x], dtype=dtype) x = activation(x) - x = tf.keras.layers.Dropout(self._dropout_p, dtype=dtype)(x) + x = tf_keras.layers.Dropout(self._dropout_p, dtype=dtype)(x) x = Conv2D(filters=2*self._num_filters, kernel_size=kernel_sizes[stack], kernel_constraint=kernel_constraints[stack])(x) @@ -890,10 +891,10 @@ def build(self, input_shape): if conditional_input is not None: h_projection = _build_and_apply_h_projection( conditional_input, self._num_filters, dtype=dtype) - x = tf.keras.layers.add([x, h_projection], dtype=dtype) + x = tf_keras.layers.add([x, h_projection], dtype=dtype) x = _apply_sigmoid_gating(x) - upward_pass[stack] = tf.keras.layers.add([input_x, x], dtype=dtype) + upward_pass[stack] = tf_keras.layers.add([input_x, x], dtype=dtype) # Define deconvolutional layers that expand height/width dimensions on the # upward pass (e.g. expanding from 8x8 to 16x16 in Figure 2 of [1]), with @@ -918,7 +919,7 @@ def build(self, input_shape): kernel_constraint=kernel_constraint)(x) upward_pass[stack] = x - x_out = tf.keras.layers.ELU(dtype=dtype)(upward_pass['horizontal']) + x_out = tf_keras.layers.ELU(dtype=dtype)(upward_pass['horizontal']) # Build final Dense/Reshape layers to output the correct number of # parameters per pixel. @@ -948,7 +949,7 @@ def build(self, input_shape): inputs = (image_input if conditional_input is None else [image_input, conditional_input]) - self._network = tf.keras.Model(inputs=inputs, outputs=outputs) + self._network = tf_keras.Model(inputs=inputs, outputs=outputs) super(_PixelCNNNetwork, self).build(input_shape) def call(self, inputs, training=None): @@ -962,7 +963,7 @@ def call(self, inputs, training=None): same leading batch dimension as the image `Tensor`. training: `bool` or `None`. If `bool`, it controls the dropout layer, where `True` implies dropout is active. If `None`, it it defaults to - `tf.keras.backend.learning_phase()` + `tf_keras.backend.learning_phase()` Returns: outputs: a 3- or 4-element `list` of `Tensor`s in the following order: @@ -996,8 +997,8 @@ def _make_kernel_constraint(kernel_size, valid_rows, valid_columns): def _build_and_apply_h_projection(h, num_filters, dtype): """Project the conditional input.""" - h = tf.keras.layers.Flatten(dtype=dtype)(h) - h_projection = tf.keras.layers.Dense( + h = tf_keras.layers.Flatten(dtype=dtype)(h) + h_projection = tf_keras.layers.Dense( 2*num_filters, kernel_initializer='random_normal', dtype=dtype)(h) return h_projection[..., tf.newaxis, tf.newaxis, :] @@ -1006,6 +1007,6 @@ def _apply_sigmoid_gating(x): """Apply the sigmoid gating in Figure 2 of [2].""" activation_tensor, gate_tensor = tf.split(x, 2, axis=-1) sigmoid_gate = tf.sigmoid(gate_tensor) - return tf.keras.layers.multiply( + return tf_keras.layers.multiply( [sigmoid_gate, activation_tensor], dtype=x.dtype) diff --git a/tensorflow_probability/python/distributions/pixel_cnn_test.py b/tensorflow_probability/python/distributions/pixel_cnn_test.py index 630f862ac3..a035a61c18 100644 --- a/tensorflow_probability/python/distributions/pixel_cnn_test.py +++ b/tensorflow_probability/python/distributions/pixel_cnn_test.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.distributions import pixel_cnn from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math import gradient @@ -64,7 +65,7 @@ def _make_fake_inputs(self): return self._make_fake_images() def _make_input_layers(self): - return tf.keras.layers.Input(self.image_shape) + return tf_keras.layers.Input(self.image_shape) def _get_single_pixel_logit_gradients(self, dist, logit_ind, pixel_ind): @@ -170,12 +171,12 @@ def testAutoregression(self): log_prob = dist.log_prob(inputs) # Build/fit a model to activate autoregressive kernel constraints - model = tf.keras.Model(inputs=inputs, outputs=log_prob) + model = tf_keras.Model(inputs=inputs, outputs=log_prob) model.add_loss(-tf.reduce_mean(log_prob)) model.compile() if not tf.executing_eagerly() and isinstance( - model.optimizer, tf.keras.optimizers.experimental.Optimizer): + model.optimizer, tf_keras.optimizers.experimental.Optimizer): return train_data = self._make_fake_inputs() model.fit(x=train_data) @@ -276,8 +277,8 @@ def _make_fake_inputs(self): return [self._make_fake_images(), self._make_fake_conditional()] def _make_input_layers(self): - return [tf.keras.layers.Input(shape=self.image_shape), - tf.keras.layers.Input(shape=self.h_shape)] + return [tf_keras.layers.Input(shape=self.image_shape), + tf_keras.layers.Input(shape=self.h_shape)] def testScalarConditional(self): dist = pixel_cnn.PixelCNN( diff --git a/tensorflow_probability/python/distributions/quantized_distribution.py b/tensorflow_probability/python/distributions/quantized_distribution.py index e7b5214e4b..a6cc4b0c8c 100644 --- a/tensorflow_probability/python/distributions/quantized_distribution.py +++ b/tensorflow_probability/python/distributions/quantized_distribution.py @@ -20,6 +20,7 @@ from tensorflow_probability.python.distributions import distribution as distributions from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties @@ -586,7 +587,7 @@ def __new__(cls, *args, **kwargs): else: distribution = kwargs.get('distribution') - if not isinstance(distribution, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(distribution): return _QuantizedDistribution(*args, **kwargs) return super(QuantizedDistribution, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/sample.py b/tensorflow_probability/python/distributions/sample.py index 9c5777c72b..40f2cf05ab 100644 --- a/tensorflow_probability/python/distributions/sample.py +++ b/tensorflow_probability/python/distributions/sample.py @@ -26,6 +26,7 @@ from tensorflow_probability.python.distributions import kullback_leibler from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps @@ -366,7 +367,7 @@ def __new__(cls, *args, **kwargs): else: distribution = kwargs.get('distribution') - if not isinstance(distribution, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(distribution): return _Sample(*args, **kwargs) return super(Sample, cls).__new__(cls) @@ -557,8 +558,8 @@ def __new__(cls, *args, **kwargs): else: bijector = kwargs.get('bijector') - if not (isinstance(distribution, tf.__internal__.CompositeTensor) - and isinstance(bijector, tf.__internal__.CompositeTensor)): + if not (auto_composite_tensor.is_composite_tensor(distribution) + and auto_composite_tensor.is_composite_tensor(bijector)): return _NonCompositeTensorDefaultSampleBijector(*args, **kwargs) return super(_DefaultSampleBijector, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/student_t_process.py b/tensorflow_probability/python/distributions/student_t_process.py index be1c655e6e..f19c0346af 100644 --- a/tensorflow_probability/python/distributions/student_t_process.py +++ b/tensorflow_probability/python/distributions/student_t_process.py @@ -26,7 +26,6 @@ from tensorflow_probability.python.distributions import cholesky_util from tensorflow_probability.python.distributions import distribution from tensorflow_probability.python.distributions import multivariate_student_t -from tensorflow_probability.python.distributions import student_t from tensorflow_probability.python.distributions.internal import stochastic_process_util from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import batch_shape_lib @@ -48,12 +47,12 @@ _ALWAYS_YIELD_MVST_DEPRECATION_WARNING = ( - '`always_yield_multivariate_student_t` is deprecated. After 2023-07-01, ' - 'this arg will be ignored, and behavior will be as though ' - '`always_yield_multivariate_student_t=True`. This means that a ' - '`StudentTProcess` evaluated at a single index point will have event shape ' - '`[1]`. To reproduce the behavior of ' - '`always_yield_multivariate_student_t=False` squeeze the rightmost ' + '`always_yield_multivariate_student_t` is deprecated. This arg is now ' + 'ignored and will be removed after 2023-11-15. A `StudentTProcess` ' + 'evaluated at a single index point now always has event shape `[1]` (the ' + 'previous behavior for `always_yield_multivariate_student_t=True`). To ' + 'reproduce the previous behavior of ' + '`always_yield_multivariate_student_t=False`, squeeze the rightmost ' 'singleton dimension from the output of `mean`, `sample`, etc.') @@ -65,7 +64,7 @@ def make_cholesky_factored_marginal_fn(cholesky_fn): The returned function computes the Cholesky factorization of the input covariance plus a diagonal jitter, and uses that for the `scale` of a - `tfd.MultivariateNormalLinearOperator`. + `tfd.MultivariateStudentTLinearOperator`. Args: cholesky_fn: Callable which takes a single (batch) matrix argument and @@ -74,7 +73,7 @@ def make_cholesky_factored_marginal_fn(cholesky_fn): Returns: marginal_fn: A Python function that takes a location, covariance matrix, optional `validate_args`, `allow_nan_stats` and `name` arguments, and - returns a `tfd.MultivariateNormalLinearOperator`. + returns a `tfd.MultivariateStudentTLinearOperator`. """ def marginal_fn( df, @@ -227,7 +226,7 @@ class StudentTProcess(distribution.AutoCompositeTensorDistribution): tp = tfd.StudentTProcess(3., kernel, observed_index_points) - optimizer = tf.optimizers.Adam() + optimizer = tf_keras.optimizers.Adam() @tf.function def optimize(): @@ -256,10 +255,10 @@ def optimize(): '2021-06-26', '`jitter` is deprecated; please use `marginal_fn` directly.', 'jitter') - @deprecation.deprecated_arg_values( - '2023-07-01', + @deprecation.deprecated_args( + '2023-11-15', _ALWAYS_YIELD_MVST_DEPRECATION_WARNING, - always_yield_multivariate_student_t=False) + 'always_yield_multivariate_student_t') def __init__(self, df, kernel, @@ -269,7 +268,7 @@ def __init__(self, marginal_fn=None, cholesky_fn=None, jitter=1e-6, - always_yield_multivariate_student_t=False, + always_yield_multivariate_student_t=None, validate_args=False, allow_nan_stats=False, name='StudentTProcess'): @@ -302,7 +301,7 @@ def __init__(self, Default value: `0.` marginal_fn: A Python callable that takes a location, covariance matrix, optional `validate_args`, `allow_nan_stats` and `name` arguments, and - returns a multivariate normal subclass of `tfd.Distribution`. + returns a multivariate Student T subclass of `tfd.Distribution`. Default value: `None`, in which case a Cholesky-factorizing function is created using `make_cholesky_factored_marginal_fn` and the `jitter` argument. @@ -314,11 +313,7 @@ def __init__(self, matrix to ensure positive definiteness of the covariance matrix. This argument is ignored if `cholesky_fn` is set. Default value: `1e-6`. - always_yield_multivariate_student_t: Deprecated. If `False` (the default), - we produce a scalar `StudentT` distribution when the number of - `index_points` is statically known to be `1`. If `True`, we avoid this - behavior, ensuring that the event shape will retain the `1` from - `index_points`. + always_yield_multivariate_student_t: Deprecated and ignored. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -337,20 +332,38 @@ def __init__(self, """ parameters = dict(locals()) with tf.name_scope(name) as name: - if tf.nest.is_nested(kernel.feature_ndims): - input_dtype = dtype_util.common_dtype( - [kernel, index_points], - dtype_hint=nest_util.broadcast_structure( - kernel.feature_ndims, tf.float32)) + input_dtype = dtype_util.common_dtype( + dict( + kernel=kernel, + index_points=index_points, + ), + dtype_hint=nest_util.broadcast_structure( + kernel.feature_ndims, tf.float32)) + + # If the input dtype is non-nested float, we infer a single dtype for the + # input and the float parameters, which is also the dtype of the STP's + # samples, log_prob, etc. If the input dtype is nested (or not float), we + # do not use it to infer the STP's float dtype. + if (not tf.nest.is_nested(input_dtype) and + dtype_util.is_floating(input_dtype)): dtype = dtype_util.common_dtype( - [df, observation_noise_variance, jitter], tf.float32) + dict( + kernel=kernel, + index_points=index_points, + observation_noise_variance=observation_noise_variance, + jitter=jitter, + df=df, + ), + dtype_hint=tf.float32, + ) + input_dtype = dtype else: - # If the index points are not nested, we assume they are of the same - # float dtype as the TP. dtype = dtype_util.common_dtype( - [df, kernel, index_points, observation_noise_variance, jitter], - tf.float32) - input_dtype = dtype + dict( + df=df, + observation_noise_variance=observation_noise_variance, + jitter=jitter, + ), dtype_hint=tf.float32) if index_points is not None: index_points = nest_util.convert_to_nested_tensor( @@ -397,42 +410,6 @@ def __init__(self, parameters=parameters, name=name) - def _is_univariate_marginal(self, index_points): - """True if the given index_points would yield a univariate marginal. - - Args: - index_points: the set of index set locations at which to compute the - marginal Student T distribution. If this set is of size 1, the marginal is - univariate. - - Returns: - is_univariate: Boolean indicating whether the marginal is univariate or - multivariate. In the case of dynamic shape in the number of index points, - defaults to "multivariate" since that's the best we can do. - """ - if self._always_yield_multivariate_student_t: - return False - - num_index_points = tf.nest.map_structure( - lambda x, nd: tf.compat.dimension_value(x.shape[-(nd + 1)]), - index_points, self.kernel.feature_ndims) - flat_num_index_points = tf.nest.flatten(num_index_points) - static_non_singleton_num_points = set( - n for n in flat_num_index_points if n is not None and n != 1) - if len(static_non_singleton_num_points) > 1: - raise ValueError( - 'Nested components of `index_points` must contain the same or ' - 'broadcastable numbers of examples. Saw components with ' - f'{", ".join(list(str(n) for n in static_non_singleton_num_points))} ' - 'examples.') - if None in flat_num_index_points: - warnings.warn( - 'Unable to detect statically whether the number of index_points is ' - '1. As a result, defaulting to treating the marginal Student T ' - 'Process at `index_points` as a multivariate Student T. This makes ' - 'some methods, like `cdf` unavailable.') - return all(n == 1 for n in flat_num_index_points) - @classmethod def _parameter_properties(cls, dtype, num_classes=None): return dict( @@ -466,28 +443,26 @@ def get_marginal_distribution(self, index_points=None): `kernel.batch_shape` and any batch dims yielded by `mean_fn`. Returns: - marginal: a Student T distribution with vector event shape, or - (deprecated) a scalar `StudentT` distribution if `index_points` consists - of a single index point and `always_yield_multivariate_student_t=False`. + marginal: a Student T distribution with vector event shape. """ with self._name_and_control_scope('get_marginal_distribution'): global _GET_MARGINAL_DISTRIBUTION_ALREADY_WARNED if (not _GET_MARGINAL_DISTRIBUTION_ALREADY_WARNED and # pylint: disable=protected-access - not self._always_yield_multivariate_student_t): # pylint: disable=protected-access + self._always_yield_multivariate_student_t is not None): # pylint: disable=protected-access warnings.warn( - 'After 2023-07-01, the `always_yield_multivariate_student_t` arg ' - 'to `StudentTProcess.__init__` will be ignored, which means that ' - '`get_marginal_distribution` will always return a Student T ' - 'distribution with vector event shape. This is the current ' - 'behavior when `always_yield_multivariate_student_t=True`. ' - 'To recover the behavior of ' - '`always_yield_multivariate_student_t=False` when `index_points` ' - 'contains a single index point, build a scalar `StudentT` ' - 'distribution as follows:\n' - '`mvst = get_marginal_distribution(index_points);`\n' - '`st = tfd.StudentT(`\n' - '` mvst.df, loc=mvst.loc[..., 0], scale=mvst.stddev()[..., 0])`\n' - 'To suppress these warnings, build the `StudentTProcess` with ' + 'The `always_yield_multivariate_student_t` arg to ' + '`StudentTProcess.__init__` is now ignored and ' + '`get_marginal_distribution` always returns a Student T ' + 'distribution with vector event shape. This was the previous ' + 'behavior of `always_yield_multivariate_student_t=True`. To ' + 'recover the behavior of ' + '`always_yield_multivariate_student_t=False` when `index_points`' + 'contains a single index point, build a scalar `StudentT`' + 'distribution as follows: ' + '`mvst = get_marginal_distribution(index_points); `' + '`dist = tfd.StudentT(mvst.loc[..., 0], `' + '`scale=mvst.stddev()[..., 0], mvst.df)`. To suppress these ' + 'warnings, build the `StudentTProcess` with ' '`always_yield_multivariate_student_t=True`.', FutureWarning) _GET_MARGINAL_DISTRIBUTION_ALREADY_WARNED = True # pylint: disable=protected-access @@ -496,31 +471,13 @@ def get_marginal_distribution(self, index_points=None): covariance = stochastic_process_util.compute_kernel_matrix( self.kernel, index_points, self.observation_noise_variance) loc = self._mean_fn(index_points) - - # If we're sure the number of index points is 1, we can just construct a - # scalar Normal. This has computational benefits and supports things like - # CDF that aren't otherwise straightforward to provide. - if self._is_univariate_marginal(index_points): - covariance = tf.squeeze(covariance, axis=[-1, -2]) - squared_scale = (df - 2.) / df * covariance - scale = tf.sqrt(squared_scale) - # `loc` has a trailing 1 in the shape; squeeze it. - loc = tf.squeeze(loc, axis=-1) - return student_t.StudentT( - df=df, - loc=loc, - scale=scale, - validate_args=self.validate_args, - allow_nan_stats=self.allow_nan_stats, - name='marginal_distribution') - else: - return self._marginal_fn( - df=df, - loc=loc, - covariance=covariance, - validate_args=self.validate_args, - allow_nan_stats=self.allow_nan_stats, - name='marginal_distribution') + return self._marginal_fn( + df=df, + loc=loc, + covariance=covariance, + validate_args=self.validate_args, + allow_nan_stats=self.allow_nan_stats, + name='marginal_distribution') @property def df(self): @@ -586,7 +543,7 @@ def _get_index_points(self, index_points=None): 'must equal ' '`self.kernel.feature_ndims` (or its corresponding ' 'nested component) and `e` is the number of index points in each ' 'batch. Ultimately, this distribution corresponds to an ' - '`e`-dimensional multivariate normal. The batch shape must be ' + '`e`-dimensional multivariate Student T. The batch shape must be ' 'broadcastable with `kernel.batch_shape` and any batch dims yielded' 'by `mean_fn`. If not specified, `self.index_points` is used. ' 'Default value: `None`.', @@ -605,8 +562,6 @@ def _log_prob(self, value, index_points=None, is_missing=None): is_missing = tf.convert_to_tensor(is_missing) value = tf.convert_to_tensor(value, dtype=self.dtype) index_points = self._get_index_points(index_points) - if self._is_univariate_marginal(index_points): - value = value[..., tf.newaxis] observation_noise_variance = tf.convert_to_tensor( self.observation_noise_variance) loc, covariance = stochastic_process_util.get_loc_and_kernel_matrix( @@ -633,24 +588,14 @@ def _log_prob(self, value, index_points=None, is_missing=None): value = tf.where(is_missing, 0., value) num_masked_dims = tf.cast( tf.math.count_nonzero(is_missing, axis=-1), self.dtype) - if self._is_univariate_marginal(index_points): - num_dims = 1 - else: - num_dims = tf.cast(event_shape[-1], self.dtype) - - if self._is_univariate_marginal(index_points): - covariance = tf.squeeze(covariance, axis=[-1, -2]) - value = tf.squeeze(value, axis=-1) - lp = -(df + num_dims - num_masked_dims) / 2. * tf.math.log1p( - tf.math.square(value) / (covariance * (df - 2.))) - lp = lp - 0.5 * tf.math.log(covariance) - else: - chol_covariance = self.cholesky_fn(covariance) # pylint: disable=not-callable - lp = -(df + num_dims - num_masked_dims) / 2. * tf.math.log1p( - linalg.hpsd_quadratic_form_solvevec( - covariance, value, cholesky_matrix=chol_covariance) / (df - 2.)) - lp = lp - 0.5 * linalg.hpsd_logdet( - covariance, cholesky_matrix=chol_covariance) + num_dims = tf.cast(event_shape[-1], self.dtype) + + chol_covariance = self.cholesky_fn(covariance) # pylint: disable=not-callable + lp = -(df + num_dims - num_masked_dims) / 2. * tf.math.log1p( + linalg.hpsd_quadratic_form_solvevec( + covariance, value, cholesky_matrix=chol_covariance) / (df - 2.)) + lp = lp - 0.5 * linalg.hpsd_logdet( + covariance, cholesky_matrix=chol_covariance) lp = lp - special.log_gamma_difference( (num_dims - num_masked_dims) / 2., df / 2.) @@ -660,15 +605,11 @@ def _log_prob(self, value, index_points=None, is_missing=None): def _event_shape_tensor(self, index_points=None): index_points = self._get_index_points(index_points) - if self._is_univariate_marginal(index_points): - return tf.constant([], dtype=tf.int32) return stochastic_process_util.event_shape_tensor(self.kernel, index_points) def _event_shape(self, index_points=None): index_points = ( index_points if index_points is not None else self._index_points) - if self._is_univariate_marginal(index_points): - return tf.TensorShape([]) return stochastic_process_util.event_shape(self.kernel, index_points) def _batch_shape(self, index_points=None): @@ -723,31 +664,21 @@ def _variance(self, index_points=None): index_points = self._get_index_points(index_points) kernel_diag = self.kernel.apply(index_points, index_points, example_ndims=1) - if self._is_univariate_marginal(index_points): - return (tf.squeeze(kernel_diag, axis=[-1]) + - self.observation_noise_variance) - else: - # We are computing diag(K + obs_noise_variance * I) = diag(K) + - # obs_noise_variance. We pad obs_noise_variance with a dimension in order - # to broadcast batch shapes of kernel_diag and obs_noise_variance (since - # kernel_diag has an extra dimension corresponding to the number of index - # points). - return kernel_diag + self.observation_noise_variance[..., tf.newaxis] + # We are computing diag(K + obs_noise_variance * I) = diag(K) + + # obs_noise_variance. We pad obs_noise_variance with a dimension in order + # to broadcast batch shapes of kernel_diag and obs_noise_variance (since + # kernel_diag has an extra dimension corresponding to the number of index + # points). + return kernel_diag + self.observation_noise_variance[..., tf.newaxis] def _covariance(self, index_points=None): observation_noise_variance = tf.convert_to_tensor( self.observation_noise_variance) index_points = self._get_index_points(index_points) - kernel_matrix = stochastic_process_util.compute_kernel_matrix( + return stochastic_process_util.compute_kernel_matrix( kernel=self.kernel, index_points=index_points, observation_noise_variance=observation_noise_variance) - if self._is_univariate_marginal(index_points): - # kernel_matrix thus has shape [..., 1, 1]; squeeze off the last dims and - # tack on the observation noise variance. - return tf.squeeze(kernel_matrix, axis=[-2, -1]) - else: - return kernel_matrix def _mode(self, index_points=None): return self.get_marginal_distribution(index_points).mode() diff --git a/tensorflow_probability/python/distributions/student_t_process_regression_model.py b/tensorflow_probability/python/distributions/student_t_process_regression_model.py index 402c3828fb..a1f6a2f65b 100644 --- a/tensorflow_probability/python/distributions/student_t_process_regression_model.py +++ b/tensorflow_probability/python/distributions/student_t_process_regression_model.py @@ -38,13 +38,13 @@ _ALWAYS_YIELD_MVST_DEPRECATION_WARNING = ( - '`always_yield_multivariate_student_t` is deprecated. After 2023-07-01, ' - 'this arg will be ignored, and behavior will be as though ' - '`always_yield_multivariate_student_t=True`. This means that a ' - '`StudentTProcessRegressionModel` evaluated at a single index point will ' - 'have event shape `[1]`. To reproduce the behavior of ' - '`always_yield_multivariate_student_t=False` squeeze the rightmost ' - 'singleton dimension from the output of `mean`, `sample`, etc.') + '`always_yield_multivariate_student_t` is deprecated. This arg is now ' + 'ignored and will be removed after 2023-11-15. A ' + '`StudentTProcessRegressionModel` evaluated at a single index point now ' + 'always has event shape `[1]` (the previous behavior for ' + '`always_yield_multivariate_student_t=True`). To reproduce the previous ' + 'behavior of `always_yield_multivariate_student_t=False`, squeeze the ' + 'rightmost singleton dimension from the output of `mean`, `sample`, etc.') class DampedSchurComplement(psd_kernel.AutoCompositeTensorPsdKernel): @@ -69,19 +69,31 @@ def __init__(self, name='DampedSchurComplement'): parameters = dict(locals()) with tf.name_scope(name) as name: - if tf.nest.is_nested(schur_complement.feature_ndims): + kernel_dtype = schur_complement.dtype + + # If the input dtype is non-nested float, we infer a single dtype for the + # input and the float parameters, which is also the dtype of the STP's + # samples, log_prob, etc. If the input dtype is nested (or not float), we + # do not use it to infer the STP's float dtype. + if (not tf.nest.is_nested(kernel_dtype) and + dtype_util.is_floating(kernel_dtype)): dtype = dtype_util.common_dtype( - [df, fixed_inputs_observations], - tf.float32) - kernel_dtype = schur_complement.dtype - else: - # If the index points are not nested, we assume they are of the same - # dtype as the STPRM. - dtype = dtype_util.common_dtype([ - schur_complement, - fixed_inputs_observations, - df], tf.float32) + dict( + schur_complement=schur_complement, + fixed_inputs_observations=fixed_inputs_observations, + df=df, + ), + dtype_hint=tf.float32, + ) kernel_dtype = dtype + else: + dtype = dtype_util.common_dtype( + dict( + fixed_inputs_observations=fixed_inputs_observations, + df=df, + ), + dtype_hint=tf.float32, + ) self._schur_complement = schur_complement self._df = tensor_util.convert_nonref_to_tensor( df, name='df', dtype=dtype) @@ -235,10 +247,10 @@ class StudentTProcessRegressionModel(student_t_process.StudentTProcess): """ # pylint:disable=invalid-name - @deprecation.deprecated_arg_values( - '2023-07-01', + @deprecation.deprecated_args( + '2023-11-15', _ALWAYS_YIELD_MVST_DEPRECATION_WARNING, - always_yield_multivariate_student_t=False) + 'always_yield_multivariate_student_t') def __init__( self, df, @@ -251,7 +263,7 @@ def __init__( mean_fn=None, cholesky_fn=None, marginal_fn=None, - always_yield_multivariate_student_t=False, + always_yield_multivariate_student_t=None, validate_args=False, allow_nan_stats=False, name='StudentTProcessRegressionModel', @@ -271,7 +283,7 @@ def __init__( must equal `kernel.feature_ndims` (or its corresponding nested component) and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to an `e`-dimensional - multivariate normal. The batch shape must be broadcastable with + multivariate Student T. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. observation_index_points: (Nested) `Tensor` representing finite collection, or batch of collections, of points in the index set for @@ -319,11 +331,7 @@ def __init__( returns a multivariate Student-T subclass of `tfd.Distribution`. Default value: `None`, in which case a Cholesky-factorizing function is created using `make_cholesky_with_jitter_fn`. - always_yield_multivariate_student_t: Deprecated. If `False` (the default), - we produce a scalar `StudentT` distribution when the number of - `index_points` is statically known to be `1`. If `True`, we avoid this - behavior, ensuring that the event shape will retain the `1` from - `index_points`. + always_yield_multivariate_student_t: Deprecated and ignored. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -463,10 +471,10 @@ def conditional_mean_fn(x): self._parameters = parameters @staticmethod - @deprecation.deprecated_arg_values( - '2023-07-01', + @deprecation.deprecated_args( + '2023-11-15', _ALWAYS_YIELD_MVST_DEPRECATION_WARNING, - always_yield_multivariate_student_t=False) + 'always_yield_multivariate_student_t') def precompute_regression_model( df, kernel, @@ -478,7 +486,7 @@ def precompute_regression_model( predictive_noise_variance=None, mean_fn=None, cholesky_fn=None, - always_yield_multivariate_student_t=False, + always_yield_multivariate_student_t=None, validate_args=False, allow_nan_stats=False, name='PrecomputedStudentTProcessRegressionModel', @@ -547,7 +555,7 @@ def precompute_regression_model( dimensions and must equal `kernel.feature_ndims` (or its corresponding nested component) and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to an `e`-dimensional - multivariate normal. The batch shape must be broadcastable with + multivariate Student T. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. observation_noise_variance: `float` `Tensor` representing the variance of the noise in the Normal likelihood distribution of the model. May be @@ -571,11 +579,7 @@ def precompute_regression_model( returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn` is used with the `jitter` parameter. - always_yield_multivariate_student_t: Deprecated. If `False` (the default), - we produce a scalar `StudentT` distribution when the number of - `index_points` is statically known to be `1`. If `True`, we avoid this - behavior, ensuring that the event shape will retain the `1` from - `index_points`. + always_yield_multivariate_student_t: Deprecated and ignored. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect diff --git a/tensorflow_probability/python/distributions/student_t_process_test.py b/tensorflow_probability/python/distributions/student_t_process_test.py index 51dcbb3264..b1ee6285dc 100644 --- a/tensorflow_probability/python/distributions/student_t_process_test.py +++ b/tensorflow_probability/python/distributions/student_t_process_test.py @@ -20,7 +20,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.distributions import multivariate_student_t as mvst -from tensorflow_probability.python.distributions import student_t from tensorflow_probability.python.distributions import student_t_process from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util @@ -240,19 +239,6 @@ def _kernel_fn(x, y): with self.assertRaises(ValueError): tp.mean() - def testMarginalHasCorrectTypes(self): - tp = student_t_process.StudentTProcess( - df=3., kernel=psd_kernels.ExponentiatedQuadratic(), validate_args=True) - - self.assertIsInstance( - tp.get_marginal_distribution( - index_points=np.ones([1, 1], dtype=np.float32)), student_t.StudentT) - - self.assertIsInstance( - tp.get_marginal_distribution( - index_points=np.ones([10, 1], dtype=np.float32)), - mvst.MultivariateStudentTLinearOperator) - @parameterized.parameters( {"foo_feature_shape": [5], "bar_feature_shape": [3]}, {"foo_feature_shape": [3, 2], "bar_feature_shape": [5]}, @@ -339,26 +325,6 @@ def testStructuredIndexPoints(self, foo_feature_shape, bar_feature_shape): stp_with_list.batch_shape_tensor()) self.assertAllClose(base_stp.log_prob(s), stp_with_list.log_prob(s)) - def testAlwaysYieldMultivariateStudentT(self): - df = np.float32(3.) - stp = student_t_process.StudentTProcess( - df=df, - kernel=psd_kernels.ExponentiatedQuadratic(), - index_points=tf.ones([5, 1, 2]), - always_yield_multivariate_student_t=False, - ) - self.assertAllEqual([5], self.evaluate(stp.batch_shape_tensor())) - self.assertAllEqual([], self.evaluate(stp.event_shape_tensor())) - - stp = student_t_process.StudentTProcess( - df=df, - kernel=psd_kernels.ExponentiatedQuadratic(), - index_points=tf.ones([5, 1, 2]), - always_yield_multivariate_student_t=True, - ) - self.assertAllEqual([5], self.evaluate(stp.batch_shape_tensor())) - self.assertAllEqual([1], self.evaluate(stp.event_shape_tensor())) - def testLogProbMatchesMVT(self): df = tf.convert_to_tensor(3.) index_points = tf.convert_to_tensor( diff --git a/tensorflow_probability/python/distributions/transformed_distribution.py b/tensorflow_probability/python/distributions/transformed_distribution.py index 6291be7517..8d05d607bc 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution.py +++ b/tensorflow_probability/python/distributions/transformed_distribution.py @@ -388,7 +388,7 @@ def _log_prob(self, y, **kwargs): return tf.reduce_logsumexp(tf.stack(lp_on_fibers), axis=0) def _prob(self, y, **kwargs): - if not hasattr(self.distribution, '_prob'): + if not hasattr(self.distribution, '_prob') or self.bijector._is_injective: # pylint: disable=protected-access return tf.exp(self._log_prob(y, **kwargs)) distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) @@ -400,9 +400,6 @@ def _prob(self, y, **kwargs): ) ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) - if self.bijector._is_injective: # pylint: disable=protected-access - base_prob = self.distribution.prob(x, **distribution_kwargs) - return base_prob * tf.exp(tf.cast(ildj, base_prob.dtype)) # Compute prob on each element of the inverse image. prob_on_fibers = [] @@ -684,8 +681,8 @@ def __new__(cls, *args, **kwargs): else: bijector = kwargs.get('bijector') - if not (isinstance(distribution, tf.__internal__.CompositeTensor) - and isinstance(bijector, tf.__internal__.CompositeTensor)): + if not (auto_composite_tensor.is_composite_tensor(distribution) + and auto_composite_tensor.is_composite_tensor(bijector)): return _TransformedDistribution(*args, **kwargs) return super(TransformedDistribution, cls).__new__(cls) diff --git a/tensorflow_probability/python/distributions/transformed_distribution_test.py b/tensorflow_probability/python/distributions/transformed_distribution_test.py index 53416a55ac..8ea33bea6f 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution_test.py +++ b/tensorflow_probability/python/distributions/transformed_distribution_test.py @@ -42,6 +42,7 @@ from tensorflow_probability.python.bijectors import split from tensorflow_probability.python.bijectors import tanh from tensorflow_probability.python.distributions import beta +from tensorflow_probability.python.distributions import dirichlet from tensorflow_probability.python.distributions import exponential from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab @@ -54,6 +55,7 @@ from tensorflow_probability.python.distributions import normal as normal_lib from tensorflow_probability.python.distributions import sample as sample_lib from tensorflow_probability.python.distributions import transformed_distribution +from tensorflow_probability.python.distributions import uniform from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util @@ -650,6 +652,26 @@ def testLogProbRatio(self): # oracle_64, d0.log_prob(x0) - d1.log_prob(x1), # rtol=0., atol=0.007) + @test_util.numpy_disable_test_missing_functionality('b/306384754') + def testLogProbMatchesProbDirichlet(self): + # This was https://github.com/tensorflow/probability/issues/1761 + scaled_dir = transformed_distribution.TransformedDistribution( + distribution=dirichlet.Dirichlet([2.0, 3.0]), + bijector=scale_lib.Scale(2.0)) + x = np.array([0.2, 1.8], dtype=np.float32) + self.assertAllClose(scaled_dir.prob(x), + tf.exp(scaled_dir.log_prob(x))) + + @test_util.numpy_disable_test_missing_functionality('b/306384754') + def testLogProbMatchesProbUniform(self): + # Uniform does not define _log_prob + scaled_uniform = transformed_distribution.TransformedDistribution( + distribution=uniform.Uniform(), + bijector=scale_lib.Scale(2.0)) + x = np.array([0.2], dtype=np.float32) + self.assertAllClose(scaled_uniform.prob(x), + tf.exp(scaled_uniform.log_prob(x))) + @test_util.test_all_tf_execution_regimes class ScalarToMultiTest(test_util.TestCase): @@ -747,8 +769,7 @@ def testMVN(self, event_shape, shift, tril, dynamic_shape): num_samples = 7e3 y = fake_mvn.sample(int(num_samples), seed=test_util.test_seed()) x = y[0:5, ...] - self.assertAllMeansClose(y, expected_mean, axis=0, - atol=0.1, rtol=0.1) + self.assertAllMeansClose(y, expected_mean, axis=0, atol=0.25) self.assertAllClose(expected_cov, sample_stats.covariance(y, sample_axis=0), atol=0., rtol=0.1) diff --git a/tensorflow_probability/python/distributions/two_piece_normal_test.py b/tensorflow_probability/python/distributions/two_piece_normal_test.py index 4887b91abf..7e04c0c80b 100644 --- a/tensorflow_probability/python/distributions/two_piece_normal_test.py +++ b/tensorflow_probability/python/distributions/two_piece_normal_test.py @@ -369,7 +369,7 @@ def get_abs_sample_mean(skewness): err = self.compute_max_gradient_error( get_abs_sample_mean, [tf.constant(skewness, self.dtype)], delta=1e-1) - maxerr = 0.05 if self.dtype == np.float64 else 0.09 + maxerr = 0.2 self.assertLess(err, maxerr) @test_util.numpy_disable_gradient_test diff --git a/tensorflow_probability/python/distributions/variational_gaussian_process.py b/tensorflow_probability/python/distributions/variational_gaussian_process.py index e10d4442d1..88aa7aa7ea 100644 --- a/tensorflow_probability/python/distributions/variational_gaussian_process.py +++ b/tensorflow_probability/python/distributions/variational_gaussian_process.py @@ -558,7 +558,7 @@ class VariationalGaussianProcess(gaussian_process.GaussianProcess, # For training, we use some simplistic numpy-based minibatching. batch_size = 64 - optimizer = tf.optimizers.Adam(learning_rate=.1) + optimizer = tf_keras.optimizers.Adam(learning_rate=.1) @tf.function def optimize(x_train_batch, y_train_batch): @@ -670,7 +670,7 @@ def optimize(x_train_batch, y_train_batch): # For training, we use some simplistic numpy-based minibatching. batch_size = 64 - optimizer = tf.optimizers.Adam(learning_rate=.05, beta_1=.5, beta_2=.99) + optimizer = tf_keras.optimizers.Adam(learning_rate=.05, beta_1=.5, beta_2=.99) @tf.function def optimize(x_train_batch, y_train_batch): diff --git a/tensorflow_probability/python/experimental/bayesopt/acquisition/BUILD b/tensorflow_probability/python/experimental/bayesopt/acquisition/BUILD index 9756920033..89793d0cf8 100644 --- a/tensorflow_probability/python/experimental/bayesopt/acquisition/BUILD +++ b/tensorflow_probability/python/experimental/bayesopt/acquisition/BUILD @@ -140,6 +140,7 @@ multi_substrate_py_library( # tensorflow dep, "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:dtype_util", + "//tensorflow_probability/python/internal:samplers", ], ) diff --git a/tensorflow_probability/python/experimental/bayesopt/acquisition/__init__.py b/tensorflow_probability/python/experimental/bayesopt/acquisition/__init__.py index 6bb3e9aecf..52f88f2b9e 100644 --- a/tensorflow_probability/python/experimental/bayesopt/acquisition/__init__.py +++ b/tensorflow_probability/python/experimental/bayesopt/acquisition/__init__.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.experimental.bayesopt.acquisition.expected_improvement import StudentTProcessExpectedImprovement from tensorflow_probability.python.experimental.bayesopt.acquisition.max_value_entropy_search import GaussianProcessMaxValueEntropySearch from tensorflow_probability.python.experimental.bayesopt.acquisition.probability_of_improvement import GaussianProcessProbabilityOfImprovement +from tensorflow_probability.python.experimental.bayesopt.acquisition.probability_of_improvement import ParallelProbabilityOfImprovement from tensorflow_probability.python.experimental.bayesopt.acquisition.upper_confidence_bound import GaussianProcessUpperConfidenceBound from tensorflow_probability.python.experimental.bayesopt.acquisition.upper_confidence_bound import ParallelUpperConfidenceBound from tensorflow_probability.python.experimental.bayesopt.acquisition.weighted_power_scalarization import WeightedPowerScalarization @@ -36,6 +37,7 @@ 'GaussianProcessUpperConfidenceBound', 'MCMCReducer', 'ParallelExpectedImprovement', + 'ParallelProbabilityOfImprovement', 'ParallelUpperConfidenceBound', 'StudentTProcessExpectedImprovement', 'WeightedPowerScalarization', diff --git a/tensorflow_probability/python/experimental/bayesopt/acquisition/max_value_entropy_search.py b/tensorflow_probability/python/experimental/bayesopt/acquisition/max_value_entropy_search.py index 13e2ac3cb5..13c6311c81 100644 --- a/tensorflow_probability/python/experimental/bayesopt/acquisition/max_value_entropy_search.py +++ b/tensorflow_probability/python/experimental/bayesopt/acquisition/max_value_entropy_search.py @@ -25,7 +25,7 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.math import root_search from tensorflow_probability.python.math import special -from tensorflow_probability.python.mcmc import sample_halton_sequence +from tensorflow_probability.python.mcmc import sample_halton_sequence_lib class GaussianProcessMaxValueEntropySearch( @@ -156,7 +156,7 @@ def fit_max_value_distribution( # where F_k is the marginal (Normal) CDF at various points. # Adjoin a grid of points so the approximation is more accurate. - grid_points = sample_halton_sequence.sample_halton_sequence( + grid_points = sample_halton_sequence_lib.sample_halton_sequence( dim=predictive_distribution.index_points.shape[-1], num_results=num_grid_points, dtype=predictive_distribution.index_points.dtype, seed=seed) diff --git a/tensorflow_probability/python/experimental/bayesopt/acquisition/probability_of_improvement.py b/tensorflow_probability/python/experimental/bayesopt/acquisition/probability_of_improvement.py index 5e7f98bbd5..1008a66c2b 100644 --- a/tensorflow_probability/python/experimental/bayesopt/acquisition/probability_of_improvement.py +++ b/tensorflow_probability/python/experimental/bayesopt/acquisition/probability_of_improvement.py @@ -19,6 +19,123 @@ from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.experimental.bayesopt.acquisition import acquisition_function from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import samplers + + +class ParallelProbabilityOfImprovement( + acquisition_function.AcquisitionFunction): + """Parallel probability of improvement acquisition function. + + Computes the q-PI from a multivariate observation model. This is also known as + batch probability of improvement. + + Requires that `predictive_distribution` has a `sample` method. + + #### Examples + + Build and evaluate a Parallel Probability of Improvement acquisition function. + + ```python + import numpy as np + import tensorflow_probability as tfp + + tfd = tfp.distributions + tfpk = tfp.math.psd_kernels + tfp_acq = tfp.experimental.bayesopt.acquisition + + # Sample 10 20-dimensional index points and associated observations. + index_points = np.random.uniform(size=[10, 20]) + observations = np.random.uniform(size=[10]) + + # Build a Student T Process regression model conditioned on observed data. + dist = tfd.StudentTProcessRegressionModel( + kernel=tfpk.ExponentiatedQuadratic(), + df=5., + observation_index_points=index_points, + observations=observations) + + # Define a Parallel Probability of Improvement acquisition function. + stp_pei = tfp_acq.ParallelProbabilityOfImprovement( + predictive_distribution=dist, + observations=observations, + num_samples=10_000) + + # Evaluate the acquisition function at a new set of index points. + pred_index_points = np.random.uniform(size=[6, 20]) + acq_fn_vals = stp_pei(pred_index_points) # Has shape [6]. + ``` + + """ + + def __init__( + self, + predictive_distribution, + observations, + seed=None, + num_samples=100, + transform_fn=None): + """Constructs a Parallel Probability of Improvement acquisition function. + + Args: + predictive_distribution: `tfd.Distribution`-like, the distribution over + observations at a set of index points. Must have a `sample` method. + observations: `Float` `Tensor` of observations. Shape has the form + `[b1, ..., bB, e]`, where `e` is the number of index points (such that + the event shape of `predictive_distribution` is `[e]`) and + `[b1, ..., bB]` is broadcastable with the batch shape of + `predictive_distribution`. + seed: PRNG seed; see tfp.random.sanitize_seed for details. + num_samples: The number of samples to use for the Parallel Probability of + Improvement approximation. + transform_fn: Optional Python `Callable` that transforms objective values. + This is used for optimizing a composite grey box function `g(f(x))` + where `f` is our black box function and `g` is `transform_fn`. + """ + self._num_samples = num_samples + self._transform_fn = transform_fn + super(ParallelProbabilityOfImprovement, self).__init__( + predictive_distribution=predictive_distribution, + observations=observations, + seed=seed) + + @property + def num_samples(self): + return self._num_samples + + @property + def transform_fn(self): + return self._transform_fn + + @property + def is_parallel(self): + return True + + def __call__(self, **kwargs): + """Computes the Parallel Probability of Improvement. + + Args: + **kwargs: Keyword args passed on to the `sample` method of + `predictive_distribution`. + + Returns: + Parallel Probability of improvement at index points implied by + `predictive_distribution` (or overridden in `**kwargs`). + """ + # Fix the seed so we get a deterministic objective per iteration. + seed = samplers.sanitize_seed( + [100, 2] if self.seed is None else self.seed, salt='qei') + + samples = self.predictive_distribution.sample( + self.num_samples, seed=seed, **kwargs) + + transform_fn = lambda x: x + if self._transform_fn is not None: + transform_fn = self._transform_fn + + best_observed = tf.reduce_max(transform_fn(self.observations), axis=-1) + qpi = (transform_fn(samples) - best_observed) > 0. + return tf.reduce_mean( + tf.cast(tf.reduce_any(qpi, axis=-1), dtype=samples.dtype), axis=0) class GaussianProcessProbabilityOfImprovement( diff --git a/tensorflow_probability/python/experimental/bayesopt/acquisition/probability_of_improvement_test.py b/tensorflow_probability/python/experimental/bayesopt/acquisition/probability_of_improvement_test.py index 55229d14ed..c368ef3bc9 100644 --- a/tensorflow_probability/python/experimental/bayesopt/acquisition/probability_of_improvement_test.py +++ b/tensorflow_probability/python/experimental/bayesopt/acquisition/probability_of_improvement_test.py @@ -74,6 +74,28 @@ def test_gp_expected_improvement(self): self.assertAllNotNan(grads) self.assertDTypeEqual(actual_poi, self.dtype) + def test_normal_probability_of_improvement_matches_parallel(self): + shape = [5, 20] + loc = 2. * np.random.uniform(size=shape).astype(self.dtype) + scale = 3. + np.random.uniform(size=[20]).astype(self.dtype) + observations = np.array([2., 3., 4.]).astype(self.dtype) + best_observed = tf.reduce_max(observations) + actual_pi = probability_of_improvement.normal_probability_of_improvement( + best_observed=best_observed, + mean=loc, + stddev=scale) + + model = normal.Normal( + loc[..., tf.newaxis], scale[..., tf.newaxis], validate_args=True) + expected_pi = probability_of_improvement.ParallelProbabilityOfImprovement( + predictive_distribution=model, + observations=observations, + num_samples=int(2e5), + seed=test_util.test_seed())() + self.assertAllClose( + self.evaluate(actual_pi), self.evaluate(expected_pi), atol=1e-2) + self.assertDTypeEqual(actual_pi, self.dtype) + @test_util.test_all_tf_execution_regimes class ProbabilityOfImprovementFloat32Test(_ProbabilityOfImprovementTest, diff --git a/tensorflow_probability/python/experimental/bijectors/BUILD b/tensorflow_probability/python/experimental/bijectors/BUILD index 53b07f6a16..b41602157b 100644 --- a/tensorflow_probability/python/experimental/bijectors/BUILD +++ b/tensorflow_probability/python/experimental/bijectors/BUILD @@ -96,6 +96,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/distributions:sample", "//tensorflow_probability/python/distributions:transformed_distribution", "//tensorflow_probability/python/distributions:uniform", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -124,6 +125,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:gradient", "//tensorflow_probability/python/mcmc:dual_averaging_step_size_adaptation", "//tensorflow_probability/python/mcmc:nuts", diff --git a/tensorflow_probability/python/experimental/bijectors/distribution_bijectors.py b/tensorflow_probability/python/experimental/bijectors/distribution_bijectors.py index ef0a9656a2..d794ba1655 100644 --- a/tensorflow_probability/python/experimental/bijectors/distribution_bijectors.py +++ b/tensorflow_probability/python/experimental/bijectors/distribution_bijectors.py @@ -107,7 +107,7 @@ def make_distribution_bijector(distribution, name='make_distribution_bijector'): pinned_model) _ = tfp.vi.fit_surrogate_posterior(pinned_model.unnormalized_log_prob, surrogate_posterior=surrogate_posterior, - optimizer=tf.optimizers.Adam(0.01), + optimizer=tf_keras.optimizers.Adam(0.01), num_steps=200) ``` diff --git a/tensorflow_probability/python/experimental/bijectors/distribution_bijectors_test.py b/tensorflow_probability/python/experimental/bijectors/distribution_bijectors_test.py index 731d4953b2..344a9467b7 100644 --- a/tensorflow_probability/python/experimental/bijectors/distribution_bijectors_test.py +++ b/tensorflow_probability/python/experimental/bijectors/distribution_bijectors_test.py @@ -35,6 +35,7 @@ from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math import gradient from tensorflow_probability.python.mcmc import dual_averaging_step_size_adaptation as dassa from tensorflow_probability.python.mcmc import nuts @@ -205,7 +206,7 @@ def model_with_funnel(): optimization.fit_surrogate_posterior( pinned_model.unnormalized_log_prob, surrogate_posterior=surrogate_posterior, - optimizer=tf.optimizers.Adam(0.01), + optimizer=tf_keras.optimizers.Adam(0.01), sample_size=10, num_steps=1) bijector = ( diff --git a/tensorflow_probability/python/experimental/distribute/BUILD b/tensorflow_probability/python/experimental/distribute/BUILD index 3c56c6ee80..201d9347e8 100644 --- a/tensorflow_probability/python/experimental/distribute/BUILD +++ b/tensorflow_probability/python/experimental/distribute/BUILD @@ -47,6 +47,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/distributions:distribution", "//tensorflow_probability/python/distributions:log_prob_ratio", "//tensorflow_probability/python/experimental/bijectors:sharded", + "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribute_lib", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:samplers", diff --git a/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py b/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py index 7db5c3ecb7..e1775c9c9e 100644 --- a/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py +++ b/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py @@ -272,7 +272,7 @@ def model(): self.strategy_run( run, (self.key,), in_axes=None)) for i in range(test_lib.NUM_DEVICES): - self.assertAllClose(sharded_log_prob[i], true_log_prob, atol=2e-2) + self.assertAllClose(sharded_log_prob[i], true_log_prob, atol=0.025) self.assertAllClose(sharded_log_prob_grad[i], true_log_prob_grad, atol=2e-2) diff --git a/tensorflow_probability/python/experimental/distribute/sharded.py b/tensorflow_probability/python/experimental/distribute/sharded.py index 5b2fd8ca9a..d58e47a1a7 100644 --- a/tensorflow_probability/python/experimental/distribute/sharded.py +++ b/tensorflow_probability/python/experimental/distribute/sharded.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.distributions import distribution as distribution_lib from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.experimental.bijectors import sharded as sharded_bij +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import distribute_lib from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import samplers @@ -76,7 +77,7 @@ def __init__(self, distribution, shard_axis_name=None, validate_args=False, """ parameters = dict(locals()) - if not isinstance(distribution, tf.__internal__.CompositeTensor): + if not auto_composite_tensor.is_composite_tensor(distribution): raise ValueError('`distribution` must be a `CompositeTensor`.') if shard_axis_name is None: diff --git a/tensorflow_probability/python/experimental/distributions/BUILD b/tensorflow_probability/python/experimental/distributions/BUILD index 6019005130..c6d3b45c87 100644 --- a/tensorflow_probability/python/experimental/distributions/BUILD +++ b/tensorflow_probability/python/experimental/distributions/BUILD @@ -58,6 +58,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:samplers", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/mcmc/internal:util", ], ) @@ -120,6 +121,7 @@ multi_substrate_py_library( deps = [ # numpy dep, # tensorflow dep, + "//tensorflow_probability/python/internal:tf_keras", ], ) diff --git a/tensorflow_probability/python/experimental/distributions/importance_resample.py b/tensorflow_probability/python/experimental/distributions/importance_resample.py index 93a46634c0..5b4ce87917 100644 --- a/tensorflow_probability/python/experimental/distributions/importance_resample.py +++ b/tensorflow_probability/python/experimental/distributions/importance_resample.py @@ -142,7 +142,7 @@ def target_log_prob_fn(x): importance_weighted_losses = tfp.vi.fit_surrogate_posterior( target_log_prob_fn, surrogate_posterior=proposal_distribution, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), num_steps=200, importance_sample_size=importance_sample_size) approximate_posterior = tfed.ImportanceResample( @@ -167,7 +167,7 @@ def target_log_prob_fn(x): proposal_distribution=proposal_distribution, target_log_prob_fn=target_log_prob_fn, importance_sample_size=importance_sample_size), - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), num_steps=200) ``` diff --git a/tensorflow_probability/python/experimental/distributions/joint_distribution_pinned.py b/tensorflow_probability/python/experimental/distributions/joint_distribution_pinned.py index 2cb2c68731..e376b8462a 100644 --- a/tensorflow_probability/python/experimental/distributions/joint_distribution_pinned.py +++ b/tensorflow_probability/python/experimental/distributions/joint_distribution_pinned.py @@ -246,7 +246,7 @@ def target_log_prob_fn(loc, scale): pulled_back_shape) vars = tf.nest.map_structure(tf.Variable, uniform_init) - opt = tf.optimizers.Adam(.01) + opt = tf_keras.optimizers.Adam(.01) @tf.function(autograph=False) def one_step(): diff --git a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py index 98376caf02..bc03dff245 100644 --- a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py +++ b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py @@ -246,19 +246,32 @@ def __init__(self, """ parameters = dict(locals()) with tf.name_scope(name) as name: - if tf.nest.is_nested(kernel.feature_ndims): - input_dtype = dtype_util.common_dtype( - [kernel, index_points], - dtype_hint=nest_util.broadcast_structure( - kernel.feature_ndims, tf.float32)) + input_dtype = dtype_util.common_dtype( + dict( + kernel=kernel, + index_points=index_points), + dtype_hint=nest_util.broadcast_structure( + kernel.feature_ndims, tf.float32)) + + # If the input dtype is non-nested float, we infer a single dtype for the + # input and the float parameters, which is also the dtype of the MTGP's + # samples, log_prob, etc. If the input dtype is nested (or not float), we + # do not use it to infer the MTGP's float dtype. + if (not tf.nest.is_nested(input_dtype) and + dtype_util.is_floating(input_dtype)): dtype = dtype_util.common_dtype( - [observation_noise_variance], tf.float32) + dict( + kernel=kernel, + index_points=index_points, + observation_noise_variance=observation_noise_variance, + ), + dtype_hint=tf.float32, + ) + input_dtype = dtype else: - # If the index points are not nested, we assume they are of the same - # float dtype as the kernel. dtype = dtype_util.common_dtype( - [kernel, index_points, observation_noise_variance], tf.float32) - input_dtype = dtype + dict(observation_noise_variance=observation_noise_variance), + dtype_hint=tf.float32) if index_points is not None: index_points = nest_util.convert_to_nested_tensor( diff --git a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py index 5078e0b3ab..fc7ec6a5de 100644 --- a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py +++ b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py @@ -292,22 +292,40 @@ def __init__(self, if not isinstance(kernel, multitask_kernel.MultiTaskKernel): raise ValueError('`kernel` must be a `MultiTaskKernel`.') - if tf.nest.is_nested(kernel.feature_ndims): - input_dtype = dtype_util.common_dtype( - [kernel, index_points, observation_index_points], - dtype_hint=nest_util.broadcast_structure( - kernel.feature_ndims, tf.float32)) + input_dtype = dtype_util.common_dtype( + dict( + kernel=kernel, + index_points=index_points, + observation_index_points=observation_index_points, + ), + dtype_hint=nest_util.broadcast_structure( + kernel.feature_ndims, tf.float32)) + + # If the input dtype is non-nested float, we infer a single dtype for the + # input and the float parameters, which is also the dtype of the MTGP's + # samples, log_prob, etc. If the input dtype is nested (or not float), we + # do not use it to infer the MTGP's float dtype. + if (not tf.nest.is_nested(input_dtype) and + dtype_util.is_floating(input_dtype)): dtype = dtype_util.common_dtype( - [observations, observation_noise_variance, - predictive_noise_variance], tf.float32) - else: - # If the index points are not nested, we assume they are of the same - # dtype as the kernel. - dtype = dtype_util.common_dtype([ - kernel, index_points, observation_index_points, observations, - observation_noise_variance, predictive_noise_variance - ], tf.float32) + dict( + kernel=kernel, + index_points=index_points, + observations=observations, + observation_index_points=observation_index_points, + observation_noise_variance=observation_noise_variance, + predictive_noise_variance=predictive_noise_variance, + ), + dtype_hint=tf.float32, + ) input_dtype = dtype + else: + dtype = dtype_util.common_dtype( + dict( + observations=observations, + observation_noise_variance=observation_noise_variance, + predictive_noise_variance=predictive_noise_variance, + ), dtype_hint=tf.float32) if index_points is not None: index_points = nest_util.convert_to_nested_tensor( @@ -595,48 +613,46 @@ def precompute_regression_model( if _precomputed_divisor_matrix_cholesky is not None: observation_scale = _scale_from_precomputed( _precomputed_divisor_matrix_cholesky, kernel) - elif observations_is_missing is not None: - # If observations are missing, there's nothing we can do to preserve the - # operator structure, so densify. - - observation_covariance = kernel.matrix_over_all_tasks( - observation_index_points, observation_index_points).to_dense() - - if observation_noise_variance is not None: - broadcast_shape = distribution_util.get_broadcast_shape( - observation_covariance, observation_noise_variance[ - ..., tf.newaxis, tf.newaxis]) - observation_covariance = tf.broadcast_to(observation_covariance, - broadcast_shape) - observation_covariance = _add_diagonal_shift( - observation_covariance, observation_noise_variance) - vec_observations_is_missing = _vec(observations_is_missing) - observation_covariance = tf.linalg.LinearOperatorFullMatrix( - psd_kernels_util.mask_matrix( - observation_covariance, - is_missing=vec_observations_is_missing), - is_non_singular=True, - is_positive_definite=True) - observation_scale = cholesky_util.cholesky_from_fn( - observation_covariance, cholesky_fn) + solve_on_observations = _precomputed_solve_on_observation else: - observation_scale = mtgp._compute_flattened_scale( # pylint:disable=protected-access - kernel=kernel, - index_points=observation_index_points, - cholesky_fn=cholesky_fn, - observation_noise_variance=observation_noise_variance) - - # Note that the conditional mean is - # k(x, o) @ (k(o, o) + sigma**2)^-1 obs. We can precompute the latter - # term since it won't change per iteration. - vec_diff = _vec(observations - mean_fn(observation_index_points)) - - if observations_is_missing is not None: - vec_diff = tf.where(vec_observations_is_missing, - tf.zeros([], dtype=vec_diff.dtype), - vec_diff) - solve_on_observations = _precomputed_solve_on_observation - if solve_on_observations is None: + # Note that the conditional mean is + # k(x, o) @ (k(o, o) + sigma**2)^-1 obs. We can precompute the latter + # term since it won't change per iteration. + vec_diff = _vec(observations - mean_fn(observation_index_points)) + + if observations_is_missing is not None: + # If observations are missing, there's nothing we can do to preserve + # the operator structure, so densify. + vec_observations_is_missing = _vec(observations_is_missing) + vec_diff = tf.where(vec_observations_is_missing, + tf.zeros([], dtype=vec_diff.dtype), + vec_diff) + + observation_covariance = kernel.matrix_over_all_tasks( + observation_index_points, observation_index_points).to_dense() + + if observation_noise_variance is not None: + broadcast_shape = distribution_util.get_broadcast_shape( + observation_covariance, observation_noise_variance[ + ..., tf.newaxis, tf.newaxis]) + observation_covariance = tf.broadcast_to(observation_covariance, + broadcast_shape) + observation_covariance = _add_diagonal_shift( + observation_covariance, observation_noise_variance) + observation_covariance = tf.linalg.LinearOperatorFullMatrix( + psd_kernels_util.mask_matrix( + observation_covariance, + is_missing=vec_observations_is_missing), + is_non_singular=True, + is_positive_definite=True) + observation_scale = cholesky_util.cholesky_from_fn( + observation_covariance, cholesky_fn) + else: + observation_scale = mtgp._compute_flattened_scale( # pylint:disable=protected-access + kernel=kernel, + index_points=observation_index_points, + cholesky_fn=cholesky_fn, + observation_noise_variance=observation_noise_variance) solve_on_observations = observation_scale.solvevec( observation_scale.solvevec(vec_diff), adjoint=True) @@ -660,6 +676,7 @@ def flattened_conditional_mean_fn(x): observation_noise_variance=observation_noise_variance, predictive_noise_variance=predictive_noise_variance, cholesky_fn=cholesky_fn, + observations_is_missing=observations_is_missing, _flattened_conditional_mean_fn=flattened_conditional_mean_fn, _observation_scale=observation_scale, validate_args=validate_args, diff --git a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model_test.py b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model_test.py index 66258acc99..2680bf6038 100644 --- a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model_test.py +++ b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model_test.py @@ -474,16 +474,26 @@ def testMeanVarianceJit(self): tf.function(jit_compile=True)(mtgprm.mean)() tf.function(jit_compile=True)(mtgprm.variance)() - def testMeanVarianceAndCovariancePrecomputed(self): + @parameterized.parameters(True, False) + def testMeanVarianceAndCovariancePrecomputed(self, has_missing_observations): num_tasks = 3 + num_obs = 7 amplitude = np.array([1., 2.], np.float64).reshape([2, 1]) length_scale = np.array([.1, .2, .3], np.float64).reshape([1, 3]) observation_noise_variance = np.array([1e-9], np.float64) observation_index_points = ( - np.random.uniform(-1., 1., (1, 1, 7, 2)).astype(np.float64)) + np.random.uniform(-1., 1., (1, 1, num_obs, 2)).astype(np.float64)) observations = np.linspace( - -20., 20., 7 * num_tasks).reshape(7, num_tasks).astype(np.float64) + -20., 20., num_obs * num_tasks).reshape( + num_obs, num_tasks).astype(np.float64) + + if has_missing_observations: + observations_is_missing = np.stack( + [np.random.randint(2, size=(num_obs,))] * num_tasks, axis=-1 + ).astype(np.bool_) + else: + observations_is_missing = None index_points = np.random.uniform(-1., 1., (6, 2)).astype(np.float64) @@ -497,6 +507,7 @@ def testMeanVarianceAndCovariancePrecomputed(self): observation_index_points=observation_index_points, observations=observations, observation_noise_variance=observation_noise_variance, + observations_is_missing=observations_is_missing, validate_args=True) precomputed_mtgprm = mtgprm_lib.MultiTaskGaussianProcessRegressionModel.precompute_regression_model( @@ -505,6 +516,7 @@ def testMeanVarianceAndCovariancePrecomputed(self): observation_index_points=observation_index_points, observations=observations, observation_noise_variance=observation_noise_variance, + observations_is_missing=observations_is_missing, validate_args=True) mock_cholesky_fn = mock.Mock(return_value=None) @@ -514,6 +526,7 @@ def testMeanVarianceAndCovariancePrecomputed(self): observation_index_points=observation_index_points, observations=observations, observation_noise_variance=observation_noise_variance, + observations_is_missing=observations_is_missing, _precomputed_divisor_matrix_cholesky=precomputed_mtgprm._precomputed_divisor_matrix_cholesky, _precomputed_solve_on_observation=precomputed_mtgprm._precomputed_solve_on_observation, cholesky_fn=mock_cholesky_fn, diff --git a/tensorflow_probability/python/experimental/linalg/linear_operator_psd_kernel_test.py b/tensorflow_probability/python/experimental/linalg/linear_operator_psd_kernel_test.py index c3e45be183..9a4aefcf13 100644 --- a/tensorflow_probability/python/experimental/linalg/linear_operator_psd_kernel_test.py +++ b/tensorflow_probability/python/experimental/linalg/linear_operator_psd_kernel_test.py @@ -271,15 +271,17 @@ def test_matmul_grad_xla_kernelparams(self): feature_dim = 3 def kernel_fn(eq_params, poly_params): - return (exponentiated_quadratic.ExponentiatedQuadratic(**eq_params) * - polynomial.Polynomial(**poly_params)) + return (exponentiated_quadratic.ExponentiatedQuadratic(*eq_params) * + polynomial.Polynomial(bias_amplitude=poly_params[0], + shift=poly_params[1])) + # TODO(b/284106340): Return this to a dictionary. kernel_args = ( - dict(length_scale=tf.random.uniform([], .5, 1.5, dtype=tf.float64), - amplitude=tf.random.uniform([], 1.5, 2.5, dtype=tf.float64)), - dict(bias_amplitude=tf.random.uniform([feature_dim], .5, 1.5, - dtype=tf.float64), - shift=tf.random.normal([feature_dim], dtype=tf.float64))) + (tf.random.uniform([], 1.5, 2.5, dtype=tf.float64), # amplitude + tf.random.uniform([], .5, 1.5, dtype=tf.float64)), # length_scale + (tf.random.uniform([feature_dim], .5, 1.5, # bias_amplitude + dtype=tf.float64), + tf.random.normal([feature_dim], dtype=tf.float64))) # shift x1 = tf.random.normal([5, feature_dim], dtype=tf.float64) x2 = tf.random.normal([7, feature_dim], dtype=tf.float64) diff --git a/tensorflow_probability/python/experimental/linalg/no_pivot_ldl_test.py b/tensorflow_probability/python/experimental/linalg/no_pivot_ldl_test.py index 88e6b62e0f..63c74db978 100644 --- a/tensorflow_probability/python/experimental/linalg/no_pivot_ldl_test.py +++ b/tensorflow_probability/python/experimental/linalg/no_pivot_ldl_test.py @@ -86,7 +86,11 @@ def testXlaCompileBug(self): self.assertAllClose(self.evaluate(alt_chol(inp)), answer) self.assertAllClose(self.evaluate(alt_chol_nojit(inp)), answer) - self.assertAllClose(self.evaluate(alt_chol_jit(inp)), answer) + # TODO(phandu): Enable the test again when the bug is resolved. + # Bug in tensorflow since 2.15.0-dev20230812, + # see details at https://github.com/tensorflow/tensorflow/issues/61674 + # self.assertAllClose(self.evaluate(alt_chol_jit(inp)), answer) + del alt_chol_jit with tf.GradientTape(): chol_with_grad = alt_chol(inp) @@ -102,7 +106,11 @@ def jit_with_grad(mat): with tf.GradientTape(): return alt_chol_jit(mat) - self.assertAllClose(self.evaluate(jit_with_grad(inp)), answer) + # TODO(phandu): Enable the test again when the bug is resolved. + # Bug in tensorflow since 2.15.0-dev20230812, + # see details at https://github.com/tensorflow/tensorflow/issues/61674 + # self.assertAllClose(self.evaluate(jit_with_grad(inp)), answer) + del jit_with_grad if __name__ == '__main__': diff --git a/tensorflow_probability/python/experimental/marginalize/BUILD b/tensorflow_probability/python/experimental/marginalize/BUILD index e92a16f96c..d2d9c1865e 100644 --- a/tensorflow_probability/python/experimental/marginalize/BUILD +++ b/tensorflow_probability/python/experimental/marginalize/BUILD @@ -15,8 +15,11 @@ # Description: # Automatic marginalization of latent variables. -# Placeholder: py_library -# Placeholder: py_test +load( + "//tensorflow_probability/python:build_defs.bzl", + "multi_substrate_py_library", + "multi_substrate_py_test", +) package( # default_applicable_licenses @@ -27,17 +30,18 @@ package( licenses(["notice"]) -py_library( +multi_substrate_py_library( name = "logeinsumexp", srcs = ["logeinsumexp.py"], deps = [ # numpy dep, # opt_einsum dep, # tensorflow dep, + "//tensorflow_probability/python/internal:prefer_static", ], ) -py_test( +multi_substrate_py_test( name = "logeinsumexp_test", size = "medium", srcs = [ @@ -53,7 +57,7 @@ py_test( ], ) -py_library( +multi_substrate_py_library( name = "marginalize", srcs = ["__init__.py"], deps = [ @@ -62,7 +66,7 @@ py_library( ], ) -py_library( +multi_substrate_py_library( name = "marginalizable", srcs = ["marginalizable.py"], deps = [ @@ -72,13 +76,17 @@ py_library( "//tensorflow_probability/python/distributions:categorical", "//tensorflow_probability/python/distributions:joint_distribution_coroutine", "//tensorflow_probability/python/distributions:sample", + "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:samplers", ], ) -py_test( +multi_substrate_py_test( name = "marginalizable_test", size = "medium", srcs = ["marginalizable_test.py"], + jax_tags = ["notap"], + numpy_tags = ["notap"], deps = [ ":marginalize", # absl/testing:parameterized dep, @@ -92,6 +100,7 @@ py_test( "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/distributions:poisson", "//tensorflow_probability/python/distributions:sample", + "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:test_util", ], ) diff --git a/tensorflow_probability/python/experimental/marginalize/logeinsumexp.py b/tensorflow_probability/python/experimental/marginalize/logeinsumexp.py index 7d8794c0b5..1923f79a27 100644 --- a/tensorflow_probability/python/experimental/marginalize/logeinsumexp.py +++ b/tensorflow_probability/python/experimental/marginalize/logeinsumexp.py @@ -15,7 +15,8 @@ """Compute einsums in log space.""" import opt_einsum as oe -import tensorflow.compat.v1 as tf +import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import prefer_static as ps # pylint: disable=no-member @@ -72,8 +73,8 @@ def rearrange(src, dst, t): if i not in src: new_indices += i new_src = src + new_indices - new_t = tf.reshape(t, tf.concat( - [tf.shape(t), tf.ones(len(new_indices), dtype=tf.int32)], axis=0)) + new_t = tf.reshape(t, ps.concat( + [ps.shape(t), ps.ones(len(new_indices), dtype=tf.int32)], axis=0)) formula = '{}->{}'.format(new_src, dst) # It is safe to use ordinary `einsum` here as no summations # are performed. diff --git a/tensorflow_probability/python/experimental/marginalize/logeinsumexp_test.py b/tensorflow_probability/python/experimental/marginalize/logeinsumexp_test.py index 182a6d42dd..016284f24f 100644 --- a/tensorflow_probability/python/experimental/marginalize/logeinsumexp_test.py +++ b/tensorflow_probability/python/experimental/marginalize/logeinsumexp_test.py @@ -18,7 +18,7 @@ from hypothesis.extra import numpy as hpnp import hypothesis.strategies as hps import numpy as np -import tensorflow.compat.v1 as tf +import tensorflow.compat.v2 as tf from tensorflow_probability.python.experimental.marginalize.logeinsumexp import _binary_einslogsumexp from tensorflow_probability.python.experimental.marginalize.logeinsumexp import logeinsumexp from tensorflow_probability.python.internal import test_util @@ -179,7 +179,6 @@ def test_compare_einsum(self): formula = 'abcdcfg,edfcbaa->bd' u = tf.math.log(tf.einsum(formula, a, b)) v = logeinsumexp(formula, tf.math.log(a), tf.math.log(b)) - self.assertAllClose(u, v) def test_zero_zero_multiplication(self): diff --git a/tensorflow_probability/python/experimental/marginalize/marginalizable.py b/tensorflow_probability/python/experimental/marginalize/marginalizable.py index d9f327f720..e3ae6fb97c 100644 --- a/tensorflow_probability/python/experimental/marginalize/marginalizable.py +++ b/tensorflow_probability/python/experimental/marginalize/marginalizable.py @@ -24,6 +24,8 @@ from tensorflow_probability.python.distributions import joint_distribution_coroutine as jdc_lib from tensorflow_probability.python.distributions import sample as sample_lib from tensorflow_probability.python.experimental.marginalize.logeinsumexp import logeinsumexp +from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.internal import samplers __all__ = [ @@ -117,10 +119,9 @@ def _support(dist): dist.sample_shape, 'expand_sample_shape') p, rank = _support(dist.distribution) product = _power(p, n) - new_shape = tf.concat([tf.shape(product)[:-1], sample_shape], axis=-1) + new_shape = ps.concat([ps.shape(product)[:-1], sample_shape], axis=-1) - new_rank = rank + tf.compat.v2.compat.dimension_value( - sample_shape.shape[0]) + new_rank = rank + tf.compat.dimension_value(sample_shape.shape[0]) return tf.reshape(product, new_shape), new_rank else: raise ValueError('Unable to find support for distribution ' + @@ -141,11 +142,11 @@ def _expand_right(a, n, pos): Tensor with inserted dimensions. """ - axis = tf.rank(a) + pos + 1 - return tf.reshape(a, tf.concat([ - tf.shape(a)[:axis], - tf.ones([n], dtype=tf.int32), - tf.shape(a)[axis:]], axis=0)) + axis = ps.rank(a) + pos + 1 + return tf.reshape(a, ps.concat([ + ps.shape(a)[:axis], + ps.ones([n], dtype=tf.int32), + ps.shape(a)[axis:]], axis=0)) def _letter(i): @@ -216,7 +217,9 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob', with tf.name_scope(name): ds = self._call_execute_model( - sample_and_trace_fn=jd_lib.trace_distributions_only) + sample_and_trace_fn=jd_lib.trace_distributions_only, + # Only used for tracing so can be fixed. + seed=samplers.zeros_seed()) # Both 'marginalize' and 'tabulate' indicate that # instead of using samples provided by the user, this method @@ -229,7 +232,7 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob', for value, dist in zip(values, ds): if value == 'marginalize': supp, rank = _support(dist) - r = supp.shape.rank + r = ps.rank(supp) num_new_variables = r - rank # We can think of supp as being a tensor containing tensors, # each of which is a draw from the distribution. @@ -251,7 +254,7 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob', formula.append(indices) elif value == 'tabulate': supp, rank = _support(dist) - r = supp.shape.rank + r = ps.rank(supp) if r is None: raise ValueError('Need to be able to statically find rank of' 'support of random variable: {}'.format(str(dist))) diff --git a/tensorflow_probability/python/experimental/marginalize/marginalizable_test.py b/tensorflow_probability/python/experimental/marginalize/marginalizable_test.py index 1211246c5e..b0da46d476 100644 --- a/tensorflow_probability/python/experimental/marginalize/marginalizable_test.py +++ b/tensorflow_probability/python/experimental/marginalize/marginalizable_test.py @@ -34,6 +34,7 @@ from tensorflow_probability.python.distributions import poisson from tensorflow_probability.python.distributions import sample as sample_dist_lib import tensorflow_probability.python.experimental.marginalize as marginalize +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import test_util @@ -48,10 +49,6 @@ def _conform(ts): return [tf.broadcast_to(a, shape) for a in ts] -def _cat(*ts): - return tf.concat(ts, axis=0) - - def _stack(*ts): return tf.stack(_conform(ts), axis=-1) @@ -209,7 +206,7 @@ def test_hmm(self): n_steps = 4 infer_step = 2 - observations = [-1.0, 0.0, 1.0, 2.0] + observations = np.array([-1.0, 0.0, 1.0, 2.0], np.float32) initial_prob = tf.constant([0.6, 0.4], dtype=tf.float32) transition_matrix = tf.constant([[0.6, 0.4], @@ -309,7 +306,7 @@ def model(): 0.4 * tf.roll(o, shift=[1, 0], axis=[-2, -1])) # Reshape just last two dimensions. - p = tf.reshape(p, _cat(p.shape[:-2], [-1])) + p = tf.reshape(p, ps.concat([ps.shape(p)[:-2], [-1]], axis=0)) xy = yield categorical.Categorical(probs=p, dtype=tf.int32) x[i] = xy // n y[i] = xy % n @@ -342,6 +339,7 @@ def model(): # order chosen by `tf.einsum` closer matches `_tree_example` above. self.assertAllClose(p, q) + @test_util.numpy_disable_gradient_test def test_marginalized_gradient(self): n = 10 diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index 82ed63c698..7fb140497b 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -548,7 +548,6 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", - "//tensorflow_probability/python/distributions:batch_reshape", ], ) @@ -575,8 +574,6 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:sample", "//tensorflow_probability/python/distributions:transformed_distribution", "//tensorflow_probability/python/distributions:uniform", - "//tensorflow_probability/python/distributions:categorical", - "//tensorflow_probability/python/distributions:hidden_markov_model", "//tensorflow_probability/python/internal:test_util", "//tensorflow_probability/python/math:gradient", # "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport diff --git a/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation_test.py b/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation_test.py index d157139aef..8e270cfdd9 100644 --- a/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation_test.py +++ b/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation_test.py @@ -317,11 +317,11 @@ def testMeanGoesInRightDirection(self): initial_running_variance=initial_running_variance) # This number started at `error_factor`. Make sure the mean is now at least - # 75% closer. + # 50% closer. final_mean_diff = tf.abs(results.final_mean - results.true_mean) np.testing.assert_array_less( self.evaluate(final_mean_diff), - self.evaluate(0.25 * error_factor)) + self.evaluate(0.5 * error_factor)) def testDoesNotGoesInWrongDirection(self): # As above, we test a weaker property, which is that the variance and diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py index 098769f36a..b586e44b86 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py @@ -42,8 +42,7 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None): return WeightedParticles( particles=proposed_particles, log_weights=weighted_particles.log_weights + - normal.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles) - ) + normal.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles)) num_particles = 16 initial_state = self.evaluate( @@ -51,8 +50,7 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None): particles=tf.random.normal([num_particles], seed=test_util.test_seed()), log_weights=tf.fill([num_particles], - -tf.math.log(float(num_particles))) - )) + -tf.math.log(float(num_particles))))) # Run a couple of steps. seeds = samplers.split_seed( diff --git a/tensorflow_probability/python/experimental/nn/BUILD b/tensorflow_probability/python/experimental/nn/BUILD index 35de967cb7..8b3a035d06 100644 --- a/tensorflow_probability/python/experimental/nn/BUILD +++ b/tensorflow_probability/python/experimental/nn/BUILD @@ -55,6 +55,7 @@ py_library( "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/experimental/nn/util:kernel_bias", "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -73,6 +74,7 @@ py_test( "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util:deferred_tensor", ], ) @@ -88,6 +90,7 @@ py_library( "//tensorflow_probability/python/experimental/nn/util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -104,6 +107,7 @@ py_test( "//tensorflow_probability/python/distributions:independent", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util:deferred_tensor", ], ) @@ -137,6 +141,7 @@ py_test( "//tensorflow_probability/python/distributions:independent", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util:deferred_tensor", ], ) @@ -151,6 +156,7 @@ py_library( "//tensorflow_probability/python/distributions:distribution", "//tensorflow_probability/python/experimental/nn/util:utils", "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util:deferred_tensor", ], ) @@ -167,6 +173,7 @@ py_test( "//tensorflow_probability/python/distributions:independent", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) diff --git a/tensorflow_probability/python/experimental/nn/README.md b/tensorflow_probability/python/experimental/nn/README.md index 4384fe1b9b..6a95f75bfb 100644 --- a/tensorflow_probability/python/experimental/nn/README.md +++ b/tensorflow_probability/python/experimental/nn/README.md @@ -11,7 +11,7 @@ Design goals include but are not limited to: - extensibility - simple implementations. -The primary differences from `tf.keras` are: +The primary differences from `tf_keras` are: 1. The TFP NN toolbox use `tf.Module` for `tf.Variable` tracking. 2. Users are expected to implement their own train loops. diff --git a/tensorflow_probability/python/experimental/nn/affine_layers.py b/tensorflow_probability/python/experimental/nn/affine_layers.py index 578de4ec49..5181fd7a39 100644 --- a/tensorflow_probability/python/experimental/nn/affine_layers.py +++ b/tensorflow_probability/python/experimental/nn/affine_layers.py @@ -45,7 +45,7 @@ def __init__( output_size, # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_kernel_bias_fn=kernel_bias_lib.make_kernel_bias, dtype=tf.float32, batch_shape=(), @@ -61,7 +61,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_kernel_bias_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias`. dtype: ... @@ -179,11 +179,11 @@ def _preprocess(image, label): padding='same', filter_shape=5, # Use `he_uniform` because we'll use the `relu` family. - kernel_initializer=tf.initializers.he_uniform()) + kernel_initializer=tf_keras.initializers.he_uniform()) BayesAffine = functools.partial( tfn.AffineVariationalReparameterization, - kernel_initializer=tf.initializers.he_normal()) + kernel_initializer=tf_keras.initializers.he_normal()) scale = tfp.util.TransformedVariable(1., tfb.Softplus()) bnn = tfn.Sequential([ @@ -206,7 +206,7 @@ def loss_fn(): kl = bnn.extra_loss / tf.cast(train_size, tf.float32) loss = nll + kl return loss, (nll, kl) - opt = tf.optimizers.Adam() + opt = tf_keras.optimizers.Adam() fit_op = tfn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(200): loss, (nll, kl), g = fit_op() @@ -232,7 +232,7 @@ def __init__( output_size, # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_posterior_fn=kernel_bias_lib.make_kernel_bias_posterior_mvn_diag, make_prior_fn=kernel_bias_lib.make_kernel_bias_prior_spike_and_slab, posterior_value_fn=tfd.Distribution.sample, @@ -252,7 +252,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_posterior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`. @@ -363,7 +363,7 @@ def __init__( output_size, # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_posterior_fn=kernel_bias_lib.make_kernel_bias_posterior_mvn_diag, make_prior_fn=kernel_bias_lib.make_kernel_bias_prior_spike_and_slab, posterior_value_fn=tfd.Distribution.sample, @@ -383,7 +383,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_posterior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`. @@ -502,7 +502,7 @@ def __init__( output_size, # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_posterior_fn=kernel_bias_lib.make_kernel_bias_posterior_mvn_diag, make_prior_fn=kernel_bias_lib.make_kernel_bias_prior_spike_and_slab, posterior_value_fn=tfd.Distribution.sample, @@ -522,7 +522,7 @@ def __init__( Default value: `None` (i.e., `tfp.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_posterior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`. diff --git a/tensorflow_probability/python/experimental/nn/affine_layers_test.py b/tensorflow_probability/python/experimental/nn/affine_layers_test.py index 91ab67de86..43433f4199 100644 --- a/tensorflow_probability/python/experimental/nn/affine_layers_test.py +++ b/tensorflow_probability/python/experimental/nn/affine_layers_test.py @@ -29,6 +29,7 @@ from tensorflow_probability.python.experimental import nn as tfn from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.util import deferred_tensor @@ -87,7 +88,7 @@ def loss_fn(): nll = -tf.reduce_mean(bnn(x).log_prob(y), axis=-1) kl = tfn.losses.compute_extra_loss(bnn) / n return nll + kl, (nll, kl) - opt = tf.optimizers.Adam() + opt = tf_keras.optimizers.Adam() fit_op = tfn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(2): loss, (nll, kl) = fit_op() # pylint: disable=unused-variable diff --git a/tensorflow_probability/python/experimental/nn/convolutional_layers.py b/tensorflow_probability/python/experimental/nn/convolutional_layers.py index 4a34059de1..38d5ab6550 100644 --- a/tensorflow_probability/python/experimental/nn/convolutional_layers.py +++ b/tensorflow_probability/python/experimental/nn/convolutional_layers.py @@ -91,7 +91,7 @@ def __init__( dilations=1, # keras::Conv::dilation_rate # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_kernel_bias_fn=kernel_bias_lib.make_kernel_bias, dtype=tf.float32, batch_shape=(), @@ -147,7 +147,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_kernel_bias_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias`. dtype: ... @@ -288,7 +288,7 @@ def _preprocess(image, label): padding='same', filter_shape=5, # Use `he_uniform` because we'll use the `relu` family. - kernel_initializer=tf.initializers.he_uniform(), + kernel_initializer=tf_keras.initializers.he_uniform(), penalty_weight=1. / n) BayesAffine = functools.partial( @@ -316,7 +316,7 @@ def loss_fn(): kl = bnn.extra_loss # Already normalized via `penalty_weight` arg. loss = nll + kl return loss, (nll, kl) - opt = tf.optimizers.Adam() + opt = tf_keras.optimizers.Adam() fit_op = tfn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(200): loss, (nll, kl), g = fit_op() @@ -349,7 +349,7 @@ def __init__( dilations=1, # keras::Conv::dilation_rate # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_posterior_fn=kernel_bias_lib.make_kernel_bias_posterior_mvn_diag, make_prior_fn=kernel_bias_lib.make_kernel_bias_prior_spike_and_slab, posterior_value_fn=tfd.Distribution.sample, @@ -408,7 +408,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_posterior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`. @@ -538,7 +538,7 @@ def __init__( dilations=1, # keras::Conv::dilation_rate # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_posterior_fn=kernel_bias_lib.make_kernel_bias_posterior_mvn_diag, make_prior_fn=kernel_bias_lib.make_kernel_bias_prior_spike_and_slab, posterior_value_fn=tfd.Distribution.sample, @@ -597,7 +597,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_posterior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`. diff --git a/tensorflow_probability/python/experimental/nn/convolutional_layers_test.py b/tensorflow_probability/python/experimental/nn/convolutional_layers_test.py index 9fd5e2e962..a6525de128 100644 --- a/tensorflow_probability/python/experimental/nn/convolutional_layers_test.py +++ b/tensorflow_probability/python/experimental/nn/convolutional_layers_test.py @@ -25,6 +25,7 @@ from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.experimental import nn as tfn from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.util import deferred_tensor @@ -79,7 +80,7 @@ def loss_fn(): nll = -tf.reduce_mean(bnn(x).log_prob(y), axis=-1) kl = tfn.losses.compute_extra_loss(bnn) / n return nll + kl, (nll, kl) - opt = tf.optimizers.Adam() + opt = tf_keras.optimizers.Adam() fit_op = tfn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(2): loss, (nll, kl) = fit_op() # pylint: disable=unused-variable diff --git a/tensorflow_probability/python/experimental/nn/convolutional_layers_v2.py b/tensorflow_probability/python/experimental/nn/convolutional_layers_v2.py index 5485833888..039755846d 100644 --- a/tensorflow_probability/python/experimental/nn/convolutional_layers_v2.py +++ b/tensorflow_probability/python/experimental/nn/convolutional_layers_v2.py @@ -94,7 +94,7 @@ def __init__( dilations=1, # keras::Conv::dilation_rate # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_kernel_bias_fn=kernel_bias_lib.make_kernel_bias, dtype=tf.float32, index_dtype=tf.int32, @@ -151,7 +151,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_kernel_bias_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias`. dtype: ... @@ -288,7 +288,7 @@ def _preprocess(image, label): padding='same', filter_shape=5, # Use `he_uniform` because we'll use the `relu` family. - kernel_initializer=tf.initializers.he_uniform(), + kernel_initializer=tf_keras.initializers.he_uniform(), penalty_weight=1. / n) BayesAffine = functools.partial( @@ -316,7 +316,7 @@ def loss_fn(): kl = bnn.extra_loss # Already normalized via `penalty_weight` arg. loss = nll + kl return loss, (nll, kl) - opt = tf.optimizers.Adam() + opt = tf_keras.optimizers.Adam() fit_op = tfn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(200): loss, (nll, kl), g = fit_op() @@ -349,7 +349,7 @@ def __init__( dilations=1, # keras::Conv::dilation_rate # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_posterior_fn=kernel_bias_lib.make_kernel_bias_posterior_mvn_diag, make_prior_fn=kernel_bias_lib.make_kernel_bias_prior_spike_and_slab, posterior_value_fn=tfd.Distribution.sample, @@ -409,7 +409,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_posterior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`. @@ -549,7 +549,7 @@ def __init__( dilations=1, # keras::Conv::dilation_rate # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_posterior_fn=kernel_bias_lib.make_kernel_bias_posterior_mvn_diag, make_prior_fn=kernel_bias_lib.make_kernel_bias_prior_spike_and_slab, posterior_value_fn=tfd.Distribution.sample, @@ -609,7 +609,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_posterior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`. diff --git a/tensorflow_probability/python/experimental/nn/convolutional_layers_v2_test.py b/tensorflow_probability/python/experimental/nn/convolutional_layers_v2_test.py index 93b5d987c5..0893af1b25 100644 --- a/tensorflow_probability/python/experimental/nn/convolutional_layers_v2_test.py +++ b/tensorflow_probability/python/experimental/nn/convolutional_layers_v2_test.py @@ -27,6 +27,7 @@ from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.experimental import nn as tfn from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.util import deferred_tensor @@ -78,7 +79,7 @@ def loss_fn(): nll = -tf.reduce_mean(bnn(x).log_prob(y), axis=-1) kl = tfn.losses.compute_extra_loss(bnn) / n return nll + kl, (nll, kl) - opt = tf.optimizers.Adam() + opt = tf_keras.optimizers.Adam() fit_op = tfn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(2): loss, (nll, kl) = fit_op() # pylint: disable=unused-variable diff --git a/tensorflow_probability/python/experimental/nn/convolutional_transpose_layers.py b/tensorflow_probability/python/experimental/nn/convolutional_transpose_layers.py index 5d2ad4ce14..ead55e8430 100644 --- a/tensorflow_probability/python/experimental/nn/convolutional_transpose_layers.py +++ b/tensorflow_probability/python/experimental/nn/convolutional_transpose_layers.py @@ -91,7 +91,7 @@ def __init__( method='auto', # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_kernel_bias_fn=kernel_bias_lib.make_kernel_bias, dtype=tf.float32, index_dtype=tf.int32, @@ -156,7 +156,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_kernel_bias_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias`. dtype: ... @@ -278,7 +278,7 @@ def _preprocess(image, label): padding='same', filter_shape=5, # Use `he_uniform` because we'll use the `relu` family. - kernel_initializer=tf.initializers.he_uniform()) + kernel_initializer=tf_keras.initializers.he_uniform()) BayesDeconv2D = functools.partial( tfn.ConvolutionTransposeVariationalReparameterization, @@ -286,7 +286,7 @@ def _preprocess(image, label): padding='same', filter_shape=5, # Use `he_uniform` because we'll use the `relu` family. - kernel_initializer=tf.initializers.he_uniform()) + kernel_initializer=tf_keras.initializers.he_uniform()) scale = tfp.util.TransformedVariable(1., tfb.Softplus()) bnn = tfn.Sequential([ @@ -316,7 +316,7 @@ def loss_fn(): kl = bnn.extra_loss / tf.cast(train_size, tf.float32) loss = nll + kl return loss, (nll, kl) - opt = tf.optimizers.Adam() + opt = tf_keras.optimizers.Adam() fit_op = tfn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(200): loss, (nll, kl), g = fit_op() @@ -351,7 +351,7 @@ def __init__( method='auto', # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_posterior_fn=kernel_bias_lib.make_kernel_bias_posterior_mvn_diag, make_prior_fn=kernel_bias_lib.make_kernel_bias_prior_spike_and_slab, posterior_value_fn=tfd.Distribution.sample, @@ -420,7 +420,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_posterior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`. @@ -527,14 +527,14 @@ class ConvolutionTransposeVariationalFlipout( padding='same', filter_shape=5, # Use `he_uniform` because we'll use the `relu` family. - kernel_initializer=tf.initializers.he_uniform()) + kernel_initializer=tf_keras.initializers.he_uniform()) BayesDeconv2D = functools.partial( tfn.ConvolutionTransposeVariationalFlipout, rank=2, padding='same', filter_shape=5, # Use `he_uniform` because we'll use the `relu` family. - kernel_initializer=tf.initializers.he_uniform()) + kernel_initializer=tf_keras.initializers.he_uniform()) ``` This example uses reparameterization gradients to minimize the @@ -567,7 +567,7 @@ def __init__( method='auto', # Weights kernel_initializer=None, # tfp.nn.initializers.glorot_uniform() - bias_initializer=None, # tf.initializers.zeros() + bias_initializer=None, # tf_keras.initializers.zeros() make_posterior_fn=kernel_bias_lib.make_kernel_bias_posterior_mvn_diag, make_prior_fn=kernel_bias_lib.make_kernel_bias_prior_spike_and_slab, posterior_value_fn=tfd.Distribution.sample, @@ -636,7 +636,7 @@ def __init__( Default value: `None` (i.e., `tfp.experimental.nn.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). make_posterior_fn: ... Default value: `tfp.experimental.nn.util.make_kernel_bias_posterior_mvn_diag`. diff --git a/tensorflow_probability/python/experimental/nn/convolutional_transpose_layers_test.py b/tensorflow_probability/python/experimental/nn/convolutional_transpose_layers_test.py index e7c166644d..eceba593ec 100644 --- a/tensorflow_probability/python/experimental/nn/convolutional_transpose_layers_test.py +++ b/tensorflow_probability/python/experimental/nn/convolutional_transpose_layers_test.py @@ -24,6 +24,7 @@ from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.experimental import nn as tfn from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.util import deferred_tensor @@ -78,7 +79,7 @@ def loss_fn(): kl = tfn.losses.compute_extra_loss(bnn) / tf.cast(train_size, tf.float32) loss = nll + kl return loss, (nll, kl) - opt = tf.optimizers.Adam() + opt = tf_keras.optimizers.Adam() fit_op = tfn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(2): loss, (nll, kl) = fit_op() # pylint: disable=unused-variable diff --git a/tensorflow_probability/python/experimental/nn/examples/bnn_mnist_advi.ipynb b/tensorflow_probability/python/experimental/nn/examples/bnn_mnist_advi.ipynb index c5ac9827fc..0fa1c85003 100644 --- a/tensorflow_probability/python/experimental/nn/examples/bnn_mnist_advi.ipynb +++ b/tensorflow_probability/python/experimental/nn/examples/bnn_mnist_advi.ipynb @@ -91,6 +91,8 @@ "\n", "from tensorflow_probability.python.internal import prefer_static\n", "\n", + "from tensorflow_probability.python.internal import tf_keras\n", + "\n", "# Globally Enable XLA.\n", "# tf.config.optimizer.set_jit(True)\n", "\n", @@ -229,7 +231,7 @@ " kernel_name='posterior_kernel',\n", " bias_name='posterior_bias'):\n", " if kernel_initializer is None:\n", - " kernel_initializer = tf.initializers.glorot_uniform()\n", + " kernel_initializer = tf_keras.initializers.glorot_uniform()\n", " if bias_initializer is None:\n", " bias_initializer = tf.zeros\n", " make_loc = lambda shape, init, name: tf.Variable( # pylint: disable=g-long-lambda\n", @@ -325,7 +327,7 @@ } ], "source": [ - "max_pool = tf.keras.layers.MaxPooling2D( # Has no tf.Variables.\n", + "max_pool = tf_keras.layers.MaxPooling2D( # Has no tf.Variables.\n", " pool_size=(2, 2),\n", " strides=(2, 2),\n", " padding='SAME',\n", @@ -348,7 +350,7 @@ " output_size=8,\n", " filter_shape=5,\n", " padding='SAME',\n", - " init_kernel_fn=tf.initializers.he_uniform(),\n", + " init_kernel_fn=tf_keras.initializers.he_uniform(),\n", " penalty_weight=1 / train_size,\n", " # penalty_weight=1e2 / train_size, # Layer specific \"beta\".\n", " # make_posterior_fn=make_posterior,\n", @@ -361,7 +363,7 @@ " output_size=16,\n", " filter_shape=5,\n", " padding='SAME',\n", - " init_kernel_fn=tf.initializers.he_uniform(),\n", + " init_kernel_fn=tf_keras.initializers.he_uniform(),\n", " penalty_weight=1 / train_size,\n", " # penalty_weight=1e2 / train_size, # Layer specific \"beta\".\n", " # make_posterior_fn=make_posterior,\n", @@ -375,7 +377,7 @@ " output_size=32,\n", " filter_shape=5,\n", " padding='SAME',\n", - " init_kernel_fn=tf.initializers.he_uniform(),\n", + " init_kernel_fn=tf_keras.initializers.he_uniform(),\n", " penalty_weight=1 / train_size,\n", " # penalty_weight=1e2 / train_size, # Layer specific \"beta\".\n", " # make_posterior_fn=make_posterior,\n", @@ -448,7 +450,7 @@ " loss, (nll, kl), _ = compute_loss_bnn(x, y)\n", " return loss, (nll, kl)\n", "\n", - "opt_bnn = tf.optimizers.Adam(learning_rate=0.003)\n", + "opt_bnn = tf_keras.optimizers.Adam(learning_rate=0.003)\n", " \n", "fit_bnn = tfn.util.make_fit_op(\n", " train_loss_bnn,\n", @@ -1191,7 +1193,7 @@ } ], "source": [ - "max_pool = tf.keras.layers.MaxPooling2D( # Has no tf.Variables.\n", + "max_pool = tf_keras.layers.MaxPooling2D( # Has no tf.Variables.\n", " pool_size=(2, 2),\n", " strides=(2, 2),\n", " padding='SAME',\n", @@ -1207,7 +1209,7 @@ " output_size=8,\n", " filter_shape=5,\n", " padding='SAME',\n", - " init_kernel_fn=tf.initializers.he_uniform(),\n", + " init_kernel_fn=tf_keras.initializers.he_uniform(),\n", " name='conv1'),\n", " maybe_batchnorm,\n", " tf.nn.leaky_relu,\n", @@ -1216,7 +1218,7 @@ " output_size=16,\n", " filter_shape=5,\n", " padding='SAME',\n", - " init_kernel_fn=tf.initializers.he_uniform(),\n", + " init_kernel_fn=tf_keras.initializers.he_uniform(),\n", " name='conv1'),\n", " maybe_batchnorm,\n", " tf.nn.leaky_relu,\n", @@ -1226,7 +1228,7 @@ " output_size=32,\n", " filter_shape=5,\n", " padding='SAME',\n", - " init_kernel_fn=tf.initializers.he_uniform(),\n", + " init_kernel_fn=tf_keras.initializers.he_uniform(),\n", " name='conv2'),\n", " maybe_batchnorm,\n", " tf.nn.leaky_relu,\n", @@ -1280,7 +1282,7 @@ " nll, _ = compute_loss_dnn(x, y)\n", " return nll, None\n", "\n", - "opt_dnn = tf.optimizers.Adam(learning_rate=0.003)\n", + "opt_dnn = tf_keras.optimizers.Adam(learning_rate=0.003)\n", " \n", "fit_dnn = tfn.util.make_fit_op(\n", " train_loss_dnn,\n", diff --git a/tensorflow_probability/python/experimental/nn/examples/single_column_mnist.ipynb b/tensorflow_probability/python/experimental/nn/examples/single_column_mnist.ipynb index a9e3490f4c..575f613919 100644 --- a/tensorflow_probability/python/experimental/nn/examples/single_column_mnist.ipynb +++ b/tensorflow_probability/python/experimental/nn/examples/single_column_mnist.ipynb @@ -283,7 +283,7 @@ "\n", " # Convenience function\n", " affine = functools.partial(tfn.Affine,\n", - " init_kernel_fn=tf.initializers.he_normal(),\n", + " init_kernel_fn=tf_keras.initializers.he_normal(),\n", " init_bias_fn = tf.zeros_initializer())\n", "\n", " self._dnn = tfn.Sequential([\n", @@ -333,7 +333,7 @@ "\n", " # Convenience function\n", " affine = functools.partial(tfn.Affine, \n", - " init_kernel_fn=tf.initializers.he_normal(),\n", + " init_kernel_fn=tf_keras.initializers.he_normal(),\n", " init_bias_fn = tf.zeros_initializer())\n", "\n", " # DNN is just an affine transformation for the decoder\n", @@ -475,7 +475,7 @@ " beta=beta,\n", " seed=seedstream)\n", "\n", - "opt = tf.optimizers.Adam(lr)\n", + "opt = tf_keras.optimizers.Adam(lr)\n", "train_op = tfn.util.make_fit_op(\n", " loss_fn=loss_fn, optimizer=opt,\n", " trainable_variables=loss_fn.trainable_variables,\n", @@ -675,7 +675,7 @@ " beta=beta,\n", " seed=seedstream)\n", "\n", - " opt = tf.optimizers.Adam(lr)\n", + " opt = tf_keras.optimizers.Adam(lr)\n", " train_op = tfn.util.make_fit_op(\n", " loss_fn=loss_fn, optimizer=opt,\n", " trainable_variables=loss_fn.trainable_variables,\n", diff --git a/tensorflow_probability/python/experimental/nn/examples/vae_mnist_advi.ipynb b/tensorflow_probability/python/experimental/nn/examples/vae_mnist_advi.ipynb index a8359220d6..c55819d5ee 100644 --- a/tensorflow_probability/python/experimental/nn/examples/vae_mnist_advi.ipynb +++ b/tensorflow_probability/python/experimental/nn/examples/vae_mnist_advi.ipynb @@ -240,7 +240,7 @@ "source": [ "Conv = functools.partial(\n", " tfn.Convolution,\n", - " init_kernel_fn=tf.initializers.he_uniform()) # Better for leaky_relu.\n", + " init_kernel_fn=tf_keras.initializers.he_uniform()) # Better for leaky_relu.\n", "\n", "encoder = tfn.Sequential([\n", " lambda x: 2. * tf.cast(x, tf.float32) - 1., # Center.\n", @@ -303,7 +303,7 @@ "source": [ "DeConv = functools.partial(\n", " tfn.ConvolutionTranspose,\n", - " init_kernel_fn=tf.initializers.he_uniform()) # Better for leaky_relu.\n", + " init_kernel_fn=tf_keras.initializers.he_uniform()) # Better for leaky_relu.\n", " \n", "decoder = tfn.Sequential([\n", " lambda x: x[..., tf.newaxis, tf.newaxis, :],\n", @@ -380,7 +380,7 @@ " loss, (nll, kl), _ = compute_loss(x)\n", " return loss, (nll, kl)\n", "\n", - "opt = tf.optimizers.Adam(learning_rate=1e-3)\n", + "opt = tf_keras.optimizers.Adam(learning_rate=1e-3)\n", "\n", "fit = tfn.util.make_fit_op(\n", " loss,\n", diff --git a/tensorflow_probability/python/experimental/nn/examples/vib_dose.ipynb b/tensorflow_probability/python/experimental/nn/examples/vib_dose.ipynb index 2d5c5c7430..2b717f6f81 100644 --- a/tensorflow_probability/python/experimental/nn/examples/vib_dose.ipynb +++ b/tensorflow_probability/python/experimental/nn/examples/vib_dose.ipynb @@ -275,7 +275,7 @@ "Conv = functools.partial(\n", " tfn.Convolution,\n", " init_bias_fn=tf.zeros_initializer(),\n", - " init_kernel_fn=tf.initializers.he_uniform()) # Better for leaky_relu.\n", + " init_kernel_fn=tf_keras.initializers.he_uniform()) # Better for leaky_relu.\n", "\n", "encoder = tfn.Sequential([\n", " lambda x: 2. * tf.cast(x, tf.float32) - 1., # Center.\n", @@ -326,11 +326,11 @@ "source": [ "DeConv = functools.partial(\n", " tfn.ConvolutionTranspose,\n", - " init_kernel_fn=tf.initializers.he_uniform()) # Better for leaky_relu.\n", + " init_kernel_fn=tf_keras.initializers.he_uniform()) # Better for leaky_relu.\n", " \n", "Affine = functools.partial(\n", " tfn.Affine,\n", - " init_kernel_fn=tf.initializers.he_uniform())\n", + " init_kernel_fn=tf_keras.initializers.he_uniform())\n", "\n", "decoder = tfn.Sequential([\n", " Affine(encoded_size, 10),\n", @@ -390,7 +390,7 @@ " loss, (nll, kl), _ = compute_loss(x, y, beta=0.075)\n", " return loss, (nll, kl)\n", "\n", - "opt = tf.optimizers.Adam(learning_rate=1e-3, decay=0.00005)\n", + "opt = tf_keras.optimizers.Adam(learning_rate=1e-3, decay=0.00005)\n", "\n", "fit = tfn.util.make_fit_op(\n", " loss,\n", diff --git a/tensorflow_probability/python/experimental/nn/util/BUILD b/tensorflow_probability/python/experimental/nn/util/BUILD index 99e0c450c2..64e7557c72 100644 --- a/tensorflow_probability/python/experimental/nn/util/BUILD +++ b/tensorflow_probability/python/experimental/nn/util/BUILD @@ -49,6 +49,7 @@ py_test( # tensorflow dep, "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -68,6 +69,7 @@ py_library( "//tensorflow_probability/python/distributions:sample", "//tensorflow_probability/python/experimental/nn/initializers:initializers_lib", "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -131,5 +133,6 @@ py_library( "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) diff --git a/tensorflow_probability/python/experimental/nn/util/convolution_util_test.py b/tensorflow_probability/python/experimental/nn/util/convolution_util_test.py index 7028a1e949..d86b31a3cd 100644 --- a/tensorflow_probability/python/experimental/nn/util/convolution_util_test.py +++ b/tensorflow_probability/python/experimental/nn/util/convolution_util_test.py @@ -24,7 +24,7 @@ from tensorflow_probability.python.experimental.nn.util import convolution_util from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import test_util - +from tensorflow_probability.python.internal import tf_keras # pylint: disable=bad-whitespace _CONV_TEST_CASES = ( @@ -374,7 +374,7 @@ def test_works_like_conv2d_transpose( perm=[0, 1, 3, 2]) # conv2d_transpose does not support dilations > 1; use Keras instead. if any(d > 1 for d in dilations): - keras_convt = tf.keras.layers.Conv2DTranspose( + keras_convt = tf_keras.layers.Conv2DTranspose( filters=channels_out, kernel_size=filter_shape, strides=strides, diff --git a/tensorflow_probability/python/experimental/nn/util/kernel_bias.py b/tensorflow_probability/python/experimental/nn/util/kernel_bias.py index e365aa8def..5b24b5002d 100644 --- a/tensorflow_probability/python/experimental/nn/util/kernel_bias.py +++ b/tensorflow_probability/python/experimental/nn/util/kernel_bias.py @@ -1,3 +1,4 @@ + # Copyright 2020 The TensorFlow Probability Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +29,7 @@ from tensorflow_probability.python.distributions.sample import Sample from tensorflow_probability.python.experimental.nn import initializers as nn_init_lib from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.util.deferred_tensor import TransformedVariable @@ -58,9 +60,9 @@ def make_kernel_bias( kernel_shape: ... bias_shape: ... kernel_initializer: ... - Default value: `None` (i.e., `tf.initializers.glorot_uniform()`). + Default value: `None` (i.e., `tf_keras.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). kernel_batch_ndims: ... Default value: `0`. bias_batch_ndims: ... @@ -79,13 +81,13 @@ def make_kernel_bias( #### Recommendations: ```python - # tf.nn.relu ==> tf.initializers.he_* - # tf.nn.elu ==> tf.initializers.he_* - # tf.nn.selu ==> tf.initializers.lecun_* - # tf.nn.tanh ==> tf.initializers.glorot_* - # tf.nn.sigmoid ==> tf.initializers.glorot_* - # tf.nn.softmax ==> tf.initializers.glorot_* - # None ==> tf.initializers.glorot_* + # tf.nn.relu ==> tf_keras.initializers.he_* + # tf.nn.elu ==> tf_keras.initializers.he_* + # tf.nn.selu ==> tf_keras.initializers.lecun_* + # tf.nn.tanh ==> tf_keras.initializers.glorot_* + # tf.nn.sigmoid ==> tf_keras.initializers.glorot_* + # tf.nn.softmax ==> tf_keras.initializers.glorot_* + # None ==> tf_keras.initializers.glorot_* # https://towardsdatascience.com/hyper-parameters-in-action-part-ii-weight-initializers-35aee1a28404 # https://stats.stackexchange.com/a/393012/1835 @@ -94,7 +96,7 @@ def make_uniform(size): return tfd.Uniform(low=-s, high=s) def make_normal(size): - # Constant is: `scipy.stats.truncnorm.var(loc=0., scale=1., a=-2., b=2.)`. + # Constant is: `scipy.stats.truncnorm.std(loc=0., scale=1., a=-2., b=2.)`. s = tf.math.rsqrt(size) / 0.87962566103423978 return tfd.TruncatedNormal(loc=0, scale=s, low=-2., high=2.) @@ -112,7 +114,7 @@ def make_normal(size): if kernel_initializer is None: kernel_initializer = nn_init_lib.glorot_uniform() if bias_initializer is None: - bias_initializer = tf.initializers.zeros() + bias_initializer = tf_keras.initializers.zeros() return ( tf.Variable(_try_call_init_fn(kernel_initializer, kernel_shape, @@ -156,9 +158,9 @@ def make_kernel_bias_prior_spike_and_slab( kernel_shape: ... bias_shape: ... kernel_initializer: Ignored. - Default value: `None` (i.e., `tf.initializers.glorot_uniform()`). + Default value: `None` (i.e., `tf_keras.initializers.glorot_uniform()`). bias_initializer: Ignored. - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). kernel_batch_ndims: ... Default value: `0`. bias_batch_ndims: ... @@ -200,9 +202,9 @@ def make_kernel_bias_posterior_mvn_diag( kernel_shape: ... bias_shape: ... kernel_initializer: ... - Default value: `None` (i.e., `tf.initializers.glorot_uniform()`). + Default value: `None` (i.e., `tf_keras.initializers.glorot_uniform()`). bias_initializer: ... - Default value: `None` (i.e., `tf.initializers.zeros()`). + Default value: `None` (i.e., `tf_keras.initializers.zeros()`). kernel_batch_ndims: ... Default value: `0`. bias_batch_ndims: ... @@ -220,7 +222,7 @@ def make_kernel_bias_posterior_mvn_diag( if kernel_initializer is None: kernel_initializer = nn_init_lib.glorot_uniform() if bias_initializer is None: - bias_initializer = tf.initializers.zeros() + bias_initializer = tf_keras.initializers.zeros() make_loc = lambda init_fn, shape, batch_ndims, name: tf.Variable( # pylint: disable=g-long-lambda _try_call_init_fn(init_fn, shape, dtype, batch_ndims), name=name + '_loc') diff --git a/tensorflow_probability/python/experimental/nn/util/utils.py b/tensorflow_probability/python/experimental/nn/util/utils.py index 1e60503682..c502298721 100644 --- a/tensorflow_probability/python/experimental/nn/util/utils.py +++ b/tensorflow_probability/python/experimental/nn/util/utils.py @@ -249,7 +249,7 @@ def make_fit_op(loss_fn, optimizer, trainable_variables, loss_fn: Python `callable` which returns the pair `loss` (`tf.Tensor`) and any other second result such that `tf.nest.map_structure(tf.convert_to_tensor, other)` will succeed. - optimizer: `tf.optimizers.Optimizer`-like instance which has members + optimizer: `tf_keras.optimizers.Optimizer`-like instance which has members `gradient` and `apply_gradients`. trainable_variables: `tf.nest.flatten`-able structure of `tf.Variable` instances. diff --git a/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py b/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py index ae40d42794..838558fc0f 100644 --- a/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py +++ b/tensorflow_probability/python/experimental/parallel_filter/parallel_kalman_filter_lib.py @@ -21,6 +21,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.distributions import mvn_tril +from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.math import linalg @@ -625,11 +626,8 @@ def kalman_filter(transition_matrix, axis=0), added_cov=time_dep.observation_cov) - # TODO(srvasude): The JVP for this can be implemented more efficiently. - log_likelihoods = mvn_tril.MultivariateNormalTriL( - loc=observation_means, - scale_tril=tf.linalg.cholesky(observation_covs)).log_prob( - observation.y) + log_likelihoods = _mvn_log_prob( + observation_means, observation_covs, observation.y) if observation.mask is not None: log_likelihoods = tf.where(observation.mask, tf.zeros([], dtype=log_likelihoods.dtype), @@ -644,6 +642,17 @@ def kalman_filter(transition_matrix, observation_covs) +def _mvn_log_prob(mean, covariance, y): + cholesky_matrix = tf.linalg.cholesky(covariance) + log_prob = -0.5 * linalg.hpsd_quadratic_form_solvevec( + covariance, y - mean, cholesky_matrix=cholesky_matrix) + log_prob = log_prob - 0.5 * linalg.hpsd_logdet( + covariance, cholesky_matrix=cholesky_matrix) + event_dims = ps.shape(mean)[-1] + return log_prob - 0.5 * event_dims * dtype_util.as_numpy_dtype( + mean.dtype)(np.log(2 * np.pi)) + + def _extract_batch_shape(x, sample_ndims, event_ndims): """Slice out the batch component of `x`'s shape.""" if x is None: diff --git a/tensorflow_probability/python/experimental/psd_kernels/additive_kernel_test.py b/tensorflow_probability/python/experimental/psd_kernels/additive_kernel_test.py index 0e685e3e6b..5ea56cc61a 100644 --- a/tensorflow_probability/python/experimental/psd_kernels/additive_kernel_test.py +++ b/tensorflow_probability/python/experimental/psd_kernels/additive_kernel_test.py @@ -139,7 +139,7 @@ def testMatrixValuesAreCorrect( amplitudes, length_scale, dim, x, y, method='matrix') self.assertAllClose( - self.evaluate(actual), self.evaluate(expected), rtol=1e-5) + self.evaluate(actual), self.evaluate(expected), rtol=3e-5) @test_util.disable_test_for_backend( disable_numpy=True, diff --git a/tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab_test.py b/tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab_test.py index d33a3e691b..e2c4ad8237 100644 --- a/tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab_test.py +++ b/tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab_test.py @@ -145,7 +145,8 @@ def test_posterior_on_nonzero_subset_matches_bayesian_regression( self.assertAllClose( nonzero_subvector(self.evaluate( initial_state.conditional_weights_mean)), - restricted_weights_posterior_mean) + restricted_weights_posterior_mean, + atol=5e-5) self.assertAllClose( nonzero_submatrix(initial_state.conditional_posterior_precision_chol), tf.linalg.cholesky(restricted_weights_posterior_prec.to_dense())) @@ -346,7 +347,7 @@ def loop_body(var_weights_seed, _): tf.float32) self.assertAllClose(nonzero_prior_prob, tf.reduce_mean(nonzero_weight_samples), - atol=0.03) + atol=0.04) @parameterized.named_parameters(('', False), ('_xla', True)) def test_deterministic_given_seed(self, use_xla): diff --git a/tensorflow_probability/python/experimental/util/BUILD b/tensorflow_probability/python/experimental/util/BUILD index ba24229343..464e167ce6 100644 --- a/tensorflow_probability/python/experimental/util/BUILD +++ b/tensorflow_probability/python/experimental/util/BUILD @@ -149,6 +149,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:wishart", "//tensorflow_probability/python/internal:structural_tuple", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:gradient", "//tensorflow_probability/python/math:minimize", ], diff --git a/tensorflow_probability/python/experimental/util/trainable.py b/tensorflow_probability/python/experimental/util/trainable.py index d668ffae8f..6dea8680db 100644 --- a/tensorflow_probability/python/experimental/util/trainable.py +++ b/tensorflow_probability/python/experimental/util/trainable.py @@ -185,7 +185,7 @@ def _make_trainable(cls, model = tfp.util.make_trainable(tfd.Normal) losses = tfp.math.minimize( lambda: -model.log_prob(samples), - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), num_steps=200) print('Fit Normal distribution with mean {} and stddev {}'.format( model.mean(), diff --git a/tensorflow_probability/python/experimental/util/trainable_test.py b/tensorflow_probability/python/experimental/util/trainable_test.py index e9ed422207..c9e23aae6b 100644 --- a/tensorflow_probability/python/experimental/util/trainable_test.py +++ b/tensorflow_probability/python/experimental/util/trainable_test.py @@ -35,6 +35,7 @@ from tensorflow_probability.python.experimental.util import trainable from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math import gradient from tensorflow_probability.python.math.minimize import minimize from tensorflow_probability.python.math.minimize import minimize_stateless @@ -198,7 +199,7 @@ def test_docstring_example_normal(self): normal.Normal, seed=test_util.test_seed(sampler_type='stateless')) losses = minimize( lambda: -model.log_prob(samples), - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), num_steps=200) self.evaluate(tf1.global_variables_initializer()) self.evaluate(losses) diff --git a/tensorflow_probability/python/experimental/vi/BUILD b/tensorflow_probability/python/experimental/vi/BUILD index 70de67fb3c..b21fb43958 100644 --- a/tensorflow_probability/python/experimental/vi/BUILD +++ b/tensorflow_probability/python/experimental/vi/BUILD @@ -69,6 +69,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:samplers", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/internal:trainable_state_util", "//tensorflow_probability/python/util", ], @@ -141,6 +142,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/internal:custom_gradient", "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:gradient", "//tensorflow_probability/python/math:minimize", "//tensorflow_probability/python/vi:optimization", @@ -180,6 +182,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:student_t", "//tensorflow_probability/python/experimental/vi/util", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/vi:optimization", ], ) diff --git a/tensorflow_probability/python/experimental/vi/automatic_structured_vi.py b/tensorflow_probability/python/experimental/vi/automatic_structured_vi.py index 9f6b4e0e24..df713787d4 100644 --- a/tensorflow_probability/python/experimental/vi/automatic_structured_vi.py +++ b/tensorflow_probability/python/experimental/vi/automatic_structured_vi.py @@ -497,7 +497,7 @@ def model_fn(): target_log_prob_fn, surrogate_posterior=surrogate_posterior, num_steps=100, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), sample_size=10) # After optimization, samples from the surrogate will approximate @@ -509,7 +509,7 @@ def model_fn(): #### References - [1]: Luca Ambrogioni, Kate Line, Emily Fertig, Sharad Vikram, Max Hinne, + [1]: Luca Ambrogioni, Kate Lin, Emily Fertig, Sharad Vikram, Max Hinne, Dave Moore, Marcel van Gerven. Automatic structured variational inference. _arXiv preprint arXiv:2002.00643_, 2020 https://arxiv.org/abs/2002.00643 diff --git a/tensorflow_probability/python/experimental/vi/automatic_structured_vi_test.py b/tensorflow_probability/python/experimental/vi/automatic_structured_vi_test.py index 9d94e3dcfd..e9287768dd 100644 --- a/tensorflow_probability/python/experimental/vi/automatic_structured_vi_test.py +++ b/tensorflow_probability/python/experimental/vi/automatic_structured_vi_test.py @@ -48,6 +48,7 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math import gradient from tensorflow_probability.python.math.minimize import minimize_stateless from tensorflow_probability.python.vi import optimization @@ -239,7 +240,7 @@ def test_fitting_surrogate_posterior(self, dtype): target_log_prob, surrogate_posterior, num_steps=3, # Don't optimize to completion. - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), sample_size=5) # Compute posterior statistics. diff --git a/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py b/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py index 7c6647c3a4..6ff693367d 100644 --- a/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py +++ b/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py @@ -153,7 +153,7 @@ def model_fn(): lambda rate, concentration: model.log_prob([rate, concentration, y]), surrogate_posterior=surrogate_posterior, num_steps=100, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), sample_size=10) # After optimization, samples from the surrogate will approximate @@ -350,7 +350,7 @@ def model_fn(): target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), sample_size=10) ``` """ @@ -532,7 +532,7 @@ def model_fn(): target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), sample_size=10) ``` @@ -728,7 +728,7 @@ def build_split_flow_surrogate_posterior( target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), sample_size=10) ``` diff --git a/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py b/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py index b6298f6255..bfe84b0bb2 100644 --- a/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py +++ b/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py @@ -44,6 +44,7 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.vi import optimization from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import @@ -131,7 +132,7 @@ def _test_fitting(self, model, surrogate_posterior): lambda rate, concentration: model.log_prob((rate, concentration, y)), surrogate_posterior, num_steps=5, # Don't optimize to completion. - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), sample_size=10) # Compute posterior statistics. diff --git a/tensorflow_probability/python/internal/BUILD b/tensorflow_probability/python/internal/BUILD index 8ab50c16d2..2182c40599 100644 --- a/tensorflow_probability/python/internal/BUILD +++ b/tensorflow_probability/python/internal/BUILD @@ -15,6 +15,8 @@ # Description: # Internal utilities for TensorFlow probability. +# [internal] load pytype.bzl (pytype_strict_test) +# [internal] load strict.bzl # Placeholder: py_library # Placeholder: py_test load( @@ -22,8 +24,6 @@ load( "multi_substrate_py_library", "multi_substrate_py_test", ) -# [internal] load pytype.bzl (pytype_strict_test) -# [internal] load strict.bzl licenses(["notice"]) @@ -71,6 +71,7 @@ py_test( # absl/testing:parameterized dep, # tensorflow dep, "//tensorflow_probability/python/bijectors:bijector", + "//tensorflow_probability/python/bijectors:bijector_test_util", "//tensorflow_probability/python/bijectors:reshape", "//tensorflow_probability/python/bijectors:scale", "//tensorflow_probability/python/bijectors:shift", @@ -652,6 +653,7 @@ multi_substrate_py_test( srcs = ["trainable_state_util_test.py"], jax_size = "medium", numpy_tags = ["notap"], + tf_tags = ["no-oss-ci"], # TODO(b/308579205) deps = [ # optax dep, # tensorflow dep, @@ -665,6 +667,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/experimental/util", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/internal:trainable_state_util", "//tensorflow_probability/python/math:gradient", "//tensorflow_probability/python/math:minimize", @@ -920,3 +923,11 @@ exports_files( # "//tensorflow_probability/google:friends", # DisableOnExport ], ) + +py_library( + name = "tf_keras", + srcs = ["tf_keras.py"], + deps = [ + # tensorflow dep, + ], +) diff --git a/tensorflow_probability/python/internal/auto_composite_tensor.py b/tensorflow_probability/python/internal/auto_composite_tensor.py index 84fe461371..a61aa3638b 100644 --- a/tensorflow_probability/python/internal/auto_composite_tensor.py +++ b/tensorflow_probability/python/internal/auto_composite_tensor.py @@ -42,6 +42,29 @@ _DEFERRED_ASSERTION_CONTEXT.is_deferred = False +def is_composite_tensor(value): + """Returns True for CTs and non-CT custom pytrees in JAX mode. + + Args: + value: A TFP component (e.g. a distribution or bijector instance) or object + that behaves as one. + + Returns: + value_is_composite: bool, True if `value` is a `CompositeTensor` in TF mode + or a non-leaf pytree in JAX mode. + """ + if isinstance(value, composite_tensor.CompositeTensor): + return True + if JAX_MODE: + from jax import tree_util # pylint: disable=g-import-not-at-top + # If `value` is not a pytree leaf, then it must be an instance of a class + # that was specially registered as a pytree or that inherits from a class + # representing a nested structure. + treedef = tree_util.tree_structure(value) + return not tree_util.treedef_is_leaf(treedef) + return False + + def is_deferred_assertion_context(): return getattr(_DEFERRED_ASSERTION_CONTEXT, 'is_deferred', False) @@ -132,15 +155,16 @@ def _extract_type_spec_recursively(value): `value` is a collection containing `Tensor` values, recursively supplant them with their respective `TypeSpec`s in a collection of parallel stucture. - If `value` is nont of the above, return it unchanged. + If `value` is none of the above, return it unchanged. Args: value: a Python `object` to (possibly) turn into a (collection of) `tf.TypeSpec`(s). Returns: - spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value` - or `value`, if no `Tensor`s are found. + spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value`; + `value`, if no `Tensor`s are found; or `None` to indicate that `value` is + registered as a JAX pytree. """ if isinstance(value, composite_tensor.CompositeTensor): return value._type_spec # pylint: disable=protected-access @@ -161,6 +185,14 @@ def _extract_type_spec_recursively(value): 'Found `{}` with both Tensor and non-Tensor parts: {}'.format( type(value), value)) return specs + elif JAX_MODE: # Handle custom pytrees. + from jax import tree_util # pylint: disable=g-import-not-at-top + treedef = tree_util.tree_structure(value) + # Return None so that the object identity comparison in + # `_AutoCompositeTensorTypeSpec.from_instance` is False, indicating that + # `value` should be treated as a "Tensor" param. + if not tree_util.treedef_is_leaf(treedef): + return None return value diff --git a/tensorflow_probability/python/internal/backend/jax/BUILD b/tensorflow_probability/python/internal/backend/jax/BUILD index 15befd719e..f7041476c4 100644 --- a/tensorflow_probability/python/internal/backend/jax/BUILD +++ b/tensorflow_probability/python/internal/backend/jax/BUILD @@ -79,12 +79,8 @@ FILENAMES = [ GEN_FILENAMES = [ "gen/__init__", "gen/tensor_shape", - "gen/adjoint_registrations", - "gen/cholesky_registrations", - "gen/inverse_registrations", "gen/linear_operator_addition", "gen/linear_operator_adjoint", - "gen/linear_operator_algebra", "gen/linear_operator_block_diag", "gen/linear_operator_block_lower_triangular", "gen/linear_operator_full_matrix", @@ -102,10 +98,8 @@ GEN_FILENAMES = [ "gen/linear_operator_toeplitz", "gen/linear_operator_util", "gen/linear_operator_zeros", - "gen/matmul_registrations", - "gen/registrations_util", + "gen/property_hint_util", "gen/slicing", - "gen/solve_registrations", ] [ diff --git a/tensorflow_probability/python/internal/backend/meta/gen_linear_operators.py b/tensorflow_probability/python/internal/backend/meta/gen_linear_operators.py index 7c915ceeaa..5c2cc38110 100644 --- a/tensorflow_probability/python/internal/backend/meta/gen_linear_operators.py +++ b/tensorflow_probability/python/internal/backend/meta/gen_linear_operators.py @@ -53,8 +53,9 @@ COMMENT_OUT = [ 'from tensorflow.python.util import dispatch', 'from tensorflow.python.util.tf_export', - 'from tensorflow.python.framework import ' + - 'tensor_conversion', + 'from tensorflow.python.framework import tensor\n', + 'from tensorflow.python.framework import ' + + 'tensor_conversion', 'from tensorflow.python.framework import tensor_util', '@tf_export', '@dispatch', @@ -195,6 +196,7 @@ def gen_module(module_name): 'np.issubdtype(\\1, np.complexfloating)', code) code = re.sub(r'([_a-zA-Z0-9.\[\]]+).is_integer', 'np.issubdtype(\\1, np.integer)', code) + code = code.replace('tensor.Tensor', 'np.ndarray') code = code.replace('array_ops.shape', 'prefer_static.shape') code = code.replace('array_ops.concat', 'prefer_static.concat') diff --git a/tensorflow_probability/python/internal/backend/numpy/BUILD b/tensorflow_probability/python/internal/backend/numpy/BUILD index 01435c76e0..48f890c4a0 100644 --- a/tensorflow_probability/python/internal/backend/numpy/BUILD +++ b/tensorflow_probability/python/internal/backend/numpy/BUILD @@ -479,7 +479,9 @@ py_test( "--test_mode=xla", # TODO(b/168718272): reduce_*([nan, nan], axis=0) (GPU) # histogram_fixed_width_bins fails with f32([0.]), [0.0, 0.0], 2 - "--xla_disabled=math.cumulative_logsumexp,math.reduce_min,math.reduce_max,histogram_fixed_width_bins", + ("--xla_disabled=math.cumulative_logsumexp,math.reduce_min,math.reduce_max,histogram_fixed_width_bins," + + # TODO(b/298426124): TF floomod GPU bug + "math.floormod"), ], main = "numpy_test.py", shard_count = 11, @@ -535,12 +537,8 @@ py_library( ) LINOP_FILES = [ - "adjoint_registrations", - "cholesky_registrations", - "inverse_registrations", "linear_operator_addition", "linear_operator_adjoint", - "linear_operator_algebra", "linear_operator_block_diag", "linear_operator_block_lower_triangular", "linear_operator_circulant", @@ -558,10 +556,8 @@ LINOP_FILES = [ "linear_operator_toeplitz", "linear_operator_util", "linear_operator_zeros", - "matmul_registrations", - "registrations_util", + "property_hint_util", "slicing", - "solve_registrations", ] [genrule( diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/adjoint_registrations.py b/tensorflow_probability/python/internal/backend/numpy/gen/adjoint_registrations.py deleted file mode 100644 index b34f66c3a5..0000000000 --- a/tensorflow_probability/python/internal/backend/numpy/gen/adjoint_registrations.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright 2020 The TensorFlow Probability Authors. All Rights Reserved. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# THIS FILE IS AUTO-GENERATED BY `gen_linear_operators.py`. -# DO NOT MODIFY DIRECTLY. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# pylint: disable=g-import-not-at-top -# pylint: disable=g-direct-tensorflow-import -# pylint: disable=g-bad-import-order -# pylint: disable=unused-import -# pylint: disable=line-too-long -# pylint: disable=reimported -# pylint: disable=g-bool-id-comparison -# pylint: disable=g-statement-before-imports -# pylint: disable=bad-continuation -# pylint: disable=useless-import-alias -# pylint: disable=property-with-parameters -# pylint: disable=trailing-whitespace -# pylint: disable=g-inconsistent-quotes - -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Registrations for LinearOperator.adjoint.""" - -from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_adjoint -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_block_diag -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_circulant -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_diag -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_householder -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_identity -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_kronecker - - -# By default, return LinearOperatorAdjoint which switched the .matmul -# and .solve methods. -@linear_operator_algebra.RegisterAdjoint(linear_operator.LinearOperator) -def _adjoint_linear_operator(linop): - return linear_operator_adjoint.LinearOperatorAdjoint( - linop, - is_non_singular=linop.is_non_singular, - is_self_adjoint=linop.is_self_adjoint, - is_positive_definite=linop.is_positive_definite, - is_square=linop.is_square) - - -@linear_operator_algebra.RegisterAdjoint( - linear_operator_adjoint.LinearOperatorAdjoint) -def _adjoint_adjoint_linear_operator(linop): - return linop.operator - - -@linear_operator_algebra.RegisterAdjoint( - linear_operator_identity.LinearOperatorIdentity) -def _adjoint_identity(identity_operator): - return identity_operator - - -@linear_operator_algebra.RegisterAdjoint( - linear_operator_identity.LinearOperatorScaledIdentity) -def _adjoint_scaled_identity(identity_operator): - multiplier = identity_operator.multiplier - if np.issubdtype(multiplier.dtype, np.complexfloating): - multiplier = math_ops.conj(multiplier) - - return linear_operator_identity.LinearOperatorScaledIdentity( - num_rows=identity_operator._num_rows, # pylint: disable=protected-access - multiplier=multiplier, - is_non_singular=identity_operator.is_non_singular, - is_self_adjoint=identity_operator.is_self_adjoint, - is_positive_definite=identity_operator.is_positive_definite, - is_square=True) - - -@linear_operator_algebra.RegisterAdjoint( - linear_operator_diag.LinearOperatorDiag) -def _adjoint_diag(diag_operator): - diag = diag_operator.diag - if np.issubdtype(diag.dtype, np.complexfloating): - diag = math_ops.conj(diag) - - return linear_operator_diag.LinearOperatorDiag( - diag=diag, - is_non_singular=diag_operator.is_non_singular, - is_self_adjoint=diag_operator.is_self_adjoint, - is_positive_definite=diag_operator.is_positive_definite, - is_square=True) - - -@linear_operator_algebra.RegisterAdjoint( - linear_operator_block_diag.LinearOperatorBlockDiag) -def _adjoint_block_diag(block_diag_operator): - # We take the adjoint of each block on the diagonal. - return linear_operator_block_diag.LinearOperatorBlockDiag( - operators=[ - operator.adjoint() for operator in block_diag_operator.operators], - is_non_singular=block_diag_operator.is_non_singular, - is_self_adjoint=block_diag_operator.is_self_adjoint, - is_positive_definite=block_diag_operator.is_positive_definite, - is_square=True) - - -@linear_operator_algebra.RegisterAdjoint( - linear_operator_kronecker.LinearOperatorKronecker) -def _adjoint_kronecker(kronecker_operator): - # Adjoint of a Kronecker product is the Kronecker product - # of adjoints. - return linear_operator_kronecker.LinearOperatorKronecker( - operators=[ - operator.adjoint() for operator in kronecker_operator.operators], - is_non_singular=kronecker_operator.is_non_singular, - is_self_adjoint=kronecker_operator.is_self_adjoint, - is_positive_definite=kronecker_operator.is_positive_definite, - is_square=True) - - -@linear_operator_algebra.RegisterAdjoint( - linear_operator_circulant._BaseLinearOperatorCirculant) # pylint: disable=protected-access -def _adjoint_circulant(circulant_operator): - spectrum = circulant_operator.spectrum - if np.issubdtype(spectrum.dtype, np.complexfloating): - spectrum = math_ops.conj(spectrum) - - # Conjugating the spectrum is sufficient to get the adjoint. - return circulant_operator.__class__( - spectrum=spectrum, - is_non_singular=circulant_operator.is_non_singular, - is_self_adjoint=circulant_operator.is_self_adjoint, - is_positive_definite=circulant_operator.is_positive_definite, - is_square=True) - - -@linear_operator_algebra.RegisterAdjoint( - linear_operator_householder.LinearOperatorHouseholder) -def _adjoint_householder(householder_operator): - return householder_operator - -import numpy as np -from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg -from tensorflow_probability.python.internal.backend.numpy import ops as _ops -from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape - -from tensorflow_probability.python.internal.backend.numpy import private -distribution_util = private.LazyLoader( - "distribution_util", globals(), - "tensorflow_probability.substrates.numpy.internal.distribution_util") -tensorshape_util = private.LazyLoader( - "tensorshape_util", globals(), - "tensorflow_probability.substrates.numpy.internal.tensorshape_util") -prefer_static = private.LazyLoader( - "prefer_static", globals(), - "tensorflow_probability.substrates.numpy.internal.prefer_static") - diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/cholesky_registrations.py b/tensorflow_probability/python/internal/backend/numpy/gen/cholesky_registrations.py deleted file mode 100644 index 1260abcf83..0000000000 --- a/tensorflow_probability/python/internal/backend/numpy/gen/cholesky_registrations.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright 2020 The TensorFlow Probability Authors. All Rights Reserved. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# THIS FILE IS AUTO-GENERATED BY `gen_linear_operators.py`. -# DO NOT MODIFY DIRECTLY. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# pylint: disable=g-import-not-at-top -# pylint: disable=g-direct-tensorflow-import -# pylint: disable=g-bad-import-order -# pylint: disable=unused-import -# pylint: disable=line-too-long -# pylint: disable=reimported -# pylint: disable=g-bool-id-comparison -# pylint: disable=g-statement-before-imports -# pylint: disable=bad-continuation -# pylint: disable=useless-import-alias -# pylint: disable=property-with-parameters -# pylint: disable=trailing-whitespace -# pylint: disable=g-inconsistent-quotes - -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Registrations for LinearOperator.cholesky.""" - -from tensorflow_probability.python.internal.backend.numpy import numpy_array as array_ops -from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg_ops -from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_block_diag -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_composition -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_diag -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_identity -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_kronecker -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_lower_triangular -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util - -LinearOperatorLowerTriangular = ( - linear_operator_lower_triangular.LinearOperatorLowerTriangular) - - -# By default, compute the Cholesky of the dense matrix, and return a -# LowerTriangular operator. Methods below specialize this registration. -@linear_operator_algebra.RegisterCholesky(linear_operator.LinearOperator) -def _cholesky_linear_operator(linop): - return LinearOperatorLowerTriangular( - linalg_ops.cholesky(linop.to_dense()), - is_non_singular=True, - is_self_adjoint=False, - is_square=True) - - -def _is_llt_product(linop): - """Determines if linop = L @ L.H for L = LinearOperatorLowerTriangular.""" - if len(linop.operators) != 2: - return False - if not linear_operator_util.is_aat_form(linop.operators): - return False - return isinstance(linop.operators[0], LinearOperatorLowerTriangular) - - -@linear_operator_algebra.RegisterCholesky( - linear_operator_composition.LinearOperatorComposition) -def _cholesky_linear_operator_composition(linop): - """Computes Cholesky(LinearOperatorComposition).""" - # L @ L.H will be handled with special code below. Why is L @ L.H the most - # important special case? - # Note that Diag @ Diag.H and Diag @ TriL and TriL @ Diag are already - # compressed to Diag or TriL by diag matmul - # registration. Similarly for Identity and ScaledIdentity. - # So these would not appear in a LinearOperatorComposition unless explicitly - # constructed as such. So the most important thing to check is L @ L.H. - if not _is_llt_product(linop): - return LinearOperatorLowerTriangular( - linalg_ops.cholesky(linop.to_dense()), - is_non_singular=True, - is_self_adjoint=False, - is_square=True) - - left_op = linop.operators[0] - - # left_op.is_positive_definite ==> op already has positive diag. So return it. - if left_op.is_positive_definite: - return left_op - - # Recall that the base class has already verified linop.is_positive_definite, - # else linop.cholesky() would have raised. - # So in particular, we know the diagonal has nonzero entries. - # In the generic case, we make op have positive diag by dividing each row - # by the sign of the diag. This is equivalent to setting A = L @ D where D is - # diag(sign(1 / L.diag_part())). Then A is lower triangular with positive diag - # and A @ A^H = L @ D @ D^H @ L^H = L @ L^H = linop. - # This also works for complex L, since sign(x + iy) = exp(i * angle(x + iy)). - diag_sign = array_ops.expand_dims(math_ops.sign(left_op.diag_part()), axis=-2) - return LinearOperatorLowerTriangular( - tril=left_op.tril / diag_sign, - is_non_singular=left_op.is_non_singular, - # L.is_self_adjoint ==> L is diagonal ==> L @ D is diagonal ==> SA - # L.is_self_adjoint is False ==> L not diagonal ==> L @ D not diag ... - is_self_adjoint=left_op.is_self_adjoint, - # L.is_positive_definite ==> L has positive diag ==> L = L @ D - # ==> (L @ D).is_positive_definite. - # L.is_positive_definite is False could result in L @ D being PD or not.. - # Consider L = [[1, 0], [-2, 1]] and quadratic form with x = [1, 1]. - # Note we will already return left_op if left_op.is_positive_definite - # above, but to be explicit write this below. - is_positive_definite=True if left_op.is_positive_definite else None, - is_square=True, - ) - - -@linear_operator_algebra.RegisterCholesky( - linear_operator_diag.LinearOperatorDiag) -def _cholesky_diag(diag_operator): - return linear_operator_diag.LinearOperatorDiag( - math_ops.sqrt(diag_operator.diag), - is_non_singular=True, - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - -@linear_operator_algebra.RegisterCholesky( - linear_operator_identity.LinearOperatorIdentity) -def _cholesky_identity(identity_operator): - return linear_operator_identity.LinearOperatorIdentity( - num_rows=identity_operator._num_rows, # pylint: disable=protected-access - batch_shape=identity_operator.batch_shape, - dtype=identity_operator.dtype, - is_non_singular=True, - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - -@linear_operator_algebra.RegisterCholesky( - linear_operator_identity.LinearOperatorScaledIdentity) -def _cholesky_scaled_identity(identity_operator): - return linear_operator_identity.LinearOperatorScaledIdentity( - num_rows=identity_operator._num_rows, # pylint: disable=protected-access - multiplier=math_ops.sqrt(identity_operator.multiplier), - is_non_singular=True, - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - -@linear_operator_algebra.RegisterCholesky( - linear_operator_block_diag.LinearOperatorBlockDiag) -def _cholesky_block_diag(block_diag_operator): - # We take the cholesky of each block on the diagonal. - return linear_operator_block_diag.LinearOperatorBlockDiag( - operators=[ - operator.cholesky() for operator in block_diag_operator.operators], - is_non_singular=True, - is_self_adjoint=None, # Let the operators passed in decide. - is_square=True) - - -@linear_operator_algebra.RegisterCholesky( - linear_operator_kronecker.LinearOperatorKronecker) -def _cholesky_kronecker(kronecker_operator): - # Cholesky decomposition of a Kronecker product is the Kronecker product - # of cholesky decompositions. - return linear_operator_kronecker.LinearOperatorKronecker( - operators=[ - operator.cholesky() for operator in kronecker_operator.operators], - is_non_singular=True, - is_self_adjoint=None, # Let the operators passed in decide. - is_square=True) - -import numpy as np -from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg -from tensorflow_probability.python.internal.backend.numpy import ops as _ops -from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape - -from tensorflow_probability.python.internal.backend.numpy import private -distribution_util = private.LazyLoader( - "distribution_util", globals(), - "tensorflow_probability.substrates.numpy.internal.distribution_util") -tensorshape_util = private.LazyLoader( - "tensorshape_util", globals(), - "tensorflow_probability.substrates.numpy.internal.tensorshape_util") -prefer_static = private.LazyLoader( - "prefer_static", globals(), - "tensorflow_probability.substrates.numpy.internal.prefer_static") - diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/inverse_registrations.py b/tensorflow_probability/python/internal/backend/numpy/gen/inverse_registrations.py deleted file mode 100644 index d6549e0d19..0000000000 --- a/tensorflow_probability/python/internal/backend/numpy/gen/inverse_registrations.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright 2020 The TensorFlow Probability Authors. All Rights Reserved. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# THIS FILE IS AUTO-GENERATED BY `gen_linear_operators.py`. -# DO NOT MODIFY DIRECTLY. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# pylint: disable=g-import-not-at-top -# pylint: disable=g-direct-tensorflow-import -# pylint: disable=g-bad-import-order -# pylint: disable=unused-import -# pylint: disable=line-too-long -# pylint: disable=reimported -# pylint: disable=g-bool-id-comparison -# pylint: disable=g-statement-before-imports -# pylint: disable=bad-continuation -# pylint: disable=useless-import-alias -# pylint: disable=property-with-parameters -# pylint: disable=trailing-whitespace -# pylint: disable=g-inconsistent-quotes - -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Registrations for LinearOperator.inverse.""" - -from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_addition -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_block_diag -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_block_lower_triangular -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_circulant -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_diag -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_full_matrix -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_householder -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_identity -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_inversion -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_kronecker - - -# By default, return LinearOperatorInversion which switched the .matmul -# and .solve methods. -@linear_operator_algebra.RegisterInverse(linear_operator.LinearOperator) -def _inverse_linear_operator(linop): - return linear_operator_inversion.LinearOperatorInversion( - linop, - is_non_singular=linop.is_non_singular, - is_self_adjoint=linop.is_self_adjoint, - is_positive_definite=linop.is_positive_definite, - is_square=linop.is_square) - - -@linear_operator_algebra.RegisterInverse( - linear_operator_inversion.LinearOperatorInversion) -def _inverse_inverse_linear_operator(linop_inversion): - return linop_inversion.operator - - -@linear_operator_algebra.RegisterInverse( - linear_operator_diag.LinearOperatorDiag) -def _inverse_diag(diag_operator): - return linear_operator_diag.LinearOperatorDiag( - 1. / diag_operator.diag, - is_non_singular=diag_operator.is_non_singular, - is_self_adjoint=diag_operator.is_self_adjoint, - is_positive_definite=diag_operator.is_positive_definite, - is_square=True) - - -@linear_operator_algebra.RegisterInverse( - linear_operator_identity.LinearOperatorIdentity) -def _inverse_identity(identity_operator): - return identity_operator - - -@linear_operator_algebra.RegisterInverse( - linear_operator_identity.LinearOperatorScaledIdentity) -def _inverse_scaled_identity(identity_operator): - return linear_operator_identity.LinearOperatorScaledIdentity( - num_rows=identity_operator._num_rows, # pylint: disable=protected-access - multiplier=1. / identity_operator.multiplier, - is_non_singular=identity_operator.is_non_singular, - is_self_adjoint=True, - is_positive_definite=identity_operator.is_positive_definite, - is_square=True) - - -@linear_operator_algebra.RegisterInverse( - linear_operator_block_diag.LinearOperatorBlockDiag) -def _inverse_block_diag(block_diag_operator): - # We take the inverse of each block on the diagonal. - return linear_operator_block_diag.LinearOperatorBlockDiag( - operators=[ - operator.inverse() for operator in block_diag_operator.operators], - is_non_singular=block_diag_operator.is_non_singular, - is_self_adjoint=block_diag_operator.is_self_adjoint, - is_positive_definite=block_diag_operator.is_positive_definite, - is_square=True) - - -@linear_operator_algebra.RegisterInverse( - linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular) -def _inverse_block_lower_triangular(block_lower_triangular_operator): - """Inverse of LinearOperatorBlockLowerTriangular. - - We recursively apply the identity: - - ```none - |A 0|' = | A' 0| - |B C| |-C'BA' C'| - ``` - - where `A` is n-by-n, `B` is m-by-n, `C` is m-by-m, and `'` denotes inverse. - - This identity can be verified through multiplication: - - ```none - |A 0|| A' 0| - |B C||-C'BA' C'| - - = | AA' 0| - |BA'-CC'BA' CC'| - - = |I 0| - |0 I| - ``` - - Args: - block_lower_triangular_operator: Instance of - `LinearOperatorBlockLowerTriangular`. - - Returns: - block_lower_triangular_operator_inverse: Instance of - `LinearOperatorBlockLowerTriangular`, the inverse of - `block_lower_triangular_operator`. - """ - if len(block_lower_triangular_operator.operators) == 1: - return (linear_operator_block_lower_triangular. - LinearOperatorBlockLowerTriangular( - [[block_lower_triangular_operator.operators[0][0].inverse()]], - is_non_singular=block_lower_triangular_operator.is_non_singular, - is_self_adjoint=block_lower_triangular_operator.is_self_adjoint, - is_positive_definite=(block_lower_triangular_operator. - is_positive_definite), - is_square=True)) - - blockwise_dim = len(block_lower_triangular_operator.operators) - - # Calculate the inverse of the `LinearOperatorBlockLowerTriangular` - # representing all but the last row of `block_lower_triangular_operator` with - # a recursive call (the matrix `A'` in the docstring definition). - upper_left_inverse = ( - linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular( - block_lower_triangular_operator.operators[:-1]).inverse()) - - bottom_row = block_lower_triangular_operator.operators[-1] - bottom_right_inverse = bottom_row[-1].inverse() - - # Find the bottom row of the inverse (equal to `[-C'BA', C']` in the docstring - # definition, where `C` is the bottom-right operator of - # `block_lower_triangular_operator` and `B` is the set of operators in the - # bottom row excluding `C`). To find `-C'BA'`, we first iterate over the - # column partitions of `A'`. - inverse_bottom_row = [] - for i in range(blockwise_dim - 1): - # Find the `i`-th block of `BA'`. - blocks = [] - for j in range(i, blockwise_dim - 1): - result = bottom_row[j].matmul(upper_left_inverse.operators[j][i]) - if not any(isinstance(result, op_type) - for op_type in linear_operator_addition.SUPPORTED_OPERATORS): - result = linear_operator_full_matrix.LinearOperatorFullMatrix( - result.to_dense()) - blocks.append(result) - - summed_blocks = linear_operator_addition.add_operators(blocks) - assert len(summed_blocks) == 1 - block = summed_blocks[0] - - # Find the `i`-th block of `-C'BA'`. - block = bottom_right_inverse.matmul(block) - block = linear_operator_identity.LinearOperatorScaledIdentity( - num_rows=bottom_right_inverse.domain_dimension_tensor(), - multiplier=_ops.cast(-1, dtype=block.dtype)).matmul(block) - inverse_bottom_row.append(block) - - # `C'` is the last block of the inverted linear operator. - inverse_bottom_row.append(bottom_right_inverse) - - return ( - linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular( - upper_left_inverse.operators + [inverse_bottom_row], - is_non_singular=block_lower_triangular_operator.is_non_singular, - is_self_adjoint=block_lower_triangular_operator.is_self_adjoint, - is_positive_definite=(block_lower_triangular_operator. - is_positive_definite), - is_square=True)) - - -@linear_operator_algebra.RegisterInverse( - linear_operator_kronecker.LinearOperatorKronecker) -def _inverse_kronecker(kronecker_operator): - # Inverse decomposition of a Kronecker product is the Kronecker product - # of inverse decompositions. - return linear_operator_kronecker.LinearOperatorKronecker( - operators=[ - operator.inverse() for operator in kronecker_operator.operators], - is_non_singular=kronecker_operator.is_non_singular, - is_self_adjoint=kronecker_operator.is_self_adjoint, - is_positive_definite=kronecker_operator.is_positive_definite, - is_square=True) - - -@linear_operator_algebra.RegisterInverse( - linear_operator_circulant._BaseLinearOperatorCirculant) # pylint: disable=protected-access -def _inverse_circulant(circulant_operator): - # Inverting the spectrum is sufficient to get the inverse. - return circulant_operator.__class__( - spectrum=1. / circulant_operator.spectrum, - is_non_singular=circulant_operator.is_non_singular, - is_self_adjoint=circulant_operator.is_self_adjoint, - is_positive_definite=circulant_operator.is_positive_definite, - is_square=True, - input_output_dtype=circulant_operator.dtype) - - -@linear_operator_algebra.RegisterInverse( - linear_operator_householder.LinearOperatorHouseholder) -def _inverse_householder(householder_operator): - return householder_operator - -import numpy as np -from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg -from tensorflow_probability.python.internal.backend.numpy import ops as _ops -from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape - -from tensorflow_probability.python.internal.backend.numpy import private -distribution_util = private.LazyLoader( - "distribution_util", globals(), - "tensorflow_probability.substrates.numpy.internal.distribution_util") -tensorshape_util = private.LazyLoader( - "tensorshape_util", globals(), - "tensorflow_probability.substrates.numpy.internal.tensorshape_util") -prefer_static = private.LazyLoader( - "prefer_static", globals(), - "tensorflow_probability.substrates.numpy.internal.prefer_static") - diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py index 1e6208b713..a7af244ec3 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py @@ -56,8 +56,8 @@ from tensorflow_probability.python.internal.backend.numpy import resource_variable_ops from tensorflow_probability.python.internal.backend.numpy import variables from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util +from tensorflow_probability.python.internal.backend.numpy.gen import property_hint_util from tensorflow_probability.python.internal.backend.numpy.gen import slicing from absl import logging as logging from tensorflow_probability.python.internal.backend.numpy import data_structures @@ -67,6 +67,7 @@ from tensorflow_probability.python.internal.backend.numpy import variable_utils # from tensorflow.python.util.tf_export import tf_export + __all__ = ["LinearOperator"] @@ -691,7 +692,13 @@ def _check_input_dtype(self, arg): def _matmul(self, x, adjoint=False, adjoint_arg=False): raise NotImplementedError("_matmul is not implemented.") - def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + def matmul( + self, + x, + adjoint=False, + adjoint_arg=False, + name="matmul", + ): """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. ```python @@ -731,8 +738,9 @@ def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): "Operators are incompatible. Expected `x` to have dimension" " {} but got {}.".format( left_operator.domain_dimension, right_operator.range_dimension)) + with self._name_scope(name): # pylint: disable=not-callable - return linear_operator_algebra.matmul(left_operator, right_operator) + return self._linop_matmul(left_operator, right_operator) with self._name_scope(name): # pylint: disable=not-callable x = ops.convert_to_tensor(x, name="x") @@ -746,6 +754,54 @@ def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + def _linop_matmul( + self, left_operator: "LinearOperator", right_operator: "LinearOperator" + ) -> "LinearOperator": + # instance of linear_operator_identity.LinearOperatorIdentity + if hasattr(right_operator, "_ones_diag") and not hasattr( + right_operator, "multiplier" + ): + return left_operator + + # instance of linear_operator_zeros.LinearOperatorZeros + elif hasattr(right_operator, "_zeros_diag"): + if not right_operator.is_square or not left_operator.is_square: + raise ValueError( + "Matmul with non-square `LinearOperator`s or " + "non-square `LinearOperatorZeros` not supported at this time." + ) + return right_operator + + else: + # Generic matmul of two `LinearOperator`s. + is_square = property_hint_util.is_square(left_operator, right_operator) + is_non_singular = None + is_self_adjoint = None + is_positive_definite = None + + if is_square: + is_non_singular = property_hint_util.combined_non_singular_hint( + left_operator, right_operator + ) + # is_square can be None, so the explicit check for False is needed. + elif is_square is False: # pylint:disable=g-bool-id-comparison + is_non_singular = False + is_self_adjoint = False + is_positive_definite = False + + # LinearOperator outputs a LinearOperatorComposition instance, which + # inherits from LinearOperator. The inline import is necessary to avoid + # errors due to this cyclic dependency. + from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_composition # pylint: disable=g-import-not-at-top + + return linear_operator_composition.LinearOperatorComposition( + operators=[left_operator, right_operator], + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=is_square, + ) + def __matmul__(self, other): return self.matmul(other) @@ -925,7 +981,7 @@ def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): " {} but got {}.".format( left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): # pylint: disable=not-callable - return linear_operator_algebra.solve(left_operator, right_operator) + return self._linop_solve(left_operator, right_operator) with self._name_scope(name): # pylint: disable=not-callable rhs = ops.convert_to_tensor( @@ -941,6 +997,48 @@ def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) + def _linop_solve( + self, left_operator: "LinearOperator", right_operator: "LinearOperator" + ) -> "LinearOperator": + # instance of linear_operator_identity.LinearOperatorIdentity + if hasattr(right_operator, "_ones_diag") and not hasattr( + right_operator, "multiplier" + ): + return left_operator.inverse() + + # Generic solve of two `LinearOperator`s. + is_square = property_hint_util.is_square(left_operator, right_operator) + is_non_singular = None + is_self_adjoint = None + is_positive_definite = None + + if is_square: + is_non_singular = property_hint_util.combined_non_singular_hint( + left_operator, right_operator + ) + elif is_square is False: # pylint:disable=g-bool-id-comparison + is_non_singular = False + is_self_adjoint = False + is_positive_definite = False + + # LinearOperator outputs a LinearOperatorComposition instance that contains + # a LinearOperatorInversion instance, both of which + # inherit from LinearOperator. The inline import is necessary to avoid + # errors due to this cyclic dependency. + from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_composition # pylint: disable=g-import-not-at-top + from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_inversion # pylint: disable=g-import-not-at-top + + return linear_operator_composition.LinearOperatorComposition( + operators=[ + linear_operator_inversion.LinearOperatorInversion(left_operator), + right_operator, + ], + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=is_square, + ) + def _solvevec(self, rhs, adjoint=False): """Default implementation of _solvevec.""" rhs_mat = array_ops.expand_dims(rhs, axis=-1) @@ -997,7 +1095,7 @@ def solvevec(self, rhs, adjoint=False, name="solve"): return self._solvevec(rhs, adjoint=adjoint) - def adjoint(self, name="adjoint"): + def adjoint(self, name: str = "adjoint") -> "LinearOperator": """Returns the adjoint of the current `LinearOperator`. Given `A` representing this `LinearOperator`, return `A*`. @@ -1012,12 +1110,21 @@ def adjoint(self, name="adjoint"): if self.is_self_adjoint is True: # pylint: disable=g-bool-id-comparison return self with self._name_scope(name): # pylint: disable=not-callable - return linear_operator_algebra.adjoint(self) + return self._linop_adjoint() # self.H is equivalent to self.adjoint(). H = property(adjoint, None) - def inverse(self, name="inverse"): + def _linop_adjoint(self) -> "LinearOperator": + from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_adjoint # pylint: disable=g-import-not-at-top + return linear_operator_adjoint.LinearOperatorAdjoint( + self, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=self.is_square) + + def inverse(self, name: str = "inverse") -> "LinearOperator": """Returns the Inverse of this `LinearOperator`. Given `A` representing this `LinearOperator`, return a `LinearOperator` @@ -1040,9 +1147,23 @@ def inverse(self, name="inverse"): "a singular matrix.") with self._name_scope(name): # pylint: disable=not-callable - return linear_operator_algebra.inverse(self) - - def cholesky(self, name="cholesky"): + return self._linop_inverse() + + def _linop_inverse(self) -> "LinearOperator": + # The in-line import is necessary because linear_operator_inversion.py + # depends on linear_operator.py. The in-line import works because the two + # files are now in the same build target, but if the import were at the top + # of the file there would be a partially-initialized module error caused by + # the code cycle. + from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_inversion # pylint: disable=g-import-not-at-top + return linear_operator_inversion.LinearOperatorInversion( + self, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=self.is_square) + + def cholesky(self, name: str = "cholesky") -> "LinearOperator": """Returns a Cholesky factor as a `LinearOperator`. Given `A` representing this `LinearOperator`, if `A` is positive definite @@ -1065,7 +1186,15 @@ def cholesky(self, name="cholesky"): raise ValueError("Cannot take the Cholesky decomposition: " "Not a positive definite self adjoint matrix.") with self._name_scope(name): # pylint: disable=not-callable - return linear_operator_algebra.cholesky(self) + return self._linop_cholesky() + + def _linop_cholesky(self) -> "LinearOperator": + from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_lower_triangular # pylint: disable=g-import-not-at-top + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( + linalg_ops.cholesky(self.to_dense()), + is_non_singular=True, + is_self_adjoint=False, + is_square=True) def _to_dense(self): """Generic and often inefficient implementation. Override often.""" diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_adjoint.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_adjoint.py index 77a38c9ad9..69c51544cf 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_adjoint.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_adjoint.py @@ -181,6 +181,9 @@ def operator(self): """The operator before taking the adjoint.""" return self._operator + def _linop_adjoint(self) -> linear_operator.LinearOperator: + return self.operator + def _assert_non_singular(self): return self.operator.assert_non_singular() diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_algebra.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_algebra.py deleted file mode 100644 index 891b885c6a..0000000000 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_algebra.py +++ /dev/null @@ -1,442 +0,0 @@ -# Copyright 2020 The TensorFlow Probability Authors. All Rights Reserved. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# THIS FILE IS AUTO-GENERATED BY `gen_linear_operators.py`. -# DO NOT MODIFY DIRECTLY. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# pylint: disable=g-import-not-at-top -# pylint: disable=g-direct-tensorflow-import -# pylint: disable=g-bad-import-order -# pylint: disable=unused-import -# pylint: disable=line-too-long -# pylint: disable=reimported -# pylint: disable=g-bool-id-comparison -# pylint: disable=g-statement-before-imports -# pylint: disable=bad-continuation -# pylint: disable=useless-import-alias -# pylint: disable=property-with-parameters -# pylint: disable=trailing-whitespace -# pylint: disable=g-inconsistent-quotes - -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Registration mechanisms for various n-ary operations on LinearOperators.""" - -import itertools - -from tensorflow_probability.python.internal.backend.numpy import ops -from tensorflow_probability.python.internal.backend.numpy import tf_inspect - - -_ADJOINTS = {} -_CHOLESKY_DECOMPS = {} -_MATMUL = {} -_SOLVE = {} -_INVERSES = {} - - -def _registered_function(type_list, registry): - """Given a list of classes, finds the most specific function registered.""" - enumerated_hierarchies = [enumerate(tf_inspect.getmro(t)) for t in type_list] - # Get all possible combinations of hierarchies. - cls_combinations = list(itertools.product(*enumerated_hierarchies)) - - def hierarchy_distance(cls_combination): - candidate_distance = sum(c[0] for c in cls_combination) - if tuple(c[1] for c in cls_combination) in registry: - return candidate_distance - return 10000 - - registered_combination = min(cls_combinations, key=hierarchy_distance) - return registry.get(tuple(r[1] for r in registered_combination), None) - - -def _registered_adjoint(type_a): - """Get the Adjoint function registered for class a.""" - return _registered_function([type_a], _ADJOINTS) - - -def _registered_cholesky(type_a): - """Get the Cholesky function registered for class a.""" - return _registered_function([type_a], _CHOLESKY_DECOMPS) - - -def _registered_matmul(type_a, type_b): - """Get the Matmul function registered for classes a and b.""" - return _registered_function([type_a, type_b], _MATMUL) - - -def _registered_solve(type_a, type_b): - """Get the Solve function registered for classes a and b.""" - return _registered_function([type_a, type_b], _SOLVE) - - -def _registered_inverse(type_a): - """Get the Cholesky function registered for class a.""" - return _registered_function([type_a], _INVERSES) - - -def adjoint(lin_op_a, name=None): - """Get the adjoint associated to lin_op_a. - - Args: - lin_op_a: The LinearOperator to take the adjoint of. - name: Name to use for this operation. - - Returns: - A LinearOperator that represents the adjoint of `lin_op_a`. - - Raises: - NotImplementedError: If no Adjoint method is defined for the LinearOperator - type of `lin_op_a`. - """ - adjoint_fn = _registered_adjoint(type(lin_op_a)) - if adjoint_fn is None: - raise ValueError("No adjoint registered for {}".format( - type(lin_op_a))) - - with ops.name_scope(name, "Adjoint"): - return adjoint_fn(lin_op_a) - - -def cholesky(lin_op_a, name=None): - """Get the Cholesky factor associated to lin_op_a. - - Args: - lin_op_a: The LinearOperator to decompose. - name: Name to use for this operation. - - Returns: - A LinearOperator that represents the lower Cholesky factor of `lin_op_a`. - - Raises: - NotImplementedError: If no Cholesky method is defined for the LinearOperator - type of `lin_op_a`. - """ - cholesky_fn = _registered_cholesky(type(lin_op_a)) - if cholesky_fn is None: - raise ValueError("No cholesky decomposition registered for {}".format( - type(lin_op_a))) - - with ops.name_scope(name, "Cholesky"): - return cholesky_fn(lin_op_a) - - -def matmul(lin_op_a, lin_op_b, name=None): - """Compute lin_op_a.matmul(lin_op_b). - - Args: - lin_op_a: The LinearOperator on the left. - lin_op_b: The LinearOperator on the right. - name: Name to use for this operation. - - Returns: - A LinearOperator that represents the matmul between `lin_op_a` and - `lin_op_b`. - - Raises: - NotImplementedError: If no matmul method is defined between types of - `lin_op_a` and `lin_op_b`. - """ - matmul_fn = _registered_matmul(type(lin_op_a), type(lin_op_b)) - if matmul_fn is None: - raise ValueError("No matmul registered for {}.matmul({})".format( - type(lin_op_a), type(lin_op_b))) - - with ops.name_scope(name, "Matmul"): - return matmul_fn(lin_op_a, lin_op_b) - - -def solve(lin_op_a, lin_op_b, name=None): - """Compute lin_op_a.solve(lin_op_b). - - Args: - lin_op_a: The LinearOperator on the left. - lin_op_b: The LinearOperator on the right. - name: Name to use for this operation. - - Returns: - A LinearOperator that represents the solve between `lin_op_a` and - `lin_op_b`. - - Raises: - NotImplementedError: If no solve method is defined between types of - `lin_op_a` and `lin_op_b`. - """ - solve_fn = _registered_solve(type(lin_op_a), type(lin_op_b)) - if solve_fn is None: - raise ValueError("No solve registered for {}.solve({})".format( - type(lin_op_a), type(lin_op_b))) - - with ops.name_scope(name, "Solve"): - return solve_fn(lin_op_a, lin_op_b) - - -def inverse(lin_op_a, name=None): - """Get the Inverse associated to lin_op_a. - - Args: - lin_op_a: The LinearOperator to decompose. - name: Name to use for this operation. - - Returns: - A LinearOperator that represents the inverse of `lin_op_a`. - - Raises: - NotImplementedError: If no Inverse method is defined for the LinearOperator - type of `lin_op_a`. - """ - inverse_fn = _registered_inverse(type(lin_op_a)) - if inverse_fn is None: - raise ValueError("No inverse registered for {}".format( - type(lin_op_a))) - - with ops.name_scope(name, "Inverse"): - return inverse_fn(lin_op_a) - - -class RegisterAdjoint: - """Decorator to register an Adjoint implementation function. - - Usage: - - @linear_operator_algebra.RegisterAdjoint(lin_op.LinearOperatorIdentity) - def _adjoint_identity(lin_op_a): - # Return the identity matrix. - """ - - def __init__(self, lin_op_cls_a): - """Initialize the LinearOperator registrar. - - Args: - lin_op_cls_a: the class of the LinearOperator to decompose. - """ - self._key = (lin_op_cls_a,) - - def __call__(self, adjoint_fn): - """Perform the Adjoint registration. - - Args: - adjoint_fn: The function to use for the Adjoint. - - Returns: - adjoint_fn - - Raises: - TypeError: if adjoint_fn is not a callable. - ValueError: if a Adjoint function has already been registered for - the given argument classes. - """ - if not callable(adjoint_fn): - raise TypeError( - "adjoint_fn must be callable, received: {}".format(adjoint_fn)) - if self._key in _ADJOINTS: - raise ValueError("Adjoint({}) has already been registered to: {}".format( - self._key[0].__name__, _ADJOINTS[self._key])) - _ADJOINTS[self._key] = adjoint_fn - return adjoint_fn - - -class RegisterCholesky: - """Decorator to register a Cholesky implementation function. - - Usage: - - @linear_operator_algebra.RegisterCholesky(lin_op.LinearOperatorIdentity) - def _cholesky_identity(lin_op_a): - # Return the identity matrix. - """ - - def __init__(self, lin_op_cls_a): - """Initialize the LinearOperator registrar. - - Args: - lin_op_cls_a: the class of the LinearOperator to decompose. - """ - self._key = (lin_op_cls_a,) - - def __call__(self, cholesky_fn): - """Perform the Cholesky registration. - - Args: - cholesky_fn: The function to use for the Cholesky. - - Returns: - cholesky_fn - - Raises: - TypeError: if cholesky_fn is not a callable. - ValueError: if a Cholesky function has already been registered for - the given argument classes. - """ - if not callable(cholesky_fn): - raise TypeError( - "cholesky_fn must be callable, received: {}".format(cholesky_fn)) - if self._key in _CHOLESKY_DECOMPS: - raise ValueError("Cholesky({}) has already been registered to: {}".format( - self._key[0].__name__, _CHOLESKY_DECOMPS[self._key])) - _CHOLESKY_DECOMPS[self._key] = cholesky_fn - return cholesky_fn - - -class RegisterMatmul: - """Decorator to register a Matmul implementation function. - - Usage: - - @linear_operator_algebra.RegisterMatmul( - lin_op.LinearOperatorIdentity, - lin_op.LinearOperatorIdentity) - def _matmul_identity(a, b): - # Return the identity matrix. - """ - - def __init__(self, lin_op_cls_a, lin_op_cls_b): - """Initialize the LinearOperator registrar. - - Args: - lin_op_cls_a: the class of the LinearOperator to multiply. - lin_op_cls_b: the class of the second LinearOperator to multiply. - """ - self._key = (lin_op_cls_a, lin_op_cls_b) - - def __call__(self, matmul_fn): - """Perform the Matmul registration. - - Args: - matmul_fn: The function to use for the Matmul. - - Returns: - matmul_fn - - Raises: - TypeError: if matmul_fn is not a callable. - ValueError: if a Matmul function has already been registered for - the given argument classes. - """ - if not callable(matmul_fn): - raise TypeError( - "matmul_fn must be callable, received: {}".format(matmul_fn)) - if self._key in _MATMUL: - raise ValueError("Matmul({}, {}) has already been registered.".format( - self._key[0].__name__, - self._key[1].__name__)) - _MATMUL[self._key] = matmul_fn - return matmul_fn - - -class RegisterSolve: - """Decorator to register a Solve implementation function. - - Usage: - - @linear_operator_algebra.RegisterSolve( - lin_op.LinearOperatorIdentity, - lin_op.LinearOperatorIdentity) - def _solve_identity(a, b): - # Return the identity matrix. - """ - - def __init__(self, lin_op_cls_a, lin_op_cls_b): - """Initialize the LinearOperator registrar. - - Args: - lin_op_cls_a: the class of the LinearOperator that is computing solve. - lin_op_cls_b: the class of the second LinearOperator to solve. - """ - self._key = (lin_op_cls_a, lin_op_cls_b) - - def __call__(self, solve_fn): - """Perform the Solve registration. - - Args: - solve_fn: The function to use for the Solve. - - Returns: - solve_fn - - Raises: - TypeError: if solve_fn is not a callable. - ValueError: if a Solve function has already been registered for - the given argument classes. - """ - if not callable(solve_fn): - raise TypeError( - "solve_fn must be callable, received: {}".format(solve_fn)) - if self._key in _SOLVE: - raise ValueError("Solve({}, {}) has already been registered.".format( - self._key[0].__name__, - self._key[1].__name__)) - _SOLVE[self._key] = solve_fn - return solve_fn - - -class RegisterInverse: - """Decorator to register an Inverse implementation function. - - Usage: - - @linear_operator_algebra.RegisterInverse(lin_op.LinearOperatorIdentity) - def _inverse_identity(lin_op_a): - # Return the identity matrix. - """ - - def __init__(self, lin_op_cls_a): - """Initialize the LinearOperator registrar. - - Args: - lin_op_cls_a: the class of the LinearOperator to decompose. - """ - self._key = (lin_op_cls_a,) - - def __call__(self, inverse_fn): - """Perform the Inverse registration. - - Args: - inverse_fn: The function to use for the Inverse. - - Returns: - inverse_fn - - Raises: - TypeError: if inverse_fn is not a callable. - ValueError: if a Inverse function has already been registered for - the given argument classes. - """ - if not callable(inverse_fn): - raise TypeError( - "inverse_fn must be callable, received: {}".format(inverse_fn)) - if self._key in _INVERSES: - raise ValueError("Inverse({}) has already been registered to: {}".format( - self._key[0].__name__, _INVERSES[self._key])) - _INVERSES[self._key] = inverse_fn - return inverse_fn - -import numpy as np -from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg -from tensorflow_probability.python.internal.backend.numpy import ops as _ops -from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape - -from tensorflow_probability.python.internal.backend.numpy import private -distribution_util = private.LazyLoader( - "distribution_util", globals(), - "tensorflow_probability.substrates.numpy.internal.distribution_util") -tensorshape_util = private.LazyLoader( - "tensorshape_util", globals(), - "tensorflow_probability.substrates.numpy.internal.tensorshape_util") -prefer_static = private.LazyLoader( - "prefer_static", globals(), - "tensorflow_probability.substrates.numpy.internal.prefer_static") - diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_diag.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_diag.py index 5342e615a9..7f5fe18bba 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_diag.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_diag.py @@ -43,8 +43,8 @@ from tensorflow_probability.python.internal.backend.numpy import debugging as check_ops from tensorflow_probability.python.internal.backend.numpy import control_flow as control_flow_ops from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util +from tensorflow_probability.python.internal.backend.numpy.gen import property_hint_util # from tensorflow.python.util.tf_export import tf_export __all__ = ["LinearOperatorBlockDiag"] @@ -312,6 +312,75 @@ def _shape_tensor(self): return prefer_static.concat((batch_shape, matrix_shape), 0) + def _linop_adjoint(self) -> "LinearOperatorBlockDiag": + # We take the adjoint of each block on the diagonal. + return LinearOperatorBlockDiag( + operators=[operator.adjoint() for operator in self.operators], + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_cholesky(self) -> "LinearOperatorBlockDiag": + # We take the cholesky of each block on the diagonal. + return LinearOperatorBlockDiag( + operators=[operator.cholesky() for operator in self.operators], + is_non_singular=True, + is_self_adjoint=None, # Let the operators passed in decide. + is_square=True) + + def _linop_inverse(self) -> "LinearOperatorBlockDiag": + # We take the inverse of each block on the diagonal. + return LinearOperatorBlockDiag( + operators=[ + operator.inverse() for operator in self.operators], + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_matmul( + self, + left_operator: "LinearOperatorBlockDiag", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + if isinstance(right_operator, LinearOperatorBlockDiag): + return LinearOperatorBlockDiag( + operators=[ + o1.matmul(o2) for o1, o2 in zip( + left_operator.operators, right_operator.operators)], + is_non_singular=property_hint_util.combined_non_singular_hint( + left_operator, right_operator), + # In general, a product of self-adjoint positive-definite + # block diagonal matrices is not self-adjoint. + is_self_adjoint=None, + # In general, a product of positive-definite block diagonal + # matrices is not positive-definite. + is_positive_definite=None, + is_square=True) + return super()._linop_matmul(left_operator, right_operator) + + def _linop_solve( + self, + left_operator: "LinearOperatorBlockDiag", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + if isinstance(right_operator, LinearOperatorBlockDiag): + return LinearOperatorBlockDiag( + operators=[ + o1.solve(o2) for o1, o2 in zip( + left_operator.operators, right_operator.operators)], + is_non_singular=property_hint_util.combined_non_singular_hint( + left_operator, right_operator), + # In general, a solve of self-adjoint positive-definite block diagonal + # matrices is not self = self - adjoint. + is_self_adjoint=None, + # In general, a solve of positive-definite block diagonal matrices is + # not positive-definite. + is_positive_definite=None, + is_square=True) + return super()._linop_solve(left_operator, right_operator) + # TODO(b/188080761): Add a more efficient implementation of `cond` that # constructs the condition number from the blockwise singular values. @@ -378,7 +447,7 @@ def _check_operators_agree(r, l, message): o1.domain_dimension, o2.range_dimension)) with self._name_scope(name): # pylint: disable=not-callable - return linear_operator_algebra.matmul(left_operator, right_operator) + return self._linop_matmul(left_operator, right_operator) with self._name_scope(name): # pylint: disable=not-callable arg_dim = -1 if adjoint_arg else -2 @@ -575,7 +644,7 @@ def _check_operators_agree(r, l, message): o1.domain_dimension, o2.range_dimension)) with self._name_scope(name): # pylint: disable=not-callable - return linear_operator_algebra.solve(left_operator, right_operator) + return self._linop_solve(left_operator, right_operator) with self._name_scope(name): # pylint: disable=not-callable block_dimensions = (self._block_domain_dimensions() if adjoint diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_lower_triangular.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_lower_triangular.py index 6ceeff9e74..3ab120804c 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_lower_triangular.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_block_lower_triangular.py @@ -45,7 +45,9 @@ from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra +from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_addition +from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_full_matrix +from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_identity from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util from tensorflow_probability.python.internal.backend.numpy import nest # from tensorflow.python.util.tf_export import tf_export @@ -418,6 +420,94 @@ def _shape_tensor(self): return prefer_static.concat((batch_shape, matrix_shape), 0) + def _linop_inverse(self) -> "LinearOperatorBlockLowerTriangular": + """Inverse of LinearOperatorBlockLowerTriangular. + + We recursively apply the identity: + + ```none + |A 0|' = | A' 0| + |B C| |-C'BA' C'| + ``` + + where `A` is n-by-n, `B` is m-by-n, + `C` is m-by-m, and `'` denotes inverse. + + This identity can be verified through multiplication: + + ```none + |A 0|| A' 0| + |B C||-C'BA' C'| + + = | AA' 0| + |BA'-CC'BA' CC'| + + = |I 0| + |0 I| + ``` + Returns: + A 'LinearOperatorBlockLowerTriangular'. + """ + if len(self.operators) == 1: + return (LinearOperatorBlockLowerTriangular( + [[self.operators[0][0].inverse()]], + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=(self. + is_positive_definite), + is_square=True)) + + blockwise_dim = len(self.operators) + + # Calculate the inverse of the `LinearOperatorBlockLowerTriangular` + # representing all but the last row of `self` with + # a recursive call (the matrix `A'` in the docstring definition). + upper_left_inverse = ( + LinearOperatorBlockLowerTriangular(self.operators[:-1]).inverse()) + + bottom_row = self.operators[-1] + bottom_right_inverse = bottom_row[-1].inverse() + + # Find the bottom row of the inverse (equal to `[-C'BA', C']` + # in the docstring definition, where `C` is the bottom-right operator of + # `self` and `B` is the set of operators in the + # bottom row excluding `C`). To find `-C'BA'`, we first iterate over the + # column partitions of `A'`. + inverse_bottom_row = [] + for i in range(blockwise_dim - 1): + # Find the `i`-th block of `BA'`. + blocks = [] + for j in range(i, blockwise_dim - 1): + result = bottom_row[j].matmul(upper_left_inverse.operators[j][i]) + if not any( + isinstance(result, op_type) + for op_type in linear_operator_addition.SUPPORTED_OPERATORS + ): + result = linear_operator_full_matrix.LinearOperatorFullMatrix( + result.to_dense()) + blocks.append(result) + + summed_blocks = linear_operator_addition.add_operators(blocks) + assert len(summed_blocks) == 1 + block = summed_blocks[0] + + # Find the `i`-th block of `-C'BA'`. + block = bottom_right_inverse.matmul(block) + block = linear_operator_identity.LinearOperatorScaledIdentity( + num_rows=bottom_right_inverse.domain_dimension_tensor(), + multiplier=_ops.cast(-1, dtype=block.dtype)).matmul(block) + inverse_bottom_row.append(block) + + # `C'` is the last block of the inverted linear operator. + inverse_bottom_row.append(bottom_right_inverse) + + return (LinearOperatorBlockLowerTriangular( + upper_left_inverse.operators + [inverse_bottom_row], + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=(self.is_positive_definite), + is_square=True)) + def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. @@ -461,7 +551,7 @@ class docstring for definition of shape compatibility. " {} but got {}.".format( left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): # pylint: disable=not-callable - return linear_operator_algebra.matmul(left_operator, right_operator) + return self._linop_matmul(left_operator, right_operator) with self._name_scope(name): # pylint: disable=not-callable arg_dim = -1 if adjoint_arg else -2 @@ -700,7 +790,7 @@ def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): " {} but got {}.".format( left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): # pylint: disable=not-callable - return linear_operator_algebra.solve(left_operator, right_operator) + return self._linop_solve(left_operator, right_operator) with self._name_scope(name): # pylint: disable=not-callable block_dimensions = (self._block_domain_dimensions() if adjoint diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_circulant.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_circulant.py index a8302c9109..6b7c6f0196 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_circulant.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_circulant.py @@ -37,6 +37,7 @@ from tensorflow_probability.python.internal.backend.numpy import dtype as dtypes from tensorflow_probability.python.internal.backend.numpy import ops +# from tensorflow.python.framework import tensor # from tensorflow.python.framework import tensor_conversion from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape from tensorflow_probability.python.internal.backend.numpy import numpy_array as array_ops @@ -57,6 +58,7 @@ from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util +from tensorflow_probability.python.internal.backend.numpy.gen import property_hint_util from tensorflow_probability.python.internal.backend.numpy import numpy_signal as fft_ops # from tensorflow.python.util.tf_export import tf_export @@ -199,13 +201,13 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator): """ def __init__(self, - spectrum, - block_depth, + spectrum: np.ndarray, + block_depth: int, input_output_dtype=dtypes.complex64, - is_non_singular=None, - is_self_adjoint=None, - is_positive_definite=None, - is_square=True, + is_non_singular: bool = None, + is_self_adjoint: bool = None, + is_positive_definite: bool = None, + is_square: bool = True, parameters=None, name="LinearOperatorCirculant"): r"""Initialize an `_BaseLinearOperatorCirculant`. @@ -334,12 +336,78 @@ def _block_shape_tensor(self, spectrum_shape=None): if spectrum_shape is None else spectrum_shape) return spectrum_shape[-self.block_depth:] + def _linop_adjoint(self) -> "_BaseLinearOperatorCirculant": + spectrum = self.spectrum + if np.issubdtype(spectrum.dtype, np.complexfloating): + spectrum = math_ops.conj(spectrum) + + # Conjugating the spectrum is sufficient to get the adjoint. + return _BaseLinearOperatorCirculant( + spectrum=spectrum, + block_depth=self.block_depth, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_inverse(self) -> "_BaseLinearOperatorCirculant": + return _BaseLinearOperatorCirculant( + spectrum=1. / self.spectrum, + block_depth=self.block_depth, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True, + input_output_dtype=self.dtype) + + def _linop_matmul( + self, + left_operator: "_BaseLinearOperatorCirculant", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + if (not isinstance(right_operator, _BaseLinearOperatorCirculant) + or not isinstance(left_operator, type(right_operator))): + return super()._linop_matmul(left_operator, right_operator) + + return _BaseLinearOperatorCirculant( + spectrum=left_operator.spectrum * right_operator.spectrum, + block_depth=left_operator.block_depth, + is_non_singular=property_hint_util.combined_non_singular_hint( + left_operator, right_operator), + is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator), + is_positive_definite=( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator)), + is_square=True) + + def _linop_solve( + self, + left_operator: "_BaseLinearOperatorCirculant", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + if (not isinstance(right_operator, _BaseLinearOperatorCirculant) + or not isinstance(left_operator, type(right_operator))): + return super()._linop_solve(left_operator, right_operator) + + return _BaseLinearOperatorCirculant( + spectrum=right_operator.spectrum / left_operator.spectrum, + block_depth=left_operator.block_depth, + is_non_singular=property_hint_util.combined_non_singular_hint( + left_operator, right_operator), + is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator), + is_positive_definite=( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator)), + is_square=True) + @property def block_shape(self): return tensor_shape.TensorShape(self.spectrum.shape)[-self.block_depth:] @property - def spectrum(self): + def spectrum(self) -> np.ndarray: return self._spectrum def _vectorize_then_blockify(self, matrix): @@ -888,12 +956,12 @@ class LinearOperatorCirculant(_BaseLinearOperatorCirculant): """ def __init__(self, - spectrum, + spectrum: np.ndarray, input_output_dtype=dtypes.complex64, - is_non_singular=None, - is_self_adjoint=None, - is_positive_definite=None, - is_square=True, + is_non_singular: bool = None, + is_self_adjoint: bool = None, + is_positive_definite: bool = None, + is_square: bool = True, name="LinearOperatorCirculant"): r"""Initialize an `LinearOperatorCirculant`. @@ -952,6 +1020,47 @@ def __init__(self, parameters=parameters, name=name) + def _linop_adjoint(self) -> "LinearOperatorCirculant": + spectrum = self.spectrum + if np.issubdtype(spectrum.dtype, np.complexfloating): + spectrum = math_ops.conj(spectrum) + + # Conjugating the spectrum is sufficient to get the adjoint. + return LinearOperatorCirculant( + spectrum=spectrum, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_inverse(self) -> "LinearOperatorCirculant": + return LinearOperatorCirculant( + spectrum=1. / self.spectrum, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True, + input_output_dtype=self.dtype) + + def _linop_solve( + self, + left_operator: "LinearOperatorCirculant", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + if not isinstance(right_operator, LinearOperatorCirculant): + return super()._linop_solve(left_operator, right_operator) + + return LinearOperatorCirculant( + spectrum=right_operator.spectrum / left_operator.spectrum, + is_non_singular=property_hint_util.combined_non_singular_hint( + left_operator, right_operator), + is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator), + is_positive_definite=( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator)), + is_square=True) + # @tf_export("linalg.LinearOperatorCirculant2D") # @linear_operator.make_composite_tensor @@ -1076,12 +1185,12 @@ class LinearOperatorCirculant2D(_BaseLinearOperatorCirculant): """ def __init__(self, - spectrum, + spectrum: np.ndarray, input_output_dtype=dtypes.complex64, - is_non_singular=None, - is_self_adjoint=None, - is_positive_definite=None, - is_square=True, + is_non_singular: bool = None, + is_self_adjoint: bool = None, + is_positive_definite: bool = None, + is_square: bool = True, name="LinearOperatorCirculant2D"): r"""Initialize an `LinearOperatorCirculant2D`. @@ -1140,6 +1249,47 @@ def __init__(self, parameters=parameters, name=name) + def _linop_adjoint(self) -> "LinearOperatorCirculant2D": + spectrum = self.spectrum + if np.issubdtype(spectrum.dtype, np.complexfloating): + spectrum = math_ops.conj(spectrum) + + # Conjugating the spectrum is sufficient to get the adjoint. + return LinearOperatorCirculant2D( + spectrum=spectrum, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_inverse(self) -> "LinearOperatorCirculant2D": + return LinearOperatorCirculant2D( + spectrum=1. / self.spectrum, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True, + input_output_dtype=self.dtype) + + def _linop_solve( + self, + left_operator: "LinearOperatorCirculant2D", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + if not isinstance(right_operator, LinearOperatorCirculant2D): + return super()._linop_solve(left_operator, right_operator) + + return LinearOperatorCirculant2D( + spectrum=right_operator.spectrum / left_operator.spectrum, + is_non_singular=property_hint_util.combined_non_singular_hint( + left_operator, right_operator), + is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator), + is_positive_definite=( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator)), + is_square=True) + # @tf_export("linalg.LinearOperatorCirculant3D") # @linear_operator.make_composite_tensor @@ -1237,12 +1387,12 @@ class LinearOperatorCirculant3D(_BaseLinearOperatorCirculant): """ def __init__(self, - spectrum, + spectrum: np.ndarray, input_output_dtype=dtypes.complex64, - is_non_singular=None, - is_self_adjoint=None, - is_positive_definite=None, - is_square=True, + is_non_singular: bool = None, + is_self_adjoint: bool = None, + is_positive_definite: bool = None, + is_square: bool = True, name="LinearOperatorCirculant3D"): """Initialize an `LinearOperatorCirculant`. @@ -1301,6 +1451,47 @@ def __init__(self, parameters=parameters, name=name) + def _linop_adjoint(self) -> "LinearOperatorCirculant3D": + spectrum = self.spectrum + if np.issubdtype(spectrum.dtype, np.complexfloating): + spectrum = math_ops.conj(spectrum) + + # Conjugating the spectrum is sufficient to get the adjoint. + return LinearOperatorCirculant3D( + spectrum=spectrum, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_inverse(self) -> "LinearOperatorCirculant3D": + return LinearOperatorCirculant3D( + spectrum=1. / self.spectrum, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True, + input_output_dtype=self.dtype) + + def _linop_solve( + self, + left_operator: "LinearOperatorCirculant3D", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + if not isinstance(right_operator, LinearOperatorCirculant3D): + return super()._linop_solve(left_operator, right_operator) + + return LinearOperatorCirculant3D( + spectrum=right_operator.spectrum / left_operator.spectrum, + is_non_singular=property_hint_util.combined_non_singular_hint( + left_operator, right_operator), + is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator), + is_positive_definite=( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator)), + is_square=True) + def _to_complex(x): if np.issubdtype(x.dtype, np.complexfloating): diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_composition.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_composition.py index 7699ddea16..05191930d9 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_composition.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_composition.py @@ -41,7 +41,10 @@ from tensorflow_probability.python.internal.backend.numpy import numpy_array as array_ops_stack from tensorflow_probability.python.internal.backend.numpy import debugging as check_ops from tensorflow_probability.python.internal.backend.numpy import control_flow as control_flow_ops +from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg_ops +from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator +from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_lower_triangular from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util # from tensorflow.python.util.tf_export import tf_export @@ -277,6 +280,66 @@ def _shape_tensor(self): return prefer_static.concat((batch_shape, matrix_shape), 0) + def _linop_cholesky(self) -> linear_operator.LinearOperator: + """Computes Cholesky(LinearOperatorComposition).""" + # L @ L.H will be handled with special code below. Why is L @ L.H the most + # important special case? + # Note that Diag @ Diag.H and Diag @ TriL and TriL @ Diag are already + # compressed to Diag or TriL by diag matmul + # registration. Similarly for Identity and ScaledIdentity. + # So these would not appear in a LinearOperatorComposition unless explicitly + # constructed as such. So the most important thing to check is L @ L.H. + def _is_llt_product(self): + """Determines if linop = L @ L.H for L = LinearOperatorLowerTriangular.""" + if len(self.operators) != 2: + return False + if not linear_operator_util.is_aat_form(self.operators): + return False + return isinstance( + self.operators[0], + linear_operator_lower_triangular.LinearOperatorLowerTriangular) + + if not _is_llt_product(self): + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( + linalg_ops.cholesky(self.to_dense()), + is_non_singular=True, + is_self_adjoint=False, + is_square=True) + + left_op = self.operators[0] + + # left_op.is_positive_definite ==> op already has positive diag,return it. + if left_op.is_positive_definite: + return left_op + + # Recall that the base class has already verified + # linop.is_positive_definite, else linop.cholesky() would have raised. + # So in particular, we know the diagonal has nonzero entries. + # In the generic case, we make op have positive diag by dividing each row + # by the sign of the diag. This is equivalent to setting A = L @ D where + # D is diag(sign(1 / L.diag_part())). Then A is lower triangular with + # positive diag and A @ A^H = L @ D @ D^H @ L^H = L @ L^H = linop. + # This also works for complex L, + # since sign(x + iy) = exp(i * angle(x + iy)). + diag_sign = array_ops.expand_dims( + math_ops.sign(left_op.diag_part()), axis=-2) + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( + tril=left_op.tril / diag_sign, + is_non_singular=left_op.is_non_singular, + # L.is_self_adjoint ==> L is diagonal ==> L @ D is diagonal ==> SA + # L.is_self_adjoint is False ==> L not diagonal ==> L @ D not diag ... + is_self_adjoint=left_op.is_self_adjoint, + # L.is_positive_definite ==> L has positive diag ==> L = L @ D + # ==> (L @ D).is_positive_definite. + # L.is_positive_definite is False could result + # in L @ D being PD or not. + # Consider L = [[1, 0], [-2, 1]] and quadratic form with x = [1, 1]. + # Note we will already return left_op if left_op.is_positive_definite + # above, but to be explicit write this below. + is_positive_definite=True if left_op.is_positive_definite else None, + is_square=True, + ) + def _matmul(self, x, adjoint=False, adjoint_arg=False): # If self.operators = [A, B], and not adjoint, then # matmul_order_list = [B, A]. diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_diag.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_diag.py index 77824f04b8..ea9a7ef5df 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_diag.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_diag.py @@ -40,7 +40,9 @@ from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator +from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_lower_triangular from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util +from tensorflow_probability.python.internal.backend.numpy.gen import property_hint_util # from tensorflow.python.util.tf_export import tf_export __all__ = ["LinearOperatorDiag",] @@ -210,6 +212,101 @@ def _shape_tensor(self): def diag(self): return self._diag + def _linop_inverse(self) -> "LinearOperatorDiag": + return LinearOperatorDiag( + 1. / self.diag, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_matmul( + self, + left_operator: "LinearOperatorDiag", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + is_non_singular = property_hint_util.combined_non_singular_hint( + left_operator, right_operator) + is_self_adjoint = property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator) + is_positive_definite = ( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator)) + if isinstance(right_operator, LinearOperatorDiag): + return LinearOperatorDiag( + diag=left_operator.diag * right_operator.diag, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=True, + ) + # instance of linear_operator_identity.LinearOperatorScaledIdentity + elif hasattr(right_operator, "_ones_diag") and hasattr( + right_operator, "multiplier" + ): + return LinearOperatorDiag( + diag=left_operator.diag * right_operator.multiplier, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=True) + elif isinstance( + right_operator, + linear_operator_lower_triangular.LinearOperatorLowerTriangular, + ): + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( + tril=left_operator.diag[..., None] * right_operator.to_dense(), + is_non_singular=is_non_singular, + # This is safe to do since the Triangular matrix is only self-adjoint + # when it is a diagonal matrix, and hence commutes. + is_self_adjoint=is_self_adjoint, + is_positive_definite=None, + is_square=True) + else: + return super()._linop_matmul(left_operator, right_operator) + + def _linop_solve( + self, + left_operator: "LinearOperatorDiag", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + is_non_singular = property_hint_util.combined_non_singular_hint( + left_operator, right_operator) + is_self_adjoint = property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator) + is_positive_definite = ( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator)) + if isinstance(right_operator, LinearOperatorDiag): + return LinearOperatorDiag( + diag=right_operator.diag / left_operator.diag, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=True) + # instance of linear_operator_identity.LinearOperatorScaledIdentity + elif (hasattr(right_operator, "_ones_diag") + and hasattr(right_operator, "multiplier")): + return LinearOperatorDiag( + diag=right_operator.multiplier / left_operator.diag, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=True) + elif isinstance( + right_operator, + linear_operator_lower_triangular.LinearOperatorLowerTriangular): + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( + tril=right_operator.to_dense() / left_operator.diag[..., None], + is_non_singular=is_non_singular, + # This is safe to do since the Triangular matrix is only self-adjoint + # when it is a diagonal matrix, and hence commutes. + is_self_adjoint=is_self_adjoint, + is_positive_definite=None, + is_square=True) + else: + return super()._linop_solve(left_operator, right_operator) + def _assert_non_singular(self): return linear_operator_util.assert_no_entries_with_modulus_zero( self._diag, @@ -236,6 +333,26 @@ def _assert_self_adjoint(self): "This diagonal operator contained non-zero imaginary values. " " Thus it was not self-adjoint.")) + def _linop_adjoint(self) -> "LinearOperatorDiag": + diag = self.diag + if np.issubdtype(diag.dtype, np.complexfloating): + diag = math_ops.conj(diag) + + return LinearOperatorDiag( + diag=diag, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_cholesky(self) -> "LinearOperatorDiag": + return LinearOperatorDiag( + math_ops.sqrt(self.diag), + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + def _matmul(self, x, adjoint=False, adjoint_arg=False): diag_term = math_ops.conj(self._diag) if adjoint else self._diag x = linalg.adjoint(x) if adjoint_arg else x diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_householder.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_householder.py index 0959ef11d2..633f7e5b70 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_householder.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_householder.py @@ -209,6 +209,12 @@ def _assert_positive_definite(self): def _assert_self_adjoint(self): return control_flow_ops.no_op("assert_self_adjoint") + def _linop_adjoint(self) -> "LinearOperatorHouseholder": + return self + + def _linop_inverse(self) -> "LinearOperatorHouseholder": + return self + def _matmul(self, x, adjoint=False, adjoint_arg=False): # Given a vector `v`, we would like to reflect `x` about the hyperplane # orthogonal to `v` going through the origin. We first project `x` to `v` diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_identity.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_identity.py index fafdf503ad..4487bd3aa6 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_identity.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_identity.py @@ -47,7 +47,9 @@ from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator +from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_diag from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util +from tensorflow_probability.python.internal.backend.numpy.gen import property_hint_util # from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -338,6 +340,38 @@ def _shape_tensor(self): return prefer_static.concat((self._batch_shape_arg, matrix_shape), 0) + def _linop_adjoint(self) -> "LinearOperatorIdentity": + return self + + def _linop_cholesky(self) -> "LinearOperatorIdentity": + return LinearOperatorIdentity( + num_rows=self._num_rows, # pylint: disable=protected-access + batch_shape=self.batch_shape, + dtype=self.dtype, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def _linop_inverse(self) -> "LinearOperatorIdentity": + return self + + def _linop_matmul( + self, + left_operator: "LinearOperatorIdentity", + right_operator: linear_operator.LinearOperator, + ) -> "LinearOperatorIdentity": + del left_operator + return right_operator + + def _linop_solve( + self, + left_operator: "LinearOperatorIdentity", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + del left_operator + return right_operator + def _assert_non_singular(self): return control_flow_ops.no_op("assert_non_singular") @@ -729,6 +763,97 @@ def _make_multiplier_matrix(self, conjugate=False): multiplier_matrix = math_ops.conj(multiplier_matrix) return multiplier_matrix + def _linop_adjoint(self) -> "LinearOperatorScaledIdentity": + multiplier = self.multiplier + if np.issubdtype(multiplier.dtype, np.complexfloating): + multiplier = math_ops.conj(multiplier) + + return LinearOperatorScaledIdentity( + num_rows=self._num_rows, + multiplier=multiplier, + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_cholesky(self) -> "LinearOperatorScaledIdentity": + return LinearOperatorScaledIdentity( + num_rows=self._num_rows, + multiplier=math_ops.sqrt(self.multiplier), + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def _linop_inverse(self) -> "LinearOperatorScaledIdentity": + return LinearOperatorScaledIdentity( + num_rows=self._num_rows, + multiplier=1. / self.multiplier, + is_non_singular=self.is_non_singular, + is_self_adjoint=True, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_matmul( + self, + left_operator: "LinearOperatorScaledIdentity", + right_operator: linear_operator.LinearOperator, + ) -> "LinearOperatorScaledIdentity": + is_non_singular = property_hint_util.combined_non_singular_hint( + left_operator, right_operator) + is_self_adjoint = property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator) + is_positive_definite = ( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator)) + if isinstance(right_operator, LinearOperatorScaledIdentity): + return LinearOperatorScaledIdentity( + num_rows=left_operator.domain_dimension_tensor(), + multiplier=left_operator.multiplier * right_operator.multiplier, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=True) + elif isinstance(right_operator, linear_operator_diag.LinearOperatorDiag): + return linear_operator_diag.LinearOperatorDiag( + diag=right_operator.diag * left_operator.multiplier, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=True) + else: + return super()._linop_matmul(left_operator, right_operator) + + def _linop_solve( + self, + left_operator: "LinearOperatorScaledIdentity", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + is_non_singular = property_hint_util.combined_non_singular_hint( + left_operator, right_operator) + is_self_adjoint = property_hint_util.combined_commuting_self_adjoint_hint( + left_operator, right_operator) + is_positive_definite = ( + property_hint_util.combined_commuting_positive_definite_hint( + left_operator, right_operator)) + if isinstance(right_operator, LinearOperatorScaledIdentity): + return LinearOperatorScaledIdentity( + num_rows=left_operator.domain_dimension_tensor(), + multiplier=right_operator.multiplier / left_operator.multiplier, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=True) + elif isinstance(right_operator, linear_operator_diag.LinearOperatorDiag): + return linear_operator_diag.LinearOperatorDiag( + diag=right_operator.diag / left_operator.multiplier, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=True) + else: + return super()._linop_solve(left_operator, right_operator) + def _matmul(self, x, adjoint=False, adjoint_arg=False): x = linalg.adjoint(x) if adjoint_arg else x if self._assert_proper_shapes: diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_inversion.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_inversion.py index aa52e4d257..c6f11f7037 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_inversion.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_inversion.py @@ -193,10 +193,21 @@ def __init__(self, name=name) @property - def operator(self): + def operator(self) -> "LinearOperatorInversion": """The operator before inversion.""" return self._operator + def _linop_inverse(self) -> linear_operator.LinearOperator: + return self.operator + + def _linop_solve( + self, + left_operator: "LinearOperatorInversion", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + """Solve inverse of generic `LinearOperator`s.""" + return left_operator.operator.matmul(right_operator) + def _assert_non_singular(self): return self.operator.assert_non_singular() diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_kronecker.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_kronecker.py index ff62307020..28e2d78308 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_kronecker.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_kronecker.py @@ -297,6 +297,34 @@ def _shape_tensor(self): return prefer_static.concat((batch_shape, matrix_shape), 0) + def _linop_adjoint(self) -> "LinearOperatorKronecker": + return LinearOperatorKronecker( + operators=[operator.adjoint() for operator in self.operators], + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + + def _linop_cholesky(self) -> "LinearOperatorKronecker": + # Cholesky decomposition of a Kronecker product is the Kronecker product + # of cholesky decompositions. + return LinearOperatorKronecker( + operators=[operator.cholesky() for operator in self.operators], + is_non_singular=True, + is_self_adjoint=None, # Let the operators passed in decide. + is_square=True) + + def _linop_inverse(self) -> "LinearOperatorKronecker": + # Inverse decomposition of a Kronecker product is the Kronecker product + # of inverse decompositions. + return LinearOperatorKronecker( + operators=[ + operator.inverse() for operator in self.operators], + is_non_singular=self.is_non_singular, + is_self_adjoint=self.is_self_adjoint, + is_positive_definite=self.is_positive_definite, + is_square=True) + def _solve_matmul_internal( self, x, diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_lower_triangular.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_lower_triangular.py index 51875d283a..a9fe2dfb86 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_lower_triangular.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_lower_triangular.py @@ -39,6 +39,7 @@ from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util +from tensorflow_probability.python.internal.backend.numpy.gen import property_hint_util # from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -217,6 +218,25 @@ def _matmul(self, x, adjoint=False, adjoint_arg=False): return _linalg.matmul( self._get_tril(), x, adjoint_a=adjoint, adjoint_b=adjoint_arg) + def _linop_matmul( + self, + left_operator: "LinearOperatorLowerTriangular", + right_operator: linear_operator.LinearOperator, + ) -> linear_operator.LinearOperator: + # instance check of linear_operator_diag.LinearOperatorDiag + if hasattr(right_operator, "_check_diag"): + return LinearOperatorLowerTriangular( + tril=left_operator.to_dense() * right_operator.diag, + is_non_singular=property_hint_util.combined_non_singular_hint( + right_operator, left_operator), + # This is safe to do since the Triangular matrix is only self-adjoint + # when it is a diagonal matrix, and hence commutes. + is_self_adjoint=property_hint_util.combined_commuting_self_adjoint_hint( + right_operator, left_operator), + is_positive_definite=None, + is_square=True) + return super()._linop_matmul(left_operator, right_operator) + def _determinant(self): return math_ops.reduce_prod(self._get_diag(), axis=[-1]) diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_zeros.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_zeros.py index 96de3d1f46..889791c7fc 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_zeros.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_zeros.py @@ -351,6 +351,16 @@ def _matmul(self, x, adjoint=False, adjoint_arg=False): zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype) return self._possibly_broadcast_batch_shape(zeros) + def _linop_matmul( + self, + left_operator: "LinearOperatorZeros", + right_operator: linear_operator.LinearOperator + ) -> linear_operator.LinearOperator: + if not left_operator.is_square or not right_operator.is_square: + raise ValueError("Matmul with non-square `LinearOperator`s or non-square " + "`LinearOperatorZeros` not supported at this time.") + return left_operator + def _determinant(self): if self.batch_shape.is_fully_defined(): return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype) diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/matmul_registrations.py b/tensorflow_probability/python/internal/backend/numpy/gen/matmul_registrations.py deleted file mode 100644 index 4753c46748..0000000000 --- a/tensorflow_probability/python/internal/backend/numpy/gen/matmul_registrations.py +++ /dev/null @@ -1,277 +0,0 @@ -# Copyright 2020 The TensorFlow Probability Authors. All Rights Reserved. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# THIS FILE IS AUTO-GENERATED BY `gen_linear_operators.py`. -# DO NOT MODIFY DIRECTLY. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# pylint: disable=g-import-not-at-top -# pylint: disable=g-direct-tensorflow-import -# pylint: disable=g-bad-import-order -# pylint: disable=unused-import -# pylint: disable=line-too-long -# pylint: disable=reimported -# pylint: disable=g-bool-id-comparison -# pylint: disable=g-statement-before-imports -# pylint: disable=bad-continuation -# pylint: disable=useless-import-alias -# pylint: disable=property-with-parameters -# pylint: disable=trailing-whitespace -# pylint: disable=g-inconsistent-quotes - -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Registrations for LinearOperator.matmul.""" - -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_block_diag -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_circulant -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_composition -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_diag -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_identity -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_lower_triangular -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_zeros -from tensorflow_probability.python.internal.backend.numpy.gen import registrations_util - - -# By default, use a LinearOperatorComposition to delay the computation. -@linear_operator_algebra.RegisterMatmul( - linear_operator.LinearOperator, linear_operator.LinearOperator) -def _matmul_linear_operator(linop_a, linop_b): - """Generic matmul of two `LinearOperator`s.""" - is_square = registrations_util.is_square(linop_a, linop_b) - is_non_singular = None - is_self_adjoint = None - is_positive_definite = None - - if is_square: - is_non_singular = registrations_util.combined_non_singular_hint( - linop_a, linop_b) - elif is_square is False: # pylint:disable=g-bool-id-comparison - is_non_singular = False - is_self_adjoint = False - is_positive_definite = False - - return linear_operator_composition.LinearOperatorComposition( - operators=[linop_a, linop_b], - is_non_singular=is_non_singular, - is_self_adjoint=is_self_adjoint, - is_positive_definite=is_positive_definite, - is_square=is_square, - ) - -# Identity - - -@linear_operator_algebra.RegisterMatmul( - linear_operator_identity.LinearOperatorIdentity, - linear_operator.LinearOperator) -def _matmul_linear_operator_identity_left(identity, linop): - del identity - return linop - - -@linear_operator_algebra.RegisterMatmul( - linear_operator.LinearOperator, - linear_operator_identity.LinearOperatorIdentity) -def _matmul_linear_operator_identity_right(linop, identity): - del identity - return linop - - -@linear_operator_algebra.RegisterMatmul( - linear_operator_identity.LinearOperatorScaledIdentity, - linear_operator_identity.LinearOperatorScaledIdentity) -def _matmul_linear_operator_scaled_identity(linop_a, linop_b): - """Matmul of two ScaledIdentity `LinearOperators`.""" - return linear_operator_identity.LinearOperatorScaledIdentity( - num_rows=linop_a.domain_dimension_tensor(), - multiplier=linop_a.multiplier * linop_b.multiplier, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_a, linop_b), - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_a, linop_b), - is_positive_definite=( - registrations_util.combined_commuting_positive_definite_hint( - linop_a, linop_b)), - is_square=True) - - -# Zeros - - -@linear_operator_algebra.RegisterMatmul( - linear_operator.LinearOperator, - linear_operator_zeros.LinearOperatorZeros) -def _matmul_linear_operator_zeros_right(linop, zeros): - if not zeros.is_square or not linop.is_square: - raise ValueError("Matmul with non-square `LinearOperator`s or non-square " - "`LinearOperatorZeros` not supported at this time.") - return zeros - - -@linear_operator_algebra.RegisterMatmul( - linear_operator_zeros.LinearOperatorZeros, - linear_operator.LinearOperator) -def _matmul_linear_operator_zeros_left(zeros, linop): - if not zeros.is_square or not linop.is_square: - raise ValueError("Matmul with non-square `LinearOperator`s or non-square " - "`LinearOperatorZeros` not supported at this time.") - return zeros - - -# Diag. - - -@linear_operator_algebra.RegisterMatmul( - linear_operator_diag.LinearOperatorDiag, - linear_operator_diag.LinearOperatorDiag) -def _matmul_linear_operator_diag(linop_a, linop_b): - return linear_operator_diag.LinearOperatorDiag( - diag=linop_a.diag * linop_b.diag, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_a, linop_b), - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_a, linop_b), - is_positive_definite=( - registrations_util.combined_commuting_positive_definite_hint( - linop_a, linop_b)), - is_square=True) - - -@linear_operator_algebra.RegisterMatmul( - linear_operator_diag.LinearOperatorDiag, - linear_operator_identity.LinearOperatorScaledIdentity) -def _matmul_linear_operator_diag_scaled_identity_right( - linop_diag, linop_scaled_identity): - return linear_operator_diag.LinearOperatorDiag( - diag=linop_diag.diag * linop_scaled_identity.multiplier, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_diag, linop_scaled_identity), - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_diag, linop_scaled_identity), - is_positive_definite=( - registrations_util.combined_commuting_positive_definite_hint( - linop_diag, linop_scaled_identity)), - is_square=True) - - -@linear_operator_algebra.RegisterMatmul( - linear_operator_identity.LinearOperatorScaledIdentity, - linear_operator_diag.LinearOperatorDiag) -def _matmul_linear_operator_diag_scaled_identity_left( - linop_scaled_identity, linop_diag): - return linear_operator_diag.LinearOperatorDiag( - diag=linop_diag.diag * linop_scaled_identity.multiplier, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_diag, linop_scaled_identity), - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_diag, linop_scaled_identity), - is_positive_definite=( - registrations_util.combined_commuting_positive_definite_hint( - linop_diag, linop_scaled_identity)), - is_square=True) - - -@linear_operator_algebra.RegisterMatmul( - linear_operator_diag.LinearOperatorDiag, - linear_operator_lower_triangular.LinearOperatorLowerTriangular) -def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular): - return linear_operator_lower_triangular.LinearOperatorLowerTriangular( - tril=linop_diag.diag[..., None] * linop_triangular.to_dense(), - is_non_singular=registrations_util.combined_non_singular_hint( - linop_diag, linop_triangular), - # This is safe to do since the Triangular matrix is only self-adjoint - # when it is a diagonal matrix, and hence commutes. - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_diag, linop_triangular), - is_positive_definite=None, - is_square=True) - - -@linear_operator_algebra.RegisterMatmul( - linear_operator_lower_triangular.LinearOperatorLowerTriangular, - linear_operator_diag.LinearOperatorDiag) -def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag): - return linear_operator_lower_triangular.LinearOperatorLowerTriangular( - tril=linop_triangular.to_dense() * linop_diag.diag, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_diag, linop_triangular), - # This is safe to do since the Triangular matrix is only self-adjoint - # when it is a diagonal matrix, and hence commutes. - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_diag, linop_triangular), - is_positive_definite=None, - is_square=True) - -# Circulant. - - -# pylint: disable=protected-access -@linear_operator_algebra.RegisterMatmul( - linear_operator_circulant._BaseLinearOperatorCirculant, - linear_operator_circulant._BaseLinearOperatorCirculant) -def _matmul_linear_operator_circulant_circulant(linop_a, linop_b): - if not isinstance(linop_a, linop_b.__class__): - return _matmul_linear_operator(linop_a, linop_b) - - return linop_a.__class__( - spectrum=linop_a.spectrum * linop_b.spectrum, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_a, linop_b), - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_a, linop_b), - is_positive_definite=( - registrations_util.combined_commuting_positive_definite_hint( - linop_a, linop_b)), - is_square=True) -# pylint: enable=protected-access - -# Block Diag - - -@linear_operator_algebra.RegisterMatmul( - linear_operator_block_diag.LinearOperatorBlockDiag, - linear_operator_block_diag.LinearOperatorBlockDiag) -def _matmul_linear_operator_block_diag_block_diag(linop_a, linop_b): - return linear_operator_block_diag.LinearOperatorBlockDiag( - operators=[ - o1.matmul(o2) for o1, o2 in zip( - linop_a.operators, linop_b.operators)], - is_non_singular=registrations_util.combined_non_singular_hint( - linop_a, linop_b), - # In general, a product of self-adjoint positive-definite block diagonal - # matrices is not self = self - adjoint. - is_self_adjoint=None, - # In general, a product of positive-definite block diagonal matrices is - # not positive-definite. - is_positive_definite=None, - is_square=True) - -import numpy as np -from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg -from tensorflow_probability.python.internal.backend.numpy import ops as _ops -from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape - -from tensorflow_probability.python.internal.backend.numpy import private -distribution_util = private.LazyLoader( - "distribution_util", globals(), - "tensorflow_probability.substrates.numpy.internal.distribution_util") -tensorshape_util = private.LazyLoader( - "tensorshape_util", globals(), - "tensorflow_probability.substrates.numpy.internal.tensorshape_util") -prefer_static = private.LazyLoader( - "prefer_static", globals(), - "tensorflow_probability.substrates.numpy.internal.prefer_static") - diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/registrations_util.py b/tensorflow_probability/python/internal/backend/numpy/gen/property_hint_util.py similarity index 98% rename from tensorflow_probability/python/internal/backend/numpy/gen/registrations_util.py rename to tensorflow_probability/python/internal/backend/numpy/gen/property_hint_util.py index 506c4c6bbe..3c97f238f1 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/registrations_util.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/property_hint_util.py @@ -31,7 +31,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common utilities for registering LinearOperator methods.""" +"""Common utilities for LinearOperator property hints.""" # Note: only use this method in the commuting case. diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/solve_registrations.py b/tensorflow_probability/python/internal/backend/numpy/gen/solve_registrations.py deleted file mode 100644 index 958c2fb4b1..0000000000 --- a/tensorflow_probability/python/internal/backend/numpy/gen/solve_registrations.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright 2020 The TensorFlow Probability Authors. All Rights Reserved. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# THIS FILE IS AUTO-GENERATED BY `gen_linear_operators.py`. -# DO NOT MODIFY DIRECTLY. -# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -# pylint: disable=g-import-not-at-top -# pylint: disable=g-direct-tensorflow-import -# pylint: disable=g-bad-import-order -# pylint: disable=unused-import -# pylint: disable=line-too-long -# pylint: disable=reimported -# pylint: disable=g-bool-id-comparison -# pylint: disable=g-statement-before-imports -# pylint: disable=bad-continuation -# pylint: disable=useless-import-alias -# pylint: disable=property-with-parameters -# pylint: disable=trailing-whitespace -# pylint: disable=g-inconsistent-quotes - -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Registrations for LinearOperator.solve.""" - -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_block_diag -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_circulant -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_composition -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_diag -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_identity -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_inversion -from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_lower_triangular -from tensorflow_probability.python.internal.backend.numpy.gen import registrations_util - - -# By default, use a LinearOperatorComposition to delay the computation. -@linear_operator_algebra.RegisterSolve( - linear_operator.LinearOperator, linear_operator.LinearOperator) -def _solve_linear_operator(linop_a, linop_b): - """Generic solve of two `LinearOperator`s.""" - is_square = registrations_util.is_square(linop_a, linop_b) - is_non_singular = None - is_self_adjoint = None - is_positive_definite = None - - if is_square: - is_non_singular = registrations_util.combined_non_singular_hint( - linop_a, linop_b) - elif is_square is False: # pylint:disable=g-bool-id-comparison - is_non_singular = False - is_self_adjoint = False - is_positive_definite = False - - return linear_operator_composition.LinearOperatorComposition( - operators=[ - linear_operator_inversion.LinearOperatorInversion(linop_a), - linop_b - ], - is_non_singular=is_non_singular, - is_self_adjoint=is_self_adjoint, - is_positive_definite=is_positive_definite, - is_square=is_square, - ) - - -@linear_operator_algebra.RegisterSolve( - linear_operator_inversion.LinearOperatorInversion, - linear_operator.LinearOperator) -def _solve_inverse_linear_operator(linop_a, linop_b): - """Solve inverse of generic `LinearOperator`s.""" - return linop_a.operator.matmul(linop_b) - - -# Identity -@linear_operator_algebra.RegisterSolve( - linear_operator_identity.LinearOperatorIdentity, - linear_operator.LinearOperator) -def _solve_linear_operator_identity_left(identity, linop): - del identity - return linop - - -@linear_operator_algebra.RegisterSolve( - linear_operator.LinearOperator, - linear_operator_identity.LinearOperatorIdentity) -def _solve_linear_operator_identity_right(linop, identity): - del identity - return linop.inverse() - - -@linear_operator_algebra.RegisterSolve( - linear_operator_identity.LinearOperatorScaledIdentity, - linear_operator_identity.LinearOperatorScaledIdentity) -def _solve_linear_operator_scaled_identity(linop_a, linop_b): - """Solve of two ScaledIdentity `LinearOperators`.""" - return linear_operator_identity.LinearOperatorScaledIdentity( - num_rows=linop_a.domain_dimension_tensor(), - multiplier=linop_b.multiplier / linop_a.multiplier, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_a, linop_b), - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_a, linop_b), - is_positive_definite=( - registrations_util.combined_commuting_positive_definite_hint( - linop_a, linop_b)), - is_square=True) - - -# Diag. - - -@linear_operator_algebra.RegisterSolve( - linear_operator_diag.LinearOperatorDiag, - linear_operator_diag.LinearOperatorDiag) -def _solve_linear_operator_diag(linop_a, linop_b): - return linear_operator_diag.LinearOperatorDiag( - diag=linop_b.diag / linop_a.diag, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_a, linop_b), - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_a, linop_b), - is_positive_definite=( - registrations_util.combined_commuting_positive_definite_hint( - linop_a, linop_b)), - is_square=True) - - -@linear_operator_algebra.RegisterSolve( - linear_operator_diag.LinearOperatorDiag, - linear_operator_identity.LinearOperatorScaledIdentity) -def _solve_linear_operator_diag_scaled_identity_right( - linop_diag, linop_scaled_identity): - return linear_operator_diag.LinearOperatorDiag( - diag=linop_scaled_identity.multiplier / linop_diag.diag, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_diag, linop_scaled_identity), - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_diag, linop_scaled_identity), - is_positive_definite=( - registrations_util.combined_commuting_positive_definite_hint( - linop_diag, linop_scaled_identity)), - is_square=True) - - -@linear_operator_algebra.RegisterSolve( - linear_operator_identity.LinearOperatorScaledIdentity, - linear_operator_diag.LinearOperatorDiag) -def _solve_linear_operator_diag_scaled_identity_left( - linop_scaled_identity, linop_diag): - return linear_operator_diag.LinearOperatorDiag( - diag=linop_diag.diag / linop_scaled_identity.multiplier, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_diag, linop_scaled_identity), - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_diag, linop_scaled_identity), - is_positive_definite=( - registrations_util.combined_commuting_positive_definite_hint( - linop_diag, linop_scaled_identity)), - is_square=True) - - -@linear_operator_algebra.RegisterSolve( - linear_operator_diag.LinearOperatorDiag, - linear_operator_lower_triangular.LinearOperatorLowerTriangular) -def _solve_linear_operator_diag_tril(linop_diag, linop_triangular): - return linear_operator_lower_triangular.LinearOperatorLowerTriangular( - tril=linop_triangular.to_dense() / linop_diag.diag[..., None], - is_non_singular=registrations_util.combined_non_singular_hint( - linop_diag, linop_triangular), - # This is safe to do since the Triangular matrix is only self-adjoint - # when it is a diagonal matrix, and hence commutes. - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_diag, linop_triangular), - is_positive_definite=None, - is_square=True) - - -# Circulant. - - -# pylint: disable=protected-access -@linear_operator_algebra.RegisterSolve( - linear_operator_circulant._BaseLinearOperatorCirculant, - linear_operator_circulant._BaseLinearOperatorCirculant) -def _solve_linear_operator_circulant_circulant(linop_a, linop_b): - if not isinstance(linop_a, linop_b.__class__): - return _solve_linear_operator(linop_a, linop_b) - - return linop_a.__class__( - spectrum=linop_b.spectrum / linop_a.spectrum, - is_non_singular=registrations_util.combined_non_singular_hint( - linop_a, linop_b), - is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( - linop_a, linop_b), - is_positive_definite=( - registrations_util.combined_commuting_positive_definite_hint( - linop_a, linop_b)), - is_square=True) -# pylint: enable=protected-access - - -# Block Diag - - -@linear_operator_algebra.RegisterSolve( - linear_operator_block_diag.LinearOperatorBlockDiag, - linear_operator_block_diag.LinearOperatorBlockDiag) -def _solve_linear_operator_block_diag_block_diag(linop_a, linop_b): - return linear_operator_block_diag.LinearOperatorBlockDiag( - operators=[ - o1.solve(o2) for o1, o2 in zip( - linop_a.operators, linop_b.operators)], - is_non_singular=registrations_util.combined_non_singular_hint( - linop_a, linop_b), - # In general, a solve of self-adjoint positive-definite block diagonal - # matrices is not self = self - adjoint. - is_self_adjoint=None, - # In general, a solve of positive-definite block diagonal matrices is - # not positive-definite. - is_positive_definite=None, - is_square=True) - -import numpy as np -from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg -from tensorflow_probability.python.internal.backend.numpy import ops as _ops -from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape - -from tensorflow_probability.python.internal.backend.numpy import private -distribution_util = private.LazyLoader( - "distribution_util", globals(), - "tensorflow_probability.substrates.numpy.internal.distribution_util") -tensorshape_util = private.LazyLoader( - "tensorshape_util", globals(), - "tensorflow_probability.substrates.numpy.internal.tensorshape_util") -prefer_static = private.LazyLoader( - "prefer_static", globals(), - "tensorflow_probability.substrates.numpy.internal.prefer_static") - diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py b/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py index 88dbd14c7b..e6747ca9bd 100755 --- a/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py @@ -90,7 +90,7 @@ class StructuredValue: """Helper classes for tensor shape inference.""" import functools import operator -from typing import Optional, Sequence, Type +from typing import Optional, Sequence, Type, Union # from tensorflow.core.framework import tensor_shape_pb2 # from tensorflow.core.function import trace_type @@ -176,8 +176,11 @@ def disable_v2_tensorshape(): @tf_export( - "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"]) -def dimension_value(dimension): + "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"] +) +def dimension_value( + dimension: Union["Dimension", int, None] +) -> Union[int, None]: """Compatibility utility required to allow for both V1 and V2 behavior in TF. Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to @@ -211,7 +214,7 @@ def dimension_value(dimension): @tf_export( "compat.dimension_at_index", v1=["dimension_at_index", "compat.dimension_at_index"]) -def dimension_at_index(shape, index): +def dimension_at_index(shape, index) -> "Dimension": """Compatibility utility required to allow for both V1 and V2 behavior in TF. Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to @@ -1354,8 +1357,28 @@ def most_specific_common_supertype( @doc_controls.do_not_doc_inheritable def placeholder_value(self, placeholder_context): - raise NotImplementedError("A graph placeholder is not currently supported" - "for an object of type: TensorShape.") + """See tf.types.experimental.TraceType base class.""" + return super().placeholder_value(placeholder_context) + + @doc_controls.do_not_doc_inheritable + def from_tensors(self, tensors): + """See tf.types.experimental.TraceType base class.""" + return super().from_tensors(tensors) + + @doc_controls.do_not_doc_inheritable + def to_tensors(self, value): + """See tf.types.experimental.TraceType base class.""" + return super().to_tensors(value) + + @doc_controls.do_not_doc_inheritable + def flatten(self): + """See tf.types.experimental.TraceType base class.""" + return super().flatten() + + @doc_controls.do_not_doc_inheritable + def cast(self, value, cast_context): + """See tf.types.experimental.TraceType base class.""" + return super().cast(value, cast_context) @classmethod def experimental_type_proto(cls) -> Type[tensor_shape_pb2.TensorShapeProto]: @@ -1435,7 +1458,7 @@ def assert_is_compatible_with(self, other): if not self.is_compatible_with(other): raise ValueError("Shapes %s and %s are incompatible" % (self, other)) - def most_specific_compatible_shape(self, other): + def most_specific_compatible_shape(self, other) -> "TensorShape": """Returns the most specific TensorShape compatible with `self` and `other`. * TensorShape([None, 1]) is the most specific TensorShape compatible with @@ -1593,7 +1616,7 @@ def do_decode(self, value, decode_fn): nested_structure_coder.register_codec(_TensorShapeCodec()) -def as_shape(shape): +def as_shape(shape) -> "TensorShape": """Converts the given object to a TensorShape.""" if isinstance(shape, TensorShape): return shape @@ -1601,7 +1624,7 @@ def as_shape(shape): return TensorShape(shape) -def unknown_shape(rank=None, **kwargs): +def unknown_shape(rank=None, **kwargs) -> "TensorShape": """Returns an unknown TensorShape, optionally with a known rank. Args: diff --git a/tensorflow_probability/python/internal/backend/numpy/linalg.py b/tensorflow_probability/python/internal/backend/numpy/linalg.py index 21ba9fd8a9..3b15d1d947 100644 --- a/tensorflow_probability/python/internal/backend/numpy/linalg.py +++ b/tensorflow_probability/python/internal/backend/numpy/linalg.py @@ -25,13 +25,6 @@ # installing bazel. try: # pylint: disable=unused-import - from tensorflow_probability.python.internal.backend.numpy.gen import adjoint_registrations as _adjoint_registrations - from tensorflow_probability.python.internal.backend.numpy.gen import cholesky_registrations as _cholesky_registrations - from tensorflow_probability.python.internal.backend.numpy.gen import inverse_registrations as _inverse_registrations - from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_algebra as _linear_operator_algebra - from tensorflow_probability.python.internal.backend.numpy.gen import matmul_registrations as _matmul_registrations - from tensorflow_probability.python.internal.backend.numpy.gen import solve_registrations as _solve_registrations - from tensorflow_probability.python.internal.backend.numpy.gen.linear_operator import * from tensorflow_probability.python.internal.backend.numpy.gen.linear_operator_addition import * from tensorflow_probability.python.internal.backend.numpy.gen.linear_operator_adjoint import * @@ -72,7 +65,8 @@ def register_pytrees(env): 'LinearOperatorScaledIdentity': ('multiplier',), 'LinearOperatorInversion': ('operator',), 'LinearOperatorKronecker': ('operators',), - 'LinearOperatorLowRankUpdate': ('base_operator', 'diag_update'), + 'LinearOperatorLowRankUpdate': ( + 'base_operator', 'diag_update', 'u', 'v'), 'LinearOperatorLowerTriangular': ('tril',), 'LinearOperatorPermutation': ('perm',), 'LinearOperatorToeplitz': ('col', 'row'), diff --git a/tensorflow_probability/python/internal/backend/numpy/numpy_math.py b/tensorflow_probability/python/internal/backend/numpy/numpy_math.py index 40b3a2526b..2037d40e7e 100644 --- a/tensorflow_probability/python/internal/backend/numpy/numpy_math.py +++ b/tensorflow_probability/python/internal/backend/numpy/numpy_math.py @@ -17,6 +17,7 @@ import collections import functools import numpy as np +import numpy as onp # Disable JAX rewrite. # pylint: disable=reimported from tensorflow_probability.python.internal.backend.numpy import _utils as utils from tensorflow_probability.python.internal.backend.numpy.numpy_array import _reverse @@ -165,10 +166,15 @@ def _astuple(x): """Attempt to convert the given argument to be a Python tuple.""" - try: - return (int(x),) - except TypeError: - pass + # Numpy used to allow casting a size-1 ndarray to python scalar literal types. + # In version 1.25 this was deprecated, causing a warning to be issued in the + # below try/except. To avoid that, we just fall through in the case of an + # np.ndarray. + if not isinstance(x, onp.ndarray): + try: + return (int(x),) + except TypeError: + pass try: return tuple(x) diff --git a/tensorflow_probability/python/internal/backend/numpy/numpy_test.py b/tensorflow_probability/python/internal/backend/numpy/numpy_test.py index ce8288496f..56c3641350 100644 --- a/tensorflow_probability/python/internal/backend/numpy/numpy_test.py +++ b/tensorflow_probability/python/internal/backend/numpy/numpy_test.py @@ -43,7 +43,7 @@ from tensorflow_probability.python.internal import test_util from tensorflow_probability.python.internal.backend import numpy as nptf from tensorflow_probability.python.internal.backend.numpy import functional_ops as np_pfor -from tensorflow.python.ops import parallel_for as tf_pfor # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops.parallel_for import control_flow_ops as tf_pfor_control_flow_ops # pylint: disable=g-direct-tensorflow-import # Allows us to test low-level TF:XLA match. @@ -1120,7 +1120,9 @@ def _not_implemented(*args, **kwargs): xla_const_args=(1,)), TestCase( 'math.reduce_prod', [ - array_axis_tuples(allow_multi_axis=True), + array_axis_tuples( + # TODO(b/298224187) TF produces 0, np NaN for large elements. + elements=floats(-1e6, 1e6), allow_multi_axis=True), array_axis_tuples(dtype=np.int32, allow_multi_axis=True) ], xla_const_args=(1,)), @@ -1175,7 +1177,7 @@ def _not_implemented(*args, **kwargs): allow_nan=False, allow_infinity=False)) ], - xla_rtol=1e-4), + atol=1e-4), TestCase('math.softmax', [ single_arrays( shape=shapes(min_dims=1), @@ -1217,8 +1219,12 @@ def _not_implemented(*args, **kwargs): # keywords=None, defaults=(0, False, False, None)) TestCase( 'math.cumprod', [ - hps.tuples(array_axis_tuples(), hps.booleans(), - hps.booleans()).map(lambda x: x[0] + (x[1], x[2])) + hps.tuples( + array_axis_tuples( + # TODO(b/298224187) TF produces 0, np NaN for large inputs. + elements=floats(min_value=-1e12, max_value=1e12)), + hps.booleans(), + hps.booleans()).map(lambda x: x[0] + (x[1], x[2])) ], xla_const_args=(1, 2, 3)), TestCase( @@ -1260,9 +1266,11 @@ def _not_implemented(*args, **kwargs): ]), TestCase('math.abs', [single_arrays()]), TestCase('math.acos', [single_arrays(elements=floats(-1., 1.))]), - TestCase('math.acosh', [single_arrays(elements=positive_floats())]), + TestCase('math.acosh', [single_arrays(elements=positive_floats())], + atol=1e-4), TestCase('math.asin', [single_arrays(elements=floats(-1., 1.))]), - TestCase('math.asinh', [single_arrays(elements=positive_floats())]), + TestCase('math.asinh', [single_arrays(elements=positive_floats())], + atol=1e-4), TestCase('math.atan', [single_arrays()]), TestCase('math.atanh', [single_arrays(elements=floats(-1., 1.))]), TestCase( @@ -1296,7 +1304,8 @@ def _not_implemented(*args, **kwargs): TestCase('math.is_inf', [single_arrays()]), TestCase('math.is_nan', [single_arrays()]), TestCase('math.lgamma', [single_arrays(elements=positive_floats())]), - TestCase('math.log', [single_arrays(elements=positive_floats())]), + TestCase('math.log', [single_arrays(elements=positive_floats())], + atol=1e-4), TestCase('math.log1p', [single_arrays(elements=floats(min_value=-1 + 1e-6))], xla_atol=1e-4, xla_rtol=1e-4), @@ -1316,11 +1325,11 @@ def _not_implemented(*args, **kwargs): TestCase('math.sign', [single_arrays()]), TestCase('math.sin', [single_arrays()]), TestCase('math.sinh', [single_arrays(elements=floats(-100., 100.))]), - TestCase('math.softplus', [single_arrays()]), + TestCase('math.softplus', [single_arrays()], atol=1e-4), TestCase('math.sqrt', [single_arrays(elements=positive_floats())]), TestCase('math.square', [single_arrays()]), TestCase('math.tan', [single_arrays()]), - TestCase('math.tanh', [single_arrays()]), + TestCase('math.tanh', [single_arrays()], atol=1e-4), # ArgSpec(args=['x', 'q', 'name'], varargs=None, keywords=None, # defaults=(None,)) @@ -1367,9 +1376,11 @@ def _not_implemented(*args, **kwargs): TestCase('math.xdivy', [n_same_shape(n=2, elements=[floats(), non_zero_floats()])]), TestCase('math.xlogy', - [n_same_shape(n=2, elements=[floats(), positive_floats()])]), + [n_same_shape(n=2, elements=[floats(), positive_floats()])], + atol=1e-4, rtol=1e-3), TestCase('math.xlog1py', - [n_same_shape(n=2, elements=[floats(), positive_floats()])]), + [n_same_shape(n=2, elements=[floats(), positive_floats()])], + atol=1e-4, rtol=1e-3), TestCase('nn.conv2d', [conv2d_params()], disabled=NUMPY_MODE), TestCase( 'nn.sparse_softmax_cross_entropy_with_logits', [sparse_xent_params()], @@ -1821,7 +1832,7 @@ def test_foldl_struct_in_alt_out(self): def test_pfor(self): self.assertAllEqual( - self.evaluate(tf_pfor.pfor(lambda x: tf.ones([]), 7)), + self.evaluate(tf_pfor_control_flow_ops.pfor(lambda x: tf.ones([]), 7)), np_pfor.pfor(lambda x: nptf.ones([]), 7)) def test_pfor_with_closure(self): @@ -1832,7 +1843,7 @@ def tf_fn(x): def np_fn(x): return nptf.gather(val, x)**2 self.assertAllEqual( - self.evaluate(tf_pfor.pfor(tf_fn, 7)), + self.evaluate(tf_pfor_control_flow_ops.pfor(tf_fn, 7)), np_pfor.pfor(np_fn, 7)) def test_pfor_with_closure_multi_out(self): @@ -1843,7 +1854,7 @@ def tf_fn(x): def np_fn(x): return nptf.gather(val, x)**2, nptf.gather(val, x) self.assertAllEqual( - self.evaluate(tf_pfor.pfor(tf_fn, 7)), + self.evaluate(tf_pfor_control_flow_ops.pfor(tf_fn, 7)), np_pfor.pfor(np_fn, 7)) def test_convert_variable_to_tensor(self): @@ -1993,7 +2004,6 @@ def assert_same_dtype(x, y): tensorflow_value = post_processor(tensorflow_value) if assert_shape_only: - def assert_same_shape(x, y): self.assertAllEqual(x.shape, y.shape) diff --git a/tensorflow_probability/python/internal/backend/numpy/ops.py b/tensorflow_probability/python/internal/backend/numpy/ops.py index 0cb8cc9ddb..2d396f5ef4 100644 --- a/tensorflow_probability/python/internal/backend/numpy/ops.py +++ b/tensorflow_probability/python/internal/backend/numpy/ops.py @@ -218,10 +218,14 @@ def _default_convert_to_tensor(value, dtype=None): """Default tensor conversion function for array, bool, int, float, and complex.""" if JAX_MODE: # TODO(b/223267515): We shouldn't need to specialize here. - if 'PRNGKeyArray' in str(type(value)): + if hasattr(value, 'dtype') and jax.dtypes.issubdtype( + value.dtype, jax.dtypes.prng_key + ): return value if isinstance(value, (list, tuple)) and value: - if 'PRNGKeyArray' in str(type(value[0])): + if hasattr(value[0], 'dtype') and jax.dtypes.issubdtype( + value[0].dtype, jax.dtypes.prng_key + ): return np.stack(value, axis=0) inferred_dtype = _infer_dtype(value, np.float32) diff --git a/tensorflow_probability/python/internal/dtype_util_test.py b/tensorflow_probability/python/internal/dtype_util_test.py index abfcc7ca11..861b99bebc 100644 --- a/tensorflow_probability/python/internal/dtype_util_test.py +++ b/tensorflow_probability/python/internal/dtype_util_test.py @@ -74,37 +74,44 @@ def testCommonStructuredDtype(self): w = structured_dtype_obj(None) # Check that structured dtypes unify correctly. - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype([w, x, y, z]), {'a': tf.float32, 'b': (None, tf.float64)}) # Check that dict `args` works and that `dtype_hint` works. dtype_hint = {'a': tf.int32, 'b': (tf.int32, None)} - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype( {'x': x, 'y': y, 'z': z}, dtype_hint=dtype_hint), {'a': tf.float32, 'b': (tf.int32, tf.float64)}) - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype([w], dtype_hint=dtype_hint), dtype_hint) # Check that non-nested dtype_hint broadcasts. - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype([y, z], dtype_hint=tf.int32), {'a': tf.int32, 'b': (tf.int32, tf.float64)}) # Check that structured `dtype_hint` behaves as expected. s = {'a': [tf.ones([3], tf.float32), 4.], 'b': (np.float64(2.), None)} - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype([x, s], dtype_hint=z.dtype), {'a': tf.float32, 'b': (tf.float64, None)}) - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype([y, s], dtype_hint=z.dtype), {'a': tf.float32, 'b': (tf.float64, tf.float64)}) t = {'a': [[1., 2., 3.]], 'b': {'c': np.float64(1.), 'd': np.float64(2.)}} - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype( [w, t], dtype_hint={'a': tf.float32, 'b': tf.float32}), diff --git a/tensorflow_probability/python/internal/loop_util.py b/tensorflow_probability/python/internal/loop_util.py index 4eaa41ae0b..f7272c62b1 100644 --- a/tensorflow_probability/python/internal/loop_util.py +++ b/tensorflow_probability/python/internal/loop_util.py @@ -52,8 +52,8 @@ def _convert_variables_to_tensors(values): def tensor_array_from_element(elem, size=None, **kwargs): """Construct a tf.TensorArray of elements with the dtype + shape of `elem`.""" - if JAX_MODE and isinstance(elem, jax.random.PRNGKeyArray): - # If `trace_elt` is a `PRNGKeyArray`, then then it is not possible to create + if JAX_MODE and jax.dtypes.issubdtype(elem.dtype, jax.dtypes.prng_key): + # If `trace_elt` is a typed prng key, then then it is not possible to create # a matching (i.e., with the same custom PRNG) instance/array inside # `TensorArray.__init__` given just a `dtype`, `size`, and `shape`. # diff --git a/tensorflow_probability/python/internal/samplers_test.py b/tensorflow_probability/python/internal/samplers_test.py index 2b860b93f9..3ae5fdfd0e 100644 --- a/tensorflow_probability/python/internal/samplers_test.py +++ b/tensorflow_probability/python/internal/samplers_test.py @@ -37,7 +37,7 @@ def setUp(self): super().setUp() if JAX_MODE and FLAGS.test_tfp_jax_prng != 'default': - from jax.config import config # pylint: disable=g-import-not-at-top + from jax import config # pylint: disable=g-import-not-at-top config.update('jax_default_prng_impl', FLAGS.test_tfp_jax_prng) @test_util.substrate_disable_stateful_random_test diff --git a/tensorflow_probability/python/internal/test_util.py b/tensorflow_probability/python/internal/test_util.py index 9af934ed5f..0da39d05b5 100644 --- a/tensorflow_probability/python/internal/test_util.py +++ b/tensorflow_probability/python/internal/test_util.py @@ -163,8 +163,12 @@ def evaluate(self, x): def _evaluate(x): if x is None: return x - # TODO(b/223267515): Improve handling of JAX PRNGKeyArray objects. - if JAX_MODE and isinstance(x, jax.random.PRNGKeyArray): + # TODO(b/223267515): Improve handling of JAX typed PRNG keys. + if ( + JAX_MODE + and hasattr(x, 'dtype') + and jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key) + ): return x return np.array(x) return tf.nest.map_structure(_evaluate, x, expand_composites=True) @@ -177,11 +181,15 @@ def _GetNdArray(self, a): def _evaluateTensors(self, a, b): if JAX_MODE: import jax # pylint: disable=g-import-not-at-top - # HACK: In assertions (like self.assertAllClose), convert PRNGKeyArrays - # to "normal" arrays so they can be compared with our existing machinery. - if isinstance(a, jax.random.PRNGKeyArray): + # HACK: In assertions (like self.assertAllClose), convert typed PRNG keys + # to raw arrays so they can be compared with our existing machinery. + if hasattr(a, 'dtype') and jax.dtypes.issubdtype( + a.dtype, jax.dtypes.prng_key + ): a = jax.random.key_data(a) - if isinstance(b, jax.random.PRNGKeyArray): + if hasattr(b, 'dtype') and jax.dtypes.issubdtype( + b.dtype, jax.dtypes.prng_key + ): b = jax.random.key_data(b) if tf.is_tensor(a) and tf.is_tensor(b): (a, b) = self.evaluate([a, b]) @@ -2013,7 +2021,7 @@ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name def main(jax_mode=JAX_MODE, jax_enable_x64=True): """Test main function that injects a custom loader.""" if jax_mode and jax_enable_x64: - from jax.config import config # pylint: disable=g-import-not-at-top + from jax import config # pylint: disable=g-import-not-at-top config.update('jax_enable_x64', True) # This logic is borrowed from TensorFlow. diff --git a/tensorflow_probability/python/internal/tf_keras.py b/tensorflow_probability/python/internal/tf_keras.py new file mode 100644 index 0000000000..5f1cdf4cff --- /dev/null +++ b/tensorflow_probability/python/internal/tf_keras.py @@ -0,0 +1,38 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utility for importing the correct version of Keras.""" + +import tensorflow.compat.v2 as tf + +# pylint: disable=g-bad-import-order +# pylint: disable=g-import-not-at-top +# pylint: disable=unused-import +# pylint: disable=wildcard-import +_keras_version_fn = getattr(tf.keras, "version", None) +if _keras_version_fn and _keras_version_fn().startswith("3."): + from tf_keras import * + from tf_keras import __internal__ + import tf_keras.api._v1.keras.__internal__.legacy.layers as tf1_layers + import tf_keras.api._v1.keras as v1 +else: + from tensorflow.compat.v2.keras import * + from tensorflow.compat.v2.keras import __internal__ + import tensorflow.compat.v1 as tf1 + v1 = tf1.keras + tf1_layers = tf1.layers + del tf1 + +del tf +del _keras_version_fn diff --git a/tensorflow_probability/python/internal/trainable_state_util_test.py b/tensorflow_probability/python/internal/trainable_state_util_test.py index aeb374c037..47bcfea474 100644 --- a/tensorflow_probability/python/internal/trainable_state_util_test.py +++ b/tensorflow_probability/python/internal/trainable_state_util_test.py @@ -33,6 +33,7 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.internal import trainable_state_util from tensorflow_probability.python.math import gradient from tensorflow_probability.python.math.minimize import minimize @@ -347,7 +348,7 @@ def test_fitting_example(self): trainable_dist = build_trainable_normal( shape=[], seed=test_util.test_seed(sampler_type='stateless')) - optimizer = tf.optimizers.Adam(1.0) + optimizer = tf_keras.optimizers.Adam(1.0) # Find the maximum likelihood distribution given observed data. x_observed = [3., -2., 1.7] losses = minimize( diff --git a/tensorflow_probability/python/internal/vectorization_util.py b/tensorflow_probability/python/internal/vectorization_util.py index efbfa15d32..ce5176fda9 100644 --- a/tensorflow_probability/python/internal/vectorization_util.py +++ b/tensorflow_probability/python/internal/vectorization_util.py @@ -24,7 +24,7 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.util import SeedStream -from tensorflow.python.ops import parallel_for # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops.parallel_for import control_flow_ops # pylint: disable=g-direct-tensorflow-import from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import __all__ = [ @@ -99,7 +99,7 @@ def pfor_loop_body(i): if static_n == 1: draws = pfor_loop_body(0) else: - draws = parallel_for.pfor(pfor_loop_body, n) + draws = control_flow_ops.pfor(pfor_loop_body, n) return tf.nest.map_structure(unflatten, draws, expand_composites=True) return iid_sample_fn diff --git a/tensorflow_probability/python/layers/BUILD b/tensorflow_probability/python/layers/BUILD index 7508b11693..ad7677b477 100644 --- a/tensorflow_probability/python/layers/BUILD +++ b/tensorflow_probability/python/layers/BUILD @@ -54,6 +54,7 @@ py_library( "//tensorflow_probability/python/distributions:kullback_leibler", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:docstring_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/random", "//tensorflow_probability/python/util:seed_stream", ], @@ -67,12 +68,12 @@ py_test( deps = [ ":conv_variational", ":util", - # keras/testing_infra:test_utils dep, # numpy dep, # tensorflow dep, "//tensorflow_probability/python/distributions:independent", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/random:random_ops", "//tensorflow_probability/python/util:seed_stream", ], @@ -90,6 +91,7 @@ py_library( "//tensorflow_probability/python/distributions:kullback_leibler", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:docstring_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/random", "//tensorflow_probability/python/util", ], @@ -102,12 +104,12 @@ py_test( deps = [ ":dense_variational", ":util", - # keras/testing_infra:test_utils dep, # numpy dep, # tensorflow dep, "//tensorflow_probability/python/distributions:independent", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util:seed_stream", ], ) @@ -120,6 +122,7 @@ py_library( deps = [ # tensorflow dep, "//tensorflow_probability/python/distributions:kullback_leibler", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util:seed_stream", ], ) @@ -138,6 +141,7 @@ py_test( "//tensorflow_probability/python/distributions:independent", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -162,6 +166,7 @@ py_library( "//tensorflow_probability/python/distributions:poisson", "//tensorflow_probability/python/distributions:transformed_distribution", "//tensorflow_probability/python/distributions:variational_gaussian_process", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/layers/internal", ], ) @@ -193,6 +198,7 @@ py_test( "//tensorflow_probability/python/distributions:poisson", "//tensorflow_probability/python/distributions:uniform", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:generic", "//tensorflow_probability/python/math/psd_kernels:exponentiated_quadratic", "//tensorflow_probability/python/util:deferred_tensor", @@ -206,6 +212,7 @@ py_library( ], deps = [ # tensorflow dep, + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -215,8 +222,8 @@ py_test( srcs = ["initializers_test.py"], deps = [ ":initializers", - # tensorflow dep, "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -227,6 +234,7 @@ py_library( # tensorflow dep, "//tensorflow_probability/python/bijectors:masked_autoregressive", "//tensorflow_probability/python/distributions:transformed_distribution", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -242,6 +250,7 @@ py_test( "//tensorflow_probability/python/bijectors:masked_autoregressive", "//tensorflow_probability/python/distributions:mvn_diag", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -251,12 +260,12 @@ py_library( "util.py", ], deps = [ - # keras dep, # numpy dep, # tensorflow dep, "//tensorflow_probability/python/distributions:deterministic", "//tensorflow_probability/python/distributions:independent", "//tensorflow_probability/python/distributions:normal", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util", ], ) @@ -268,6 +277,7 @@ py_library( ], deps = [ # tensorflow dep, + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -282,6 +292,7 @@ py_test( "//tensorflow_probability/python/distributions:independent", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -292,6 +303,7 @@ py_library( ], deps = [ # tensorflow dep, + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -304,6 +316,7 @@ py_test( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/layers:weight_norm", ], ) diff --git a/tensorflow_probability/python/layers/conv_variational.py b/tensorflow_probability/python/layers/conv_variational.py index 88e9fc1c8e..2003f96118 100644 --- a/tensorflow_probability/python/layers/conv_variational.py +++ b/tensorflow_probability/python/layers/conv_variational.py @@ -21,9 +21,9 @@ from tensorflow_probability.python.distributions import kullback_leibler as kl_lib from tensorflow_probability.python.distributions import normal as normal_lib from tensorflow_probability.python.internal import docstring_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import util as tfp_layers_util from tensorflow_probability.python.util.seed_stream import SeedStream -from tensorflow.python.layers import utils as tf_layers_util # pylint: disable=g-direct-tensorflow-import from tensorflow.python.ops import nn_ops # pylint: disable=g-direct-tensorflow-import @@ -74,7 +74,7 @@ sample is a `Tensor`.""" -class _ConvVariational(tf.keras.layers.Layer): +class _ConvVariational(tf_keras.layers.Layer): """Abstract nD convolution layer (private, used as implementation base). This layer creates a convolution kernel that is convolved @@ -149,15 +149,15 @@ def __init__( **kwargs) self.rank = rank self.filters = filters - self.kernel_size = tf_layers_util.normalize_tuple( + self.kernel_size = normalize_tuple( kernel_size, rank, 'kernel_size') - self.strides = tf_layers_util.normalize_tuple(strides, rank, 'strides') - self.padding = tf_layers_util.normalize_padding(padding) - self.data_format = tf_layers_util.normalize_data_format(data_format) - self.dilation_rate = tf_layers_util.normalize_tuple( + self.strides = normalize_tuple(strides, rank, 'strides') + self.padding = normalize_padding(padding) + self.data_format = normalize_data_format(data_format) + self.dilation_rate = normalize_tuple( dilation_rate, rank, 'dilation_rate') - self.activation = tf.keras.activations.get(activation) - self.input_spec = tf.keras.layers.InputSpec(ndim=self.rank + 2) + self.activation = tf_keras.activations.get(activation) + self.input_spec = tf_keras.layers.InputSpec(ndim=self.rank + 2) self.kernel_posterior_fn = kernel_posterior_fn self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn self.kernel_prior_fn = kernel_prior_fn @@ -180,7 +180,7 @@ def build(self, input_shape): kernel_shape = self.kernel_size + (input_dim, self.filters) # If self.dtype is None, build weights using the default dtype. - dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) + dtype = tf.as_dtype(self.dtype or tf_keras.backend.floatx()) # Must have a posterior kernel. self.kernel_posterior = self.kernel_posterior_fn( @@ -208,7 +208,7 @@ def build(self, input_shape): dtype, (self.filters,), 'bias_prior', self.trainable, self.add_variable) - self.input_spec = tf.keras.layers.InputSpec( + self.input_spec = tf_keras.layers.InputSpec( ndim=self.rank + 2, axes={channel_axis: input_dim}) self._convolution_op = nn_ops.Convolution( input_shape, @@ -216,7 +216,7 @@ def build(self, input_shape): dilation_rate=self.dilation_rate, strides=self.strides, padding=self.padding.upper(), - data_format=tf_layers_util.convert_data_format( + data_format=convert_data_format( self.data_format, self.rank + 2)) self.built = True @@ -256,7 +256,7 @@ def compute_output_shape(self, input_shape): space = input_shape[1:-1] new_space = [] for i in range(len(space)): - new_dim = tf_layers_util.conv_output_length( + new_dim = conv_output_length( space[i], self.kernel_size[i], padding=self.padding, @@ -268,7 +268,7 @@ def compute_output_shape(self, input_shape): space = input_shape[2:] new_space = [] for i in range(len(space)): - new_dim = tf_layers_util.conv_output_length( + new_dim = conv_output_length( space[i], self.kernel_size[i], padding=self.padding, @@ -295,10 +295,10 @@ def get_config(self): 'padding': self.padding, 'data_format': self.data_format, 'dilation_rate': self.dilation_rate, - 'activation': (tf.keras.activations.serialize(self.activation) + 'activation': (tf_keras.activations.serialize(self.activation) if self.activation else None), 'activity_regularizer': - tf.keras.initializers.serialize(self.activity_regularizer), + tf_keras.initializers.serialize(self.activity_regularizer), } function_keys = [ 'kernel_posterior_fn', @@ -491,7 +491,7 @@ def __init__( padding=padding, data_format=data_format, dilation_rate=dilation_rate, - activation=tf.keras.activations.get(activation), + activation=tf_keras.activations.get(activation), activity_regularizer=activity_regularizer, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, @@ -555,11 +555,11 @@ class Conv1DReparameterization(_ConvReparameterization): import tensorflow as tf import tensorflow_probability as tfp - model = tf.keras.Sequential([ - tf.keras.layers.Reshape([128, 1]), + model = tf_keras.Sequential([ + tf_keras.layers.Reshape([128, 1]), tfp.layers.Convolution1DReparameterization( 64, kernel_size=5, padding='SAME', activation=tf.nn.relu), - tf.keras.layers.Flatten(), + tf_keras.layers.Flatten(), tfp.layers.DenseReparameterization(10), ]) @@ -639,7 +639,7 @@ def __init__( padding=padding, data_format=data_format, dilation_rate=dilation_rate, - activation=tf.keras.activations.get(activation), + activation=tf_keras.activations.get(activation), activity_regularizer=activity_regularizer, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, @@ -695,14 +695,14 @@ class Conv2DReparameterization(_ConvReparameterization): import tensorflow as tf import tensorflow_probability as tfp - model = tf.keras.Sequential([ - tf.keras.layers.Reshape([32, 32, 3]), + model = tf_keras.Sequential([ + tf_keras.layers.Reshape([32, 32, 3]), tfp.layers.Convolution2DReparameterization( 64, kernel_size=5, padding='SAME', activation=tf.nn.relu), - tf.keras.layers.MaxPooling2D(pool_size=[2, 2], + tf_keras.layers.MaxPooling2D(pool_size=[2, 2], strides=[2, 2], padding='SAME'), - tf.keras.layers.Flatten(), + tf_keras.layers.Flatten(), tfp.layers.DenseReparameterization(10), ]) @@ -788,7 +788,7 @@ def __init__( padding=padding, data_format=data_format, dilation_rate=dilation_rate, - activation=tf.keras.activations.get(activation), + activation=tf_keras.activations.get(activation), activity_regularizer=activity_regularizer, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, @@ -840,14 +840,14 @@ class Conv3DReparameterization(_ConvReparameterization): import tensorflow as tf import tensorflow_probability as tfp - model = tf.keras.Sequential([ - tf.keras.layers.Reshape([256, 32, 32, 3]), + model = tf_keras.Sequential([ + tf_keras.layers.Reshape([256, 32, 32, 3]), tfp.layers.Convolution3DReparameterization( 64, kernel_size=5, padding='SAME', activation=tf.nn.relu), - tf.keras.layers.MaxPooling3D(pool_size=[2, 2, 2], + tf_keras.layers.MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2], padding='SAME'), - tf.keras.layers.Flatten(), + tf_keras.layers.Flatten(), tfp.layers.DenseReparameterization(10), ]) @@ -934,7 +934,7 @@ def __init__( padding=padding, data_format=data_format, dilation_rate=dilation_rate, - activation=tf.keras.activations.get(activation), + activation=tf_keras.activations.get(activation), activity_regularizer=activity_regularizer, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, @@ -1039,7 +1039,7 @@ def __init__( padding=padding, data_format=data_format, dilation_rate=dilation_rate, - activation=tf.keras.activations.get(activation), + activation=tf_keras.activations.get(activation), activity_regularizer=activity_regularizer, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, @@ -1166,11 +1166,11 @@ class Conv1DFlipout(_ConvFlipout): import tensorflow as tf import tensorflow_probability as tfp - model = tf.keras.Sequential([ - tf.keras.layers.Reshape([128, 1]), + model = tf_keras.Sequential([ + tf_keras.layers.Reshape([128, 1]), tfp.layers.Convolution1DFlipout( 64, kernel_size=5, padding='SAME', activation=tf.nn.relu), - tf.keras.layers.Flatten(), + tf_keras.layers.Flatten(), tfp.layers.DenseFlipout(10), ]) @@ -1254,7 +1254,7 @@ def __init__( padding=padding, data_format=data_format, dilation_rate=dilation_rate, - activation=tf.keras.activations.get(activation), + activation=tf_keras.activations.get(activation), activity_regularizer=activity_regularizer, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, @@ -1309,14 +1309,14 @@ class Conv2DFlipout(_ConvFlipout): import tensorflow as tf import tensorflow_probability as tfp - model = tf.keras.Sequential([ - tf.keras.layers.Reshape([32, 32, 3]), + model = tf_keras.Sequential([ + tf_keras.layers.Reshape([32, 32, 3]), tfp.layers.Convolution2DFlipout( 64, kernel_size=5, padding='SAME', activation=tf.nn.relu), - tf.keras.layers.MaxPooling2D(pool_size=[2, 2], + tf_keras.layers.MaxPooling2D(pool_size=[2, 2], strides=[2, 2], padding='SAME'), - tf.keras.layers.Flatten(), + tf_keras.layers.Flatten(), tfp.layers.DenseFlipout(10), ]) @@ -1406,7 +1406,7 @@ def __init__( padding=padding, data_format=data_format, dilation_rate=dilation_rate, - activation=tf.keras.activations.get(activation), + activation=tf_keras.activations.get(activation), activity_regularizer=activity_regularizer, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, @@ -1461,14 +1461,14 @@ class Conv3DFlipout(_ConvFlipout): import tensorflow as tf import tensorflow_probability as tfp - model = tf.keras.Sequential([ - tf.keras.layers.Reshape([256, 32, 32, 3]), + model = tf_keras.Sequential([ + tf_keras.layers.Reshape([256, 32, 32, 3]), tfp.layers.Convolution3DFlipout( 64, kernel_size=5, padding='SAME', activation=tf.nn.relu), - tf.keras.layers.MaxPooling3D(pool_size=[2, 2, 2], + tf_keras.layers.MaxPooling3D(pool_size=[2, 2, 2], strides=[2, 2, 2], padding='SAME'), - tf.keras.layers.Flatten(), + tf_keras.layers.Flatten(), tfp.layers.DenseFlipout(10), ]) @@ -1559,7 +1559,7 @@ def __init__( padding=padding, data_format=data_format, dilation_rate=dilation_rate, - activation=tf.keras.activations.get(activation), + activation=tf_keras.activations.get(activation), activity_regularizer=activity_regularizer, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, @@ -1581,3 +1581,113 @@ def __init__( Convolution1DFlipout = Conv1DFlipout Convolution2DFlipout = Conv2DFlipout Convolution3DFlipout = Conv3DFlipout + + +def convert_data_format(data_format, ndim): # pylint: disable=missing-function-docstring + if data_format == 'channels_last': + if ndim == 3: + return 'NWC' + elif ndim == 4: + return 'NHWC' + elif ndim == 5: + return 'NDHWC' + else: + raise ValueError(f'Input rank: {ndim} not supported. We only support ' + 'input rank 3, 4 or 5.') + elif data_format == 'channels_first': + if ndim == 3: + return 'NCW' + elif ndim == 4: + return 'NCHW' + elif ndim == 5: + return 'NCDHW' + else: + raise ValueError(f'Input rank: {ndim} not supported. We only support ' + 'input rank 3, 4 or 5.') + else: + raise ValueError(f'Invalid data_format: {data_format}. We only support ' + '"channels_first" or "channels_last"') + + +def normalize_tuple(value, n, name): + """Transforms a single integer or iterable of integers into an integer tuple. + + Args: + value: The value to validate and convert. Could an int, or any iterable + of ints. + n: The size of the tuple to be returned. + name: The name of the argument being validated, e.g. "strides" or + "kernel_size". This is only used to format error messages. + + Returns: + A tuple of n integers. + + Raises: + ValueError: If something else than an int/long or iterable thereof was + passed. + """ + if isinstance(value, int): + return (value,) * n + else: + try: + value_tuple = tuple(value) + except TypeError: + raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} ' + f'integers. Received: {str(value)}') from None + if len(value_tuple) != n: + raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} ' + f'integers. Received: {str(value)}') + for single_value in value_tuple: + try: + int(single_value) + except (ValueError, TypeError): + raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} ' + f'integers. Received: {str(value)} including element ' + f'{str(single_value)} of type ' + f'{str(type(single_value))}') from None + return value_tuple + + +def normalize_data_format(value): + data_format = value.lower() + if data_format not in {'channels_first', 'channels_last'}: + raise ValueError('The `data_format` argument must be one of ' + '"channels_first", "channels_last". Received: ' + f'{str(value)}.') + return data_format + + +def normalize_padding(value): + padding = value.lower() + if padding not in {'valid', 'same'}: + raise ValueError('The `padding` argument must be one of "valid", "same". ' + f'Received: {str(padding)}.') + return padding + + +def conv_output_length(input_length, filter_size, padding, stride, dilation=1): + """Determines output length of a convolution given input length. + + Args: + input_length: integer. + filter_size: integer. + padding: one of "same", "valid", "full". + stride: integer. + dilation: dilation rate, integer. + + Returns: + The output length (integer). + """ + if input_length is None: + return None + assert padding in {'same', 'valid', 'full'} + dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1) + if padding == 'same': + output_length = input_length + elif padding == 'valid': + output_length = input_length - dilated_filter_size + 1 + elif padding == 'full': + output_length = input_length + dilated_filter_size - 1 + else: + raise ValueError(f'Invalid padding: {padding}') + return (output_length + stride - 1) // stride diff --git a/tensorflow_probability/python/layers/conv_variational_test.py b/tensorflow_probability/python/layers/conv_variational_test.py index d842f45c12..3822257aa2 100644 --- a/tensorflow_probability/python/layers/conv_variational_test.py +++ b/tensorflow_probability/python/layers/conv_variational_test.py @@ -26,11 +26,11 @@ from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import conv_variational from tensorflow_probability.python.layers import util from tensorflow_probability.python.random import random_ops from tensorflow_probability.python.util import seed_stream -from tensorflow.python.layers import utils as tf_layers_util from tensorflow.python.ops import nn_ops @@ -217,7 +217,7 @@ def kernel_posterior_fn(dtype, shape, name, trainable, add_variable_fn): if self.data_format == 'channels_first': input_shape = channels_last_to_first(input_shape) - with tf.keras.utils.CustomObjectScope({layer_class.__name__: layer_class}): + with tf_keras.utils.CustomObjectScope({layer_class.__name__: layer_class}): with self.cached_session(): # TODO(scottzhu): reenable the test when the repo switch change reach # the TF PIP package. @@ -369,13 +369,13 @@ def _testConvReparameterization(self, layer_class): # pylint: disable=invalid-n tf.TensorShape(inputs.shape), filter_shape=tf.TensorShape(kernel_shape), padding='SAME', - data_format=tf_layers_util.convert_data_format( + data_format=conv_variational.convert_data_format( self.data_format, inputs.shape.rank)) expected_outputs = convolution_op(inputs, kernel_posterior.result_sample) expected_outputs = tf.nn.bias_add( expected_outputs, bias_posterior.result_sample, - data_format=tf_layers_util.convert_data_format(self.data_format, 4)) + data_format=conv_variational.convert_data_format(self.data_format, 4)) [ expected_outputs_, actual_outputs_, @@ -435,7 +435,7 @@ def _testConvFlipout(self, layer_class): # pylint: disable=invalid-name tf.TensorShape(inputs.shape), filter_shape=tf.TensorShape(kernel_shape), padding='SAME', - data_format=tf_layers_util.convert_data_format( + data_format=conv_variational.convert_data_format( self.data_format, inputs.shape.rank)) expected_kernel_posterior_affine = normal.Normal( @@ -483,7 +483,7 @@ def _testConvFlipout(self, layer_class): # pylint: disable=invalid-name expected_outputs = tf.nn.bias_add( expected_outputs, bias_posterior.result_sample, - data_format=tf_layers_util.convert_data_format(self.data_format, 4)) + data_format=conv_variational.convert_data_format(self.data_format, 4)) [ expected_outputs_, actual_outputs_, @@ -607,7 +607,7 @@ def _testLayerInSequential(self, layer_class): # pylint: disable=invalid-name inputs = self.maybe_transpose_tensor(inputs) outputs = self.maybe_transpose_tensor(outputs) - net = tf.keras.Sequential([ + net = tf_keras.Sequential([ layer_class(filters=2, kernel_size=3, data_format=self.data_format, input_shape=inputs.shape[1:]), layer_class(filters=2, kernel_size=1, data_format=self.data_format)]) @@ -718,7 +718,7 @@ def testSequentialConvolution3DFlipout(self): self._testLayerInSequential(conv_variational.Convolution3DFlipout) def testGradients(self): - net = tf.keras.Sequential([ + net = tf_keras.Sequential([ conv_variational.Convolution1DFlipout( 1, 1, data_format=self.data_format), conv_variational.Convolution1DReparameterization( diff --git a/tensorflow_probability/python/layers/dense_variational.py b/tensorflow_probability/python/layers/dense_variational.py index 2f842016b9..c58ce88061 100644 --- a/tensorflow_probability/python/layers/dense_variational.py +++ b/tensorflow_probability/python/layers/dense_variational.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.distributions import kullback_leibler as kl_lib from tensorflow_probability.python.distributions import normal as normal_lib from tensorflow_probability.python.internal import docstring_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import util as tfp_layers_util from tensorflow_probability.python.util import SeedStream @@ -70,7 +71,7 @@ sample is a `Tensor`.""" -class _DenseVariational(tf.keras.layers.Layer): +class _DenseVariational(tf_keras.layers.Layer): """Abstract densely-connected class (private, used as implementation base). This layer implements the Bayesian variational inference analogue to @@ -115,8 +116,8 @@ def __init__( activity_regularizer=activity_regularizer, **kwargs) self.units = units - self.activation = tf.keras.activations.get(activation) - self.input_spec = tf.keras.layers.InputSpec(min_ndim=2) + self.activation = tf_keras.activations.get(activation) + self.input_spec = tf_keras.layers.InputSpec(min_ndim=2) self.kernel_posterior_fn = kernel_posterior_fn self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn self.kernel_prior_fn = kernel_prior_fn @@ -132,10 +133,10 @@ def build(self, input_shape): if in_size is None: raise ValueError('The last dimension of the inputs to `Dense` ' 'should be defined. Found `None`.') - self._input_spec = tf.keras.layers.InputSpec(min_ndim=2, axes={-1: in_size}) + self._input_spec = tf_keras.layers.InputSpec(min_ndim=2, axes={-1: in_size}) # If self.dtype is None, build weights using the default dtype. - dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) + dtype = tf.as_dtype(self.dtype or tf_keras.backend.floatx()) # Must have a posterior kernel. self.kernel_posterior = self.kernel_posterior_fn( @@ -221,10 +222,10 @@ def get_config(self): """ config = { 'units': self.units, - 'activation': (tf.keras.activations.serialize(self.activation) + 'activation': (tf_keras.activations.serialize(self.activation) if self.activation else None), 'activity_regularizer': - tf.keras.initializers.serialize(self.activity_regularizer), + tf_keras.initializers.serialize(self.activity_regularizer), } function_keys = [ 'kernel_posterior_fn', @@ -346,7 +347,7 @@ class DenseReparameterization(_DenseVariational): import tensorflow as tf import tensorflow_probability as tfp - model = tf.keras.Sequential([ + model = tf_keras.Sequential([ tfp.layers.DenseReparameterization(512, activation=tf.nn.relu), tfp.layers.DenseReparameterization(10), ]) @@ -465,7 +466,7 @@ class DenseLocalReparameterization(_DenseVariational): ```python import tensorflow_probability as tfp - model = tf.keras.Sequential([ + model = tf_keras.Sequential([ tfp.layers.DenseLocalReparameterization(512, activation=tf.nn.relu), tfp.layers.DenseLocalReparameterization(10), ]) @@ -592,7 +593,7 @@ class DenseFlipout(_DenseVariational): ```python import tensorflow_probability as tfp - model = tf.keras.Sequential([ + model = tf_keras.Sequential([ tfp.layers.DenseFlipout(512, activation=tf.nn.relu), tfp.layers.DenseFlipout(10), ]) diff --git a/tensorflow_probability/python/layers/dense_variational_test.py b/tensorflow_probability/python/layers/dense_variational_test.py index 33b53423bf..7f06b1ade5 100644 --- a/tensorflow_probability/python/layers/dense_variational_test.py +++ b/tensorflow_probability/python/layers/dense_variational_test.py @@ -25,6 +25,7 @@ from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import dense_variational from tensorflow_probability.python.layers import util from tensorflow_probability.python.random import random_ops @@ -124,7 +125,7 @@ def kernel_posterior_fn(dtype, shape, name, trainable, add_variable_fn): 'kernel_prior_fn': None, 'bias_posterior_fn': None, 'bias_prior_fn': None} - with tf.keras.utils.CustomObjectScope({layer_class.__name__: layer_class}): + with tf_keras.utils.CustomObjectScope({layer_class.__name__: layer_class}): # TODO(scottzhu): reenable the test when the repo switch change reach # the TF PIP package. self.skipTest('Skip the test until the TF and Keras has a new PIP.') @@ -500,7 +501,7 @@ def testDenseLayersInSequential(self): y = np.random.uniform( -1., 1., size=(data_size, out_size)).astype(np.float32) - model = tf.keras.Sequential([ + model = tf_keras.Sequential([ dense_variational.DenseReparameterization(6, activation=tf.nn.relu), dense_variational.DenseFlipout(6, activation=tf.nn.relu), dense_variational.DenseLocalReparameterization(out_size) @@ -514,7 +515,7 @@ def testDenseLayersInSequential(self): self.assertAllEqual(batch_output.shape, [batch_size, out_size]) def testGradients(self): - net = tf.keras.Sequential([ + net = tf_keras.Sequential([ dense_variational.DenseReparameterization(1), dense_variational.DenseFlipout(1), dense_variational.DenseLocalReparameterization(1) diff --git a/tensorflow_probability/python/layers/dense_variational_v2.py b/tensorflow_probability/python/layers/dense_variational_v2.py index 9f8dd3ebcd..3f6bf70566 100644 --- a/tensorflow_probability/python/layers/dense_variational_v2.py +++ b/tensorflow_probability/python/layers/dense_variational_v2.py @@ -18,13 +18,15 @@ from tensorflow_probability.python.distributions import kullback_leibler +from tensorflow_probability.python.internal import tf_keras -class DenseVariational(tf.keras.layers.Layer): + +class DenseVariational(tf_keras.layers.Layer): """Dense layer with random `kernel` and `bias`. This layer uses variational inference to fit a "surrogate" posterior to the distribution over both the `kernel` matrix and the `bias` terms which are - otherwise used in a manner similar to `tf.keras.layers.Dense`. + otherwise used in a manner similar to `tf_keras.layers.Dense`. This layer fits the "weights posterior" according to the following generative process: @@ -67,12 +69,12 @@ def __init__(self, use_bias: Boolean, whether the layer uses a bias vector. activity_regularizer: Regularizer function applied to the output of the layer (its "activation").. - **kwargs: Extra arguments forwarded to `tf.keras.layers.Layer`. + **kwargs: Extra arguments forwarded to `tf_keras.layers.Layer`. """ if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'),) super(DenseVariational, self).__init__( - activity_regularizer=tf.keras.regularizers.get(activity_regularizer), + activity_regularizer=tf_keras.regularizers.get(activity_regularizer), **kwargs) self.units = int(units) @@ -81,13 +83,13 @@ def __init__(self, self._kl_divergence_fn = _make_kl_divergence_penalty( kl_use_exact, weight=kl_weight) - self.activation = tf.keras.activations.get(activation) + self.activation = tf_keras.activations.get(activation) self.use_bias = use_bias self.supports_masking = False - self.input_spec = tf.keras.layers.InputSpec(min_ndim=2) + self.input_spec = tf_keras.layers.InputSpec(min_ndim=2) def build(self, input_shape): - dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) + dtype = tf.as_dtype(self.dtype or tf_keras.backend.floatx()) if not (dtype.is_floating or dtype.is_complex): raise TypeError('Unable to build `Dense` layer with non-floating point ' 'dtype %s' % (dtype,)) @@ -96,7 +98,7 @@ def build(self, input_shape): if last_dim is None: raise ValueError('The last dimension of the inputs to `DenseVariational` ' 'should be defined. Found `None`.') - self.input_spec = tf.keras.layers.InputSpec( + self.input_spec = tf_keras.layers.InputSpec( min_ndim=2, axes={-1: last_dim}) with tf.name_scope('posterior'): @@ -113,7 +115,7 @@ def build(self, input_shape): self.built = True def call(self, inputs): - dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) + dtype = tf.as_dtype(self.dtype or tf_keras.backend.floatx()) inputs = tf.cast(inputs, dtype, name='inputs') q = self._posterior(inputs) diff --git a/tensorflow_probability/python/layers/dense_variational_v2_test.py b/tensorflow_probability/python/layers/dense_variational_v2_test.py index aca410fc45..51c61d9fae 100644 --- a/tensorflow_probability/python/layers/dense_variational_v2_test.py +++ b/tensorflow_probability/python/layers/dense_variational_v2_test.py @@ -22,6 +22,7 @@ from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import dense_variational_v2 from tensorflow_probability.python.layers import distribution_layer from tensorflow_probability.python.layers import variable_input @@ -51,7 +52,7 @@ def s(x): def posterior_mean_field(kernel_size, bias_size=0, dtype=None): n = kernel_size + bias_size c = np.log(np.expm1(1.)) - return tf.keras.Sequential([ + return tf_keras.Sequential([ variable_input.VariableLayer(2 * n, dtype=dtype), distribution_layer.DistributionLambda(lambda t: independent.Independent( # pylint: disable=g-long-lambda normal.Normal(loc=t[..., :n], @@ -62,7 +63,7 @@ def posterior_mean_field(kernel_size, bias_size=0, dtype=None): def prior_trainable(kernel_size, bias_size=0, dtype=None): n = kernel_size + bias_size - return tf.keras.Sequential([ + return tf_keras.Sequential([ variable_input.VariableLayer(n, dtype=dtype), distribution_layer.DistributionLambda( lambda t: independent.Independent(normal.Normal(loc=t, scale=1), # pylint: disable=g-long-lambda @@ -83,16 +84,16 @@ def test_end_to_end(self): layer = dense_variational_v2.DenseVariational(1, posterior_mean_field, prior_trainable) - model = tf.keras.Sequential([ + model = tf_keras.Sequential([ layer, distribution_layer.DistributionLambda( lambda t: normal.Normal(loc=t, scale=1)) ]) if tf.__internal__.tf2.enabled() and tf.executing_eagerly(): - optimizer = tf.keras.optimizers.Adam(learning_rate=0.05) + optimizer = tf_keras.optimizers.Adam(learning_rate=0.05) else: - optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=0.05) + optimizer = tf_keras.optimizers.legacy.Adam(learning_rate=0.05) # Do inference. model.compile(optimizer=optimizer, loss=negloglik) diff --git a/tensorflow_probability/python/layers/distribution_layer.py b/tensorflow_probability/python/layers/distribution_layer.py index 82777bbec5..638d15e61d 100644 --- a/tensorflow_probability/python/layers/distribution_layer.py +++ b/tensorflow_probability/python/layers/distribution_layer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Layers for combining `tfp.distributions` and `tf.keras`.""" +"""Layers for combining `tfp.distributions` and `tf_keras`.""" import codecs import collections @@ -43,6 +43,7 @@ from tensorflow_probability.python.distributions import transformed_distribution as transformed_distribution_lib from tensorflow_probability.python.distributions import variational_gaussian_process as variational_gaussian_process_lib from tensorflow_probability.python.internal import distribution_util as dist_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers.internal import distribution_tensor_coercible as dtc from tensorflow_probability.python.layers.internal import tensor_tuple @@ -65,7 +66,7 @@ ] -tf.keras.__internal__.utils.register_symbolic_tensor_type(dtc._TensorCoercible) # pylint: disable=protected-access +tf_keras.__internal__.utils.register_symbolic_tensor_type(dtc._TensorCoercible) # pylint: disable=protected-access def _event_size(event_shape, name=None): @@ -92,7 +93,7 @@ def _event_size(event_shape, name=None): return tf.reduce_prod(event_shape) -class DistributionLambda(tf.keras.layers.Lambda): +class DistributionLambda(tf_keras.layers.Lambda): """Keras layer enabling plumbing TFP distributions through Keras models. A `DistributionLambda` is minimially characterized by a function that returns @@ -108,8 +109,8 @@ class DistributionLambda(tf.keras.layers.Lambda): #### Examples ```python - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers tfd = tfp.distributions tfpl = tfp.layers @@ -139,7 +140,7 @@ def __init__(self, instance and returns a `tf.Tensor`-like object. For examples, see `class` docstring. Default value: `tfd.Distribution.sample`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ # TODO(b/120440642): See if something like this code block is needed. # if output_shape is None: @@ -298,8 +299,8 @@ class MultivariateNormalTriL(DistributionLambda): #### Example ```python - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers tfd = tfp.distributions tfpl = tfp.layers @@ -355,7 +356,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ super(MultivariateNormalTriL, self).__init__( lambda t: MultivariateNormalTriL.new(t, event_size, validate_args), @@ -396,8 +397,8 @@ class OneHotCategorical(DistributionLambda): #### Example ```python - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers tfd = tfp.distributions tfpl = tfp.layers @@ -459,7 +460,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ super(OneHotCategorical, self).__init__( lambda t: OneHotCategorical.new( # pylint: disable=g-long-lambda @@ -500,8 +501,8 @@ class CategoricalMixtureOfOneHotCategorical(DistributionLambda): #### Example ```python - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers tfd = tfp.distributions tfpl = tfp.layers @@ -564,7 +565,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ super(CategoricalMixtureOfOneHotCategorical, self).__init__( # pylint: disable=g-long-lambda @@ -622,8 +623,8 @@ class IndependentBernoulli(DistributionLambda): #### Example ```python - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers tfd = tfp.distributions tfpl = tfp.layers @@ -685,7 +686,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) @@ -788,8 +789,8 @@ class IndependentLogistic(DistributionLambda): ```python tfd = tfp.distributions tfpl = tfp.layers - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers # Create a stochastic encoder -- e.g., for use in a variational auto-encoder. input_shape = [28, 28, 1] @@ -823,7 +824,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) @@ -903,8 +904,8 @@ class IndependentNormal(DistributionLambda): ```python tfd = tfp.distributions tfpl = tfp.layers - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers # Create a stochastic encoder -- e.g., for use in a variational auto-encoder. input_shape = [28, 28, 1] @@ -938,7 +939,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) @@ -1018,8 +1019,8 @@ class IndependentPoisson(DistributionLambda): ```python tfd = tfp.distributions tfpl = tfp.layers - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers # Create example data. n = 2000 @@ -1069,7 +1070,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) @@ -1141,7 +1142,7 @@ def get_config(self): # We mix-in `tf.Module` since Keras `Regularizer` base class tracks neither # tf.Variables nor tf.Modules. -class KLDivergenceRegularizer(tf.keras.regularizers.Regularizer, tf.Module): +class KLDivergenceRegularizer(tf_keras.regularizers.Regularizer, tf.Module): """Regularizer that adds a KL divergence penalty to the model loss. When using Monte Carlo approximation (e.g., `use_exact=False`), it is presumed @@ -1154,8 +1155,8 @@ class KLDivergenceRegularizer(tf.keras.regularizers.Regularizer, tf.Module): ```python tfd = tfp.distributions tfpl = tfp.layers - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers # Create a variational encoder and add a KL Divergence penalty to the # loss that encourages marginal coherence with a unit-MVN (the "prior"). @@ -1251,7 +1252,7 @@ def __call__(self, distribution_a): return self._kl_divergence_fn(distribution_a) -class KLDivergenceAddLoss(tf.keras.layers.Layer): +class KLDivergenceAddLoss(tf_keras.layers.Layer): """Pass-through layer that adds a KL divergence penalty to the model loss. When using Monte Carlo approximation (e.g., `use_exact=False`), it is presumed @@ -1264,8 +1265,8 @@ class KLDivergenceAddLoss(tf.keras.layers.Layer): ```python tfd = tfp.distributions tfpl = tfp.layers - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers # Create a variational encoder and add a KL Divergence penalty to the # loss that encourages marginal coherence with a unit-MVN (the "prior"). @@ -1315,7 +1316,7 @@ def __init__(self, weight: Multiplier applied to the calculated KL divergence for each Keras batch member. Default value: `None` (i.e., do not weight each batch member). - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ super(KLDivergenceAddLoss, self).__init__(**kwargs) self._regularizer = KLDivergenceRegularizer( @@ -1358,7 +1359,7 @@ def kl_divergence_fn(distribution_a, distribution_b): def _fn(distribution_a): """Closure that computes KLDiv as a function of `a` as in `KL[a, b]`.""" with tf.name_scope('kldivergence_loss'): - if isinstance(distribution_b, tf.keras.Model): + if isinstance(distribution_b, tf_keras.Model): distribution_b_ = distribution_b(0.) # Pass a dummy arg. elif callable(distribution_b): # TODO(b/119756336): Due to eager/graph Jacobian graph caching bug we @@ -1391,8 +1392,8 @@ class MixtureSameFamily(DistributionLambda): ```python tfd = tfp.distributions tfpl = tfp.layers - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers # Load data -- graph of a [cardioid](https://en.wikipedia.org/wiki/Cardioid). n = 2000 @@ -1449,7 +1450,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ super(MixtureSameFamily, self).__init__( lambda t: MixtureSameFamily.new( # pylint: disable=g-long-lambda @@ -1518,8 +1519,8 @@ class MixtureNormal(DistributionLambda): ```python tfd = tfp.distributions tfpl = tfp.layers - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers # Load data -- graph of a [cardioid](https://en.wikipedia.org/wiki/Cardioid). n = 2000 @@ -1571,7 +1572,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) @@ -1643,8 +1644,8 @@ class MixtureLogistic(DistributionLambda): ```python tfd = tfp.distributions tfpl = tfp.layers - tfk = tf.keras - tfkl = tf.keras.layers + tfk = tf_keras + tfkl = tf_keras.layers # Load data -- graph of a [cardioid](https://en.wikipedia.org/wiki/Cardioid). n = 2000 @@ -1696,7 +1697,7 @@ def __init__(self, performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) @@ -1766,7 +1767,7 @@ class VariationalGaussianProcess(DistributionLambda): Create a VariationalGaussianProcess distribtuion whose `index_points` are the inputs to the layer. Parameterized by number of inducing points and a - `kernel_provider`, which should be a `tf.keras.Layer` with an @property that + `kernel_provider`, which should be a `tf_keras.Layer` with an @property that late-binds variable parameters to a `tfp.positive_semidefinite_kernel.PositiveSemidefiniteKernel` instance (this requirement has to do with the way that variables must be created in a keras @@ -1782,7 +1783,7 @@ def __init__( event_shape=(1,), inducing_index_points_initializer=None, unconstrained_observation_noise_variance_initializer=( - tf.initializers.constant(-10.)), + tf_keras.initializers.constant(-10.)), variational_inducing_observations_scale_initializer=None, mean_fn=None, jitter=1e-6, @@ -1802,17 +1803,17 @@ def __init__( example, `event_shape = [3]` means we are modeling a batch of 3 distributions over functions. We can think of this as a distrbution over 3-dimensional vector-valued functions. - inducing_index_points_initializer: a `tf.keras.initializer.Initializer` + inducing_index_points_initializer: a `tf_keras.initializer.Initializer` used to initialize the trainable `inducing_index_points` variables. Training VGP's is pretty sensitive to choice of initial inducing index point locations. A reasonable heuristic is to scatter them near the data, not too close to each other. unconstrained_observation_noise_variance_initializer: a - `tf.keras.initializer.Initializer` used to initialize the unconstrained + `tf_keras.initializer.Initializer` used to initialize the unconstrained observation noise variable. The observation noise variance is computed from this variable via the `tf.nn.softplus` function. variational_inducing_observations_scale_initializer: a - `tf.keras.initializer.Initializer` used to initialize the variational + `tf_keras.initializer.Initializer` used to initialize the variational inducing observations scale. mean_fn: a callable that maps layer inputs to mean function values. Passed to the mean_fn parameter of VariationalGaussianProcess distribution. If @@ -1869,7 +1870,7 @@ def build(self, input_shape): if self._mean_fn is None: self.mean = self.add_weight( - initializer=tf.initializers.constant([0.]), + initializer=tf_keras.initializers.constant([0.]), dtype=self._dtype, name='mean') self._mean_fn = lambda x: self.mean @@ -1896,14 +1897,14 @@ def build(self, input_shape): self._variational_inducing_observations_loc = self.add_weight( name='variational_inducing_observations_loc', shape=self._event_shape.as_list() + [self._num_inducing_points], - initializer=tf.initializers.zeros(), + initializer=tf_keras.initializers.zeros(), dtype=self._dtype) if self._variational_inducing_observations_scale_initializer is None: eyes = (np.ones(self._event_shape.as_list() + [1, 1]) * np.eye(self._num_inducing_points, dtype=self._dtype)) self._variational_inducing_observations_scale_initializer = ( - tf.initializers.constant(1e-5 * eyes)) + tf_keras.initializers.constant(1e-5 * eyes)) self._variational_inducing_observations_scale = self.add_weight( name='variational_inducing_observations_scale', shape=(self._event_shape.as_list() + @@ -1945,7 +1946,7 @@ def _transposed_variational_loss(y, kl_weight=1.): # For deserialization. -tf.keras.utils.get_custom_objects().update({ +tf_keras.utils.get_custom_objects().update({ 'DistributionLambda': DistributionLambda, 'IndependentBernoulli': IndependentBernoulli, 'IndependentLogistic': IndependentLogistic, @@ -1963,11 +1964,11 @@ def _transposed_variational_loss(y, kl_weight=1.): def _serialize(convert_to_tensor_fn): - return tf.keras.utils.legacy.serialize_keras_object(convert_to_tensor_fn) + return tf_keras.utils.legacy.serialize_keras_object(convert_to_tensor_fn) def _deserialize(name, custom_objects=None): - return tf.keras.utils.legacy.deserialize_keras_object( + return tf_keras.utils.legacy.deserialize_keras_object( name, module_objects=globals(), custom_objects=custom_objects, diff --git a/tensorflow_probability/python/layers/distribution_layer_test.py b/tensorflow_probability/python/layers/distribution_layer_test.py index 1238fb5e51..197f889a7f 100644 --- a/tensorflow_probability/python/layers/distribution_layer_test.py +++ b/tensorflow_probability/python/layers/distribution_layer_test.py @@ -37,15 +37,15 @@ from tensorflow_probability.python.distributions import poisson from tensorflow_probability.python.distributions import uniform from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import distribution_layer from tensorflow_probability.python.layers import variable_input from tensorflow_probability.python.math import generic from tensorflow_probability.python.math.psd_kernels import exponentiated_quadratic from tensorflow_probability.python.util import deferred_tensor -tfk = tf.keras - -tfkl = tf.keras.layers +tfk = tf_keras +tfkl = tf_keras.layers def _logit_avg_expit(t): @@ -72,8 +72,8 @@ def _unwrap_tensor_coercible(dist): def _get_adam_optimizer(learning_rate): if tf.__internal__.tf2.enabled() and tf.executing_eagerly(): - return tf.keras.optimizers.Adam(learning_rate=learning_rate) - return tf.keras.optimizers.legacy.Adam(learning_rate=learning_rate) + return tf_keras.optimizers.Adam(learning_rate=learning_rate) + return tf_keras.optimizers.legacy.Adam(learning_rate=learning_rate) # TODO(b/143642032): Figure out how to solve issues with save/load, so that we @@ -92,9 +92,9 @@ class EndToEndTest(test_util.TestCase): registered via `tf.register_tensor_conversion_function`. Fundamentally, there are three ways to be Keras models: - 1. `tf.keras.Sequential` + 1. `tf_keras.Sequential` 2. Functional API - 3. Subclass `tf.keras.Model`. + 3. Subclass `tf_keras.Model`. Its important to have end-to-end tests for all three, because #1 and #2 call `__call__` and `call` differently. (#3's call pattern depends on user @@ -336,8 +336,8 @@ def test_side_variable_is_auto_tracked(self): # `s` is the "side variable". s = deferred_tensor.TransformedVariable(1., softplus.Softplus()) prior = normal_lib.Normal(tf.Variable(0.), 1.) - linear_regression = tf.keras.Sequential([ - tf.keras.layers.Dense(1), + linear_regression = tf_keras.Sequential([ + tf_keras.layers.Dense(1), distribution_layer.DistributionLambda( lambda t: normal_lib.Normal(t, s), activity_regularizer=distribution_layer.KLDivergenceRegularizer( @@ -600,8 +600,8 @@ def test_doc_string(self): true_bias = np.array([0, 0, np.log(scale_noise), 0, np.log(scale_noise)]) # Create model. - model = tf.keras.Sequential([ - tf.keras.layers.Dense( + model = tf_keras.Sequential([ + tf_keras.layers.Dense( distribution_layer.MultivariateNormalTriL.params_size(d), kernel_initializer=lambda s, **_: true_kernel, bias_initializer=lambda s, **_: true_bias), @@ -660,10 +660,10 @@ def test_doc_string(self): d = y.shape[-1] # Create model. - model = tf.keras.Sequential([ - tf.keras.layers.Dense( + model = tf_keras.Sequential([ + tf_keras.layers.Dense( distribution_layer.OneHotCategorical.params_size(d) - 1), - tf.keras.layers.Lambda(_vec_pad), + tf_keras.layers.Lambda(_vec_pad), distribution_layer.OneHotCategorical(d), ]) @@ -748,8 +748,8 @@ def test_doc_string(self): k = 2 p = distribution_layer.CategoricalMixtureOfOneHotCategorical.params_size( d, k) - model = tf.keras.Sequential([ - tf.keras.layers.Dense(p), + model = tf_keras.Sequential([ + tf_keras.layers.Dense(p), distribution_layer.CategoricalMixtureOfOneHotCategorical(d, k), ]) @@ -908,8 +908,8 @@ def test_doc_string(self): event_shape = y.shape[1:] # Create model. - model = tf.keras.Sequential([ - tf.keras.layers.Dense( + model = tf_keras.Sequential([ + tf_keras.layers.Dense( distribution_layer.IndependentBernoulli.params_size(event_shape)), distribution_layer.IndependentBernoulli(event_shape), ]) @@ -1510,13 +1510,13 @@ def s(x): y = (w0 * x * (1 + np.sin(x)) + b0) + eps x0 = np.linspace(*x_range, num=1000) - class KernelFn(tf.keras.layers.Layer): + class KernelFn(tf_keras.layers.Layer): def __init__(self, **kwargs): super(KernelFn, self).__init__(**kwargs) self._amplitude = self.add_weight( - initializer=tf.initializers.constant(.54), + initializer=tf_keras.initializers.constant(.54), dtype=dtype, name='amplitude') @@ -1533,17 +1533,17 @@ def kernel(self): # Add a leading dimension for the event_shape. eyes = np.expand_dims(np.eye(num_inducing_points), 0) variational_inducing_observations_scale_initializer = ( - tf.initializers.constant(1e-3 * eyes)) + tf_keras.initializers.constant(1e-3 * eyes)) - model = tf.keras.Sequential([ - tf.keras.layers.InputLayer(input_shape=[1], dtype=dtype), - tf.keras.layers.Dense(1, kernel_initializer='Ones', use_bias=False, + model = tf_keras.Sequential([ + tf_keras.layers.InputLayer(input_shape=[1], dtype=dtype), + tf_keras.layers.Dense(1, kernel_initializer='Ones', use_bias=False, activation=None, dtype=dtype), distribution_layer.VariationalGaussianProcess( num_inducing_points=num_inducing_points, kernel_provider=KernelFn(dtype=dtype), inducing_index_points_initializer=( - tf.initializers.constant( + tf_keras.initializers.constant( np.linspace(*x_range, num=num_inducing_points, dtype=dtype)[..., np.newaxis])), diff --git a/tensorflow_probability/python/layers/initializers.py b/tensorflow_probability/python/layers/initializers.py index 0ebe5fdf69..c0b57bfddb 100644 --- a/tensorflow_probability/python/layers/initializers.py +++ b/tensorflow_probability/python/layers/initializers.py @@ -18,9 +18,10 @@ import numpy as np import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import tf_keras -class BlockwiseInitializer(tf.keras.initializers.Initializer): +class BlockwiseInitializer(tf_keras.initializers.Initializer): """Initializer which concats other intializers.""" def __init__(self, initializers, sizes, validate_args=False): @@ -28,7 +29,7 @@ def __init__(self, initializers, sizes, validate_args=False): Args: initializers: `list` of Keras initializers, e.g., `"glorot_uniform"` or - `tf.keras.initializers.Constant(0.5413)`. + `tf_keras.initializers.Constant(0.5413)`. sizes: `list` of `int` scalars representing the number of elements associated with each initializer in `initializers`. validate_args: Python `bool` indicating we should do (possibly expensive) @@ -58,7 +59,7 @@ def __call__(self, shape, dtype=None): dtype: Optional dtype of the tensor. If not provided will return tensor of `tf.float32`. """ - dtype = tf.as_dtype(dtype or tf.keras.backend.floatx()) + dtype = tf.as_dtype(dtype or tf_keras.backend.floatx()) if isinstance(shape, tf.TensorShape): shape_dtype = tf.int32 shape_ = np.int32(shape) @@ -88,14 +89,14 @@ def __call__(self, shape, dtype=None): else shape_[:-1]) if sizes_ is not None and isinstance(s, (np.ndarray, np.generic)): return tf.concat([ - tf.keras.initializers.get(init)(np.concatenate([ + tf_keras.initializers.get(init)(np.concatenate([ s, np.array([e], shape_dtype.as_numpy_dtype)], axis=-1), dtype) for init, e in zip(self.initializers, sizes_.tolist()) ], axis=-1) sizes = tf.split(self.sizes, len(self.initializers)) return tf.concat([ - tf.keras.initializers.get(init)(tf.concat([s, e], axis=-1), dtype) + tf_keras.initializers.get(init)(tf.concat([s, e], axis=-1), dtype) for init, e in zip(self.initializers, sizes) ], axis=-1) @@ -103,8 +104,8 @@ def get_config(self): """Returns initializer configuration as a JSON-serializable dict.""" return { 'initializers': [ - tf.initializers.serialize( - tf.keras.initializers.get(init)) + tf_keras.initializers.serialize( + tf_keras.initializers.get(init)) for init in self.initializers ], 'sizes': self.sizes, @@ -115,12 +116,12 @@ def get_config(self): def from_config(cls, config): """Instantiates an initializer from a configuration dictionary.""" return cls(**{ - 'initializers': [tf.initializers.deserialize(init) + 'initializers': [tf_keras.initializers.deserialize(init) for init in config.get('initializers', [])], 'sizes': config.get('sizes', []), 'validate_args': config.get('validate_args', False), }) -tf.keras.utils.get_custom_objects()[ +tf_keras.utils.get_custom_objects()[ 'BlockwiseInitializer'] = BlockwiseInitializer diff --git a/tensorflow_probability/python/layers/initializers_test.py b/tensorflow_probability/python/layers/initializers_test.py index 91fc165a2e..dc451cee26 100644 --- a/tensorflow_probability/python/layers/initializers_test.py +++ b/tensorflow_probability/python/layers/initializers_test.py @@ -17,8 +17,8 @@ # Dependency imports import numpy as np -import tensorflow.compat.v2 as tf from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import initializers @@ -34,9 +34,9 @@ def test_works_correctly(self): self.assertAllEqual(np.zeros([2, 1, 4]), x_[..., 3:]) def test_de_serialization(self): - s = tf.initializers.serialize( + s = tf_keras.initializers.serialize( initializers.BlockwiseInitializer(['glorot_uniform', 'zeros'], [3, 4])) - init_clone = tf.initializers.deserialize(s) + init_clone = tf_keras.initializers.deserialize(s) x = init_clone([2, 1, 7]) self.assertEqual((2, 1, 7), x.shape) x_ = self.evaluate(x) diff --git a/tensorflow_probability/python/layers/masked_autoregressive.py b/tensorflow_probability/python/layers/masked_autoregressive.py index 8ff923c125..07a406ec5a 100644 --- a/tensorflow_probability/python/layers/masked_autoregressive.py +++ b/tensorflow_probability/python/layers/masked_autoregressive.py @@ -19,7 +19,7 @@ from tensorflow_probability.python.bijectors import masked_autoregressive as masked_autoregressive_lib from tensorflow_probability.python.distributions import transformed_distribution as transformed_distribution_lib - +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers.distribution_layer import DistributionLambda @@ -61,7 +61,7 @@ def f_inverse(x): tfd = tfp.distributions tfpl = tfp.layers tfb = tfp.bijectors - tfk = tf.keras + tfk = tf_keras # Generate data -- as in Figure 1 in [Papamakarios et al. (2017)][1]). n = 2000 @@ -121,7 +121,7 @@ def __init__(self, made, **kwargs): Args: made: A `Made` layer, which must output two parameters for each input. - **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + **kwargs: Additional keyword arguments passed to `tf_keras.Layer`. """ super(AutoregressiveTransform, self).__init__(self._transform, **kwargs) @@ -132,8 +132,8 @@ def __init__(self, made, **kwargs): self._made = made def build(self, input_shape): - tf.keras.Sequential([ - tf.keras.layers.InputLayer( + tf_keras.Sequential([ + tf_keras.layers.InputLayer( input_shape=input_shape[1:], dtype=self.dtype), self._made ]) diff --git a/tensorflow_probability/python/layers/masked_autoregressive_test.py b/tensorflow_probability/python/layers/masked_autoregressive_test.py index ebddc2eb4d..24b382ffba 100644 --- a/tensorflow_probability/python/layers/masked_autoregressive_test.py +++ b/tensorflow_probability/python/layers/masked_autoregressive_test.py @@ -19,11 +19,12 @@ from tensorflow_probability.python.bijectors import masked_autoregressive as masked_autoregressive_lib from tensorflow_probability.python.distributions import mvn_diag from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import distribution_layer from tensorflow_probability.python.layers import masked_autoregressive -tfk = tf.keras -tfkl = tf.keras.layers +tfk = tf_keras +tfkl = tf_keras.layers @test_util.test_all_tf_execution_regimes diff --git a/tensorflow_probability/python/layers/util.py b/tensorflow_probability/python/layers/util.py index 5fcdb72ca7..c8b607f3c1 100644 --- a/tensorflow_probability/python/layers/util.py +++ b/tensorflow_probability/python/layers/util.py @@ -21,7 +21,6 @@ import types # Dependency imports import numpy as np -import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf from tensorflow_probability.python import util as tfp_util @@ -29,6 +28,8 @@ from tensorflow_probability.python.distributions import independent as independent_lib from tensorflow_probability.python.distributions import normal as normal_lib +from tensorflow_probability.python.internal import tf_keras + __all__ = [ 'default_loc_scale_fn', @@ -41,8 +42,8 @@ def default_loc_scale_fn( is_singular=False, - loc_initializer=tf1.initializers.random_normal(stddev=0.1), - untransformed_scale_initializer=tf1.initializers.random_normal( + loc_initializer=tf_keras.initializers.RandomNormal(stddev=0.1), + untransformed_scale_initializer=tf_keras.initializers.RandomNormal( mean=-3., stddev=0.1), loc_regularizer=None, untransformed_scale_regularizer=None, @@ -122,8 +123,8 @@ def _fn(dtype, shape, name, trainable, add_variable_fn): def default_mean_field_normal_fn( is_singular=False, - loc_initializer=tf1.initializers.random_normal(stddev=0.1), - untransformed_scale_initializer=tf1.initializers.random_normal( + loc_initializer=tf_keras.initializers.RandomNormal(stddev=0.1), + untransformed_scale_initializer=tf_keras.initializers.RandomNormal( mean=-3., stddev=0.1), loc_regularizer=None, untransformed_scale_regularizer=None, @@ -235,7 +236,7 @@ def deserialize_function(serial, function_type): Keras-deserialized functions do not perform lexical scoping. Any modules that the function requires must be imported within the function itself. - This serialization mimicks the implementation in `tf.keras.layers.Lambda`. + This serialization mimicks the implementation in `tf_keras.layers.Lambda`. Args: serial: Serialized Keras object: typically a dict, string, or bytecode. @@ -255,7 +256,7 @@ def deserialize_function(serial, function_type): """ if function_type == 'function': # Simple lookup in custom objects - function = tf.keras.utils.legacy.deserialize_keras_object(serial) + function = tf_keras.utils.legacy.deserialize_keras_object(serial) elif function_type == 'lambda': # Unsafe deserialization from bytecode function = _func_load(serial) @@ -273,7 +274,7 @@ def serialize_function(func): us use the Python scope to obtain the function rather than reload it from bytecode. (Note that both cases are brittle!) - This serialization mimicks the implementation in `tf.keras.layers.Lambda`. + This serialization mimicks the implementation in `tf_keras.layers.Lambda`. Args: func: Python function to serialize. diff --git a/tensorflow_probability/python/layers/variable_input.py b/tensorflow_probability/python/layers/variable_input.py index 9dbdb2edc9..0dae6ff7ef 100644 --- a/tensorflow_probability/python/layers/variable_input.py +++ b/tensorflow_probability/python/layers/variable_input.py @@ -18,27 +18,28 @@ import numpy as np import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import tf_keras -class VariableLayer(tf.keras.layers.Layer): +class VariableLayer(tf_keras.layers.Layer): """Simply returns a (trainable) variable, regardless of input. This layer implements the mathematical function `f(x) = c` where `c` is a constant, i.e., unchanged for all `x`. Like other Keras layers, the constant is `trainable`. This layer can also be interpretted as the special case of - `tf.keras.layers.Dense` when the `kernel` is forced to be the zero matrix + `tf_keras.layers.Dense` when the `kernel` is forced to be the zero matrix (`tf.zeros`). #### Examples ```python - trainable_normal = tf.keras.models.Sequential([ + trainable_normal = tf_keras.models.Sequential([ tfp.layers.VariableLayer( shape=[3, 4, 2], dtype=tf.float64, initializer=tfp.layers.BlockwiseInitializer([ 'zeros', - tf.keras.initializers.Constant(np.log(np.expm1(1.))), + tf_keras.initializers.Constant(np.log(np.expm1(1.))), ], sizes=[1, 1])), tfp.layers.DistributionLambda(lambda t: tfd.Independent( tfd.Normal(loc=t[..., 0], scale=tf.math.softplus(t[..., 1])), @@ -83,7 +84,7 @@ def __init__(self, shape: integer or integer vector specifying the shape of the output of this layer. dtype: TensorFlow `dtype` of the variable created by this layer. - Default value: `None` (i.e., `tf.as_dtype(tf.keras.backend.floatx())`). + Default value: `None` (i.e., `tf.as_dtype(tf_keras.backend.floatx())`). activation: Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: `a(x) = x`). Default value: `None`. @@ -93,7 +94,7 @@ def __init__(self, ```python tfp.layers.BlockwiseInitializer([ 'zeros', - tf.keras.initializers.Constant(np.log(np.expm1(1.))), # = 0.541325 + tf_keras.initializers.Constant(np.log(np.expm1(1.))), # = 0.541325 ], sizes=[1, 1]) ``` Default value: `'zeros'`. @@ -101,14 +102,14 @@ def __init__(self, Default value: `None`. constraint: Constraint function applied to the `constant` vector. Default value: `None`. - **kwargs: Extra arguments forwarded to `tf.keras.layers.Layer`. + **kwargs: Extra arguments forwarded to `tf_keras.layers.Layer`. """ super(VariableLayer, self).__init__(**kwargs) - self.activation = tf.keras.activations.get(activation) - self.initializer = tf.keras.initializers.get(initializer) - self.regularizer = tf.keras.regularizers.get(regularizer) - self.constraint = tf.keras.constraints.get(constraint) + self.activation = tf_keras.activations.get(activation) + self.initializer = tf_keras.initializers.get(initializer) + self.regularizer = tf_keras.regularizers.get(regularizer) + self.constraint = tf_keras.constraints.get(constraint) shape = tf.get_static_value(shape) if shape is None: diff --git a/tensorflow_probability/python/layers/variable_input_test.py b/tensorflow_probability/python/layers/variable_input_test.py index 94f9a57e1d..d80899bc64 100644 --- a/tensorflow_probability/python/layers/variable_input_test.py +++ b/tensorflow_probability/python/layers/variable_input_test.py @@ -18,6 +18,7 @@ from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import distribution_layer from tensorflow_probability.python.layers import variable_input @@ -27,13 +28,13 @@ class VariableInputLayerTest(test_util.TestCase): def test_sequential_api(self): # Create a trainable distribution using the Sequential API. - model = tf.keras.models.Sequential([ + model = tf_keras.models.Sequential([ variable_input.VariableLayer( shape=[2, 3, 4], dtype=tf.float64, trainable=False), # You'd probably never want this in IRL. # The Dense serves no real purpose; it will change the event_shape. - tf.keras.layers.Dense(5, use_bias=False, dtype=tf.float64), + tf_keras.layers.Dense(5, use_bias=False, dtype=tf.float64), distribution_layer.DistributionLambda( lambda t: independent.Independent( # pylint: disable=g-long-lambda normal.Normal(loc=t[0], scale=t[1]), @@ -68,19 +69,19 @@ def test_sequential_api(self): def test_functional_api(self): # Create a trainable distribution using the functional API. - dummy_input = tf.keras.Input(shape=()) + dummy_input = tf_keras.Input(shape=()) x = variable_input.VariableLayer( shape=[2, 3, 4], dtype=tf.float64, trainable=False, # You'd probably never want this in IRL. )(dummy_input) # The Dense serves no real purpose; it will change the event_shape. - x = tf.keras.layers.Dense(5, use_bias=False, dtype=tf.float64)(x) + x = tf_keras.layers.Dense(5, use_bias=False, dtype=tf.float64)(x) x = distribution_layer.DistributionLambda( lambda t: independent.Independent(normal.Normal(loc=t[0], scale=t[1]), # pylint: disable=g-long-lambda reinterpreted_batch_ndims=1), dtype=tf.float64)(x) - model = tf.keras.Model(dummy_input, x) + model = tf_keras.Model(dummy_input, x) # Instantiate the model (as a TFP distribution). dist = model(tf.zeros([])) diff --git a/tensorflow_probability/python/layers/weight_norm.py b/tensorflow_probability/python/layers/weight_norm.py index b8b7b84925..c255f1c5a1 100644 --- a/tensorflow_probability/python/layers/weight_norm.py +++ b/tensorflow_probability/python/layers/weight_norm.py @@ -17,9 +17,10 @@ import warnings import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import tf_keras -class WeightNorm(tf.keras.layers.Wrapper): +class WeightNorm(tf_keras.layers.Wrapper): """Layer wrapper to decouple magnitude and direction of the layer's weights. This wrapper reparameterizes a layer by decoupling the weight's @@ -32,13 +33,13 @@ class WeightNorm(tf.keras.layers.Wrapper): #### Example ```python - net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'), + net = WeightNorm(tf_keras.layers.Conv2D(2, 2, activation='relu'), input_shape=(32, 32, 3), data_init=True)(x) - net = WeightNorm(tf.keras.layers.Conv2DTranspose(16, 5, activation='relu'), + net = WeightNorm(tf_keras.layers.Conv2DTranspose(16, 5, activation='relu'), data_init=True) - net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'), + net = WeightNorm(tf_keras.layers.Dense(120, activation='relu'), data_init=True)(net) - net = WeightNorm(tf.keras.layers.Dense(num_classes), + net = WeightNorm(tf_keras.layers.Dense(num_classes), data_init=True)(net) ``` @@ -54,19 +55,19 @@ def __init__(self, layer, data_init=True, **kwargs): """Initialize WeightNorm wrapper. Args: - layer: A `tf.keras.layers.Layer` instance. Supported layer types are + layer: A `tf_keras.layers.Layer` instance. Supported layer types are `Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs are not supported. data_init: `bool`, if `True` use data dependent variable initialization. - **kwargs: Additional keyword args passed to `tf.keras.layers.Wrapper`. + **kwargs: Additional keyword args passed to `tf_keras.layers.Wrapper`. Raises: - ValueError: If `layer` is not a `tf.keras.layers.Layer` instance. + ValueError: If `layer` is not a `tf_keras.layers.Layer` instance. """ - if not isinstance(layer, tf.keras.layers.Layer): + if not isinstance(layer, tf_keras.layers.Layer): raise ValueError( - 'Please initialize `WeightNorm` layer with a `tf.keras.layers.Layer` ' + 'Please initialize `WeightNorm` layer with a `tf_keras.layers.Layer` ' 'instance. You passed: {input}'.format(input=layer)) layer_type = type(layer).__name__ @@ -138,7 +139,7 @@ def build(self, input_shape=None): input_shape = tf.TensorShape(input_shape).as_list() input_shape[0] = None - self.input_spec = tf.keras.layers.InputSpec(shape=input_shape) + self.input_spec = tf_keras.layers.InputSpec(shape=input_shape) if not self.layer.built: self.layer.build(input_shape) diff --git a/tensorflow_probability/python/layers/weight_norm_test.py b/tensorflow_probability/python/layers/weight_norm_test.py index 5d47d1a9ab..bdbda2fa49 100644 --- a/tensorflow_probability/python/layers/weight_norm_test.py +++ b/tensorflow_probability/python/layers/weight_norm_test.py @@ -24,10 +24,11 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.layers import weight_norm -tfk = tf.keras -tfkl = tf.keras.layers +tfk = tf_keras +tfkl = tf_keras.layers # TODO(b/143642032): Figure out how to get this working with @@ -225,9 +226,9 @@ def testGradientValues(self, model_type): @parameterized.parameters(['sequential', 'sequential_no_input', 'functional']) def testTrainableVariableInitializationInModelFit(self, model_type): if tf.__internal__.tf2.enabled() and tf.executing_eagerly(): - sgd = tf.keras.optimizers.SGD(learning_rate=0.) + sgd = tf_keras.optimizers.SGD(learning_rate=0.) else: - sgd = tf.keras.optimizers.legacy.SGD(learning_rate=0.) + sgd = tf_keras.optimizers.legacy.SGD(learning_rate=0.) model = self._define_model(model_type, self.data_dim, self.num_hidden) model.compile(optimizer=sgd, loss='mse') model.fit( diff --git a/tensorflow_probability/python/math/BUILD b/tensorflow_probability/python/math/BUILD index 2d962b39ca..f9b7557749 100644 --- a/tensorflow_probability/python/math/BUILD +++ b/tensorflow_probability/python/math/BUILD @@ -338,6 +338,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:loop_util", "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:tf_keras", ], ) @@ -353,6 +354,7 @@ multi_substrate_py_test( # tensorflow dep, "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/optimizer", # "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport ], diff --git a/tensorflow_probability/python/math/linalg.py b/tensorflow_probability/python/math/linalg.py index 93f770791a..fd421f354c 100644 --- a/tensorflow_probability/python/math/linalg.py +++ b/tensorflow_probability/python/math/linalg.py @@ -454,8 +454,9 @@ def low_rank_cholesky(matrix, max_rank, trace_atol=0, trace_rtol=0, name=None): dtype_hint=tf.float32) if not isinstance(matrix, tf.linalg.LinearOperator): matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype) + matrix = tf.linalg.LinearOperatorFullMatrix(matrix) - mtrace = tf.linalg.trace(matrix) + mtrace = matrix.trace() mrank = tensorshape_util.rank(matrix.shape) batch_dims = mrank - 2 @@ -477,7 +478,7 @@ def lr_cholesky_body(i, lr, residual_diag): residual_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis] # 2. Construct vector v that kills that diagonal entry and its row & col. - # v = residual_matrix[max_j, :] / sqrt(residual_matrix[max_j, maxj]) + # v = residual_matrix[max_j, :] / sqrt(residual_matrix[max_j, max_j]) maxval = tf.gather( residual_diag, max_j, axis=-1, batch_dims=batch_dims)[..., 0] normalizer = tf.sqrt(maxval) @@ -485,7 +486,7 @@ def lr_cholesky_body(i, lr, residual_diag): matrix_row = tf.squeeze(matrix.row(max_j), axis=-2) else: matrix_row = tf.gather( - matrix, max_j, axis=-1, batch_dims=batch_dims)[..., 0] + matrix.to_dense(), max_j, axis=-1, batch_dims=batch_dims)[..., 0] # residual_matrix[max_j, :] = matrix_row[max_j, :] - (lr * lr^t)[max_j, :] # And (lr * lr^t)[max_j, :] = lr[max_j, :] * lr^t lr_row_maxj = tf.gather(lr, max_j, axis=-2, batch_dims=batch_dims) @@ -494,6 +495,13 @@ def lr_cholesky_body(i, lr, residual_diag): unnormalized_v = matrix_row - lr_lrt_row v = unnormalized_v / normalizer[..., tf.newaxis] + # Mask v so that it is zero in row/columns we've already zerod. + # We can use the sign of the residual_diag as the mask because the input + # matrix being positive definite implies that the diag starts off + # positive, and only becomes zero on the entries that we've chosen + # in previous iterations. + v = v * tf.math.sign(residual_diag) + # 3. Add v to lr. # Conceptually the same as # new_lr = lr @@ -509,11 +517,21 @@ def lr_cholesky_body(i, lr, residual_diag): # 4. Compute the new residual_diag = old_residual_diag - v * v new_residual_diag = residual_diag - v * v + # Explicitly set new_residual_diag[max_j] = 0 (both to guarantee we never + # choose its index again, and to let us use the tf.math.sign of the + # residual as a mask.) + n = new_residual_diag.shape[-1] + oh = tf.one_hot( + indices=max_j[..., 0], depth=n, on_value=0.0, off_value=1.0, + dtype=new_residual_diag.dtype + ) + new_residual_diag = new_residual_diag * oh + return i + 1, new_lr, new_residual_diag lr = tf.zeros(matrix.shape, dtype=matrix.dtype)[..., :max_rank] - mdiag = tf.linalg.diag_part(matrix) + mdiag = matrix.diag_part() i, lr, residual_diag = tf.while_loop( cond=lr_cholesky_cond, body=lr_cholesky_body, diff --git a/tensorflow_probability/python/math/linalg_test.py b/tensorflow_probability/python/math/linalg_test.py index 4d57957b5e..eb36dc3e9b 100644 --- a/tensorflow_probability/python/math/linalg_test.py +++ b/tensorflow_probability/python/math/linalg_test.py @@ -447,12 +447,7 @@ def testLowRankCholesky(self): self.assertTrue(self.evaluate(tf.reduce_all( residual_trace < old_residual_trace))) old_residual_trace = residual_trace - # Compared to pivot_cholesky, low_rank_cholesky will sometimes have - # approximate zeros like 7e-17 or -2.6e-7 where it "should" have a - # real zero. - zeros_per_col = tf.math.count_nonzero( - tf.math.less(tf.math.abs(pchol), 1e-6), - axis=-2) + zeros_per_col = dim - tf.math.count_nonzero(pchol, axis=-2) mat = tf.matmul(pchol, pchol, transpose_b=True) pchol_shp, diag_diff, diff_norm, zeros_per_col = self.evaluate([ tf.shape(pchol), diff --git a/tensorflow_probability/python/math/minimize.py b/tensorflow_probability/python/math/minimize.py index 8fa2f295b6..001ee1f5a0 100644 --- a/tensorflow_probability/python/math/minimize.py +++ b/tensorflow_probability/python/math/minimize.py @@ -410,7 +410,7 @@ def minimize_stateless(loss_fn, def _make_stateful_optimizer_step_fn(loss_fn, optimizer, trainable_variables): - """Constructs a single step of a stateful (`tf.optimizers`) optimizer.""" + """Constructs a single step of a stateful (`tf_keras.optimizers`) optimizer.""" @tf.function(autograph=False) def optimizer_step(parameters, @@ -460,8 +460,8 @@ def minimize(loss_fn, `tfp.random.sanitize_seed`). num_steps: Python `int` maximum number of steps to run the optimizer. optimizer: Optimizer instance to use. This may be a TF1-style - `tf.train.Optimizer`, TF2-style `tf.optimizers.Optimizer`, or any Python - object that implements `optimizer.apply_gradients(grads_and_vars)`. + `tf.train.Optimizer`, TF2-style `tf_keras.optimizers.Optimizer`, or any + Python object that implements `optimizer.apply_gradients(grads_and_vars)`. convergence_criterion: Optional instance of `tfp.optimizer.convergence_criteria.ConvergenceCriterion` representing a criterion for detecting convergence. If `None`, @@ -528,9 +528,10 @@ def minimize(loss_fn, ```python x = tf.Variable(0.) loss_fn = lambda: (x - 5.)**2 - losses = tfp.math.minimize(loss_fn, - num_steps=100, - optimizer=tf.optimizers.Adam(learning_rate=0.1)) + losses = tfp.math.minimize( + loss_fn, + num_steps=100, + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1)) # In TF2/eager mode, the optimization runs immediately. print("optimized value is {} with loss {}".format(x, losses[-1])) @@ -552,7 +553,9 @@ def minimize(loss_fn, ```python losses = tfp.math.minimize( - loss_fn, num_steps=1000, optimizer=tf.optimizers.Adam(learning_rate=0.1), + loss_fn, + num_steps=1000, + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), convergence_criterion=( tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01))) ``` @@ -574,7 +577,7 @@ def minimize(loss_fn, trace_fn = lambda traceable_quantities: { 'loss': traceable_quantities.loss, 'x': x} trace = tfp.math.minimize(loss_fn, num_steps=100, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), trace_fn=trace_fn) print(trace['loss'].shape, # => [100] trace['x'].shape) # => [100] @@ -594,7 +597,7 @@ def minimize(loss_fn, 'loss': traceable_quantities.loss, 'has_converged': traceable_quantities.has_converged} trace = tfp.math.minimize(loss_fn, num_steps=100, - optimizer=tf.optimizers.Adam(0.1),, + optimizer=tf_keras.optimizers.Adam(0.1),, trace_fn=trace_fn, convergence_criterion=( tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01))) diff --git a/tensorflow_probability/python/math/minimize_test.py b/tensorflow_probability/python/math/minimize_test.py index ab16d8f602..ef373022cc 100644 --- a/tensorflow_probability/python/math/minimize_test.py +++ b/tensorflow_probability/python/math/minimize_test.py @@ -24,6 +24,7 @@ from tensorflow_probability.python import optimizer from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math.minimize import minimize from tensorflow_probability.python.math.minimize import minimize_stateless @@ -32,14 +33,14 @@ def _get_adam_optimizer(learning_rate): if tf.__internal__.tf2.enabled(): - return tf.keras.optimizers.Adam(learning_rate=learning_rate) - return tf.keras.optimizers.legacy.Adam(learning_rate=learning_rate) + return tf_keras.optimizers.Adam(learning_rate=learning_rate) + return tf_keras.optimizers.legacy.Adam(learning_rate=learning_rate) def _get_sgd_optimizer(learning_rate): if tf.__internal__.tf2.enabled(): - return tf.keras.optimizers.SGD(learning_rate=learning_rate) - return tf.keras.optimizers.legacy.SGD(learning_rate=learning_rate) + return tf_keras.optimizers.SGD(learning_rate=learning_rate) + return tf_keras.optimizers.legacy.SGD(learning_rate=learning_rate) @test_util.test_all_tf_execution_regimes diff --git a/tensorflow_probability/python/math/ode/ode_test.py b/tensorflow_probability/python/math/ode/ode_test.py index ad3f2c2d74..650c9ce034 100644 --- a/tensorflow_probability/python/math/ode/ode_test.py +++ b/tensorflow_probability/python/math/ode/ode_test.py @@ -73,14 +73,19 @@ def __init__(self, make_solver_fn, first_step_size): ) def _solve(self, **kwargs): - step_size = kwargs.pop('previous_solver_internal_state') + step_size, solve_count = kwargs.pop('previous_solver_internal_state') results = self._make_solver_fn(step_size).solve(**kwargs) return results._replace( - solver_internal_state=results.solver_internal_state.step_size) + solver_internal_state=( + results.solver_internal_state.step_size, + solve_count + 1, + ) + ) def _initialize_solver_internal_state(self, **kwargs): del kwargs - return self._first_step_size + # The second value is solve count, for testing. + return (self._first_step_size, 0) def _adjust_solver_internal_state_for_state_jump(self, **kwargs): return kwargs['previous_solver_internal_state'] @@ -447,17 +452,17 @@ def test_riccati_custom_adjoint_solver(self, solver, solution_times_fn): # Instrument the adjoint solver for testing. We have to do this because the # API doesn't provide access to the adjoint solver's diagnostics. first_step_size = np.float64(1.) - last_initial_step_size = tf.Variable(0., dtype=tf.float64) - self.evaluate(last_initial_step_size.initializer) + solve_count = tf.Variable(0, dtype=tf.int32) + self.evaluate(solve_count.initializer) class _InstrumentedSolver(StepSizeHeuristicAdjointSolver): def solve(self, **kwargs): - with tf.control_dependencies([ - last_initial_step_size.assign( - kwargs['previous_solver_internal_state']) - ]): - return super(_InstrumentedSolver, self).solve(**kwargs) + results = super(_InstrumentedSolver, self).solve(**kwargs) + with tf.control_dependencies( + [solve_count.assign(results.solver_internal_state[1])] + ): + return tf.nest.map_structure(tf.identity, results) adjoint_solver = _InstrumentedSolver( make_solver_fn=lambda step_size: solver( # pylint: disable=g-long-lambda @@ -479,13 +484,14 @@ def grad_fn(initial_state): final_state = results.states[-1] return final_state _, grad = tfp_gradient.value_and_gradient(grad_fn, initial_state) - grad, last_initial_step_size = self.evaluate((grad, last_initial_step_size)) + grad = self.evaluate(grad) + # There's a race condition if we evaluate solve_count right away. Evaluate + # it after we're done the computation to produce `grad`. + solve_count = self.evaluate(solve_count) grad_exact = 1. / (1. - initial_state_value * final_time)**2 self.assertAllClose(grad, grad_exact, rtol=1e-3, atol=1e-3) - # This indicates that the adaptation carried over to the final solve. We - # expect the step size to decrease because we purposefully made the initial - # step size way too large. - self.assertLess(last_initial_step_size, first_step_size) + # This indicates that the adaptation carried over to the final solve. + self.assertGreater(solve_count, 0) def test_linear_ode(self, solver, solution_times_fn): if not tf1.control_flow_v2_enabled(): diff --git a/tensorflow_probability/python/math/psd_kernels/psd_kernel_properties_test.py b/tensorflow_probability/python/math/psd_kernels/psd_kernel_properties_test.py index c6b0678e03..f57f61d0b5 100644 --- a/tensorflow_probability/python/math/psd_kernels/psd_kernel_properties_test.py +++ b/tensorflow_probability/python/math/psd_kernels/psd_kernel_properties_test.py @@ -218,9 +218,9 @@ def _test_slicing( (slices,)) apply_slices += tuple([slice(None)] * example_ndims) - # Check that sampling a sliced kernel produces the same shape as - # slicing the samples from the original. - self.assertAllClose(results[apply_slices], sliced_results) + # Check that applying a sliced kernel produces the same results as slicing + # the results from the original. + self.assertAllClose(results[apply_slices], sliced_results, rtol=1e-5) @parameterized.named_parameters( {'testcase_name': dname, 'kernel_name': dname} diff --git a/tensorflow_probability/python/math/special_test.py b/tensorflow_probability/python/math/special_test.py index 9a7613f149..675bcf2a44 100644 --- a/tensorflow_probability/python/math/special_test.py +++ b/tensorflow_probability/python/math/special_test.py @@ -560,8 +560,8 @@ def _test_betaincinv_value(self, a_high, b_high, dtype, atol, rtol): "rtol": 2e-3}, {"testcase_name": "float64", "dtype": np.float64, - "atol": 1e-12, - "rtol": 1e-11}) + "atol": 3e-12, + "rtol": 3e-11}) def testBetaincinvSmall(self, dtype, atol, rtol): self._test_betaincinv_value( a_high=1., b_high=1., dtype=dtype, atol=atol, rtol=rtol) diff --git a/tensorflow_probability/python/mcmc/BUILD b/tensorflow_probability/python/mcmc/BUILD index 357c546c25..0e7b126578 100644 --- a/tensorflow_probability/python/mcmc/BUILD +++ b/tensorflow_probability/python/mcmc/BUILD @@ -141,6 +141,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:distribute_lib", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/mcmc/internal:leapfrog_integrator", "//tensorflow_probability/python/mcmc/internal:util", "//tensorflow_probability/python/util:seed_stream", @@ -175,6 +176,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/internal:tensorshape_util", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:generic", "//tensorflow_probability/python/util:deferred_tensor", ], @@ -494,7 +496,7 @@ multi_substrate_py_test( multi_substrate_py_library( name = "sample_halton_sequence", - srcs = ["sample_halton_sequence.py"], + srcs = ["sample_halton_sequence_lib.py"], deps = [ # numpy dep, # tensorflow dep, diff --git a/tensorflow_probability/python/mcmc/__init__.py b/tensorflow_probability/python/mcmc/__init__.py index 7aa4d79db2..0399a17981 100644 --- a/tensorflow_probability/python/mcmc/__init__.py +++ b/tensorflow_probability/python/mcmc/__init__.py @@ -36,7 +36,7 @@ from tensorflow_probability.python.mcmc.sample import sample_chain from tensorflow_probability.python.mcmc.sample import StatesAndTrace from tensorflow_probability.python.mcmc.sample_annealed_importance import sample_annealed_importance_chain -from tensorflow_probability.python.mcmc.sample_halton_sequence import sample_halton_sequence +from tensorflow_probability.python.mcmc.sample_halton_sequence_lib import sample_halton_sequence from tensorflow_probability.python.mcmc.simple_step_size_adaptation import SimpleStepSizeAdaptation from tensorflow_probability.python.mcmc.slice_sampler_kernel import SliceSampler from tensorflow_probability.python.mcmc.transformed_kernel import TransformedTransitionKernel diff --git a/tensorflow_probability/python/mcmc/hmc.py b/tensorflow_probability/python/mcmc/hmc.py index 9019f8d0dc..aceae587de 100644 --- a/tensorflow_probability/python/mcmc/hmc.py +++ b/tensorflow_probability/python/mcmc/hmc.py @@ -308,7 +308,7 @@ def make_response_likelihood(w, x): log_sigma = tf.Variable(0., dtype=dtype, name='log_sigma') - optimizer = tf.optimizers.SGD(learning_rate=0.01) + optimizer = tf_keras.optimizers.SGD(learning_rate=0.01) @tf.function def mcem_iter(weights_chain_start, step_size): diff --git a/tensorflow_probability/python/mcmc/hmc_test.py b/tensorflow_probability/python/mcmc/hmc_test.py index 5aa88ee549..ef0fd4d4a4 100644 --- a/tensorflow_probability/python/mcmc/hmc_test.py +++ b/tensorflow_probability/python/mcmc/hmc_test.py @@ -40,6 +40,7 @@ from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math import generic from tensorflow_probability.python.mcmc import hmc from tensorflow_probability.python.mcmc import sample as sample_lib @@ -997,7 +998,7 @@ def test_mcem_converges(self): sigma = deferred_tensor.TransformedVariable( name='sigma', initial_value=np.array(1, dtype), bijector=exp.Exp()) - optimizer = tf.optimizers.SGD(learning_rate=0.01) + optimizer = tf_keras.optimizers.SGD(learning_rate=0.01) # TODO(b/144045420): eliminate the need for this tf.function decorator. The # reason it was added was that the test code is written to work in both diff --git a/tensorflow_probability/python/mcmc/sample_halton_sequence.py b/tensorflow_probability/python/mcmc/sample_halton_sequence_lib.py similarity index 84% rename from tensorflow_probability/python/mcmc/sample_halton_sequence.py rename to tensorflow_probability/python/mcmc/sample_halton_sequence_lib.py index f9d6cf2ce0..c767ed2fba 100644 --- a/tensorflow_probability/python/mcmc/sample_halton_sequence.py +++ b/tensorflow_probability/python/mcmc/sample_halton_sequence_lib.py @@ -31,7 +31,7 @@ # The maximum dimension we support. This is limited by the number of primes # in the _PRIMES array. -_MAX_DIMENSION = 1000 +_MAX_DIMENSION = 10000 def sample_halton_sequence(dim, @@ -53,7 +53,7 @@ def sample_halton_sequence(dim, Computes the members of the low discrepancy Halton sequence in dimension `dim`. The `dim`-dimensional sequence takes values in the unit hypercube in - `dim` dimensions. Currently, only dimensions up to 1000 are supported. The + `dim` dimensions. Currently, only dimensions up to 10000 are supported. The prime base for the k-th axes is the k-th prime starting from 2. For example, if `dim` = 3, then the bases will be [2, 3, 5] respectively and the first element of the non-randomized sequence will be: [0.5, 0.333, 0.2]. For a more @@ -121,7 +121,7 @@ def sample_halton_sequence(dim, Args: dim: Positive Python `int` representing each sample's `event_size.` Must - not be greater than 1000. + not be greater than 10000. num_results: (Optional) Positive scalar `Tensor` of dtype int32. The number of samples to generate. Either this parameter or sequence_indices must be specified but not both. If this parameter is None, then the behaviour @@ -158,7 +158,7 @@ def sample_halton_sequence(dim, Raises: ValueError: if both `sequence_indices` and `num_results` were specified or - if dimension `dim` is less than 1 or greater than 1000. + if dimension `dim` is less than 1 or greater than 10000. #### References @@ -182,17 +182,14 @@ def sample_halton_sequence(dim, # The coefficient dimension is an intermediate axes which will hold the # weights of the starting integer when expressed in the (prime) base for # an event dimension. - if num_results is not None: - num_results = tf.convert_to_tensor(num_results) if sequence_indices is not None: sequence_indices = tf.convert_to_tensor(sequence_indices) indices = _get_indices(num_results, sequence_indices, dtype) - radixes = tf.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1]) - - max_sizes_by_axes = _base_expansion_size( - tf.reduce_max(indices), radixes) - - max_size = tf.reduce_max(max_sizes_by_axes) + if num_results is None: + num_results = ps.reduce_max(indices) + radixes = _PRIMES[0:dim][..., np.newaxis] + max_sizes_by_axes = _base_expansion_size(num_results, radixes, dtype) + max_size = ps.reduce_max(max_sizes_by_axes) # The powers of the radixes that we will need. Note that there is a bit # of an excess here. Suppose we need the place value coefficients of 7 @@ -204,14 +201,13 @@ def sample_halton_sequence(dim, # dimensions, then the 10th prime (29) we will end up computing 29^10 even # though we don't need it. We avoid this by setting the exponents for each # axes to 0 beyond the maximum value needed for that dimension. - exponents_by_axes = tf.tile([tf.range(max_size)], [dim, 1]) + exponents_by_axes = tf.tile([tf.range(max_size, dtype=dtype)], [dim, 1]) # The mask is true for those coefficients that are irrelevant. weight_mask = exponents_by_axes < max_sizes_by_axes - capped_exponents = tf.where(weight_mask, - exponents_by_axes, - tf.constant(0, exponents_by_axes.dtype)) - weights = radixes ** capped_exponents + capped_exponents = tf.where( + weight_mask, exponents_by_axes, dtype_util.as_numpy_dtype(dtype)(0.)) + weights = tf.cast(radixes ** capped_exponents, dtype=dtype) # The following computes the base b expansion of the indices. Suppose, # x = a0 + a1*b + a2*b^2 + ... Then, performing a floor div of x with # the vector (1, b, b^2, b^3, ...) will produce @@ -246,7 +242,7 @@ def sample_halton_sequence(dim, zero_correction = samplers.uniform([dim, 1], seed=zero_correction_seed, dtype=dtype) - zero_correction /= radixes ** max_sizes_by_axes + zero_correction /= tf.cast(radixes ** max_sizes_by_axes, dtype) return base_values + tf.reshape(zero_correction, [-1]) @@ -254,14 +250,14 @@ def _randomize(coeffs, radixes, seed=None): """Applies the Owen (2017) randomization to the coefficients.""" given_dtype = coeffs.dtype coeffs = tf.cast(coeffs, dtype=tf.int32) - num_coeffs = tf.shape(coeffs)[-1] - radixes = tf.reshape(tf.cast(radixes, dtype=tf.int32), shape=[-1]) - perms = _get_permutations(num_coeffs, radixes, seed=seed) + num_coeffs = ps.shape(coeffs)[-1] + perms = _get_permutations(num_coeffs, np.squeeze(radixes, axis=-1), seed=seed) perms = tf.reshape(perms, shape=[-1]) + radixes = tf.reshape(tf.cast(radixes, dtype=tf.int32), shape=[-1]) radix_sum = tf.reduce_sum(radixes) radix_offsets = tf.reshape(tf.cumsum(radixes, exclusive=True), shape=[-1, 1]) - offsets = radix_offsets + tf.range(num_coeffs) * radix_sum + offsets = radix_offsets + ps.range(num_coeffs, dtype=tf.int32) * radix_sum permuted_coeffs = tf.gather(perms, coeffs + offsets) return tf.cast(permuted_coeffs, dtype=given_dtype) @@ -280,7 +276,7 @@ def _get_permutations(num_results, dims, seed=None): Args: num_results: A positive scalar `Tensor` of integral type. The number of draws from the discrete uniform distribution over the permutation groups. - dims: A 1D `Tensor` of the same dtype as `num_results`. The degree of the + dims: A 1D numpy array of the same dtype as `num_results`. The degree of the permutation groups from which to sample. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. @@ -288,14 +284,20 @@ def _get_permutations(num_results, dims, seed=None): permutations: A `Tensor` of shape `[num_results, sum(dims)]` and the same dtype as `dims`. """ - seeds = samplers.split_seed(seed, n=ps.size(dims)) - - def generate_one(dim, seed): - return tf.argsort(samplers.uniform([num_results, dim], seed=seed), axis=-1) - - return tf.concat([generate_one(dim, seed) - for dim, seed in zip(tf.unstack(dims), tf.unstack(seeds))], - axis=-1) + n = dims.size + max_size = np.max(dims) + samples = samplers.uniform([num_results, n, max_size], seed=seed) + should_mask = np.arange(max_size) >= dims[..., np.newaxis] + # Choose a number that does not affect the permutation and relative location. + samples = tf.where( + should_mask, + dtype_util.as_numpy_dtype(samples.dtype)(np.arange(max_size) + 10.), + samples) + samples = tf.argsort(samples, axis=-1) + # Generate the set of indices to gather. + should_mask = np.tile(should_mask, [num_results, 1, 1]) + indices = np.stack(np.where(~should_mask), axis=-1) + return tf.gather_nd(samples, indices) def _get_indices(num_results, sequence_indices, dtype, name=None): @@ -325,8 +327,13 @@ def _get_indices(num_results, sequence_indices, dtype, name=None): """ with tf.name_scope(name or 'get_indices'): if sequence_indices is None: - num_results = tf.cast(num_results, dtype=dtype) - sequence_indices = tf.range(num_results, dtype=dtype) + np_dtype = dtype_util.as_numpy_dtype(dtype) + num_results_ = tf.get_static_value(num_results) + if num_results_ is not None: + sequence_indices = ps.range(np_dtype(num_results_), dtype=dtype) + else: + num_results = tf.cast(num_results, dtype=dtype) + sequence_indices = ps.range(num_results, dtype=dtype) else: sequence_indices = tf.cast(sequence_indices, dtype) @@ -338,7 +345,7 @@ def _get_indices(num_results, sequence_indices, dtype, name=None): return tf.reshape(indices, [-1, 1, 1]) -def _base_expansion_size(num, bases): +def _base_expansion_size(num, bases, dtype): """Computes the number of terms in the place value expansion. Let num = a0 + a1 b + a2 b^2 + ... ak b^k be the place value expansion of @@ -349,37 +356,36 @@ def _base_expansion_size(num, bases): $$k = Floor(log_b (num)) + 1 = Floor( log(num) / log(b)) + 1$$ Args: - num: Scalar `Tensor` of dtype either `float32` or `float64`. The number to + num: Scalar `Tensor` of dtype either `int32` or `int64`. The number to compute the base expansion size of. bases: `Tensor` of the same dtype as num. The bases to compute the size against. + dtype: Return `dtype`. Returns: - Tensor of same dtype and shape as `bases` containing the size of num when + Tensor of dtype `dtype` and shape as `bases` containing the size of num when written in that base. """ - return tf.floor(tf.math.log(num) / tf.math.log(bases)) + 1 + num_ = tf.get_static_value(num) + if num_ is not None: + return (np.floor(np.log(num_) / np.log(bases)) + 1).astype( + dtype_util.as_numpy_dtype(dtype)) + + return tf.floor( + tf.math.log(tf.cast(num, dtype)) / tf.math.log(tf.cast(bases, dtype))) + 1 def _primes_less_than(n): - # Based on - # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188 """Returns sorted array of primes such that `2 <= prime < n`.""" - small_primes = np.array((2, 3, 5)) - if n <= 6: - return small_primes[small_primes < n] - sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool_) - sieve[0] = False - m = int(n ** 0.5) // 3 + 1 - for i in range(m): - if not sieve[i]: - continue - k = 3 * i + 1 | 1 - sieve[k ** 2 // 3::2 * k] = False - sieve[(k ** 2 + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False - return np.r_[2, 3, 3 * np.nonzero(sieve)[0] + 1 | 1] - -_PRIMES = _primes_less_than(7919 + 1) - - + primes = np.ones((n + 1) // 2, dtype=bool) + j = 3 + while j * j <= n: + if primes[j//2]: + primes[j*j//2::j] = False + j += 2 + ret = 2 * np.where(primes)[0] + 1 + ret[0] = 2 # :( + return ret + +_PRIMES = _primes_less_than(104729 + 1) assert len(_PRIMES) == _MAX_DIMENSION diff --git a/tensorflow_probability/python/mcmc/sample_halton_sequence_test.py b/tensorflow_probability/python/mcmc/sample_halton_sequence_test.py index 64d3c1964b..0776dbc5cf 100644 --- a/tensorflow_probability/python/mcmc/sample_halton_sequence_test.py +++ b/tensorflow_probability/python/mcmc/sample_halton_sequence_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Tests for sample_halton_sequence.py.""" +"""Tests for sample_halton_sequence_lib.py.""" # Dependency imports @@ -21,7 +21,7 @@ from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.internal import monte_carlo from tensorflow_probability.python.internal import test_util -from tensorflow_probability.python.mcmc import sample_halton_sequence +from tensorflow_probability.python.mcmc import sample_halton_sequence_lib JAX_MODE = False @@ -38,7 +38,7 @@ def test_known_values_small_bases(self): [3. / 4, 1. / 9], [1. / 8, 4. / 9], [5. / 8, 7. / 9]], dtype=np.float32) - sample = sample_halton_sequence.sample_halton_sequence( + sample = sample_halton_sequence_lib.sample_halton_sequence( 2, num_results=5, randomized=False) self.assertAllClose(expected, self.evaluate(sample), rtol=1e-6) @@ -51,7 +51,7 @@ def test_dynamic_num_samples(self): [3. / 4, 1. / 9], [1. / 8, 4. / 9], [5. / 8, 7. / 9]], dtype=np.float32) - sample = sample_halton_sequence.sample_halton_sequence( + sample = sample_halton_sequence_lib.sample_halton_sequence( 2, num_results=tf.constant(5), randomized=False) self.assertAllClose(expected, self.evaluate(sample), rtol=1e-6) @@ -59,9 +59,9 @@ def test_sequence_indices(self): """Tests access of sequence elements by index.""" dim = 5 indices = tf.range(10, dtype=tf.int32) - sample_direct = sample_halton_sequence.sample_halton_sequence( + sample_direct = sample_halton_sequence_lib.sample_halton_sequence( dim, num_results=10, randomized=False) - sample_from_indices = sample_halton_sequence.sample_halton_sequence( + sample_from_indices = sample_halton_sequence_lib.sample_halton_sequence( dim, sequence_indices=indices, randomized=False) self.assertAllClose( self.evaluate(sample_direct), self.evaluate(sample_from_indices), @@ -70,13 +70,30 @@ def test_sequence_indices(self): def test_dtypes_works_correctly(self): """Tests that all supported dtypes work without error.""" dim = 3 - sample_float32 = sample_halton_sequence.sample_halton_sequence( + sample_float32 = sample_halton_sequence_lib.sample_halton_sequence( dim, num_results=10, dtype=tf.float32, seed=test_util.test_seed()) - sample_float64 = sample_halton_sequence.sample_halton_sequence( + sample_float64 = sample_halton_sequence_lib.sample_halton_sequence( dim, num_results=10, dtype=tf.float64, seed=test_util.test_seed()) self.assertEqual(self.evaluate(sample_float32).dtype, np.float32) self.assertEqual(self.evaluate(sample_float64).dtype, np.float64) + @test_util.disable_test_for_backend( + disable_numpy=True, reason="Numpy has no notion of jit compilation.") + def test_jit_works_correctly(self): + @tf.function(jit_compile=True) + def sample_float32(): + return sample_halton_sequence_lib.sample_halton_sequence( + 5, num_results=10, dtype=tf.float32, seed=test_util.test_seed()) + samples = sample_float32() + self.assertEqual(samples.shape, [10, 5]) + + @tf.function(jit_compile=True) + def sample_float64(): + return sample_halton_sequence_lib.sample_halton_sequence( + 5, num_results=10, dtype=tf.float64, seed=test_util.test_seed()) + samples = sample_float64() + self.assertEqual(samples.shape, [10, 5]) + def test_normal_integral_mean_and_var_correctly_estimated(self): n = 1000 # This test is almost identical to the similarly named test in @@ -93,7 +110,7 @@ def test_normal_integral_mean_and_var_correctly_estimated(self): p = normal.Normal(loc=mu_p, scale=sigma_p) q = normal.Normal(loc=mu_q, scale=sigma_q) - cdf_sample = sample_halton_sequence.sample_halton_sequence( + cdf_sample = sample_halton_sequence_lib.sample_halton_sequence( 2, num_results=n, dtype=tf.float64, seed=test_util.test_seed()) q_sample = q.quantile(cdf_sample) @@ -116,7 +133,7 @@ def test_docstring_example(self): # Produce the first 1000 members of the Halton sequence in 3 dimensions. num_results = 1000 dim = 3 - sample = sample_halton_sequence.sample_halton_sequence( + sample = sample_halton_sequence_lib.sample_halton_sequence( dim, num_results=num_results, randomized=False) # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional @@ -134,7 +151,7 @@ def test_docstring_example(self): sequence_indices = tf.range(start=1000, limit=1000 + num_results, dtype=tf.int32) - sample_leaped = sample_halton_sequence.sample_halton_sequence( + sample_leaped = sample_halton_sequence_lib.sample_halton_sequence( dim, sequence_indices=sequence_indices, randomized=False) integral_leaped = tf.reduce_mean( @@ -150,7 +167,7 @@ def test_randomized_qmc_basic(self): num_results = 2000 replicas = 50 - samples = sample_halton_sequence.sample_halton_sequence( + samples = sample_halton_sequence_lib.sample_halton_sequence( dim, num_results=replicas * num_results, seed=test_util.test_seed_stream()) @@ -195,9 +212,9 @@ def func_estimate(x): axis=-1) stream = test_util.test_seed_stream() - sample_lo = sample_halton_sequence.sample_halton_sequence( + sample_lo = sample_halton_sequence_lib.sample_halton_sequence( dim, num_results=replica * num_results_lo, seed=stream()) - sample_hi = sample_halton_sequence.sample_halton_sequence( + sample_hi = sample_halton_sequence_lib.sample_halton_sequence( dim, num_results=replica * num_results_hi, seed=stream()) sample_lo = tf.reshape(sample_lo, [replica, -1, dim]) @@ -223,11 +240,11 @@ def test_seed_implies_deterministic_results(self): dim = 20 num_results = 100 seed = test_util.test_seed() - sample1 = sample_halton_sequence.sample_halton_sequence( + sample1 = sample_halton_sequence_lib.sample_halton_sequence( dim, num_results=num_results, seed=seed) if tf.executing_eagerly() and not JAX_MODE: tf.random.set_seed(seed) - sample2 = sample_halton_sequence.sample_halton_sequence( + sample2 = sample_halton_sequence_lib.sample_halton_sequence( dim, num_results=num_results, seed=seed) [sample1_, sample2_] = self.evaluate([sample1, sample2]) self.assertAllClose(sample1_, sample2_, atol=0., rtol=1e-6) diff --git a/tensorflow_probability/python/optimizer/BUILD b/tensorflow_probability/python/optimizer/BUILD index fd6dc3df8e..46c51d0cf4 100644 --- a/tensorflow_probability/python/optimizer/BUILD +++ b/tensorflow_probability/python/optimizer/BUILD @@ -55,6 +55,7 @@ multi_substrate_py_library( deps = [ # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math:diag_jacobian", ], ) @@ -84,6 +85,7 @@ multi_substrate_py_library( # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", "//tensorflow_probability/python/internal:distribution_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) diff --git a/tensorflow_probability/python/optimizer/bfgs.py b/tensorflow_probability/python/optimizer/bfgs.py index 4a5287af30..22342300b6 100644 --- a/tensorflow_probability/python/optimizer/bfgs.py +++ b/tensorflow_probability/python/optimizer/bfgs.py @@ -61,8 +61,12 @@ # `final_position`. If the search converged # the max-norm of this tensor should be # below the tolerance. - 'inverse_hessian_estimate' # A tensor containing the inverse of the - # estimated Hessian. + 'inverse_hessian_estimate', # A tensor containing the inverse of the + # estimated Hessian. + 'scale_initial_inverse_hessian' # Should the initial inverse Hessian + # be rescaled on the first iteration, + # as per Chapter 6 of Nocedal and + # Wright. ]) @@ -72,6 +76,7 @@ def minimize(value_and_gradients_function, x_tolerance=0, f_relative_tolerance=0, initial_inverse_hessian_estimate=None, + scale_initial_inverse_hessian=True, max_iterations=50, parallel_iterations=1, stopping_condition=None, @@ -149,6 +154,9 @@ def quadratic_loss_and_gradient(x): the inverse of the Hessian at the initial point. If not specified, the identity matrix is used as the starting estimate for the inverse Hessian. + scale_initial_inverse_hessian: If overridden to False, we skip scaling the + initial inverse Hessian (Chapter 6 of Nocedal and Wright suggests scaling + this). max_iterations: Scalar positive int32 `Tensor`. The maximum number of iterations for BFGS updates. parallel_iterations: Positive integer. The number of iterations allowed to @@ -290,6 +298,7 @@ def _body(state): tolerance, control_inputs) kwargs['inverse_hessian_estimate'] = initial_inv_hessian + kwargs['scale_initial_inverse_hessian'] = scale_initial_inverse_hessian initial_state = BfgsOptimizerResults(**kwargs) return tf.while_loop( cond=_cond, @@ -355,9 +364,11 @@ def _update_inv_hessian(prev_state, next_state): # Rescale the initial hessian at the first step, as suggested # in Chapter 6 of Numerical Optimization, by Nocedal and Wright. scale_factor = tf.where( - tf.math.equal(prev_state.num_iterations, 0), + (tf.math.equal(prev_state.num_iterations, 0) & + prev_state.scale_initial_inverse_hessian), normalization_factor / tf.reduce_sum( - tf.math.square(gradient_delta), axis=-1), 1.) + tf.math.square(gradient_delta), axis=-1), + 1.) inverse_hessian_estimate = scale_factor[ ..., tf.newaxis, tf.newaxis] * prev_state.inverse_hessian_estimate diff --git a/tensorflow_probability/python/optimizer/bfgs_test.py b/tensorflow_probability/python/optimizer/bfgs_test.py index e7cc1d1d9e..5a200dc5c0 100644 --- a/tensorflow_probability/python/optimizer/bfgs_test.py +++ b/tensorflow_probability/python/optimizer/bfgs_test.py @@ -427,6 +427,50 @@ def himmelblau(coord): self.assertArrayNear(actual, expected, 1e-5) self.assertEqual(batch_results.num_objective_evaluations, 31) + def test_scale_initial_inverse_hessian(self): + """Tests optional scaling of the initial inverse Hessian estimate. + + Shows that the choice of the option determines the behaviour inside + the BFGS optimisation. + """ + @_make_val_and_grad_fn + def sin_x_times_sin_y(coord): + x, y = coord[0], coord[1] + return tf.math.sin(x) + tf.math.sin(y) + + start = tf.constant((1, -2), dtype=np.float64) + + results = {} + for scale in (True, False): + for max_iter in (1, 2, 50): + results[scale, max_iter] = self.evaluate( + bfgs.minimize( + sin_x_times_sin_y, + initial_position=start, + tolerance=1e-8, + scale_initial_inverse_hessian=scale, + max_iterations=max_iter, + ) + ) + + expected_positions = { + # Positions traced by the optimisation on the first iteration + # are not affected by the choice of `scale_initial_inverse_hessian`. + (True, 1): (-0.62581634, -0.7477782), + (False, 1): (-0.62581634, -0.7477782), + # However, gradient calculations on the first iteration _are_ affected, + # and this affects positions identified on the second iteration. + (True, 2): (-1.70200959, -0.37774139), + (False, 2): (-1.24714478, -0.55028845), + # Both approaches converge to the same maximum eventually (although + # this is not guaranteed, it depends on the exact problem being solved). + (True, 50): (-1.57079633, -1.57079633), + (False, 50): (-1.57079633, -1.57079633), + } + + for key, res in results.items(): + self.assertArrayNear(res.position, expected_positions[key], 1e-6) + def test_data_fitting(self): """Tests MLE estimation for a simple geometric GLM.""" n, dim = 100, 3 diff --git a/tensorflow_probability/python/optimizer/convergence_criteria/BUILD b/tensorflow_probability/python/optimizer/convergence_criteria/BUILD index fb18821b74..731d84822e 100644 --- a/tensorflow_probability/python/optimizer/convergence_criteria/BUILD +++ b/tensorflow_probability/python/optimizer/convergence_criteria/BUILD @@ -99,6 +99,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/bijectors:softplus", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/util:deferred_tensor", "//tensorflow_probability/python/vi:csiszar_divergence", ], diff --git a/tensorflow_probability/python/optimizer/convergence_criteria/successive_gradients_are_uncorrelated_test.py b/tensorflow_probability/python/optimizer/convergence_criteria/successive_gradients_are_uncorrelated_test.py index 401a3d23f7..33c46a6011 100644 --- a/tensorflow_probability/python/optimizer/convergence_criteria/successive_gradients_are_uncorrelated_test.py +++ b/tensorflow_probability/python/optimizer/convergence_criteria/successive_gradients_are_uncorrelated_test.py @@ -20,6 +20,7 @@ from tensorflow_probability.python.bijectors import softplus from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.optimizer.convergence_criteria import successive_gradients_are_uncorrelated as sgau from tensorflow_probability.python.util import deferred_tensor from tensorflow_probability.python.vi import csiszar_divergence @@ -44,7 +45,7 @@ def test_stochastic_optimization(self): trained_dist = normal.Normal(locs, scales) target_dist = normal.Normal(loc=-0.4, scale=1.2) - optimizer = tf.optimizers.Adam(learning_rate=0.1) + optimizer = tf_keras.optimizers.Adam(learning_rate=0.1) @tf.function(autograph=False) def optimization_step(): with tf.GradientTape() as tape: diff --git a/tensorflow_probability/python/optimizer/sgld.py b/tensorflow_probability/python/optimizer/sgld.py index e40c6353aa..8e27f87ff6 100644 --- a/tensorflow_probability/python/optimizer/sgld.py +++ b/tensorflow_probability/python/optimizer/sgld.py @@ -19,8 +19,8 @@ from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math.diag_jacobian import diag_jacobian -from tensorflow.python.training import training_ops __all__ = [ @@ -29,7 +29,7 @@ # pylint: disable=g-classes-have-attributes -class StochasticGradientLangevinDynamics(tf.keras.optimizers.legacy.Optimizer): +class StochasticGradientLangevinDynamics(tf_keras.optimizers.legacy.Optimizer): """An optimizer module for stochastic gradient Langevin dynamics. This implements the preconditioned Stochastic Gradient Langevin Dynamics @@ -168,7 +168,7 @@ def __init__(self, diagonal_bias, name='diagonal_bias') # TODO(b/124800185): Consider migrating `learning_rate` to be a # hyperparameter handled by the base Optimizer class. This would allow - # users to plug in a `tf.keras.optimizers.schedules.LearningRateSchedule` + # users to plug in a `tf_keras.optimizers.schedules.LearningRateSchedule` # object in addition to Tensors. self._learning_rate = tf.convert_to_tensor( learning_rate, name='learning_rate') @@ -235,10 +235,10 @@ def _prepare(self, var_list): def _resource_apply_dense(self, grad, var): rms = self.get_slot(var, 'rms') new_grad = self._apply_noisy_update(rms, grad, var) - return training_ops.resource_apply_gradient_descent( - var.handle, - tf.cast(self._learning_rate_tensor, var.dtype.base_dtype), - new_grad, + return tf.raw_ops.ResourceApplyGradientDescent( + var=var.handle, + alpha=tf.cast(self._learning_rate_tensor, var.dtype.base_dtype), + delta=new_grad, use_locking=self._use_locking) def _resource_apply_sparse(self, grad, var, indices): diff --git a/tensorflow_probability/python/optimizer/variational_sgd.py b/tensorflow_probability/python/optimizer/variational_sgd.py index 635d6b6f5b..8109f8ae5b 100644 --- a/tensorflow_probability/python/optimizer/variational_sgd.py +++ b/tensorflow_probability/python/optimizer/variational_sgd.py @@ -19,7 +19,8 @@ from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util -from tensorflow.python.training import training_ops + +from tensorflow_probability.python.internal import tf_keras __all__ = [ @@ -28,7 +29,7 @@ # pylint: disable=g-classes-have-attributes -class VariationalSGD(tf.keras.optimizers.legacy.Optimizer): +class VariationalSGD(tf_keras.optimizers.legacy.Optimizer): """An optimizer module for constant stochastic gradient descent. This implements an optimizer module for the constant stochastic gradient @@ -236,10 +237,10 @@ def _resource_apply_dense(self, grad, var): tf.cast(max_learning_rate, var.dtype.base_dtype)) newgrad = grad * learn_rates - return training_ops.resource_apply_gradient_descent( - var.handle, - tf.cast(1., var.dtype), - newgrad, + return tf.raw_ops.ResourceApplyGradientDescent( + var=var.handle, + alpha=tf.cast(1., var.dtype), + delta=newgrad, use_locking=self._use_locking) def _resource_apply_sparse(self, grad, var, indices): diff --git a/tensorflow_probability/python/sts/BUILD b/tensorflow_probability/python/sts/BUILD index aab722e1f4..2c5c923dfb 100644 --- a/tensorflow_probability/python/sts/BUILD +++ b/tensorflow_probability/python/sts/BUILD @@ -113,6 +113,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:exponential", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/optimizer", "//tensorflow_probability/python/sts/components:local_linear_trend", "//tensorflow_probability/python/sts/components:seasonal", diff --git a/tensorflow_probability/python/sts/default_model.py b/tensorflow_probability/python/sts/default_model.py index fb1f138425..6b494486db 100644 --- a/tensorflow_probability/python/sts/default_model.py +++ b/tensorflow_probability/python/sts/default_model.py @@ -95,7 +95,7 @@ def build_default_model(observed_time_series, losses = tfp.vi.fit_surrogate_posterior( target_log_prob_fn=model.joint_distribution(series).log_prob, surrogate_posterior=surrogate_posterior, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), num_steps=1000, convergence_criterion=( tfp.optimizer.convergence_criteria.SuccessiveGradientsAreUncorrelated( diff --git a/tensorflow_probability/python/sts/default_model_test.py b/tensorflow_probability/python/sts/default_model_test.py index d96b2a471d..d679401533 100644 --- a/tensorflow_probability/python/sts/default_model_test.py +++ b/tensorflow_probability/python/sts/default_model_test.py @@ -22,6 +22,7 @@ from tensorflow_probability.python.distributions import exponential from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.optimizer.convergence_criteria import successive_gradients_are_uncorrelated from tensorflow_probability.python.sts import default_model from tensorflow_probability.python.sts import fitting @@ -111,7 +112,7 @@ def test_docstring_fitting_example(self): _ = optimization.fit_surrogate_posterior( target_log_prob_fn=model.joint_distribution(series).log_prob, surrogate_posterior=surrogate_posterior, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), num_steps=1000, convergence_criterion=(successive_gradients_are_uncorrelated .SuccessiveGradientsAreUncorrelated( diff --git a/tensorflow_probability/python/sts/fitting.py b/tensorflow_probability/python/sts/fitting.py index a5cf0f4ee1..38eec124e3 100644 --- a/tensorflow_probability/python/sts/fitting.py +++ b/tensorflow_probability/python/sts/fitting.py @@ -132,7 +132,7 @@ def build_factored_surrogate_posterior( loss_curve = tfp.vi.fit_surrogate_posterior( target_log_prob_fn=model.joint_distribution(observed_time_series).log_prob, surrogate_posterior=surrogate_posterior, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), num_steps=200) posterior_samples = surrogate_posterior.sample(50) @@ -152,7 +152,7 @@ def loss_fn(): surrogate_posterior, sample_size=10) - optimizer = tf.optimizers.Adam(learning_rate=0.1) + optimizer = tf_keras.optimizers.Adam(learning_rate=0.1) for step in range(200): with tf.GradientTape() as tape: loss = loss_fn() diff --git a/tensorflow_probability/python/sts/forecast.py b/tensorflow_probability/python/sts/forecast.py index 32c2322571..3950b559af 100644 --- a/tensorflow_probability/python/sts/forecast.py +++ b/tensorflow_probability/python/sts/forecast.py @@ -120,7 +120,7 @@ def one_step_predictive(model, observed_time_series, parameter_samples, loss_curve = tfp.vi.fit_surrogate_posterior( target_log_prob_fn=model.joint_distribution(observed_time_series).log_prob, surrogate_posterior=surrogate_posterior, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), num_steps=200) samples = surrogate_posterior.sample(30) @@ -272,7 +272,7 @@ def forecast(model, loss_curve = tfp.vi.fit_surrogate_posterior( target_log_prob_fn=model.joint_distribution(observed_time_series).log_prob, surrogate_posterior=surrogate_posterior, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), num_steps=200) samples = surrogate_posterior.sample(30) diff --git a/tensorflow_probability/python/sts/holiday_effects.py b/tensorflow_probability/python/sts/holiday_effects.py index e6c23aaa79..571fadbf7c 100644 --- a/tensorflow_probability/python/sts/holiday_effects.py +++ b/tensorflow_probability/python/sts/holiday_effects.py @@ -52,8 +52,8 @@ def get_default_holidays(times, country): columns=['geo', 'holiday', 'date']) holidays = holidays.explode('holiday') # Ensure that only holiday dates covered by times are used. - holidays = holidays[(holidays['date'] >= times.min()) - & (holidays['date'] <= times.max())] + holidays = holidays[(pd.to_datetime(holidays['date']) >= times.min()) + & (pd.to_datetime(holidays['date']) <= times.max())] holidays = holidays.reset_index(drop=True) holidays['date'] = pd.to_datetime(holidays['date']) holidays = holidays.sort_values('date') diff --git a/tensorflow_probability/python/sts/structural_time_series.py b/tensorflow_probability/python/sts/structural_time_series.py index 37475353b1..483ecd4100 100644 --- a/tensorflow_probability/python/sts/structural_time_series.py +++ b/tensorflow_probability/python/sts/structural_time_series.py @@ -346,7 +346,7 @@ def joint_distribution(self, losses = tfp.vi.fit_surrogate_posterior( target_log_prob_fn=jd.unnormalized_log_prob, surrogate_posterior=surrogate_posterior, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), num_steps=200) parameter_samples = surrogate_posterior.sample(50) diff --git a/tensorflow_probability/python/util/BUILD b/tensorflow_probability/python/util/BUILD index 1c9df28512..66e603cf60 100644 --- a/tensorflow_probability/python/util/BUILD +++ b/tensorflow_probability/python/util/BUILD @@ -48,6 +48,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:name_util", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/internal:tf_keras", ], ) diff --git a/tensorflow_probability/python/util/deferred_tensor.py b/tensorflow_probability/python/util/deferred_tensor.py index 7b0f9f52fe..65858e5286 100644 --- a/tensorflow_probability/python/util/deferred_tensor.py +++ b/tensorflow_probability/python/util/deferred_tensor.py @@ -156,7 +156,7 @@ class DeferredTensor(six.with_metaclass( Which we could then fit as: ```python - opt = tf.optimizers.Adam(learning_rate=0.05) + opt = tf_keras.optimizers.Adam(learning_rate=0.05) loss = tf.function(lambda: -trainable_normal.log_prob(0.5), autograph=True) for _ in range(int(1e3)): opt.minimize(loss, trainable_normal.trainable_variables) @@ -477,7 +477,7 @@ class TransformedVariable(DeferredTensor): g = tape.gradient(negloglik, trainable_normal.trainable_variables) # ==> (-0.5, 0.75) - opt = tf.optimizers.Adam(learning_rate=0.05) + opt = tf_keras.optimizers.Adam(learning_rate=0.05) loss = tf.function(lambda: -trainable_normal.log_prob(0.5)) for _ in range(int(1e3)): opt.minimize(loss, trainable_normal.trainable_variables) diff --git a/tensorflow_probability/python/version.py b/tensorflow_probability/python/version.py index a4c934aeec..8386c1c294 100644 --- a/tensorflow_probability/python/version.py +++ b/tensorflow_probability/python/version.py @@ -16,7 +16,7 @@ # We follow Semantic Versioning (https://semver.org/) _MAJOR_VERSION = '0' -_MINOR_VERSION = '20' +_MINOR_VERSION = '23' _PATCH_VERSION = '0' # When building releases, we can update this value on the release branch to diff --git a/tensorflow_probability/python/vi/BUILD b/tensorflow_probability/python/vi/BUILD index 9ffb679efd..80bcc0ea12 100644 --- a/tensorflow_probability/python/vi/BUILD +++ b/tensorflow_probability/python/vi/BUILD @@ -55,6 +55,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:nest_util", "//tensorflow_probability/python/internal:reparameterization", + "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/monte_carlo", "//tensorflow_probability/python/stats:leave_one_out", ], @@ -148,6 +149,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/experimental/util", "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/internal:tf_keras", "//tensorflow_probability/python/math/psd_kernels:exponentiated_quadratic", "//tensorflow_probability/python/util:deferred_tensor", ], diff --git a/tensorflow_probability/python/vi/csiszar_divergence.py b/tensorflow_probability/python/vi/csiszar_divergence.py index 0904bce01c..e8aec7504c 100644 --- a/tensorflow_probability/python/vi/csiszar_divergence.py +++ b/tensorflow_probability/python/vi/csiszar_divergence.py @@ -15,6 +15,7 @@ """Csiszar f-Divergence and helpers.""" import enum +import functools import warnings # Dependency imports @@ -26,6 +27,7 @@ from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import nest_util from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal.reparameterization import FULLY_REPARAMETERIZED from tensorflow_probability.python.stats.leave_one_out import log_soomean_exp @@ -55,6 +57,16 @@ ] +def _call_fn_maybe_with_seed(fn, args, *, seed=None): + try: + return nest_util.call_fn(functools.partial(fn, seed=seed), args) + except (TypeError, ValueError) as e: + if ("'seed'" in str(e) or ('one of *args or **kwargs' in str(e))): + return nest_util.call_fn(fn, args) + else: + raise e + + class GradientEstimators(enum.Enum): """Gradient estimators for variational losses. @@ -1045,6 +1057,7 @@ def monte_carlo_variational_loss( raise TypeError('`target_log_prob_fn` must be a Python `callable`' 'function.') + sample_seed, target_seed = samplers.split_seed(seed, 2) reparameterization_types = tf.nest.flatten( surrogate_posterior.reparameterization_type) if gradient_estimator is None: @@ -1067,7 +1080,7 @@ def monte_carlo_variational_loss( 'losses with `importance_sample_size != 1`.') # Score fn objective requires explicit gradients of `log_prob`. q_samples = surrogate_posterior.sample( - [sample_size * importance_sample_size], seed=seed) + [sample_size * importance_sample_size], seed=sample_seed) q_lp = None else: if any(reparameterization_type != FULLY_REPARAMETERIZED @@ -1080,7 +1093,7 @@ def monte_carlo_variational_loss( # Attempt to avoid bijector inverses by computing the surrogate log prob # during the forward sampling pass. q_samples, q_lp = surrogate_posterior.experimental_sample_and_log_prob( - [sample_size * importance_sample_size], seed=seed) + [sample_size * importance_sample_size], seed=sample_seed) return monte_carlo.expectation( f=_make_importance_weighted_divergence_fn( @@ -1090,8 +1103,8 @@ def monte_carlo_variational_loss( precomputed_surrogate_log_prob=q_lp, importance_sample_size=importance_sample_size, gradient_estimator=gradient_estimator, - stopped_surrogate_posterior=( - stopped_surrogate_posterior)), + stopped_surrogate_posterior=stopped_surrogate_posterior, + seed=target_seed), samples=q_samples, # Log-prob is only used if `gradient_estimator == SCORE_FUNCTION`. log_prob=surrogate_posterior.log_prob, @@ -1106,18 +1119,19 @@ def _make_importance_weighted_divergence_fn( precomputed_surrogate_log_prob=None, importance_sample_size=1, gradient_estimator=GradientEstimators.REPARAMETERIZATION, - stopped_surrogate_posterior=None): + stopped_surrogate_posterior=None, + seed=None): """Defines a function to compute an importance-weighted divergence.""" def divergence_fn(q_samples): q_lp = precomputed_surrogate_log_prob - target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples) + target_log_prob = _call_fn_maybe_with_seed( + target_log_prob_fn, q_samples, seed=seed) if gradient_estimator == GradientEstimators.DOUBLY_REPARAMETERIZED: # Sticking-the-landing is the special case of doubly-reparameterized # gradients with `importance_sample_size=1`. q_lp = stopped_surrogate_posterior.log_prob(q_samples) - log_weights = target_log_prob - q_lp else: if q_lp is None: q_lp = surrogate_posterior.log_prob(q_samples) @@ -1128,7 +1142,8 @@ def importance_weighted_divergence_fn(q_samples): q_lp = precomputed_surrogate_log_prob if q_lp is None: q_lp = surrogate_posterior.log_prob(q_samples) - target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples) + target_log_prob = _call_fn_maybe_with_seed( + target_log_prob_fn, q_samples, seed=seed) log_weights = target_log_prob - q_lp # Explicitly break out `importance_sample_size` as a separate axis. @@ -1243,10 +1258,12 @@ def csiszar_vimco(f, raise ValueError('Must specify num_draws > 1.') stop = tf.stop_gradient # For readability. - q_sample = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) + sample_seed, target_seed = samplers.split_seed(seed, 2) + q_sample = q.sample(sample_shape=[num_draws, num_batch_draws], + seed=sample_seed) x = tf.nest.map_structure(stop, q_sample) logqx = q.log_prob(x) - logu = nest_util.call_fn(p_log_prob, x) - logqx + logu = _call_fn_maybe_with_seed(p_log_prob, x, seed=target_seed) - logqx f_log_sooavg_u, f_log_avg_u = map(f, log_soomean_exp(logu, axis=0)) dotprod = tf.reduce_sum( diff --git a/tensorflow_probability/python/vi/csiszar_divergence_test.py b/tensorflow_probability/python/vi/csiszar_divergence_test.py index 34d3656812..5862e6c4b5 100644 --- a/tensorflow_probability/python/vi/csiszar_divergence_test.py +++ b/tensorflow_probability/python/vi/csiszar_divergence_test.py @@ -907,7 +907,10 @@ def target_log_prob_fn(x): # Manually estimate the expected multi-sample / IWAE loss. zs, q_lp = surrogate_posterior.experimental_sample_and_log_prob( - [sample_size, importance_sample_size], seed=seed) + [sample_size, importance_sample_size], + # Brittle hack to ensure that the q samples match those + # drawn in `monte_carlo_variational_loss`. + seed=samplers.split_seed(seed, 2)[0]) log_weights = target_log_prob_fn(zs) - q_lp iwae_loss = -tf.reduce_mean( tf.math.reduce_logsumexp(log_weights, axis=1) - tf.math.log( @@ -988,7 +991,10 @@ def vimco_loss(s): def logu(s): q = build_q(s) - x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) + x = q.sample(sample_shape=[num_draws, num_batch_draws], + # Brittle hack to ensure that the q samples match those + # drawn in `monte_carlo_variational_loss`. + seed=samplers.split_seed(seed, 2)[0]) x = tf.stop_gradient(x) return p.log_prob(x) - q.log_prob(x) @@ -997,7 +1003,10 @@ def f_log_sum_u(s): def q_log_prob_x(s): q = build_q(s) - x = q.sample(sample_shape=[num_draws, num_batch_draws], seed=seed) + x = q.sample(sample_shape=[num_draws, num_batch_draws], + # Brittle hack to ensure that the q samples match those + # drawn in `monte_carlo_variational_loss`. + seed=samplers.split_seed(seed, 2)[0]) x = tf.stop_gradient(x) return q.log_prob(x) diff --git a/tensorflow_probability/python/vi/optimization.py b/tensorflow_probability/python/vi/optimization.py index c06f31cb98..983fa8a8aa 100644 --- a/tensorflow_probability/python/vi/optimization.py +++ b/tensorflow_probability/python/vi/optimization.py @@ -442,8 +442,8 @@ def fit_surrogate_posterior(target_log_prob_fn, transformations of unconstrained variables, so that the transformations execute at runtime instead of at distribution creation. optimizer: Optimizer instance to use. This may be a TF1-style - `tf.train.Optimizer`, TF2-style `tf.optimizers.Optimizer`, or any Python - object that implements `optimizer.apply_gradients(grads_and_vars)`. + `tf.train.Optimizer`, TF2-style `tf_keras.optimizers.Optimizer`, or any + Python object that implements `optimizer.apply_gradients(grads_and_vars)`. num_steps: Python `int` number of steps to run the optimizer. convergence_criterion: Optional instance of `tfp.optimizer.convergence_criteria.ConvergenceCriterion` @@ -522,7 +522,7 @@ def log_prob(z, x): losses = tfp.vi.fit_surrogate_posterior( conditioned_log_prob, surrogate_posterior=q_z, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), num_steps=100) print(q_z.mean(), q_z.stddev()) # => approximately [2.5, 1/sqrt(2)] ``` @@ -535,7 +535,7 @@ def log_prob(z, x): losses = tfp.vi.fit_surrogate_posterior( conditioned_log_prob, surrogate_posterior=q_z, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), num_steps=100, discrepancy_fn=tfp.vi.kl_forward) ``` @@ -589,7 +589,7 @@ def log_prob(z, x): conditioned_log_prob, surrogate_posterior=q_z, importance_sample_size=10, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), num_steps=200) # Estimate posterior statistics with importance sampling. @@ -680,7 +680,7 @@ def variational_model_fn(): losses, log_amplitude_path, sample_path = tfp.vi.fit_surrogate_posterior( target_log_prob_fn=lambda *args: model.log_prob(args), surrogate_posterior=q, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), sample_size=1, num_steps=500, trace_fn=lambda loss, grads, vars: (loss, kernel_log_amplitude, diff --git a/tensorflow_probability/python/vi/optimization_test.py b/tensorflow_probability/python/vi/optimization_test.py index 65e75d1fe2..5193f5b898 100644 --- a/tensorflow_probability/python/vi/optimization_test.py +++ b/tensorflow_probability/python/vi/optimization_test.py @@ -33,6 +33,7 @@ from tensorflow_probability.python.experimental.util import trainable from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.internal import tf_keras from tensorflow_probability.python.math.psd_kernels import exponentiated_quadratic from tensorflow_probability.python.util import deferred_tensor from tensorflow_probability.python.vi import optimization @@ -79,7 +80,7 @@ def trainable_log_prob(z): q, num_steps=1000, sample_size=10, - optimizer=tf.optimizers.Adam(0.1), + optimizer=tf_keras.optimizers.Adam(0.1), seed=seed) self.evaluate(tf1.global_variables_initializer()) with tf.control_dependencies([loss_curve]): @@ -112,7 +113,7 @@ def log_prob(z, x): conditioned_log_prob, surrogate_posterior=q_z, importance_sample_size=10, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), num_steps=100, seed=opt_seed) self.evaluate(tf1.global_variables_initializer()) @@ -140,7 +141,7 @@ def log_prob(z, x): conditioned_log_prob, surrogate_posterior=q_z_again, importance_sample_size=10, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), num_steps=100, seed=opt_seed) self.evaluate(tf1.global_variables_initializer()) @@ -172,7 +173,7 @@ def trainable_q_fn(): q, num_steps=1000, sample_size=100, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), seed=seed) self.evaluate(tf1.global_variables_initializer()) loss_curve_ = self.evaluate((loss_curve)) @@ -230,7 +231,7 @@ def variational_model_fn(): losses, sample_path = optimization.fit_surrogate_posterior( target_log_prob_fn=lambda *args: model.log_prob(args), surrogate_posterior=q, - optimizer=tf.optimizers.Adam(learning_rate=0.1), + optimizer=tf_keras.optimizers.Adam(learning_rate=0.1), num_steps=100, seed=test_util.test_seed(), sample_size=1, @@ -351,9 +352,14 @@ def variational_model_fn(): return import optax # pylint: disable=g-import-not-at-top + def seeded_target_log_prob_fn(*xs, seed=None): + # Add a tiny amount of noise to the target log-prob to see if it works. + ret = pinned.unnormalized_log_prob(xs) + return ret + samplers.normal(ret.shape, stddev=0.01, seed=seed) + [optimized_parameters, (losses, _, sample_path)] = optimization.fit_surrogate_posterior_stateless( - target_log_prob_fn=pinned.unnormalized_log_prob, + target_log_prob_fn=seeded_target_log_prob_fn, build_surrogate_posterior_fn=build_surrogate_posterior_fn, initial_parameters=initial_parameters, optimizer=optax.adam(learning_rate=0.1), diff --git a/tensorflow_probability/substrates/meta/rewrite.py b/tensorflow_probability/substrates/meta/rewrite.py index f242583e33..3f05c71017 100644 --- a/tensorflow_probability/substrates/meta/rewrite.py +++ b/tensorflow_probability/substrates/meta/rewrite.py @@ -67,9 +67,10 @@ 'from tensorflow_probability.python.internal.backend.numpy.private', 'from tensorflow.python.ops.linalg': 'from tensorflow_probability.python.internal.backend.numpy.gen', - 'from tensorflow.python.ops import parallel_for': + ('from tensorflow.python.ops.parallel_for ' + 'import control_flow_ops'): 'from tensorflow_probability.python.internal.backend.numpy ' - 'import functional_ops as parallel_for', + 'import functional_ops as control_flow_ops', 'from tensorflow.python.ops import control_flow_case': 'from tensorflow_probability.python.internal.backend.numpy ' 'import control_flow as control_flow_case', @@ -85,7 +86,10 @@ 'pass', ('from tensorflow.python ' 'import pywrap_tensorflow as c_api'): - 'pass' + 'pass', + 'from tensorflow_probability.python.internal import tf_keras': + ('from tensorflow_probability.python.internal.backend.numpy ' + 'import keras as tf_keras'), } DISABLED_BY_PKG = { diff --git a/testing/dependency_install_lib.sh b/testing/dependency_install_lib.sh index 9db1691815..801d7a3361 100644 --- a/testing/dependency_install_lib.sh +++ b/testing/dependency_install_lib.sh @@ -69,7 +69,9 @@ install_tensorflow() { PIP_FLAGS=${2-} # NB: tf-nightly pulls in other deps, like numpy, absl, and six, transitively. TF_VERSION_STR=$(find_good_tf_nightly_version_str $TF_NIGHTLY_PACKAGE) - python -m pip install $PIP_FLAGS $TF_NIGHTLY_PACKAGE==$TF_VERSION_STR + python -m pip install $PIP_FLAGS \ + $TF_NIGHTLY_PACKAGE==$TF_VERSION_STR \ + tf-keras-nightly } install_jax() {