diff --git a/tensorflow_probability/python/experimental/autobnn/BUILD b/tensorflow_probability/python/experimental/autobnn/BUILD new file mode 100644 index 0000000000..1a7f1a5381 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/BUILD @@ -0,0 +1,227 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Code for AutoBNN. See README.md for more information. + +# Placeholder: py_library +# Placeholder: py_test + +licenses(["notice"]) + +package( + # default_applicable_licenses + default_visibility = ["//visibility:public"], +) + +py_library( + name = "bnn", + srcs = ["bnn.py"], + deps = [ + ":likelihoods", + # flax:core dep, + # jax dep, + # jaxtyping dep, + "//tensorflow_probability/python/distributions:distribution.jax", + ], +) + +py_test( + name = "bnn_test", + srcs = ["bnn_test.py"], + deps = [ + ":bnn", + # absl/testing:absltest dep, + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + "//tensorflow_probability:jax", + "//tensorflow_probability/python/distributions:lognormal.jax", + "//tensorflow_probability/python/distributions:normal.jax", + ], +) + +py_library( + name = "kernels", + srcs = ["kernels.py"], + deps = [ + ":bnn", + # flax dep, + # flax:core dep, + # jax dep, + "//tensorflow_probability/python/distributions:lognormal.jax", + "//tensorflow_probability/python/distributions:normal.jax", + "//tensorflow_probability/python/distributions:student_t.jax", + "//tensorflow_probability/python/distributions:uniform.jax", + ], +) + +py_test( + name = "kernels_test", + srcs = ["kernels_test.py"], + deps = [ + ":kernels", + ":util", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + "//tensorflow_probability/python/distributions:lognormal.jax", + ], +) + +py_library( + name = "likelihoods", + srcs = ["likelihoods.py"], + deps = [ + # flax:core dep, + # jax dep, + # jaxtyping dep, + "//tensorflow_probability:jax", + "//tensorflow_probability/python/bijectors:softplus.jax", + "//tensorflow_probability/python/distributions:distribution.jax", + "//tensorflow_probability/python/distributions:inflated.jax", + "//tensorflow_probability/python/distributions:logistic.jax", + "//tensorflow_probability/python/distributions:lognormal.jax", + "//tensorflow_probability/python/distributions:negative_binomial.jax", + "//tensorflow_probability/python/distributions:normal.jax", + "//tensorflow_probability/python/distributions:transformed_distribution.jax", + ], +) + +py_test( + name = "likelihoods_test", + srcs = ["likelihoods_test.py"], + deps = [ + ":likelihoods", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # jax dep, + ], +) + +py_library( + name = "models", + srcs = ["models.py"], + deps = [ + ":bnn", + ":bnn_tree", + ":kernels", + ":likelihoods", + ":operators", + # jax dep, + ], +) + +py_test( + name = "models_test", + srcs = ["models_test.py"], + shard_count = 3, + deps = [ + ":likelihoods", + ":models", + ":operators", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # jax dep, + ], +) + +py_library( + name = "operators", + srcs = ["operators.py"], + deps = [ + ":bnn", + ":likelihoods", + # flax:core dep, + # jax dep, + "//tensorflow_probability:jax", + "//tensorflow_probability/python/bijectors:chain.jax", + "//tensorflow_probability/python/bijectors:scale.jax", + "//tensorflow_probability/python/bijectors:shift.jax", + "//tensorflow_probability/python/distributions:beta.jax", + "//tensorflow_probability/python/distributions:dirichlet.jax", + "//tensorflow_probability/python/distributions:half_normal.jax", + "//tensorflow_probability/python/distributions:normal.jax", + "//tensorflow_probability/python/distributions:transformed_distribution.jax", + ], +) + +py_test( + name = "operators_test", + srcs = ["operators_test.py"], + deps = [ + ":kernels", + ":operators", + ":util", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + # numpy dep, + "//tensorflow_probability/python/distributions:distribution.jax", + ], +) + +py_library( + name = "bnn_tree", + srcs = ["bnn_tree.py"], + deps = [ + ":bnn", + ":kernels", + ":operators", + ":util", + # flax:core dep, + # jax dep, + ], +) + +py_test( + name = "bnn_tree_test", + timeout = "long", + srcs = ["bnn_tree_test.py"], + shard_count = 3, + deps = [ + ":bnn_tree", + ":kernels", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # flax dep, + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + ], +) + +py_library( + name = "util", + srcs = ["util.py"], + deps = [ + ":bnn", + # jax dep, + # numpy dep, + # scipy dep, + "//tensorflow_probability/python/distributions:distribution.jax", + ], +) + +py_test( + name = "util_test", + srcs = ["util_test.py"], + deps = [ + ":kernels", + ":util", + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + # numpy dep, + "//tensorflow_probability/python/internal:test_util", + ], +) diff --git a/tensorflow_probability/python/experimental/autobnn/README.md b/tensorflow_probability/python/experimental/autobnn/README.md new file mode 100644 index 0000000000..c10a446ada --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/README.md @@ -0,0 +1,25 @@ +# AutoBNN + +This library contains code to specify BNNs that correspond to various useful GP +kernels and assemble them into models using operators such as Addition, +Multiplication and Changepoint. + +It is based on the ideas in the following papers: + +* Lassi Meronen, Martin Trapp, Arno Solin. _Periodic Activation Functions +Induce Stationarity_. NeurIPS 2021. + +* Tim Pearce, Russell Tsuchida, Mohamed Zaki, Alexandra Brintrup, Andy Neely. +_Expressive Priors in Bayesian Neural Networks: Kernel Combinations and +Periodic Functions_. UAI 2019. + +* Feras A. Saad, Brian J. Patton, Matthew D. Hoffman, Rif A. Saurous, +Vikash K. Mansinghka. _Sequential Monte Carlo Learning for Time Series +Structure Discovery_. ICML 2023. + + +## Setup + +AutoBNN has three additional dependencies beyond those used by the core +Tensorflow Probability package: flax, scipy and jaxtyping. These +can be installed by running `setup\_autobnn.sh`. diff --git a/tensorflow_probability/python/experimental/autobnn/bnn.py b/tensorflow_probability/python/experimental/autobnn/bnn.py new file mode 100644 index 0000000000..cb33e373c1 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/bnn.py @@ -0,0 +1,170 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base class for Bayesian Neural Networks.""" + +import dataclasses + +import flax +from flax import linen as nn +import jax.numpy as jnp +from jaxtyping import Array, Float, PyTree # pylint: disable=g-importing-member,g-multiple-import +from tensorflow_probability.python.experimental.autobnn import likelihoods +from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib + + +def log_prior_of_parameters(params, distributions) -> Float: + """Return the prior of the parameters according to the distributions.""" + if 'params' in params: + params = params['params'] + # We can't use jax.tree_util.tree_map here because params is allowed to + # have extra things (like bnn_0, ... for a BnnOperator) that aren't in + # distributions. + lp = 0.0 + for k, v in distributions.items(): + p = params[k] + if isinstance(v, distribution_lib.Distribution): + lp += jnp.sum(v.log_prob(p)) + else: + lp += log_prior_of_parameters(p, v) + return lp + + +class BayesianModule(nn.Module): + """A linen.Module with distributions over its parameters. + + Example usage: + class MyModule(BayesianModule): + + def distributions(self): + return {'dense': {'kernel': tfd.Normal(loc=0, scale=1), + 'bias': tfd.Normal(loc=0, scale=1)}, + 'amplitude': tfd.LogNormal(loc=0, scale=1)} + + def setup(self): + self.dense = nn.Dense(50) + super().setup() # <-- Very important, do not forget! + + def __call__(self, inputs): + return self.amplitude * self.dense(inputs) + + + my_bnn = MyModule() + params = my_bnn.init(jax.random.PRNGKey(0), jnp.zeros(10)) + lp = my_bnn.log_prior(params) + + Note that in this example, self.amplitude will be initialized using + the given tfd.LogNormal distribution, but the self.dense's parameters + will be initialized using the nn.Dense's default initializers. However, + the log_prior score will take into account all of the parameters. + """ + + def distributions(self): + """Return a nested dictionary of distributions for the model's params. + + The nested dictionary should have the same structure as the + variables returned by the init() method, except all leaves should + be tensorflow probability Distributions. + """ + # TODO(thomaswc): Consider having this optionally also be able to + # return a tfd.JointNamedDistribution, so as to support dependencies + # between the subdistributions. + raise NotImplementedError('Subclasses of BNN must define this.') + + def setup(self): + """Children classes must call this from their setup() !""" + + def make_sample_func(dist): + def sample_func(key, shape): + return dist.sample(sample_shape=shape, seed=key) + + return sample_func + + for k, v in self.distributions().items(): + # Create a variable for every distribution that doesn't already + # have one. If you define a variable in your setup, we assume + # you initialize it correctly. + if not hasattr(self, k): + try: + setattr(self, k, self.param(k, make_sample_func(v), 1)) + except flax.errors.NameInUseError: + # Sometimes subclasses will have parameters where the + # parameter name doesn't exactly correspond to the name of + # the object field. This can happen with arrays of parameters + # (like PolynomialBBN's hidden parameters.) for example. I + # don't know of any way to detect this beforehand except by + # trying to call self.params and having it fail with NameInUseError. + # (For example, self.variables doesn't exist at setup() time.) + pass + + def log_prior(self, params) -> float: + """Return the log probability of the params according to the prior.""" + return log_prior_of_parameters(params, self.distributions()) + + def shortname(self) -> str: + """Return the class name, minus any BNN suffix.""" + return type(self).__name__.removesuffix('BNN') + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + return self.shortname() + + +class BNN(BayesianModule): + """A Bayesian Neural Network. + + A BNN's __call__ method must accept a tensor of shape (..., num_features) + and return a tensor of shape (..., likelihood_model.num_outputs()). + Given that, it provides log_likelihood and log_prob methods based + on the provided likelihood_model. + """ + + likelihood_model: likelihoods.LikelihoodModel = dataclasses.field( + default_factory=likelihoods.NormalLikelihoodLogisticNoise + ) + + def distributions(self): + # Children classes must call super().distributions() to include this! + return self.likelihood_model.distributions() + + def set_likelihood_model(self, likelihood_model: likelihoods.LikelihoodModel): + self.likelihood_model = likelihood_model + + def log_likelihood( + self, + params: PyTree, + data: Float[Array, 'time features'], + observations: Float[Array, 'time'], + ) -> Float[Array, '']: + """Return the likelihood of the data given the model.""" + nn_out = self.apply(params, data) + if 'params' in params: + params = params['params'] + # Sum over all axes here - user should use `vmap` for batching. + return jnp.sum( + self.likelihood_model.log_likelihood(params, nn_out, observations) + ) + + def log_prob( + self, + params: PyTree, + data: Float[Array, 'time features'], + observations: Float[Array, 'time'], + ) -> Float[Array, '']: + return self.log_prior(params) + self.log_likelihood( + params, data, observations + ) + + def get_all_distributions(self): + return self.distributions() diff --git a/tensorflow_probability/python/experimental/autobnn/bnn_test.py b/tensorflow_probability/python/experimental/autobnn/bnn_test.py new file mode 100644 index 0000000000..53bb7b09e4 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/bnn_test.py @@ -0,0 +1,68 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for bnn.py.""" + +from flax import linen as nn +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib +from tensorflow_probability.substrates.jax.distributions import normal as normal_lib +from absl.testing import absltest + + +class MyBNN(bnn.BNN): + + def distributions(self): + return super().distributions() | { + 'dense': { + 'kernel': normal_lib.Normal(loc=0, scale=1), + 'bias': normal_lib.Normal(loc=0, scale=1), + }, + 'amplitude': lognormal_lib.LogNormal(loc=0, scale=1), + } + + def setup(self): + self.dense = nn.Dense(50) + super().setup() + + def __call__(self, inputs): + return self.amplitude * jnp.sum(self.dense(inputs)) + + +class BnnTests(absltest.TestCase): + + def test_mybnn(self): + my_bnn = MyBNN() + d = my_bnn.distributions() + self.assertIn('noise_scale', d) + sample_noise = d['noise_scale'].sample(1, seed=jax.random.PRNGKey(0)) + self.assertEqual((1,), sample_noise.shape) + + params = my_bnn.init(jax.random.PRNGKey(0), jnp.zeros(1)) + lp1 = my_bnn.log_prior(params) + params['params']['amplitude'] += 50 + lp2 = my_bnn.log_prior(params) + self.assertLess(lp2, lp1) + + data = jnp.array([[0], [1], [2], [3], [4], [5]], dtype=jnp.float32) + obs = jnp.array([1, 0, 1, 0, 1, 0], dtype=jnp.float32) + ll = my_bnn.log_likelihood(params, data, obs) + lp = my_bnn.log_prob(params, data, obs) + self.assertLess(jnp.sum(lp), jnp.sum(ll)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_probability/python/experimental/autobnn/bnn_tree.py b/tensorflow_probability/python/experimental/autobnn/bnn_tree.py new file mode 100644 index 0000000000..4a02f23252 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/bnn_tree.py @@ -0,0 +1,171 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Routines for making tree-structured BNNs.""" + +from typing import Iterable, List + +from flax import linen as nn +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import operators +from tensorflow_probability.python.experimental.autobnn import util + +Array = jnp.ndarray + + +LEAVES = [ + kernels.ExponentiatedQuadraticBNN, + kernels.MaternBNN, + kernels.LinearBNN, + kernels.QuadraticBNN, + kernels.PeriodicBNN, + kernels.OneLayerBNN, +] + + +OPERATORS = [ + operators.Multiply, + operators.Add, + operators.WeightedSum, + operators.ChangePoint, + operators.LearnableChangePoint +] + + +NON_PERIODIC_KERNELS = [ + kernels.ExponentiatedQuadraticBNN, + kernels.MaternBNN, + kernels.LinearBNN, + kernels.QuadraticBNN, + kernels.OneLayerBNN, +] + + +def list_of_all( + time_series_xs: Array, + depth: int = 2, + width: int = 50, + periods: Iterable[float] = (), + parent_is_multiply: bool = False, + include_sums: bool = True, + include_changepoints: bool = True, + only_safe_products: bool = False +) -> List[bnn.BNN]: + """Return a list of all BNNs of the given depth.""" + all_bnns = [] + if depth == 0: + all_bnns.extend(k(width=width, going_to_be_multiplied=parent_is_multiply) + for k in NON_PERIODIC_KERNELS) + for p in periods: + all_bnns.append(kernels.PeriodicBNN( + width=width, period=p, going_to_be_multiplied=parent_is_multiply)) + return all_bnns + + multiply_children = list_of_all( + time_series_xs, depth-1, width, periods, True) + if parent_is_multiply: + non_multiply_children = multiply_children + else: + non_multiply_children = list_of_all( + time_series_xs, depth-1, width, periods, False) + + # Abelian operators that aren't Multiply. + if include_sums: + for i, c1 in enumerate(non_multiply_children): + for j in range(i + 1): + c2 = non_multiply_children[j] + # Add is also abelian, but WeightedSum is more general. + all_bnns.append( + operators.WeightedSum( + bnns=(c1.clone(_deep_clone=True), c2.clone(_deep_clone=True)) + ) + ) + + if parent_is_multiply: + # Remaining operators don't expose .penultimate() method. + return all_bnns + + # Multiply + for i, c1 in enumerate(multiply_children): + if only_safe_products: + # The only safe kernels to multiply by are Linear and Quadratic. + if not isinstance(c1, kernels.PolynomialBNN): + continue + for j in range(i+1): + c2 = multiply_children[j] + all_bnns.append(operators.Multiply(bnns=( + c1.clone(_deep_clone=True), c2.clone(_deep_clone=True)))) + + # Non-abelian operators + if include_changepoints: + for c1 in non_multiply_children: + for c2 in non_multiply_children: + # ChangePoint is also non-abelian, but requires that we know + # what the change point is. + all_bnns.append(operators.LearnableChangePoint( + bnns=(c1.clone(_deep_clone=True), c2.clone(_deep_clone=True)), + time_series_xs=time_series_xs)) + + return all_bnns + + +def weighted_sum_of_all(time_series_xs: Array, + time_series_ys: Array, + depth: int = 2, width: int = 50, + alpha: float = 1.0) -> bnn.BNN: + """Return a weighted sum of all BNNs of the given depth.""" + periods = util.suggest_periods(time_series_ys) + + all_bnns = list_of_all(time_series_xs, depth, width, periods, False) + + return operators.WeightedSum(bnns=tuple(all_bnns), alpha=alpha) + + +def random_tree(key: jax.Array, depth: int, width: int, period: float, + parent_is_multiply: bool = False) -> nn.Module: + """Return a random complete tree BNN of the given depth. + + Args: + key: Random number key. + depth: Return a BNN of this tree depth. Zero based, so depth=0 returns + a leaf BNN. + width: The number of hidden nodes in the leaf layers. + period: The period of any PeriodicBNN kernels in the tree. + parent_is_multiply: If true, don't create a weight layer after the hidden + nodes of any leaf kernels and only use addition as an internal node. + + Returns: + A BNN of the specified tree depth. + """ + if depth == 0: + c = jax.random.choice(key, len(LEAVES)) + return LEAVES[c]( + width=width, going_to_be_multiplied=parent_is_multiply, + period=period) + + key1, key2, key3 = jax.random.split(key, 3) + if parent_is_multiply: + c = 1 # Can't multiply Multiply or ChangePoints + is_multiply = True + else: + c = jax.random.choice(key1, len(OPERATORS)) + is_multiply = (c == 0) + + sub1 = random_tree(key2, depth - 1, width, period, is_multiply) + sub2 = random_tree(key3, depth - 1, width, period, is_multiply) + + return OPERATORS[c](bnns=(sub1, sub2)) diff --git a/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py b/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py new file mode 100644 index 0000000000..10b38b24c2 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py @@ -0,0 +1,117 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for bnn_tree.py.""" + +from absl.testing import parameterized +from flax import linen as nn +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn_tree +from tensorflow_probability.python.experimental.autobnn import kernels +from absl.testing import absltest + + +class TreeTest(parameterized.TestCase): + + def test_list_of_all(self): + l0 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 0) + # With no periods, there should be five kernels. + self.assertLen(l0, 5) + for k in l0: + self.assertFalse(k.going_to_be_multiplied) + + l0 = bnn_tree.list_of_all(100, 0, 50, [20.0, 40.0], parent_is_multiply=True) + self.assertLen(l0, 7) + for k in l0: + self.assertTrue(k.going_to_be_multiplied) + + l1 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 1) + # With no periods, there should be + # 15 trees with a Multiply top node, + # 15 trees with a WeightedSum top node, and + # 25 trees with a LearnableChangePoint top node. + self.assertLen(l1, 55) + + # Check that all of the BNNs in the tree can be trained. + for k in l1: + params = k.init(jax.random.PRNGKey(0), jnp.zeros(5)) + lp = k.log_prior(params) + self.assertLess(lp, 0.0) + output = k.apply(params, jnp.ones(5)) + self.assertEqual((1,), output.shape) + + l1 = bnn_tree.list_of_all( + jnp.linspace(0.0, 100.0, 100), + 1, + 50, + [20.0, 40.0], + parent_is_multiply=True, + ) + # With 2 periods and parent_is_multiply, there are only WeightedSum top + # nodes, with 7*8/2 = 28 trees. + self.assertLen(l1, 28) + + l2 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 2) + # With no periods, there should be + # 15*16/2 = 120 trees with a Multiply top node, + # 55*56/2 = 1540 trees with a WeightedSum top node, and + # 55*55 = 3025 trees with a LearnableChangePoint top node. + self.assertLen(l2, 4685) + + @parameterized.parameters(0, 1) # depth=2 segfaults on my desktop :( + def test_weighted_sum_of_all(self, depth): + soa = bnn_tree.weighted_sum_of_all( + jnp.linspace(0.0, 1.0, 100), jnp.ones(100), depth=depth + ) + params = soa.init(jax.random.PRNGKey(0), jnp.zeros(5)) + lp = soa.log_prior(params) + self.assertLess(lp, 0.0) + output = soa.apply(params, jnp.ones(5)) + self.assertEqual((1,), output.shape) + + def test_random_tree(self): + r0 = bnn_tree.random_tree( + jax.random.PRNGKey(0), depth=0, width=50, period=7 + ) + self.assertIsInstance(r0, kernels.OneLayerBNN) + params = r0.init(jax.random.PRNGKey(1), jnp.zeros(5)) + lp = r0.log_prior(params) + self.assertLess(lp, 0.0) + output = r0.apply(params, jnp.ones(5)) + self.assertEqual((1,), output.shape) + + r1 = bnn_tree.random_tree( + jax.random.PRNGKey(0), depth=1, width=50, period=24 + ) + self.assertIsInstance(r1, nn.Module) + params = r1.init(jax.random.PRNGKey(1), jnp.zeros(5)) + lp = r1.log_prior(params) + self.assertLess(lp, 0.0) + output = r1.apply(params, jnp.ones(5)) + self.assertEqual((1,), output.shape) + + r2 = bnn_tree.random_tree( + jax.random.PRNGKey(0), depth=2, width=50, period=52 + ) + self.assertIsInstance(r2, nn.Module) + params = r2.init(jax.random.PRNGKey(1), jnp.zeros(5)) + lp = r2.log_prior(params) + self.assertLess(lp, 0.0) + output = r2.apply(params, jnp.ones(5)) + self.assertEqual((1,), output.shape) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_probability/python/experimental/autobnn/kernels.py b/tensorflow_probability/python/experimental/autobnn/kernels.py new file mode 100644 index 0000000000..b02ffb7316 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/kernels.py @@ -0,0 +1,324 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""`Leaf` BNNs, most of which correspond to some known GP kernel.""" + +from flax import linen as nn +from flax.linen import initializers +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib +from tensorflow_probability.substrates.jax.distributions import normal as normal_lib +from tensorflow_probability.substrates.jax.distributions import student_t as student_t_lib +from tensorflow_probability.substrates.jax.distributions import uniform as uniform_lib + + +Array = jnp.ndarray + + +SQRT_TWO = 1.41421356237309504880168872420969807856967187537694807317667 + + +class MultipliableBNN(bnn.BNN): + """Abstract base class for BNN's that can be multiplied.""" + width: int = 50 + going_to_be_multiplied: bool = False + + def penultimate(self, inputs): + raise NotImplementedError('Subclasses of MultipliableBNN must define this.') + + +class IdentityBNN(MultipliableBNN): + """A BNN that always predicts 1.""" + + def penultimate(self, inputs): + return jnp.ones(shape=inputs.shape[:-1] + (self.width,)) + + def __call__(self, inputs, deterministic=True): + out_shape = inputs.shape[:-1] + (self.likelihood_model.num_outputs(),) + return jnp.ones(shape=out_shape) + + +class OneLayerBNN(MultipliableBNN): + """A BNN with one hidden layer.""" + + # Period is currently only used by the PeriodicBNN class, but we declare it + # here so it can be passed to a "generic" OneLayerBNN instance. + period: float = 0.0 + + bias_scale: float = 1.0 + + def setup(self): + if not hasattr(self, 'input_warping'): + self.input_warping = lambda x: x + if not hasattr(self, 'activation_function'): + self.activation_function = nn.relu + if not hasattr(self, 'kernel_init'): + self.kernel_init = initializers.lecun_normal() + if not hasattr(self, 'bias_init'): + self.bias_init = initializers.zeros_init() + self.dense1 = nn.Dense(self.width, + kernel_init=self.kernel_init, + bias_init=self.bias_init) + if not self.going_to_be_multiplied: + self.dense2 = nn.Dense( + self.likelihood_model.num_outputs(), + kernel_init=nn.initializers.normal(1. / jnp.sqrt(self.width)), + bias_init=nn.initializers.zeros) + else: + def fake_dense2(x): + out_shape = x.shape[:-1] + (self.likelihood_model.num_outputs(),) + return jnp.ones(out_shape) + self.dense2 = fake_dense2 + super().setup() + + def distributions(self): + # Strictly speaking, these distributions don't exactly correspond to + # the initializations used in setup(). lecun_normal uses a truncated + # normal, for example, and the zeros_init used for the bias certainly + # isn't a sample from a normal. + d = { + 'dense1': { + 'kernel': normal_lib.Normal( + loc=0, scale=1.0 / jnp.sqrt(self.width) + ), + 'bias': normal_lib.Normal(loc=0, scale=self.bias_scale), + } + } + if not self.going_to_be_multiplied: + d['dense2'] = { + 'kernel': normal_lib.Normal(loc=0, scale=1.0 / jnp.sqrt(self.width)), + 'bias': normal_lib.Normal(loc=0, scale=self.bias_scale), + } + return super().distributions() | d + + def penultimate(self, inputs): + y = self.input_warping(inputs) + return self.activation_function(self.dense1(y)) + + def __call__(self, inputs, deterministic=True): + return self.dense2(self.penultimate(inputs)) + + +class ExponentiatedQuadraticBNN(OneLayerBNN): + """A BNN corresponding to the Radial Basis Function kernel.""" + amplitude_scale: float = 1.0 + length_scale_scale: float = 1.0 + + def setup(self): + if not hasattr(self, 'activation_function'): + self.activation_function = lambda x: SQRT_TWO * jnp.sin(x) + if not hasattr(self, 'input_warping'): + self.input_warping = lambda x: x / self.length_scale + self.kernel_init = nn.initializers.normal(1.0) + def uniform_init(seed, shape, dtype): + return nn.initializers.uniform(scale=2.0 * jnp.pi)( + seed, shape, dtype=dtype) - jnp.pi + self.bias_init = uniform_init + super().setup() + + def distributions(self): + d = super().distributions() + return d | { + 'amplitude': lognormal_lib.LogNormal(loc=0, scale=self.amplitude_scale), + 'length_scale': lognormal_lib.LogNormal( + loc=0, scale=self.length_scale_scale + ), + 'dense1': { + 'kernel': normal_lib.Normal(loc=0, scale=1.0), + 'bias': uniform_lib.Uniform(low=-jnp.pi, high=jnp.pi), + }, + } + + def __call__(self, inputs, deterministic=True): + return self.amplitude * self.dense2(self.penultimate(inputs)) + + def shortname(self) -> str: + sn = super().shortname() + return 'RBF' if sn == 'ExponentiatedQuadratic' else sn + + +class MaternBNN(ExponentiatedQuadraticBNN): + """A BNN corresponding to the Matern kernel.""" + degrees_of_freedom: float = 2.5 + + def setup(self): + def kernel_init(seed, shape, unused_dtype): + return student_t_lib.StudentT( + df=2.0 * self.degrees_of_freedom, loc=0.0, scale=1.0 + ).sample(shape, seed=seed) + self.kernel_init = kernel_init + super().setup() + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + return f'{self.shortname()}({self.degrees_of_freedom})' + + +class PolynomialBNN(OneLayerBNN): + """A BNN where samples are polynomial functions.""" + degree: int = 2 + shift_mean: float = 0.0 + shift_scale: float = 1.0 + amplitude_scale: float = 1.0 + bias_init_amplitude: float = 0.0 + + def distributions(self): + d = super().distributions() + del d['dense1'] + for i in range(self.degree): + # Do not scale these layers by 1/sqrt(width), because we also + # multiply these weights by the learned `amplitude` parameter. + d[f'hiddens_{i}'] = { + 'kernel': normal_lib.Normal(loc=0, scale=1.0), + 'bias': normal_lib.Normal(loc=0, scale=self.bias_scale), + } + return d | { + 'shift': normal_lib.Normal(loc=self.shift_mean, scale=self.shift_scale), + 'amplitude': lognormal_lib.LogNormal(loc=0, scale=self.amplitude_scale), + } + + def setup(self): + kernel_init = nn.initializers.normal(1.0) + def bias_init(seed, shape, dtype=jnp.float32): + return self.bias_init_amplitude * jax.random.normal( + seed, shape, dtype=dtype) + self.hiddens = [ + nn.Dense(self.width, kernel_init=kernel_init, bias_init=bias_init) + for _ in range(self.degree)] + super().setup() + + def penultimate(self, inputs): + x = inputs - self.shift + ys = jnp.stack([h(x) for h in self.hiddens], axis=-1) + return self.amplitude * jnp.prod(ys, axis=-1) + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + return f'{self.shortname()}(degree={self.degree})' + + +class LinearBNN(PolynomialBNN): + """A BNN where samples are lines.""" + degree: int = 1 + + def summarize(self, params=None, full: bool = False) -> str: + return self.shortname() + + +class QuadraticBNN(PolynomialBNN): + """A BNN where samples are parabolas.""" + + degree: int = 2 + + def summarize(self, params=None, full: bool = False) -> str: + return self.shortname() + + +def make_periodic_input_warping(period, periodic_index, include_original): + """Return an input warping function that adds Fourier features. + + Args: + period: The added features will repeat this many time steps. + periodic_index: Look for the time feature in input[..., periodic_index]. + include_original: If true, don't replace the time feature with the + new Fourier features. + + Returns: + A function that takes an input tensor of shape [..., n] and returns a + tensor of shape [..., n+2] if include_original is True and of shape + [..., n+1] if include_original is False. + """ + def input_warping(x): + time = x[..., periodic_index] + y = 2.0 * jnp.pi * time / period + features = [jnp.cos(y), jnp.sin(y)] + if include_original: + features.append(time) + if jnp.ndim(x) == 1: + features = jnp.array(features).T + else: + features = jnp.vstack(features).T + return jnp.concatenate( + [ + x[..., :periodic_index], + features, + x[..., periodic_index + 1:], + ], + -1, + ) + + return input_warping + + +class PeriodicBNN(ExponentiatedQuadraticBNN): + """A BNN corresponding to a periodic kernel.""" + periodic_index: int = 0 + + def setup(self): + # TODO(colcarroll): Figure out how to assert that self.period is positive. + + self.input_warping = make_periodic_input_warping( + self.period, self.periodic_index, include_original=False + ) + super().setup() + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + return f'{self.shortname()}(period={self.period:.2f})' + + +class MultiLayerBNN(OneLayerBNN): + """Multi-layer BNN that also has access to periodic features.""" + num_layers: int = 3 + periodic_index: int = 0 + + def setup(self): + if not hasattr(self, 'kernel_init'): + self.kernel_init = initializers.lecun_normal() + if not hasattr(self, 'bias_init'): + self.bias_init = initializers.zeros_init() + self.input_warping = make_periodic_input_warping( + self.period, self.periodic_index, include_original=True + ) + self.dense = [ + nn.Dense( + self.width, kernel_init=self.kernel_init, bias_init=self.bias_init + ) + for _ in range(self.num_layers) + ] + super().setup() + + def distributions(self): + d = super().distributions() + del d['dense1'] + for i in range(self.num_layers): + d[f'dense_{i}'] = { + 'kernel': normal_lib.Normal(loc=0, scale=1.0 / jnp.sqrt(self.width)), + 'bias': normal_lib.Normal(loc=0, scale=self.bias_scale), + } + return d + + def penultimate(self, inputs): + y = self.input_warping(inputs) + for i in range(self.num_layers): + y = self.activation_function(self.dense[i](y)) + return y + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + return ( + f'{self.shortname()}(num_layers={self.num_layers},period={self.period})' + ) diff --git a/tensorflow_probability/python/experimental/autobnn/kernels_test.py b/tensorflow_probability/python/experimental/autobnn/kernels_test.py new file mode 100644 index 0000000000..67e574d517 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/kernels_test.py @@ -0,0 +1,242 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for kernels.py.""" + +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import util +from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib + +from absl.testing import absltest + + +KERNELS = [ + kernels.IdentityBNN, + kernels.OneLayerBNN, + kernels.ExponentiatedQuadraticBNN, + kernels.MaternBNN, + kernels.PeriodicBNN, + kernels.PolynomialBNN, + kernels.LinearBNN, + kernels.MultiLayerBNN, +] + + +class ReproduceExperimentTest(absltest.TestCase): + + def get_bnn_and_params(self): + x_train, y_train = util.load_fake_dataset() + linear_bnn = kernels.OneLayerBNN(width=50) + seed = jax.random.PRNGKey(0) + init_params = linear_bnn.init(seed, x_train) + constant_params = jax.tree_map( + lambda x: jnp.full(x.shape, 0.1), init_params) + constant_params['params']['noise_scale'] = jnp.array([0.005 ** 0.5]) + return linear_bnn, constant_params, x_train, y_train + + # This now uses a Logistic noise model, not Normal as in Pearce + @absltest.expectedFailure + def test_log_prior_matches(self): + # Pearce has a set `noise_scale` of 0.005 ** 0.5 that we must account for. + linear_bnn, constant_params, _, _ = self.get_bnn_and_params() + diff = lognormal_lib.LogNormal( + linear_bnn.noise_min, linear_bnn.log_noise_scale + ).log_prob(0.005**0.5) + self.assertAlmostEqual( + linear_bnn.log_prior(constant_params) - diff, + 31.59, # Hardcoded from reference implementation. + places=2) + + def test_log_likelihood_matches(self): + linear_bnn, constant_params, x_train, y_train = self.get_bnn_and_params() + self.assertAlmostEqual( + linear_bnn.log_likelihood(constant_params, x_train, y_train), + -7808.4434, + places=2) + + # This now uses a Logistic noise model, not Normal as in Pearce + @absltest.expectedFailure + def test_log_prob_matches(self): + # Pearce has a set `noise_scale` of 0.005 ** 0.5 that we must account for. + linear_bnn, constant_params, x_train, y_train = self.get_bnn_and_params() + diff = lognormal_lib.LogNormal( + linear_bnn.noise_min, linear_bnn.log_noise_scale + ).log_prob(0.005**0.5) + self.assertAlmostEqual( + linear_bnn.log_prob(constant_params, x_train, y_train) - diff, + -14505.76, # Hardcoded from reference implementation. + places=2) + + +class KernelsTest(parameterized.TestCase): + + @parameterized.product( + shape=[(5,), (5, 1), (5, 5)], + kernel=KERNELS, + ) + def test_default_kernels(self, shape, kernel): + if kernel in [kernels.PeriodicBNN, kernels.MultiLayerBNN]: + bnn = kernel(period=0.1, periodic_index=shape[-1]//2) + else: + bnn = kernel() + if isinstance(bnn, kernels.PolynomialBNN): + self.assertIn('shift', bnn.distributions()) + elif isinstance(bnn, kernels.MultiLayerBNN): + self.assertIn('dense_1', bnn.distributions()) + elif isinstance(bnn, kernels.IdentityBNN): + pass + else: + self.assertIn('dense1', bnn.distributions()) + if not isinstance(bnn, kernels.IdentityBNN): + self.assertIn('dense2', bnn.distributions()) + params = bnn.init(jax.random.PRNGKey(0), jnp.zeros(shape)) + lprior = bnn.log_prior(params) + params2 = params + if 'params' in params2: + params2 = params2['params'] + params2['noise_scale'] = params2['noise_scale'] + 100.0 + lprior2 = bnn.log_prior(params2) + self.assertLess(lprior2, lprior) + output = bnn.apply(params, jnp.ones(shape)) + self.assertEqual(shape[:-1] + (1,), output.shape) + + @parameterized.parameters(KERNELS) + def test_likelihood(self, kernel): + if kernel in [kernels.PeriodicBNN, kernels.MultiLayerBNN]: + bnn = kernel(period=0.1) + else: + bnn = kernel() + params = bnn.init(jax.random.PRNGKey(1), jnp.zeros(1)) + data = jnp.array([[0], [1], [2], [3], [4], [5]], dtype=jnp.float32) + obs = jnp.array([1, 0, 1, 0, 1, 0], dtype=jnp.float32) + ll = bnn.log_likelihood(params, data, obs) + lp = bnn.log_prob(params, data, obs) + # We are mostly just testing that ll and lp are both float-ish numbers + # than can be compared. In general, there is no reason to expect that + # lp < ll because there is no reason to expect in general that the + # log_prior will be negative. + if kernel == kernels.MultiLayerBNN: + self.assertLess(ll, lp) + else: + self.assertLess(lp, ll) + + @parameterized.parameters( + (kernels.OneLayerBNN(width=10), 'OneLayer'), + (kernels.ExponentiatedQuadraticBNN(width=5), 'RBF'), + (kernels.MaternBNN(width=5), 'Matern(2.5)'), + (kernels.PeriodicBNN(period=10, width=10), 'Periodic(period=10.00)'), + (kernels.PolynomialBNN(degree=3, width=2), 'Polynomial(degree=3)'), + (kernels.LinearBNN(width=5), 'Linear'), + (kernels.QuadraticBNN(width=5), 'Quadratic'), + ( + kernels.MultiLayerBNN(width=10, num_layers=3, period=20), + 'MultiLayer(num_layers=3,period=20)', + ), + ) + def test_summarize(self, bnn, expected): + self.assertEqual(expected, bnn.summarize()) + + @parameterized.parameters(KERNELS) + def test_penultimate(self, kernel): + if kernel in [kernels.PeriodicBNN, kernels.MultiLayerBNN]: + bnn = kernel(period=0.1, going_to_be_multiplied=True) + else: + bnn = kernel(going_to_be_multiplied=True) + self.assertNotIn('dense2', bnn.distributions()) + params = bnn.init(jax.random.PRNGKey(0), jnp.zeros(5)) + lprior = bnn.log_prior(params) + if kernel != kernels.MultiLayerBNN: + self.assertLess(lprior, 0.0) + h = bnn.apply(params, jnp.ones(5), method=bnn.penultimate) + self.assertEqual((50,), h.shape) + + def test_polynomial_is_almost_a_polynomial(self): + poly_bnn = kernels.PolynomialBNN(degree=3) + init_params = poly_bnn.init(jax.random.PRNGKey(0), jnp.ones((10, 1))) + + # compute power series + func = lambda x: poly_bnn.apply(init_params, x)[0] + params = [func(0.)] + for _ in range(4): + func = jax.grad(func) + params.append(func(0.)) + + # Last 4th degree coefficient should be around 0. + self.assertAlmostEqual(params[-1], 0.) + + # Check that the random initialization is approximately a polynomial by + # evaluating far away from the expansion. + x = 17.0 + self.assertAlmostEqual( + poly_bnn.apply(init_params, x)[0], + params[0] + x * params[1] + x**2 * params[2] / 2 + x**3 * params[3] / 6, + places=3) + + def test_make_periodic_input_warping_onedim(self): + iw = kernels.make_periodic_input_warping(4, 0, True) + np.testing.assert_allclose( + jnp.array([0, 1, 1, 2, 3, 4, 5]), + iw(jnp.array([1, 2, 3, 4, 5])), + atol=1e-6 + ) + iw = kernels.make_periodic_input_warping(4, 0, False) + np.testing.assert_allclose( + jnp.array([0, 1, 2, 3, 4, 5]), + iw(jnp.array([1, 2, 3, 4, 5])), + atol=1e-6 + ) + + def test_make_periodic_input_warping_onedim_features(self): + iw = kernels.make_periodic_input_warping(4, 0, True) + np.testing.assert_allclose( + jnp.array([[1, 0, 0], [0, 1, 1], [-1, 0, 2], [0, -1, 3], [1, 0, 4]]), + iw(jnp.array([[0], [1], [2], [3], [4]])), + atol=1e-6 + ) + iw = kernels.make_periodic_input_warping(4, 0, False) + np.testing.assert_allclose( + jnp.array([[1, 0], [0, 1], [-1, 0], [0, -1], [1, 0]]), + iw(jnp.array([[0], [1], [2], [3], [4]])), + atol=1e-6 + ) + + def test_make_periodic_input_warping_twodim(self): + iw = kernels.make_periodic_input_warping(2, 0, True) + np.testing.assert_allclose( + jnp.array([[1, 0, 0, 0], [-1, 0, 1, 1], [1, 0, 2, 4], [-1, 0, 3, 9], + [1, 0, 4, 16]]), + iw(jnp.array([[0, 0], [1, 1], [2, 4], [3, 9], [4, 16]])), + atol=1e-6 + ) + iw = kernels.make_periodic_input_warping(4, 1, True) + np.testing.assert_allclose( + jnp.array([[0, 1, 0, 0], [1, 0, 1, 1], [2, 1, 0, 4], [3, 0, 1, 9], + [4, 1, 0, 16]]), + iw(jnp.array([[0, 0], [1, 1], [2, 4], [3, 9], [4, 16]])), + atol=1e-6 + ) + iw = kernels.make_periodic_input_warping(2, 0, False) + np.testing.assert_allclose( + jnp.array([[1, 0, 0], [-1, 0, 1], [1, 0, 4], [-1, 0, 9], [1, 0, 16]]), + iw(jnp.array([[0, 0], [1, 1], [2, 4], [3, 9], [4, 16]])), + atol=1e-6 + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_probability/python/experimental/autobnn/likelihoods.py b/tensorflow_probability/python/experimental/autobnn/likelihoods.py new file mode 100644 index 0000000000..384d9c6735 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/likelihoods.py @@ -0,0 +1,173 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Likelihood models for Bayesian Neural Networks.""" + +import dataclasses +from typing import Any +import jax +from tensorflow_probability.substrates.jax.bijectors import softplus as softplus_lib +from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib +from tensorflow_probability.substrates.jax.distributions import inflated as inflated_lib +from tensorflow_probability.substrates.jax.distributions import logistic as logistic_lib +from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib +from tensorflow_probability.substrates.jax.distributions import negative_binomial as negative_binomial_lib +from tensorflow_probability.substrates.jax.distributions import normal as normal_lib +from tensorflow_probability.substrates.jax.distributions import transformed_distribution as transformed_distribution_lib + + +@dataclasses.dataclass +class LikelihoodModel: + """A class that knows how to compute the likelihood of some data.""" + + def dist(self, params, nn_out) -> distribution_lib.Distribution: + """Return the distribution underlying the likelihood.""" + raise NotImplementedError() + + def sample(self, params, nn_out, seed, sample_shape=None) -> jax.Array: + """Sample from the likelihood.""" + return self.dist(params, nn_out).sample( + seed=seed, sample_shape=sample_shape + ) + + def num_outputs(self): + """The number of outputs from the neural network the model needs.""" + return 1 + + def distributions(self): + """Like BayesianModule::distributions but for the model's parameters.""" + return {} + + def log_likelihood( + self, params, nn_out: jax.Array, observations: jax.Array + ) -> jax.Array: + return self.dist(params, nn_out).log_prob(observations) + + +@dataclasses.dataclass +class DummyLikelihoodModel(LikelihoodModel): + """A likelihood model that only knows how many outputs it has.""" + num_outs: int + + def num_outputs(self): + return self.num_outs + + +class NormalLikelihoodFixedNoise(LikelihoodModel): + """Abstract base class for observations = N(nn_out, noise_scale).""" + + def dist(self, params, nn_out): + return normal_lib.Normal(loc=nn_out, scale=params['noise_scale']) + + +@dataclasses.dataclass +class NormalLikelihoodLogisticNoise(NormalLikelihoodFixedNoise): + noise_min: float = 0.0 + log_noise_scale: float = 1.0 + + def distributions(self): + noise_scale = transformed_distribution_lib.TransformedDistribution( + logistic_lib.Logistic(0.0, self.log_noise_scale), + softplus_lib.Softplus(low=self.noise_min), + ) + return {'noise_scale': noise_scale} + + +@dataclasses.dataclass +class BoundedNormalLikelihoodLogisticNoise(NormalLikelihoodLogisticNoise): + lower_bound: float = 0.0 + + def dist(self, params, nn_out): + return softplus_lib.Softplus(low=self.lower_bound)( + normal_lib.Normal(loc=nn_out, scale=params['noise_scale']) + ) + + +@dataclasses.dataclass +class NormalLikelihoodLogNormalNoise(NormalLikelihoodFixedNoise): + log_noise_mean: float = -2.0 + log_noise_scale: float = 1.0 + + def distributions(self): + return { + 'noise_scale': lognormal_lib.LogNormal( + loc=self.log_noise_mean, scale=self.log_noise_scale + ) + } + + +class NormalLikelihoodVaryingNoise(LikelihoodModel): + + def num_outputs(self): + return 2 + + def dist(self, params, nn_out): + # TODO(colcarroll): Add a prior to constrain the scale (`nn_out[..., [1]]`) + # separately before it goes into the likelihood. + return normal_lib.Normal( + loc=nn_out[..., [0]], scale=jax.nn.softplus(nn_out[..., [1]]) + ) + + +class NegativeBinomial(LikelihoodModel): + """observations = NB(total_count = nn_out[0], logits = nn_out[1]).""" + + def num_outputs(self): + return 2 + + def dist(self, params, nn_out): + return negative_binomial_lib.NegativeBinomial( + total_count=nn_out[..., [0]], + logits=nn_out[..., [1]], + require_integer_total_count=False, + ) + + +class ZeroInflatedNegativeBinomial(LikelihoodModel): + """observations = NB(total_count = nn_out[0], logits = nn_out[1]).""" + + def num_outputs(self): + return 3 + + def dist(self, params, nn_out): + return inflated_lib.ZeroInflatedNegativeBinomial( + total_count=nn_out[..., [0]], + logits=nn_out[..., [1]], + inflated_loc_logits=nn_out[..., [2]], + require_integer_total_count=False, + ) + + +NAME_TO_LIKELIHOOD_MODEL = { + 'normal_likelihood_logistic_noise': NormalLikelihoodLogisticNoise, + 'bounded_normal_likelihood_logistic_noise': ( + BoundedNormalLikelihoodLogisticNoise + ), + 'normal_likelihood_lognormal_noise': NormalLikelihoodLogNormalNoise, + 'normal_likelihood_varying_noise': NormalLikelihoodVaryingNoise, + 'negative_binomial': NegativeBinomial, + 'zero_inflated_negative_binomial': ZeroInflatedNegativeBinomial, +} + + +def get_likelihood_model( + likelihood_model: str, likelihood_parameters: dict[str, Any] +) -> Any: + # Actually returns a Likelihood model, but pytype thinks it returns a + # Union[NegativeBinomial, ...]. + m = NAME_TO_LIKELIHOOD_MODEL[likelihood_model]() + for k, v in likelihood_parameters.items(): + if hasattr(m, k): + setattr(m, k, v) + return m diff --git a/tensorflow_probability/python/experimental/autobnn/likelihoods_test.py b/tensorflow_probability/python/experimental/autobnn/likelihoods_test.py new file mode 100644 index 0000000000..4776f951a3 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/likelihoods_test.py @@ -0,0 +1,63 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for bnn.py.""" + +from absl.testing import parameterized +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import likelihoods +from absl.testing import absltest + + +class LikelihoodTests(parameterized.TestCase): + + @parameterized.parameters( + likelihoods.NormalLikelihoodLogisticNoise(), + likelihoods.NormalLikelihoodLogNormalNoise(), + likelihoods.NormalLikelihoodVaryingNoise(), + likelihoods.NegativeBinomial(), + likelihoods.ZeroInflatedNegativeBinomial(), + ) + def test_likelihoods(self, likelihood_model): + lp = likelihood_model.log_likelihood( + params={'noise_scale': 0.4}, + nn_out=jnp.ones(shape=(10, likelihood_model.num_outputs())), + observations=jnp.zeros(shape=(10, 1)), + ) + self.assertEqual(lp.shape, (10, 1)) + + @parameterized.parameters(list(likelihoods.NAME_TO_LIKELIHOOD_MODEL.keys())) + def test_get_likelihood_model(self, likelihood_model): + m = likelihoods.get_likelihood_model(likelihood_model, {}) + lp = m.log_likelihood( + params={'noise_scale': 0.4}, + nn_out=jnp.ones(shape=(10, m.num_outputs())), + observations=jnp.zeros(shape=(10, 1)), + ) + self.assertEqual(lp.shape, (10, 1)) + + m2 = likelihoods.get_likelihood_model( + likelihood_model, + {'noise_min': 0.1, 'log_noise_scale': 0.5, 'log_noise_mean': -1.0}, + ) + lp2 = m2.log_likelihood( + params={'noise_scale': 0.4}, + nn_out=jnp.ones(shape=(10, m2.num_outputs())), + observations=jnp.zeros(shape=(10, 1)), + ) + self.assertEqual(lp2.shape, (10, 1)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_probability/python/experimental/autobnn/models.py b/tensorflow_probability/python/experimental/autobnn/models.py new file mode 100644 index 0000000000..63c7165988 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/models.py @@ -0,0 +1,309 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""BNF models. + +The "combo" model is a simple sum of linear and periodic components. The sum of +products is the smallest example of a sum of two products over two leaves each, +where each leaf is a continuous relaxiation (using WeightedSum) of periodic and +linear components. +""" +import functools +from typing import Sequence +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.python.experimental.autobnn import bnn_tree +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import likelihoods +from tensorflow_probability.python.experimental.autobnn import operators + + +Array = jnp.ndarray + + +def make_sum_of_operators_of_relaxed_leaves( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + use_mul: bool = True, + num_outputs: int = 1, +) -> bnn.BNN: + """Returns BNN model consisting of a sum of products or changeponts of leaves. + + Each leaf is a continuous relaxation over base kernels. + + Args: + time_series_xs: The x-values of the training data. + width: Width of the leaf BNNs. + periods: Periods for the PeriodicBNN kernel. + use_mul: If true, use Multiply as the depth 1 operator. If false, use + a LearnableChangepoint instead. + num_outputs: Number of outputs on the BNN. + """ + del num_outputs + def _make_continuous_relaxation( + width: int, + periods: Sequence[float], + include_eq_and_poly: bool) -> bnn.BNN: + leaves = [kernels.PeriodicBNN( + width=width, period=p, going_to_be_multiplied=use_mul) for p in periods] + leaves.append(kernels.LinearBNN( + width=width, going_to_be_multiplied=use_mul)) + if include_eq_and_poly: + leaves.extend([ + kernels.ExponentiatedQuadraticBNN( + width=width, going_to_be_multiplied=use_mul), + kernels.PolynomialBNN(width=width, going_to_be_multiplied=use_mul), + kernels.IdentityBNN(width=width, going_to_be_multiplied=use_mul), + ]) + return operators.WeightedSum(bnns=tuple(leaves), num_outputs=1, + going_to_be_multiplied=use_mul) + + leaf1 = _make_continuous_relaxation(width, periods, include_eq_and_poly=False) + leaf2 = _make_continuous_relaxation(width, periods, include_eq_and_poly=False) + + if use_mul: + op = operators.Multiply + else: + op = functools.partial(operators.LearnableChangePoint, + time_series_xs=time_series_xs) + + bnn1 = op(bnns=(leaf1, leaf2)) + + leaf3 = _make_continuous_relaxation(width, periods, include_eq_and_poly=True) + leaf4 = _make_continuous_relaxation(width, periods, include_eq_and_poly=True) + bnn2 = op(bnns=(leaf3, leaf4)) + + net = operators.Add(bnns=(bnn1, bnn2)) + return net + + +def make_sum_of_products( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + return make_sum_of_operators_of_relaxed_leaves( + time_series_xs, width, periods, use_mul=True, num_outputs=num_outputs) + + +def make_sum_of_changepoints( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + return make_sum_of_operators_of_relaxed_leaves( + time_series_xs, width, periods, use_mul=False, num_outputs=num_outputs) + + +def make_linear_plus_periodic( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Returns Combo model, consisting of linear and periodic leafs. + + Args: + time_series_xs: The x-values of the training data. + width: Width of the leaf BNNs. + periods: Periods for the PeriodicBNN kernel. + num_outputs: Number of outputs on the BNN. + """ + del num_outputs + del time_series_xs + leaves = [kernels.PeriodicBNN(width=width, period=p) for p in periods] + leaves.append(kernels.LinearBNN(width=width)) + return operators.Add(bnns=tuple(leaves)) + + +def make_sum_of_stumps( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Return a sum of depth 0 trees.""" + stumps = bnn_tree.list_of_all(time_series_xs, 0, width, periods=periods) + + return operators.WeightedSum(bnns=tuple(stumps), num_outputs=num_outputs) + + +def make_sum_of_stumps_and_products( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Return a sum of depth 0 and depth 1 product-only trees.""" + stumps = bnn_tree.list_of_all(time_series_xs, 0, width, periods=periods) + products = bnn_tree.list_of_all( + time_series_xs, + 1, + width, + periods=periods, + include_sums=False, + include_changepoints=False, + ) + + return operators.WeightedSum( + bnns=tuple(stumps + products), num_outputs=num_outputs) + + +def make_sum_of_shallow( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Return a sum of depth 0 and 1 trees.""" + stumps = bnn_tree.list_of_all(time_series_xs, 0, width, periods=periods) + depth1 = bnn_tree.list_of_all( + time_series_xs, 1, width, periods=periods, include_sums=False + ) + + return operators.WeightedSum( + bnns=tuple(stumps + depth1), num_outputs=num_outputs) + + +def make_sum_of_safe_shallow( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Return a sum of depth 0 and 1 trees, but not unsafe products.""" + stumps = bnn_tree.list_of_all(time_series_xs, 0, width, periods=periods) + depth1 = bnn_tree.list_of_all( + time_series_xs, + 1, + width, + periods=periods, + include_sums=False, + only_safe_products=True, + ) + + return operators.WeightedSum( + bnns=tuple(stumps + depth1), num_outputs=num_outputs) + + +def make_changepoint_of_safe_products( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Return a changepoint over two Multiply(Linear, WeightedSum(kernels))'s.""" + # By varying the weights inside the WeightedSum (and by relying on the + # identity Changepoint(A, A) = A), this model can express + # * all base kernels, + # * all "safe" multiplies over two base kernels (i.e., one of the terms + # has a very low effective parameter count to avoid overfitting noise), and + # * all single changepoints over two of the above. + + all_kernels = [ + kernels.PeriodicBNN(width=width, period=p, going_to_be_multiplied=True) + for p in periods + ] + all_kernels.extend( + [ + k(width=width, going_to_be_multiplied=True) + for k in [ + kernels.ExponentiatedQuadraticBNN, + kernels.MaternBNN, + kernels.LinearBNN, + kernels.QuadraticBNN, + ] + ] + ) + + safe_product = operators.Multiply( + bnns=( + operators.WeightedSum( + num_outputs=num_outputs, + bnns=( + kernels.IdentityBNN(width=width, going_to_be_multiplied=True), + kernels.LinearBNN(width=width, going_to_be_multiplied=True), + kernels.QuadraticBNN( + width=width, going_to_be_multiplied=True), + ), + going_to_be_multiplied=True + ), + operators.WeightedSum(bnns=tuple(all_kernels), + going_to_be_multiplied=True, + num_outputs=num_outputs), + ), + ) + + return operators.LearnableChangePoint( + time_series_xs=time_series_xs, + bnns=(safe_product, safe_product.clone(_deep_clone=True)), + ) + + +def make_mlp(num_layers: int): + """Return a make function for the MultiLayerBNN of the given depth.""" + + def make_multilayer( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, + ): + del num_outputs + del time_series_xs + assert len(periods) == 1 + return kernels.MultiLayerBNN( + num_layers=num_layers, + width=width, + period=periods[0], + ) + + return make_multilayer + + +MODEL_NAME_TO_MAKE_FUNCTION = { + 'sum_of_products': make_sum_of_products, + 'sum_of_changepoints': make_sum_of_changepoints, + 'linear_plus_periodic': make_linear_plus_periodic, + 'sum_of_stumps': make_sum_of_stumps, + 'sum_of_stumps_and_products': make_sum_of_stumps_and_products, + 'sum_of_shallow': make_sum_of_shallow, + 'sum_of_safe_shallow': make_sum_of_safe_shallow, + 'changepoint_of_safe_products': make_changepoint_of_safe_products, + 'mlp_depth2': make_mlp(2), + 'mlp_depth3': make_mlp(3), + 'mlp_depth4': make_mlp(4), + 'mlp_depth5': make_mlp(5), +} + + +def make_model( + model_name: str, + likelihood_model: likelihoods.LikelihoodModel, + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), +) -> bnn.BNN: + """Create a BNN model by name.""" + m = MODEL_NAME_TO_MAKE_FUNCTION[model_name]( + time_series_xs=time_series_xs, + width=width, + periods=periods, + num_outputs=likelihood_model.num_outputs(), + ) + m.set_likelihood_model(likelihood_model) + return m diff --git a/tensorflow_probability/python/experimental/autobnn/models_test.py b/tensorflow_probability/python/experimental/autobnn/models_test.py new file mode 100644 index 0000000000..ed9d44ca6d --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/models_test.py @@ -0,0 +1,67 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for models.py.""" + +from absl.testing import parameterized +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import likelihoods +from tensorflow_probability.python.experimental.autobnn import models +from absl.testing import absltest + + +MODELS = list(models.MODEL_NAME_TO_MAKE_FUNCTION.keys()) + + +class ModelsTest(parameterized.TestCase): + + @parameterized.parameters(MODELS) + def test_make_model(self, model_name): + m = models.make_model( + model_name, + likelihoods.NormalLikelihoodLogisticNoise(), + time_series_xs=jnp.linspace(0.0, 1.0, 50), + width=5, + periods=[0.2], + ) + params = m.init(jax.random.PRNGKey(0), jnp.zeros(5)) + lp = m.log_prior(params) + self.assertTrue((lp < 0.0) or (lp > 0.0)) + + @parameterized.product( + model_name=MODELS, + # It takes too long to test all of the likelihoods, so just test a + # couple to make sure each model correctly handles num_outputs > 1. + likelihood_name=[ + 'normal_likelihood_varying_noise', + 'zero_inflated_negative_binomial', + ], + ) + def test_make_model_and_likelihood(self, model_name, likelihood_name): + ll = likelihoods.get_likelihood_model(likelihood_name, {}) + m = models.make_model( + model_name, + ll, + time_series_xs=jnp.linspace(0.0, 1.0, 50), + width=5, + periods=[0.2], + ) + params = m.init(jax.random.PRNGKey(0), jnp.zeros(5)) + lp = m.log_prior(params) + self.assertTrue((lp < 0.0) or (lp > 0.0)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_probability/python/experimental/autobnn/operators.py b/tensorflow_probability/python/experimental/autobnn/operators.py new file mode 100644 index 0000000000..6772ab56d7 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/operators.py @@ -0,0 +1,292 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Flax.linen modules for combining BNNs.""" + +from typing import Optional +from flax import linen as nn +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.python.experimental.autobnn import likelihoods +from tensorflow_probability.substrates.jax.bijectors import chain as chain_lib +from tensorflow_probability.substrates.jax.bijectors import scale as scale_lib +from tensorflow_probability.substrates.jax.bijectors import shift as shift_lib +from tensorflow_probability.substrates.jax.distributions import beta as beta_lib +from tensorflow_probability.substrates.jax.distributions import dirichlet as dirichlet_lib +from tensorflow_probability.substrates.jax.distributions import half_normal as half_normal_lib +from tensorflow_probability.substrates.jax.distributions import normal as normal_lib +from tensorflow_probability.substrates.jax.distributions import transformed_distribution as transformed_distribution_lib + + +Array = jnp.ndarray + + +class BnnOperator(bnn.BNN): + """Base class for BNNs that are made from other BNNs.""" + bnns: tuple[bnn.BNN, ...] = tuple() + + def setup(self): + assert self.bnns, 'Forgot to pass `bnns` keyword argument?' + super().setup() + + def set_likelihood_model(self, likelihood_model: likelihoods.LikelihoodModel): + super().set_likelihood_model(likelihood_model) + # We need to set the likelihood models on the component + # bnns so that they will know how many outputs they are + # supposed to have. BUT: we also don't want to accidentally + # create any additional variables, distributions or parameters + # in them. So we set them all to having a dummy likelihood + # model that only knows how many outputs it has. + dummy_ll_model = likelihoods.DummyLikelihoodModel( + num_outs=likelihood_model.num_outputs() + ) + for b in self.bnns: + b.set_likelihood_model(dummy_ll_model) + + def log_prior(self, params): + if 'params' in params: + params = params['params'] + # params for bnns[i] are stored in params['bnns_{i}']. + lp = bnn.log_prior_of_parameters(params, self.distributions()) + for i, b in enumerate(self.bnns): + params_field = f'bnns_{i}' + if params_field in params: + lp += b.log_prior(params[params_field]) + return lp + + def get_all_distributions(self): + distributions = self.distributions() + for idx, sub_bnn in enumerate(self.bnns): + d = sub_bnn.get_all_distributions() + if d: + distributions[f'bnns_{idx}'] = d + return distributions + + def summary_join_string(self, params) -> str: + """String to use when joining the component summaries.""" + raise NotImplementedError() + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + params = params or {} + if 'params' in params: + params = params['params'] + + names = [ + b.summarize(params.get(f'bnns_{i}'), full) + for i, b in enumerate(self.bnns) + ] + + return f'({self.summary_join_string(params).join(names)})' + + +class MultipliableBnnOperator(BnnOperator): + """Abstract base class for a BnnOperator that can be multiplied.""" + # Ideally, this would just inherit from both BnnOperator and + # kernels.MultipliableBNN, but pytype gets really confused by that. + going_to_be_multiplied: bool = False + + def setup(self): + if self.going_to_be_multiplied: + for b in self.bnns: + assert b.going_to_be_multiplied + else: + for b in self.bnns: + assert not getattr(b, 'going_to_be_multiplied', False) + super().setup() + + def penultimate(self, inputs): + raise NotImplementedError( + 'Subclasses of MultipliableBnnOperator must define this.') + + +class Add(MultipliableBnnOperator): + """Add two or more BNNs.""" + + def penultimate(self, inputs): + penultimates = [b.penultimate(inputs) for b in self.bnns] + return jnp.sum(jnp.stack(penultimates, axis=-1), axis=-1) + + def __call__(self, inputs, deterministic=True): + return jnp.sum( + jnp.stack([b(inputs) for b in self.bnns], axis=-1), + axis=-1) + + def summary_join_string(self, params) -> str: + return '#' + + +class WeightedSum(MultipliableBnnOperator): + """Add two or more BNNs, with weights taken from a Dirichlet prior.""" + + # `alpha=1` is a uniform prior on mixing weights, higher values will favor + # weights like `1/n`, and lower weights will favor sparsity. + alpha: float = 1.0 + num_outputs: int = 1 + + def distributions(self): + bnn_concentrations = [1.0 if isinstance(b, BnnOperator) else 1.5 + for b in self.bnns] + if self.going_to_be_multiplied: + concentration = self.alpha * jnp.array(bnn_concentrations) + else: + concentration = self.alpha * jnp.array( + [bnn_concentrations for _ in range(self.num_outputs)]) + return super().distributions() | { + 'bnn_weights': dirichlet_lib.Dirichlet(concentration=concentration) + } + + def penultimate(self, inputs): + penultimates = [ + b.penultimate(inputs) * self.bnn_weights[0, i] + for i, b in enumerate(self.bnns) + ] + return jnp.sum(jnp.stack(penultimates, axis=-1), axis=-1) + + def __call__(self, inputs, deterministic=True): + return jnp.sum( + jnp.stack( + [ + b(inputs) * self.bnn_weights[0, :, i] + for i, b in enumerate(self.bnns) + ], + axis=-1, + ), + axis=-1, + ) + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + params = params or {} + if 'params' in params: + params = params['params'] + + names = [ + b.summarize(params.get(f'bnns_{i}'), full) + for i, b in enumerate(self.bnns) + ] + + def pretty_print(w): + try: + s = f'{jnp.array_str(jnp.array(w), precision=3)}' + except Exception: # pylint: disable=broad-exception-caught + try: + s = f'{w:.3f}' + except Exception: # pylint: disable=broad-exception-caught + s = f'{w}' + return s.replace('\n', ' ') + + weights = params.get('bnn_weights') + if weights is not None: + weights = jnp.array(weights)[0].T.squeeze() + names = [ + f'{pretty_print(w)} {n}' + for w, n in zip(weights, names) + if full or jnp.max(w) > 0.04 + ] + + return f'({"+".join(names)})' + + +class Multiply(BnnOperator): + """Multiply two or more BNNs.""" + + def setup(self): + self.dense = nn.Dense(self.likelihood_model.num_outputs()) + for b in self.bnns: + assert hasattr(b, 'penultimate') + assert b.going_to_be_multiplied, 'Forgot to set going_to_be_multiplied?' + super().setup() + + def distributions(self): + return super().distributions() | { + 'dense': { + 'kernel': normal_lib.Normal(loc=0, scale=1.0), + 'bias': normal_lib.Normal(loc=0, scale=1.0), + } + } + + def __call__(self, inputs, deterministic=True): + penultimates = [b.penultimate(inputs) for b in self.bnns] + return self.dense(jnp.prod(jnp.stack(penultimates, axis=-1), axis=-1)) + + def summary_join_string(self, params) -> str: + return '*' + + +class ChangePoint(BnnOperator): + """Switch from one BNN to another based on a time point.""" + change_point: float = 0.0 + slope: float = 1.0 + change_index: int = 0 + + def setup(self): + assert len(self.bnns) == 2 + super().setup() + + def __call__(self, inputs, deterministic=True): + time = inputs[..., self.change_index, jnp.newaxis] + y = (time - self.change_point) / self.slope + return nn.sigmoid(y) * self.bnns[1](inputs) + nn.sigmoid( + -y) * self.bnns[0](inputs) + + def summary_join_string(self, params) -> str: + return f'<[{self.change_point}]' + + +class LearnableChangePoint(BnnOperator): + """Switch from one BNN to another based on a time point.""" + time_series_xs: Optional[Array] = None + change_index: int = 0 + + def distributions(self): + assert self.time_series_xs is not None + lo = jnp.min(self.time_series_xs) + hi = jnp.max(self.time_series_xs) + # We want change_slope_scale to be the average value of + # time_series_xs[i+1] - time_series_xs[i] + change_slope_scale = (hi - lo) / self.time_series_xs.size + + # this distribution puts a lower density at the endpoints, and a reasonably + # flat distribution near the middle of the timeseries. + bij = chain_lib.Chain([shift_lib.Shift(lo), scale_lib.Scale(hi - lo)]) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=beta_lib.Beta(1.5, 1.5), bijector=bij + ) + return super().distributions() | { + 'change_point': dist, + 'change_slope': half_normal_lib.HalfNormal(scale=change_slope_scale), + } + + def setup(self): + assert len(self.bnns) == 2 + assert len(self.time_series_xs) >= 2 + super().setup() + + def __call__(self, inputs, deterministic=True): + time = inputs[..., self.change_index, jnp.newaxis] + y = (time - self.change_point) / self.change_slope + return nn.sigmoid(y) * self.bnns[1](inputs) + nn.sigmoid(-y) * self.bnns[0]( + inputs + ) + + def summary_join_string(self, params) -> str: + params = params or {} + if 'params' in params: + params = params['params'] + change_point = params.get('change_point') + cp_str = '' + if change_point is not None: + cp_str = f'[{jnp.array_str(change_point, precision=2)}]' + return f'<{cp_str}' diff --git a/tensorflow_probability/python/experimental/autobnn/operators_test.py b/tensorflow_probability/python/experimental/autobnn/operators_test.py new file mode 100644 index 0000000000..63f978f003 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/operators_test.py @@ -0,0 +1,218 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for operators.py.""" + +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import operators +from tensorflow_probability.python.experimental.autobnn import util +from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib +from absl.testing import absltest + + +KERNELS = [ + operators.Add( + bnns=(kernels.OneLayerBNN(width=50), kernels.OneLayerBNN(width=50)) + ), + operators.Add( + bnns=(kernels.OneLayerBNN(width=50), kernels.OneLayerBNN(width=100)) + ), + operators.Add( + bnns=( + kernels.PeriodicBNN(width=50, period=0.1), + kernels.OneLayerBNN(width=50), + ) + ), + operators.WeightedSum( + bnns=(kernels.OneLayerBNN(width=50), kernels.OneLayerBNN(width=50)) + ), + operators.WeightedSum( + bnns=( + kernels.PeriodicBNN(width=50, period=0.1), + kernels.OneLayerBNN(width=50), + ), + alpha=2.0, + ), + operators.WeightedSum( + bnns=( + kernels.ExponentiatedQuadraticBNN(width=50), + kernels.ExponentiatedQuadraticBNN(width=50), + ) + ), + operators.Multiply( + bnns=( + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + ), + ), + operators.Multiply( + bnns=( + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + ) + ), + operators.Multiply( + bnns=( + operators.Add( + bnns=( + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + ), + going_to_be_multiplied=True + ), + operators.Add( + bnns=( + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + ), + going_to_be_multiplied=True + ), + ) + ), + operators.ChangePoint( + bnns=(kernels.OneLayerBNN(width=50), kernels.OneLayerBNN(width=50)), + change_point=5.0, + slope=1.0, + ), + operators.LearnableChangePoint( + bnns=(kernels.OneLayerBNN(width=50), kernels.OneLayerBNN(width=50)), + time_series_xs=np.linspace(0., 5., 100), + ), +] + + +NAMES = [ + "(OneLayer#OneLayer)", + "(OneLayer#OneLayer)", + "(Periodic(period=0.10)#OneLayer)", + "(OneLayer+OneLayer)", + "(Periodic(period=0.10)+OneLayer)", + "(RBF+RBF)", + "(OneLayer*OneLayer)", + "(OneLayer*OneLayer*OneLayer)", + "((OneLayer#OneLayer)*(OneLayer#OneLayer))", + "(OneLayer<[5.0]OneLayer)", + "(OneLayer Tuple[Callable[..., Any], Callable[..., Any], Callable[..., Any]]: + """Returns unconstraining bijectors for all variables in the BNN.""" + jb = jax.tree_map( + lambda x: x.experimental_default_event_space_bijector(), + net.get_all_distributions(), + is_leaf=lambda x: isinstance(x, distribution_lib.Distribution), + ) + + def transform(params): + return {'params': jax.tree_map(lambda p, b: b(p), params['params'], jb)} + + def inverse_transform(params): + return { + 'params': jax.tree_map(lambda p, b: b.inverse(p), params['params'], jb) + } + + def inverse_log_det_jacobian(params): + return jax.tree_util.tree_reduce( + lambda a, b: a + b, + jax.tree_map( + lambda p, b: jnp.sum(b.inverse_log_det_jacobian(p)), + params['params'], + jb, + ), + initializer=0.0, + ) + + return transform, inverse_transform, inverse_log_det_jacobian + + +def suggest_periods(ys) -> List[float]: + """Suggest a few periods for the time series.""" + f, pxx = scipy.signal.periodogram(ys) + + top5_powers, top5_indices = jax.lax.top_k(pxx, 5) + top5_power = jnp.sum(top5_powers) + best_indices = [i for i in top5_indices if pxx[i] > 0.05 * top5_power] + # Sort in descending order so the best periods are first. + best_indices.sort(reverse=True, key=lambda i: pxx[i]) + return [1.0 / f[i] for i in best_indices if 1.0 / f[i] < 0.6 * len(ys)] + + +def load_fake_dataset(): + """Return some fake data for testing purposes.""" + x_train = jnp.arange(0.0, 120.0) / 120.0 + y_train = x_train + jnp.sin(x_train * 10.0) + x_train * x_train + x_train = x_train[..., jnp.newaxis] + return x_train, y_train[..., jnp.newaxis] diff --git a/tensorflow_probability/python/experimental/autobnn/util_test.py b/tensorflow_probability/python/experimental/autobnn/util_test.py new file mode 100644 index 0000000000..dbaa3866c9 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/util_test.py @@ -0,0 +1,72 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for util.py.""" + +import jax +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import util +from tensorflow_probability.python.internal import test_util + + +class UtilTest(test_util.TestCase): + + def test_suggest_periods(self): + self.assertListEqual([], util.suggest_periods([1 for _ in range(20)])) + self.assertListEqual( + [2.0], util.suggest_periods([i % 2 for i in range(20)]) + ) + np.testing.assert_allclose( + [20.0], + util.suggest_periods( + [jnp.sin(2.0 * jnp.pi * i / 20.0) for i in range(100)] + ), + ) + # suggest_periods is robust against small linear trends ... + np.testing.assert_allclose( + [20.0], + util.suggest_periods( + [0.01 * i + jnp.sin(2.0 * jnp.pi * i / 20.0) for i in range(100)] + ), + ) + # but sort of falls apart currently for large linear trends. + np.testing.assert_allclose( + [50.0, 100.0 / 3.0], + util.suggest_periods( + [i + jnp.sin(2.0 * jnp.pi * i / 20.0) for i in range(100)] + ), + ) + + def test_transform(self): + seed = jax.random.PRNGKey(20231018) + bnn = kernels.LinearBNN(width=5) + bnn.likelihood_model.noise_min = 0.2 + transform, _, _ = util.make_transforms(bnn) + p = bnn.init(seed, jnp.ones((1, 10), dtype=jnp.float32)) + + # Softplus(low=0.2) bijector + self.assertEqual(0.2 + jax.nn.softplus(p['params']['noise_scale']), + transform(p)['params']['noise_scale']) + self.assertEqual(jnp.exp(p['params']['amplitude']), + transform(p)['params']['amplitude']) + + # Identity bijector + self.assertAllEqual(p['params']['dense2']['kernel'], + transform(p)['params']['dense2']['kernel']) + + +if __name__ == '__main__': + test_util.main() diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index aa2843bfb9..7fb140497b 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -18,6 +18,8 @@ # //tensorflow_probability/python/internal/auto_batching # internally. +# Placeholder: py_library +# Placeholder: py_test load( "//tensorflow_probability/python:build_defs.bzl", "multi_substrate_py_library", @@ -546,9 +548,6 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", - "//tensorflow_probability/python/distributions:batch_reshape", - "//tensorflow_probability/python/distributions:batch_broadcast", - "//tensorflow_probability/python/distributions:independent" ], ) @@ -575,8 +574,6 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:sample", "//tensorflow_probability/python/distributions:transformed_distribution", "//tensorflow_probability/python/distributions:uniform", - "//tensorflow_probability/python/distributions:categorical", - "//tensorflow_probability/python/distributions:hidden_markov_model", "//tensorflow_probability/python/internal:test_util", "//tensorflow_probability/python/math:gradient", # "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport @@ -655,7 +652,6 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:mvn_diag", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/distributions:sample", - "//tensorflow_probability/python/experimental/mcmc:sequential_monte_carlo_kernel", "//tensorflow_probability/python/distributions:uniform", "//tensorflow_probability/python/distributions/internal:statistical_testing", "//tensorflow_probability/python/internal:test_util", diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index b920b0db85..1bcbc870f4 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -25,11 +25,6 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.mcmc.internal import util as mcmc_util -from tensorflow_probability.python.distributions import batch_reshape -from tensorflow_probability.python.distributions import batch_broadcast -from tensorflow_probability.python.distributions import normal -from tensorflow_probability.python.distributions import uniform - __all__ = [ 'infer_trajectories', @@ -49,39 +44,6 @@ def _default_trace_fn(state, kernel_results): kernel_results.incremental_log_marginal_likelihood) -def _default_kernel(parameters): - mean, variance = tf.nn.moments(parameters, axes=[0]) - proposal_distribution = normal.Normal(loc=tf.fill(parameters.shape, mean), scale=tf.sqrt(variance)) - return proposal_distribution - - -def _default_extra_fn(step, - state, - seed - ): - return state.extra - - -def where_fn(accept, a, b, num_outer_particles, num_inner_particles): - is_scalar = tf.rank(a) == tf.constant(0) - is_nan = tf.math.is_nan(tf.cast(a, tf.float32)) - is_all_nan = tf.reduce_all(is_nan) - if is_scalar and is_all_nan: - return a - elif a.shape == 2 and b.shape == 2: - # extra - return a - elif a.shape == num_outer_particles and b.shape == num_outer_particles: - return mcmc_util.choose(accept, a, b) - elif a.shape == [num_outer_particles, num_inner_particles] and \ - b.shape == [num_outer_particles, num_inner_particles]: - return mcmc_util.choose(accept, a, b) - elif a.shape == () and b.shape == (): - return a - else: - raise ValueError("Unexpected tensor shapes") - - particle_filter_arg_str = """\ Each latent state is a `Tensor` or nested structure of `Tensor`s, as defined by the `initial_state_prior`. @@ -473,344 +435,6 @@ def seeded_one_step(seed_state_results, _): return traced_results -def smc_squared( - inner_observations, - initial_parameter_prior, - num_outer_particles, - inner_initial_state_prior, - inner_transition_fn, - inner_observation_fn, - num_inner_particles, - outer_trace_fn=_default_trace_fn, - outer_rejuvenation_criterion_fn=None, - outer_resample_criterion_fn=None, - outer_resample_fn=weighted_resampling.resample_systematic, - inner_resample_criterion_fn=smc_kernel.ess_below_threshold, - inner_resample_fn=weighted_resampling.resample_systematic, - extra_fn=_default_extra_fn, - parameter_proposal_kernel=_default_kernel, - inner_proposal_fn=None, - inner_initial_state_proposal=None, - outer_trace_criterion_fn=_always_trace, - parallel_iterations=1, - num_transitions_per_observation=1, - static_trace_allocation_size=None, - initial_parameter_proposal=None, - unbiased_gradients=True, - seed=None, -): - init_seed, loop_seed, step_seed = samplers.split_seed(seed, n=3, salt='smc_squared') - - num_observation_steps = ps.size0(tf.nest.flatten(inner_observations)[0]) - - # TODO: The following two lines compensates for having the first empty step in smc2 - num_timesteps = (1 + num_transitions_per_observation * - (num_observation_steps - 1)) + 1 - last_obs_expanded = tf.expand_dims(inner_observations[-1], axis=0) - inner_observations = tf.concat([inner_observations, last_obs_expanded], axis=0) - - if outer_rejuvenation_criterion_fn is None: - outer_rejuvenation_criterion_fn = lambda *_: tf.constant(False) - - if outer_resample_criterion_fn is None: - outer_resample_criterion_fn = lambda *_: tf.constant(False) - - # If trace criterion is `None`, we'll return only the final results. - never_trace = lambda *_: False - if outer_trace_criterion_fn is None: - static_trace_allocation_size = 0 - outer_trace_criterion_fn = never_trace - - if initial_parameter_proposal is None: - initial_state = initial_parameter_prior.sample(num_outer_particles, - seed=seed) - initial_log_weights = ps.zeros_like( - initial_parameter_prior.log_prob(initial_state)) - else: - initial_state = initial_parameter_proposal.sample(num_outer_particles, - seed=seed) - initial_log_weights = ( - initial_parameter_prior.log_prob(initial_state) - - initial_parameter_proposal.log_prob(initial_state) - ) - - # Normalize the initial weights. If we used a proposal, the weights are - # normalized in expectation, but actually normalizing them reduces variance. - initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0) - - inner_weighted_particles = _particle_filter_initial_weighted_particles( - observations=inner_observations, - observation_fn=inner_observation_fn(initial_state), - initial_state_prior=inner_initial_state_prior(0, initial_state), - initial_state_proposal=(inner_initial_state_proposal(0, initial_state) - if inner_initial_state_proposal is not None else None), - num_particles=num_inner_particles, - particles_dim=1, - seed=seed - ) - - init_state = smc_kernel.WeightedParticles(*inner_weighted_particles) - - batch_zeros = tf.zeros(ps.shape(initial_state)) - - initial_filter_results = smc_kernel.SequentialMonteCarloResults( - steps=0, - parent_indices=smc_kernel._dummy_indices_like(init_state.log_weights), - incremental_log_marginal_likelihood=batch_zeros, - accumulated_log_marginal_likelihood=batch_zeros, - seed=samplers.zeros_seed()) - - initial_state = smc_kernel.WeightedParticles( - particles=(initial_state, - inner_weighted_particles, - initial_filter_results.parent_indices, - initial_filter_results.incremental_log_marginal_likelihood, - initial_filter_results.accumulated_log_marginal_likelihood), - log_weights=initial_log_weights, - extra=(tf.constant(0), - initial_filter_results.seed) - ) - - outer_propose_and_update_log_weights_fn = ( - _outer_particle_filter_propose_and_update_log_weights_fn( - outer_rejuvenation_criterion_fn=outer_rejuvenation_criterion_fn, - inner_observations=inner_observations, - inner_transition_fn=inner_transition_fn, - inner_proposal_fn=inner_proposal_fn, - inner_observation_fn=inner_observation_fn, - inner_resample_fn=inner_resample_fn, - inner_resample_criterion_fn=inner_resample_criterion_fn, - parameter_proposal_kernel=parameter_proposal_kernel, - initial_parameter_prior=initial_parameter_prior, - num_transitions_per_observation=num_transitions_per_observation, - unbiased_gradients=unbiased_gradients, - inner_initial_state_prior=inner_initial_state_prior, - inner_initial_state_proposal=inner_initial_state_proposal, - num_inner_particles=num_inner_particles, - num_outer_particles=num_outer_particles, - extra_fn=extra_fn - ) - ) - - traced_results = sequential_monte_carlo( - initial_weighted_particles=initial_state, - propose_and_update_log_weights_fn=outer_propose_and_update_log_weights_fn, - resample_fn=outer_resample_fn, - resample_criterion_fn=outer_resample_criterion_fn, - trace_criterion_fn=outer_trace_criterion_fn, - static_trace_allocation_size=static_trace_allocation_size, - parallel_iterations=parallel_iterations, - unbiased_gradients=unbiased_gradients, - num_steps=num_timesteps, - particles_dim=0, - trace_fn=outer_trace_fn, - seed=loop_seed - ) - - return traced_results - - -def _outer_particle_filter_propose_and_update_log_weights_fn( - inner_observations, - inner_transition_fn, - inner_proposal_fn, - inner_observation_fn, - initial_parameter_prior, - inner_initial_state_prior, - inner_initial_state_proposal, - num_transitions_per_observation, - inner_resample_fn, - inner_resample_criterion_fn, - outer_rejuvenation_criterion_fn, - unbiased_gradients, - parameter_proposal_kernel, - num_inner_particles, - num_outer_particles, - extra_fn -): - """Build a function specifying a particle filter update step.""" - def _outer_propose_and_update_log_weights_fn(step, state, seed=None): - outside_parameters = state.particles[0] - inner_weighted_particles, log_weights = state.particles[1], state.log_weights - - filter_results = smc_kernel.SequentialMonteCarloResults( - steps=step, - parent_indices=state.particles[2], - incremental_log_marginal_likelihood=state.particles[3], - accumulated_log_marginal_likelihood=state.particles[4], - seed=state.extra[1]) - - inner_propose_and_update_log_weights_fn = ( - _particle_filter_propose_and_update_log_weights_fn( - observations=inner_observations, - transition_fn=inner_transition_fn(outside_parameters), - proposal_fn=(inner_proposal_fn(outside_parameters) - if inner_proposal_fn is not None else None), - observation_fn=inner_observation_fn(outside_parameters), - particles_dim=1, - num_transitions_per_observation=num_transitions_per_observation, - extra_fn=extra_fn - ) - ) - - kernel = smc_kernel.SequentialMonteCarlo( - propose_and_update_log_weights_fn=inner_propose_and_update_log_weights_fn, - resample_fn=inner_resample_fn, - resample_criterion_fn=inner_resample_criterion_fn, - particles_dim=1, - unbiased_gradients=unbiased_gradients - ) - - inner_weighted_particles, filter_results = kernel.one_step(inner_weighted_particles, - filter_results, - seed=seed) - - updated_log_weights = log_weights + filter_results.incremental_log_marginal_likelihood - - do_rejuvenation = outer_rejuvenation_criterion_fn(step, state) - - def rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results): - proposed_parameters = parameter_proposal_kernel(outside_parameters).sample(seed=seed) - - rej_params_log_weights = ps.zeros_like( - initial_parameter_prior.log_prob(proposed_parameters) - ) - rej_params_log_weights = tf.nn.log_softmax(rej_params_log_weights, axis=0) - - rej_inner_weighted_particles = _particle_filter_initial_weighted_particles( - observations=inner_observations, - observation_fn=inner_observation_fn(proposed_parameters), - initial_state_prior=inner_initial_state_prior(0, proposed_parameters), - initial_state_proposal=(inner_initial_state_proposal(0, proposed_parameters) - if inner_initial_state_proposal is not None else None), - num_particles=num_inner_particles, - particles_dim=1, - seed=seed) - - batch_zeros = tf.zeros(ps.shape(log_weights)) - - rej_filter_results = smc_kernel.SequentialMonteCarloResults( - steps=tf.constant(0, dtype=tf.int32), - parent_indices=smc_kernel._dummy_indices_like( - rej_inner_weighted_particles.log_weights - ), - incremental_log_marginal_likelihood=batch_zeros, - accumulated_log_marginal_likelihood=batch_zeros, - seed=samplers.zeros_seed()) - - rej_inner_particles_weights = rej_inner_weighted_particles.log_weights - - rej_inner_propose_and_update_log_weights_fn = ( - _particle_filter_propose_and_update_log_weights_fn( - observations=inner_observations, - transition_fn=inner_transition_fn(proposed_parameters), - proposal_fn=(inner_proposal_fn(proposed_parameters) - if inner_proposal_fn is not None else None), - observation_fn=inner_observation_fn(proposed_parameters), - extra_fn=extra_fn, - particles_dim=1, - num_transitions_per_observation=num_transitions_per_observation) - ) - - rej_kernel = smc_kernel.SequentialMonteCarlo( - propose_and_update_log_weights_fn=rej_inner_propose_and_update_log_weights_fn, - resample_fn=inner_resample_fn, - resample_criterion_fn=inner_resample_criterion_fn, - particles_dim=1, - unbiased_gradients=unbiased_gradients) - - def condition(i, - rej_inner_weighted_particles, - rej_filter_results, - rej_parameters_weights, - rej_params_log_weights): - return tf.less_equal(i, step) - - def body(i, - rej_inner_weighted_particles, - rej_filter_results, - rej_parameters_weights, - rej_params_log_weights): - - rej_inner_weighted_particles, rej_filter_results = rej_kernel.one_step( - rej_inner_weighted_particles, rej_filter_results, seed=seed - ) - - rej_parameters_weights += rej_inner_weighted_particles.log_weights - - rej_params_log_weights = rej_params_log_weights + rej_filter_results.incremental_log_marginal_likelihood - return i + 1, rej_inner_weighted_particles, rej_filter_results, rej_parameters_weights, rej_params_log_weights - - i, rej_inner_weighted_particles, rej_filter_results, rej_inner_particles_weights, rej_params_log_weights = tf.while_loop( - condition, - body, - loop_vars=[0, - rej_inner_weighted_particles, - rej_filter_results, - rej_inner_particles_weights, - rej_params_log_weights - ] - ) - - log_a = rej_filter_results.accumulated_log_marginal_likelihood - \ - filter_results.accumulated_log_marginal_likelihood + \ - parameter_proposal_kernel(proposed_parameters).log_prob(outside_parameters) - \ - parameter_proposal_kernel(outside_parameters).log_prob(proposed_parameters) - - acceptance_probs = tf.minimum(1., tf.exp(log_a)) - - random_numbers = uniform.Uniform(0., 1.).sample(num_outer_particles, seed=seed) - - # Determine if the proposed particle should be accepted or reject - accept = random_numbers > acceptance_probs - - # Update the chosen particles and filter restults based on the acceptance step - outside_parameters = tf.where(accept, outside_parameters, proposed_parameters) - updated_log_weights = tf.where(accept, updated_log_weights, rej_params_log_weights) - - inner_weighted_particles_particles = mcmc_util.choose( - accept, - inner_weighted_particles.particles, - rej_inner_weighted_particles.particles - ) - inner_weighted_particles_log_weights = mcmc_util.choose( - accept, - inner_weighted_particles.log_weights, - rej_inner_weighted_particles.log_weights - ) - - inner_weighted_particles = smc_kernel.WeightedParticles( - particles=inner_weighted_particles_particles, - log_weights=inner_weighted_particles_log_weights, - extra=inner_weighted_particles.extra - ) - - filter_results = tf.nest.map_structure( - lambda a, b: where_fn(accept, a, b, num_outer_particles, num_inner_particles), - filter_results, - rej_filter_results - ) - - return outside_parameters, updated_log_weights, inner_weighted_particles, filter_results - - outside_parameters, updated_log_weights, inner_weighted_particles, filter_results = tf.cond( - do_rejuvenation, - lambda: (rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results)), - lambda: (outside_parameters, updated_log_weights, inner_weighted_particles, filter_results) - ) - - return smc_kernel.WeightedParticles( - particles=(outside_parameters, - inner_weighted_particles, - filter_results.parent_indices, - filter_results.incremental_log_marginal_likelihood, - filter_results.accumulated_log_marginal_likelihood), - log_weights=updated_log_weights, - extra=(step, - filter_results.seed)) - return _outer_propose_and_update_log_weights_fn - - @docstring_util.expand_docstring( particle_filter_arg_str=particle_filter_arg_str.format(scibor_ref_idx=1)) def particle_filter(observations, @@ -818,7 +442,6 @@ def particle_filter(observations, transition_fn, observation_fn, num_particles, - extra_fn=_default_extra_fn, initial_state_proposal=None, proposal_fn=None, resample_fn=weighted_resampling.resample_systematic, @@ -903,9 +526,7 @@ def particle_filter(observations, particles_dim=particles_dim, proposal_fn=proposal_fn, observation_fn=observation_fn, - num_transitions_per_observation=num_transitions_per_observation, - extra_fn=extra_fn - )) + num_transitions_per_observation=num_transitions_per_observation)) return sequential_monte_carlo( initial_weighted_particles=initial_weighted_particles, @@ -928,7 +549,6 @@ def _particle_filter_initial_weighted_particles(observations, initial_state_proposal, num_particles, particles_dim=0, - extra=np.nan, seed=None): """Initialize a set of weighted particles including the first observation.""" # Propose an initial state. @@ -954,14 +574,6 @@ def _particle_filter_initial_weighted_particles(observations, axis=particles_dim) # Return particles weighted by the initial observation. - if extra is np.nan: - if len(ps.shape(initial_log_weights)) == 1: - # initial extra for particle filter - extra = tf.constant(0) - else: - # initial extra for inner particles of smc_squared - extra = tf.constant(0, shape=ps.shape(initial_log_weights)) - return smc_kernel.WeightedParticles( particles=initial_state, log_weights=initial_log_weights + _compute_observation_log_weights( @@ -969,8 +581,7 @@ def _particle_filter_initial_weighted_particles(observations, particles=initial_state, observations=observations, observation_fn=observation_fn, - particles_dim=particles_dim), - extra=extra) + particles_dim=particles_dim)) def _particle_filter_propose_and_update_log_weights_fn( @@ -978,7 +589,6 @@ def _particle_filter_propose_and_update_log_weights_fn( transition_fn, proposal_fn, observation_fn, - extra_fn, num_transitions_per_observation=1, particles_dim=0): """Build a function specifying a particle filter update step.""" @@ -1009,18 +619,13 @@ def propose_and_update_log_weights_fn(step, state, seed=None): else: proposed_particles = transition_dist.sample(seed=seed) - updated_extra = extra_fn(step, - state, - seed) - with tf.control_dependencies(assertions): return smc_kernel.WeightedParticles( particles=proposed_particles, log_weights=log_weights + _compute_observation_log_weights( step + 1, proposed_particles, observations, observation_fn, num_transitions_per_observation=num_transitions_per_observation, - particles_dim=particles_dim), - extra=updated_extra) + particles_dim=particles_dim)) return propose_and_update_log_weights_fn @@ -1065,8 +670,6 @@ def _compute_observation_log_weights(step, observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) - if particles_dim == 1: - observation = tf.expand_dims(observation, axis=0) observation = tf.nest.map_structure( lambda x: tf.expand_dims(x, axis=particles_dim), observation) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index e190c76bda..6508eb6231 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -21,7 +21,6 @@ from tensorflow_probability.python.bijectors import shift from tensorflow_probability.python.distributions import bernoulli from tensorflow_probability.python.distributions import deterministic -from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab from tensorflow_probability.python.distributions import joint_distribution_named as jdn from tensorflow_probability.python.distributions import linear_gaussian_ssm as lgssm @@ -178,6 +177,128 @@ def observation_fn(_, state): self.assertAllEqual(incremental_log_marginal_likelihoods.shape, [num_timesteps] + batch_shape) + def test_batch_of_filters_particles_dim_1(self): + + batch_shape = [3, 2] + num_particles = 1000 + num_timesteps = 40 + + # Batch of priors on object 1D positions and velocities. + initial_state_prior = jdn.JointDistributionNamed({ + 'position': normal.Normal(loc=0., scale=tf.ones(batch_shape)), + 'velocity': normal.Normal(loc=0., scale=tf.ones(batch_shape) * 0.1) + }) + + def transition_fn(_, previous_state): + return jdn.JointDistributionNamed({ + 'position': + normal.Normal( + loc=previous_state['position'] + previous_state['velocity'], + scale=0.1), + 'velocity': + normal.Normal(loc=previous_state['velocity'], scale=0.01) + }) + + def observation_fn(_, state): + return normal.Normal(loc=state['position'], scale=0.1) + + # Batch of synthetic observations, . + true_initial_positions = np.random.randn(*batch_shape).astype(self.dtype) + true_velocities = 0.1 * np.random.randn( + *batch_shape).astype(self.dtype) + observed_positions = ( + true_velocities * + np.arange(num_timesteps).astype( + self.dtype)[..., tf.newaxis, tf.newaxis] + + true_initial_positions) + + (particles, log_weights, parent_indices, + incremental_log_marginal_likelihoods) = self.evaluate( + particle_filter.particle_filter( + observations=observed_positions, + initial_state_prior=initial_state_prior, + transition_fn=transition_fn, + observation_fn=observation_fn, + num_particles=num_particles, + seed=test_util.test_seed(), + particles_dim=1)) + + self.assertAllEqual(particles['position'].shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) + self.assertAllEqual(particles['velocity'].shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) + self.assertAllEqual(parent_indices.shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) + self.assertAllEqual(incremental_log_marginal_likelihoods.shape, + [num_timesteps] + batch_shape) + + self.assertAllClose( + self.evaluate( + tf.reduce_sum(tf.exp(log_weights) * + particles['position'], axis=2)), + observed_positions, + atol=0.3) + + velocity_means = tf.reduce_sum(tf.exp(log_weights) * + particles['velocity'], axis=2) + + self.assertAllClose( + self.evaluate(tf.reduce_mean(velocity_means, axis=0)), + true_velocities, atol=0.05) + + # Uncertainty in velocity should decrease over time. + velocity_stddev = self.evaluate( + tf.math.reduce_std(particles['velocity'], axis=2)) + self.assertAllLess((velocity_stddev[-1] - velocity_stddev[0]), 0.) + + trajectories = self.evaluate( + particle_filter.reconstruct_trajectories(particles, + parent_indices, + particles_dim=1)) + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['position'].shape) + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['velocity'].shape) + + # Verify that `infer_trajectories` also works on batches. + trajectories, incremental_log_marginal_likelihoods = self.evaluate( + particle_filter.infer_trajectories( + observations=observed_positions, + initial_state_prior=initial_state_prior, + transition_fn=transition_fn, + observation_fn=observation_fn, + num_particles=num_particles, + particles_dim=1, + seed=test_util.test_seed())) + + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['position'].shape) + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['velocity'].shape) + self.assertAllEqual(incremental_log_marginal_likelihoods.shape, + [num_timesteps] + batch_shape) + def test_reconstruct_trajectories_toy_example(self): particles = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6,], [7, 8, 9]]) # 1 -- 4 -- 7 @@ -613,205 +734,6 @@ def marginal_log_likelihood(level_scale, noise_scale): self.assertAllNotNone(grads) self.assertAllAssertsNested(self.assertNotAllZero, grads) - def test_smc_squared_rejuvenation_parameters(self): - def particle_dynamics(params, _, previous_state): - reshaped_params = tf.reshape(params, [params.shape[0]] + [1] * (previous_state.shape.rank - 1)) - broadcasted_params = tf.broadcast_to(reshaped_params, previous_state.shape) - return normal.Normal(previous_state + broadcasted_params + 1, 0.1) - - def rejuvenation_criterion(step, state): - # Rejuvenation every 2 steps - cond = tf.logical_and( - tf.equal(tf.math.mod(step, tf.constant(2)), tf.constant(0)), - tf.not_equal(state.extra[0], tf.constant(0)) - ) - return tf.cond(cond, lambda: tf.constant(True), lambda: tf.constant(False)) - - inner_observations = tf.range(30, dtype=tf.float32) - - num_outer_particles = 3 - num_inner_particles = 7 - - loc = tf.broadcast_to([0., 0.], [num_outer_particles, 2]) - scale_diag = tf.broadcast_to([0.05, 0.05], [num_outer_particles, 2]) - - params, inner_pt = self.evaluate(particle_filter.smc_squared( - inner_observations=inner_observations, - inner_initial_state_prior=lambda _, params: mvn_diag.MultivariateNormalDiag( - loc=loc, scale_diag=scale_diag - ), - initial_parameter_prior=normal.Normal(3., 1.), - num_outer_particles=num_outer_particles, - num_inner_particles=num_inner_particles, - outer_rejuvenation_criterion_fn=rejuvenation_criterion, - inner_transition_fn=lambda params: ( - lambda _, state: independent.Independent(particle_dynamics(params, _, state), 1)), - inner_observation_fn=lambda params: ( - lambda _, state: independent.Independent(normal.Normal(state, 2.), 1)), - outer_trace_fn=lambda s, r: ( - s.particles[0], - s.particles[1] - ), - parameter_proposal_kernel=lambda params: normal.Normal(params, 3), - seed=test_util.test_seed() - ) - ) - - abs_params = tf.abs(params) - differences = abs_params[1:] - abs_params[:-1] - mask_parameters = tf.reduce_all(tf.less_equal(differences, 0), axis=0) - - self.assertAllTrue(mask_parameters) - - def test_smc_squared_can_step_dynamics_faster_than_observations(self): - initial_state_prior = jdn.JointDistributionNamed({ - 'position': deterministic.Deterministic([1.]), - 'velocity': deterministic.Deterministic([0.]) - }) - - # Use 100 steps between observations to integrate a simple harmonic - # oscillator. - dt = 0.01 - def simple_harmonic_motion_transition_fn(_, state): - return jdn.JointDistributionNamed({ - 'position': - normal.Normal( - loc=state['position'] + dt * state['velocity'], - scale=dt * 0.01), - 'velocity': - normal.Normal( - loc=state['velocity'] - dt * state['position'], - scale=dt * 0.01) - }) - - def observe_position(_, state): - return normal.Normal(loc=state['position'], scale=0.01) - - particles, lps = self.evaluate(particle_filter.smc_squared( - inner_observations=tf.convert_to_tensor( - [tf.math.cos(0.), tf.math.cos(1.)]), - inner_initial_state_prior=lambda _, params: initial_state_prior, - initial_parameter_prior=deterministic.Deterministic(0.), - num_outer_particles=1, - inner_transition_fn=lambda params: simple_harmonic_motion_transition_fn, - inner_observation_fn=lambda params: observe_position, - num_inner_particles=1024, - outer_trace_fn=lambda s, r: ( - s.particles[1].particles, - s.particles[3] - ), - num_transitions_per_observation=100, - seed=test_util.test_seed()) - ) - - self.assertAllEqual(ps.shape(particles['position']), tf.constant([102, 1, 1024])) - - self.assertAllClose(tf.transpose(np.mean(particles['position'], axis=-1)), - tf.reshape(tf.math.cos(dt * np.arange(102)), [1, -1]), - atol=0.04) - - self.assertAllEqual(ps.shape(lps), [102, 1]) - self.assertGreater(lps[1][0], 1.) - self.assertGreater(lps[-1][0], 3.) - - def test_smc_squared_custom_outer_trace_fn(self): - def trace_fn(state, _): - # Traces the mean and stddev of the particle population at each step. - weights = tf.exp(state[0][1].log_weights[0]) - mean = tf.reduce_sum(weights * state[0][1].particles[0], axis=0) - variance = tf.reduce_sum( - weights * (state[0][1].particles[0] - mean[tf.newaxis, ...]) ** 2) - return {'mean': mean, - 'stddev': tf.sqrt(variance), - # In real usage we would likely not track the particles and - # weights. We keep them here just so we can double-check the - # stats, below. - 'particles': state[0][1].particles[0], - 'weights': weights} - - results = self.evaluate(particle_filter.smc_squared( - inner_observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), - inner_initial_state_prior=lambda _, params: normal.Normal([0.], 1.), - initial_parameter_prior=deterministic.Deterministic(0.), - inner_transition_fn=lambda params: (lambda _, state: normal.Normal(state, 1.)), - inner_observation_fn=lambda params: (lambda _, state: normal.Normal(state, 1.)), - num_inner_particles=1024, - num_outer_particles=1, - outer_trace_fn=trace_fn, - seed=test_util.test_seed()) - ) - - # Verify that posterior means are increasing. - self.assertAllGreater(results['mean'][1:] - results['mean'][:-1], 0.) - - # Check that our traced means and scales match values computed - # by averaging over particles after the fact. - all_means = self.evaluate(tf.reduce_sum( - results['weights'] * results['particles'], axis=1)) - all_variances = self.evaluate( - tf.reduce_sum( - results['weights'] * - (results['particles'] - all_means[..., tf.newaxis])**2, - axis=1)) - self.assertAllClose(results['mean'], all_means) - self.assertAllClose(results['stddev'], np.sqrt(all_variances)) - - def test_smc_squared_indices_to_trace(self): - num_outer_particles = 7 - num_inner_particles = 13 - - def rejuvenation_criterion(step, state): - # Rejuvenation every 3 steps - cond = tf.logical_and( - tf.equal(tf.math.mod(step, tf.constant(3)), tf.constant(0)), - tf.not_equal(state.extra[0], tf.constant(0)) - ) - return tf.cond(cond, lambda: tf.constant(True), lambda: tf.constant(False)) - - (parameters, weight_parameters, inner_particles, inner_log_weights, lp) = self.evaluate( - particle_filter.smc_squared( - inner_observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), - initial_parameter_prior=deterministic.Deterministic(0.), - inner_initial_state_prior=lambda _, params: normal.Normal([0.] * num_outer_particles, 1.), - inner_transition_fn=lambda params: (lambda _, state: normal.Normal(state, 10.)), - inner_observation_fn=lambda params: (lambda _, state: normal.Normal(state, 0.1)), - num_inner_particles=num_inner_particles, - num_outer_particles=num_outer_particles, - outer_rejuvenation_criterion_fn=rejuvenation_criterion, - outer_trace_fn=lambda s, r: ( # pylint: disable=g-long-lambda - s.particles[0], - s.log_weights, - s.particles[1].particles, - s.particles[1].log_weights, - r.accumulated_log_marginal_likelihood), - seed=test_util.test_seed()) - ) - - # TODO: smc_squared at the moment starts his run with an empty step - self.assertAllEqual(ps.shape(parameters), [6, 7]) - self.assertAllEqual(ps.shape(weight_parameters), [6, 7]) - self.assertAllEqual(ps.shape(inner_particles), [6, 7, 13]) - self.assertAllEqual(ps.shape(inner_log_weights), [6, 7, 13]) - self.assertAllEqual(ps.shape(lp), [6]) - - def test_extra(self): - def step_hundred(step, state, seed): - return step * 2 - - results = self.evaluate( - particle_filter.particle_filter( - observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), - initial_state_prior=normal.Normal(0., 1.), - transition_fn=lambda _, state: normal.Normal(state, 1.), - observation_fn=lambda _, state: normal.Normal(state, 1.), - num_particles=1024, - extra_fn=step_hundred, - trace_fn=lambda s, r: s.extra, - seed=test_util.test_seed()) - ) - - self.assertAllEqual(results, [0, 0, 2, 4, 6]) - # TODO(b/186068104): add tests with dynamic shapes. class ParticleFilterTestFloat32(_ParticleFilterTest): diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py index 300418c87d..73cb0f8414 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py @@ -34,7 +34,7 @@ # SequentialMonteCarlo `state` structure. class WeightedParticles(collections.namedtuple( - 'WeightedParticles', ['particles', 'log_weights', 'extra'])): + 'WeightedParticles', ['particles', 'log_weights'])): """Particles with corresponding log weights. This structure serves as the `state` for the `SequentialMonteCarlo` transition @@ -50,10 +50,6 @@ class WeightedParticles(collections.namedtuple( `exp(reduce_logsumexp(log_weights, axis=0)) == 1.`. These must be used in conjunction with `particles` to compute expectations under the target distribution. - extra: a (structure of) Tensor(s) each of shape - `concat([[b1, ..., bN], event_shape])`, where `event_shape` - may differ across component `Tensor`s. This represents global state of the - sampling process that is not associated with individual particles. In some contexts, particles may be stacked across multiple inference steps, in which case all `Tensor` shapes will be prefixed by an additional dimension @@ -296,7 +292,7 @@ def one_step(self, state, kernel_results, seed=None): - tf.gather(normalized_log_weights, 0, axis=self.particles_dim)) do_resample = self.resample_criterion_fn( - state, self.particles_dim) + state, particles_dim=self.particles_dim) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to # use the resampled values for each batch element according to @@ -330,8 +326,7 @@ def one_step(self, state, kernel_results, seed=None): normalized_log_weights)) return (WeightedParticles(particles=resampled_particles, - log_weights=log_weights, - extra=state.extra), + log_weights=log_weights), SequentialMonteCarloResults( steps=kernel_results.steps + 1, parent_indices=resample_indices, diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py index 2e29f6c4dd..2a9302a420 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py @@ -42,9 +42,7 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None): return WeightedParticles( particles=proposed_particles, log_weights=weighted_particles.log_weights + - normal.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles), - extra=tf.constant(np.nan) - ) + normal.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles)) num_particles = 16 initial_state = self.evaluate( @@ -52,9 +50,7 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None): particles=tf.random.normal([num_particles], seed=test_util.test_seed()), log_weights=tf.fill([num_particles], - -tf.math.log(float(num_particles))), - extra=tf.constant(np.nan) - )) + -tf.math.log(float(num_particles))))) # Run a couple of steps. seeds = samplers.split_seed( @@ -100,9 +96,7 @@ def testMarginalLikelihoodGradientIsDefined(self): WeightedParticles( particles=samplers.normal([num_particles], seed=seeds[0]), log_weights=tf.fill([num_particles], - -tf.math.log(float(num_particles))), - extra=tf.constant(np.nan) - )) + -tf.math.log(float(num_particles))))) def propose_and_update_log_weights_fn(_, weighted_particles, @@ -116,9 +110,7 @@ def propose_and_update_log_weights_fn(_, particles=proposed_particles, log_weights=(weighted_particles.log_weights + transition_dist.log_prob(proposed_particles) - - proposal_dist.log_prob(proposed_particles)), - extra=tf.constant(np.nan) - ) + proposal_dist.log_prob(proposed_particles))) def marginal_logprob(transition_scale): kernel = SequentialMonteCarlo( diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py index a7af244ec3..e8c7f52fb3 100644 --- a/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py @@ -1675,7 +1675,10 @@ def _matmul( # pylint:disable=missing-docstring a_is_sparse=False, b_is_sparse=False, output_type=None, # pylint: disable=unused-argument - name=None): + grad_a=False, # pylint: disable=unused-argument + grad_b=False, # pylint: disable=unused-argument + name=None, +): if transpose_a or transpose_b: raise ValueError("Transposing not supported at this time.") if a_is_sparse or b_is_sparse: diff --git a/tensorflow_probability/python/layers/internal/distribution_tensor_coercible_test.py b/tensorflow_probability/python/layers/internal/distribution_tensor_coercible_test.py index e84d69d642..85bf8380f7 100644 --- a/tensorflow_probability/python/layers/internal/distribution_tensor_coercible_test.py +++ b/tensorflow_probability/python/layers/internal/distribution_tensor_coercible_test.py @@ -294,8 +294,6 @@ def testPropagatedAttributes(self): class MemoryLeakTest(test_util.TestCase): def testTypeObjectLeakage(self): - # TODO(b/303352281): Reenable this test. - self.skipTest('This test does not currently work under Python 3.11.') if not tf.executing_eagerly(): self.skipTest('only relevant to eager') diff --git a/testing/dependency_install_lib.sh b/testing/dependency_install_lib.sh index 801d7a3361..261cc1665b 100644 --- a/testing/dependency_install_lib.sh +++ b/testing/dependency_install_lib.sh @@ -93,9 +93,11 @@ install_test_only_packages() { # The following unofficial dependencies are used only by tests. PIP_FLAGS=${1-} python -m pip install $PIP_FLAGS \ + flax \ hypothesis \ jax \ jaxlib \ + jaxtyping \ optax \ matplotlib \ mock \