diff --git a/tensorflow_probability/python/experimental/autobnn/BUILD b/tensorflow_probability/python/experimental/autobnn/BUILD new file mode 100644 index 0000000000..1a7f1a5381 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/BUILD @@ -0,0 +1,227 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Code for AutoBNN. See README.md for more information. + +# Placeholder: py_library +# Placeholder: py_test + +licenses(["notice"]) + +package( + # default_applicable_licenses + default_visibility = ["//visibility:public"], +) + +py_library( + name = "bnn", + srcs = ["bnn.py"], + deps = [ + ":likelihoods", + # flax:core dep, + # jax dep, + # jaxtyping dep, + "//tensorflow_probability/python/distributions:distribution.jax", + ], +) + +py_test( + name = "bnn_test", + srcs = ["bnn_test.py"], + deps = [ + ":bnn", + # absl/testing:absltest dep, + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + "//tensorflow_probability:jax", + "//tensorflow_probability/python/distributions:lognormal.jax", + "//tensorflow_probability/python/distributions:normal.jax", + ], +) + +py_library( + name = "kernels", + srcs = ["kernels.py"], + deps = [ + ":bnn", + # flax dep, + # flax:core dep, + # jax dep, + "//tensorflow_probability/python/distributions:lognormal.jax", + "//tensorflow_probability/python/distributions:normal.jax", + "//tensorflow_probability/python/distributions:student_t.jax", + "//tensorflow_probability/python/distributions:uniform.jax", + ], +) + +py_test( + name = "kernels_test", + srcs = ["kernels_test.py"], + deps = [ + ":kernels", + ":util", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + "//tensorflow_probability/python/distributions:lognormal.jax", + ], +) + +py_library( + name = "likelihoods", + srcs = ["likelihoods.py"], + deps = [ + # flax:core dep, + # jax dep, + # jaxtyping dep, + "//tensorflow_probability:jax", + "//tensorflow_probability/python/bijectors:softplus.jax", + "//tensorflow_probability/python/distributions:distribution.jax", + "//tensorflow_probability/python/distributions:inflated.jax", + "//tensorflow_probability/python/distributions:logistic.jax", + "//tensorflow_probability/python/distributions:lognormal.jax", + "//tensorflow_probability/python/distributions:negative_binomial.jax", + "//tensorflow_probability/python/distributions:normal.jax", + "//tensorflow_probability/python/distributions:transformed_distribution.jax", + ], +) + +py_test( + name = "likelihoods_test", + srcs = ["likelihoods_test.py"], + deps = [ + ":likelihoods", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # jax dep, + ], +) + +py_library( + name = "models", + srcs = ["models.py"], + deps = [ + ":bnn", + ":bnn_tree", + ":kernels", + ":likelihoods", + ":operators", + # jax dep, + ], +) + +py_test( + name = "models_test", + srcs = ["models_test.py"], + shard_count = 3, + deps = [ + ":likelihoods", + ":models", + ":operators", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # jax dep, + ], +) + +py_library( + name = "operators", + srcs = ["operators.py"], + deps = [ + ":bnn", + ":likelihoods", + # flax:core dep, + # jax dep, + "//tensorflow_probability:jax", + "//tensorflow_probability/python/bijectors:chain.jax", + "//tensorflow_probability/python/bijectors:scale.jax", + "//tensorflow_probability/python/bijectors:shift.jax", + "//tensorflow_probability/python/distributions:beta.jax", + "//tensorflow_probability/python/distributions:dirichlet.jax", + "//tensorflow_probability/python/distributions:half_normal.jax", + "//tensorflow_probability/python/distributions:normal.jax", + "//tensorflow_probability/python/distributions:transformed_distribution.jax", + ], +) + +py_test( + name = "operators_test", + srcs = ["operators_test.py"], + deps = [ + ":kernels", + ":operators", + ":util", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + # numpy dep, + "//tensorflow_probability/python/distributions:distribution.jax", + ], +) + +py_library( + name = "bnn_tree", + srcs = ["bnn_tree.py"], + deps = [ + ":bnn", + ":kernels", + ":operators", + ":util", + # flax:core dep, + # jax dep, + ], +) + +py_test( + name = "bnn_tree_test", + timeout = "long", + srcs = ["bnn_tree_test.py"], + shard_count = 3, + deps = [ + ":bnn_tree", + ":kernels", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # flax dep, + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + ], +) + +py_library( + name = "util", + srcs = ["util.py"], + deps = [ + ":bnn", + # jax dep, + # numpy dep, + # scipy dep, + "//tensorflow_probability/python/distributions:distribution.jax", + ], +) + +py_test( + name = "util_test", + srcs = ["util_test.py"], + deps = [ + ":kernels", + ":util", + # google/protobuf:use_fast_cpp_protos dep, + # jax dep, + # numpy dep, + "//tensorflow_probability/python/internal:test_util", + ], +) diff --git a/tensorflow_probability/python/experimental/autobnn/README.md b/tensorflow_probability/python/experimental/autobnn/README.md new file mode 100644 index 0000000000..c10a446ada --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/README.md @@ -0,0 +1,25 @@ +# AutoBNN + +This library contains code to specify BNNs that correspond to various useful GP +kernels and assemble them into models using operators such as Addition, +Multiplication and Changepoint. + +It is based on the ideas in the following papers: + +* Lassi Meronen, Martin Trapp, Arno Solin. _Periodic Activation Functions +Induce Stationarity_. NeurIPS 2021. + +* Tim Pearce, Russell Tsuchida, Mohamed Zaki, Alexandra Brintrup, Andy Neely. +_Expressive Priors in Bayesian Neural Networks: Kernel Combinations and +Periodic Functions_. UAI 2019. + +* Feras A. Saad, Brian J. Patton, Matthew D. Hoffman, Rif A. Saurous, +Vikash K. Mansinghka. _Sequential Monte Carlo Learning for Time Series +Structure Discovery_. ICML 2023. + + +## Setup + +AutoBNN has three additional dependencies beyond those used by the core +Tensorflow Probability package: flax, scipy and jaxtyping. These +can be installed by running `setup\_autobnn.sh`. diff --git a/tensorflow_probability/python/experimental/autobnn/bnn.py b/tensorflow_probability/python/experimental/autobnn/bnn.py new file mode 100644 index 0000000000..cb33e373c1 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/bnn.py @@ -0,0 +1,170 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base class for Bayesian Neural Networks.""" + +import dataclasses + +import flax +from flax import linen as nn +import jax.numpy as jnp +from jaxtyping import Array, Float, PyTree # pylint: disable=g-importing-member,g-multiple-import +from tensorflow_probability.python.experimental.autobnn import likelihoods +from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib + + +def log_prior_of_parameters(params, distributions) -> Float: + """Return the prior of the parameters according to the distributions.""" + if 'params' in params: + params = params['params'] + # We can't use jax.tree_util.tree_map here because params is allowed to + # have extra things (like bnn_0, ... for a BnnOperator) that aren't in + # distributions. + lp = 0.0 + for k, v in distributions.items(): + p = params[k] + if isinstance(v, distribution_lib.Distribution): + lp += jnp.sum(v.log_prob(p)) + else: + lp += log_prior_of_parameters(p, v) + return lp + + +class BayesianModule(nn.Module): + """A linen.Module with distributions over its parameters. + + Example usage: + class MyModule(BayesianModule): + + def distributions(self): + return {'dense': {'kernel': tfd.Normal(loc=0, scale=1), + 'bias': tfd.Normal(loc=0, scale=1)}, + 'amplitude': tfd.LogNormal(loc=0, scale=1)} + + def setup(self): + self.dense = nn.Dense(50) + super().setup() # <-- Very important, do not forget! + + def __call__(self, inputs): + return self.amplitude * self.dense(inputs) + + + my_bnn = MyModule() + params = my_bnn.init(jax.random.PRNGKey(0), jnp.zeros(10)) + lp = my_bnn.log_prior(params) + + Note that in this example, self.amplitude will be initialized using + the given tfd.LogNormal distribution, but the self.dense's parameters + will be initialized using the nn.Dense's default initializers. However, + the log_prior score will take into account all of the parameters. + """ + + def distributions(self): + """Return a nested dictionary of distributions for the model's params. + + The nested dictionary should have the same structure as the + variables returned by the init() method, except all leaves should + be tensorflow probability Distributions. + """ + # TODO(thomaswc): Consider having this optionally also be able to + # return a tfd.JointNamedDistribution, so as to support dependencies + # between the subdistributions. + raise NotImplementedError('Subclasses of BNN must define this.') + + def setup(self): + """Children classes must call this from their setup() !""" + + def make_sample_func(dist): + def sample_func(key, shape): + return dist.sample(sample_shape=shape, seed=key) + + return sample_func + + for k, v in self.distributions().items(): + # Create a variable for every distribution that doesn't already + # have one. If you define a variable in your setup, we assume + # you initialize it correctly. + if not hasattr(self, k): + try: + setattr(self, k, self.param(k, make_sample_func(v), 1)) + except flax.errors.NameInUseError: + # Sometimes subclasses will have parameters where the + # parameter name doesn't exactly correspond to the name of + # the object field. This can happen with arrays of parameters + # (like PolynomialBBN's hidden parameters.) for example. I + # don't know of any way to detect this beforehand except by + # trying to call self.params and having it fail with NameInUseError. + # (For example, self.variables doesn't exist at setup() time.) + pass + + def log_prior(self, params) -> float: + """Return the log probability of the params according to the prior.""" + return log_prior_of_parameters(params, self.distributions()) + + def shortname(self) -> str: + """Return the class name, minus any BNN suffix.""" + return type(self).__name__.removesuffix('BNN') + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + return self.shortname() + + +class BNN(BayesianModule): + """A Bayesian Neural Network. + + A BNN's __call__ method must accept a tensor of shape (..., num_features) + and return a tensor of shape (..., likelihood_model.num_outputs()). + Given that, it provides log_likelihood and log_prob methods based + on the provided likelihood_model. + """ + + likelihood_model: likelihoods.LikelihoodModel = dataclasses.field( + default_factory=likelihoods.NormalLikelihoodLogisticNoise + ) + + def distributions(self): + # Children classes must call super().distributions() to include this! + return self.likelihood_model.distributions() + + def set_likelihood_model(self, likelihood_model: likelihoods.LikelihoodModel): + self.likelihood_model = likelihood_model + + def log_likelihood( + self, + params: PyTree, + data: Float[Array, 'time features'], + observations: Float[Array, 'time'], + ) -> Float[Array, '']: + """Return the likelihood of the data given the model.""" + nn_out = self.apply(params, data) + if 'params' in params: + params = params['params'] + # Sum over all axes here - user should use `vmap` for batching. + return jnp.sum( + self.likelihood_model.log_likelihood(params, nn_out, observations) + ) + + def log_prob( + self, + params: PyTree, + data: Float[Array, 'time features'], + observations: Float[Array, 'time'], + ) -> Float[Array, '']: + return self.log_prior(params) + self.log_likelihood( + params, data, observations + ) + + def get_all_distributions(self): + return self.distributions() diff --git a/tensorflow_probability/python/experimental/autobnn/bnn_test.py b/tensorflow_probability/python/experimental/autobnn/bnn_test.py new file mode 100644 index 0000000000..53bb7b09e4 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/bnn_test.py @@ -0,0 +1,68 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for bnn.py.""" + +from flax import linen as nn +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib +from tensorflow_probability.substrates.jax.distributions import normal as normal_lib +from absl.testing import absltest + + +class MyBNN(bnn.BNN): + + def distributions(self): + return super().distributions() | { + 'dense': { + 'kernel': normal_lib.Normal(loc=0, scale=1), + 'bias': normal_lib.Normal(loc=0, scale=1), + }, + 'amplitude': lognormal_lib.LogNormal(loc=0, scale=1), + } + + def setup(self): + self.dense = nn.Dense(50) + super().setup() + + def __call__(self, inputs): + return self.amplitude * jnp.sum(self.dense(inputs)) + + +class BnnTests(absltest.TestCase): + + def test_mybnn(self): + my_bnn = MyBNN() + d = my_bnn.distributions() + self.assertIn('noise_scale', d) + sample_noise = d['noise_scale'].sample(1, seed=jax.random.PRNGKey(0)) + self.assertEqual((1,), sample_noise.shape) + + params = my_bnn.init(jax.random.PRNGKey(0), jnp.zeros(1)) + lp1 = my_bnn.log_prior(params) + params['params']['amplitude'] += 50 + lp2 = my_bnn.log_prior(params) + self.assertLess(lp2, lp1) + + data = jnp.array([[0], [1], [2], [3], [4], [5]], dtype=jnp.float32) + obs = jnp.array([1, 0, 1, 0, 1, 0], dtype=jnp.float32) + ll = my_bnn.log_likelihood(params, data, obs) + lp = my_bnn.log_prob(params, data, obs) + self.assertLess(jnp.sum(lp), jnp.sum(ll)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_probability/python/experimental/autobnn/bnn_tree.py b/tensorflow_probability/python/experimental/autobnn/bnn_tree.py new file mode 100644 index 0000000000..4a02f23252 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/bnn_tree.py @@ -0,0 +1,171 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Routines for making tree-structured BNNs.""" + +from typing import Iterable, List + +from flax import linen as nn +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import operators +from tensorflow_probability.python.experimental.autobnn import util + +Array = jnp.ndarray + + +LEAVES = [ + kernels.ExponentiatedQuadraticBNN, + kernels.MaternBNN, + kernels.LinearBNN, + kernels.QuadraticBNN, + kernels.PeriodicBNN, + kernels.OneLayerBNN, +] + + +OPERATORS = [ + operators.Multiply, + operators.Add, + operators.WeightedSum, + operators.ChangePoint, + operators.LearnableChangePoint +] + + +NON_PERIODIC_KERNELS = [ + kernels.ExponentiatedQuadraticBNN, + kernels.MaternBNN, + kernels.LinearBNN, + kernels.QuadraticBNN, + kernels.OneLayerBNN, +] + + +def list_of_all( + time_series_xs: Array, + depth: int = 2, + width: int = 50, + periods: Iterable[float] = (), + parent_is_multiply: bool = False, + include_sums: bool = True, + include_changepoints: bool = True, + only_safe_products: bool = False +) -> List[bnn.BNN]: + """Return a list of all BNNs of the given depth.""" + all_bnns = [] + if depth == 0: + all_bnns.extend(k(width=width, going_to_be_multiplied=parent_is_multiply) + for k in NON_PERIODIC_KERNELS) + for p in periods: + all_bnns.append(kernels.PeriodicBNN( + width=width, period=p, going_to_be_multiplied=parent_is_multiply)) + return all_bnns + + multiply_children = list_of_all( + time_series_xs, depth-1, width, periods, True) + if parent_is_multiply: + non_multiply_children = multiply_children + else: + non_multiply_children = list_of_all( + time_series_xs, depth-1, width, periods, False) + + # Abelian operators that aren't Multiply. + if include_sums: + for i, c1 in enumerate(non_multiply_children): + for j in range(i + 1): + c2 = non_multiply_children[j] + # Add is also abelian, but WeightedSum is more general. + all_bnns.append( + operators.WeightedSum( + bnns=(c1.clone(_deep_clone=True), c2.clone(_deep_clone=True)) + ) + ) + + if parent_is_multiply: + # Remaining operators don't expose .penultimate() method. + return all_bnns + + # Multiply + for i, c1 in enumerate(multiply_children): + if only_safe_products: + # The only safe kernels to multiply by are Linear and Quadratic. + if not isinstance(c1, kernels.PolynomialBNN): + continue + for j in range(i+1): + c2 = multiply_children[j] + all_bnns.append(operators.Multiply(bnns=( + c1.clone(_deep_clone=True), c2.clone(_deep_clone=True)))) + + # Non-abelian operators + if include_changepoints: + for c1 in non_multiply_children: + for c2 in non_multiply_children: + # ChangePoint is also non-abelian, but requires that we know + # what the change point is. + all_bnns.append(operators.LearnableChangePoint( + bnns=(c1.clone(_deep_clone=True), c2.clone(_deep_clone=True)), + time_series_xs=time_series_xs)) + + return all_bnns + + +def weighted_sum_of_all(time_series_xs: Array, + time_series_ys: Array, + depth: int = 2, width: int = 50, + alpha: float = 1.0) -> bnn.BNN: + """Return a weighted sum of all BNNs of the given depth.""" + periods = util.suggest_periods(time_series_ys) + + all_bnns = list_of_all(time_series_xs, depth, width, periods, False) + + return operators.WeightedSum(bnns=tuple(all_bnns), alpha=alpha) + + +def random_tree(key: jax.Array, depth: int, width: int, period: float, + parent_is_multiply: bool = False) -> nn.Module: + """Return a random complete tree BNN of the given depth. + + Args: + key: Random number key. + depth: Return a BNN of this tree depth. Zero based, so depth=0 returns + a leaf BNN. + width: The number of hidden nodes in the leaf layers. + period: The period of any PeriodicBNN kernels in the tree. + parent_is_multiply: If true, don't create a weight layer after the hidden + nodes of any leaf kernels and only use addition as an internal node. + + Returns: + A BNN of the specified tree depth. + """ + if depth == 0: + c = jax.random.choice(key, len(LEAVES)) + return LEAVES[c]( + width=width, going_to_be_multiplied=parent_is_multiply, + period=period) + + key1, key2, key3 = jax.random.split(key, 3) + if parent_is_multiply: + c = 1 # Can't multiply Multiply or ChangePoints + is_multiply = True + else: + c = jax.random.choice(key1, len(OPERATORS)) + is_multiply = (c == 0) + + sub1 = random_tree(key2, depth - 1, width, period, is_multiply) + sub2 = random_tree(key3, depth - 1, width, period, is_multiply) + + return OPERATORS[c](bnns=(sub1, sub2)) diff --git a/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py b/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py new file mode 100644 index 0000000000..10b38b24c2 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py @@ -0,0 +1,117 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for bnn_tree.py.""" + +from absl.testing import parameterized +from flax import linen as nn +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn_tree +from tensorflow_probability.python.experimental.autobnn import kernels +from absl.testing import absltest + + +class TreeTest(parameterized.TestCase): + + def test_list_of_all(self): + l0 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 0) + # With no periods, there should be five kernels. + self.assertLen(l0, 5) + for k in l0: + self.assertFalse(k.going_to_be_multiplied) + + l0 = bnn_tree.list_of_all(100, 0, 50, [20.0, 40.0], parent_is_multiply=True) + self.assertLen(l0, 7) + for k in l0: + self.assertTrue(k.going_to_be_multiplied) + + l1 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 1) + # With no periods, there should be + # 15 trees with a Multiply top node, + # 15 trees with a WeightedSum top node, and + # 25 trees with a LearnableChangePoint top node. + self.assertLen(l1, 55) + + # Check that all of the BNNs in the tree can be trained. + for k in l1: + params = k.init(jax.random.PRNGKey(0), jnp.zeros(5)) + lp = k.log_prior(params) + self.assertLess(lp, 0.0) + output = k.apply(params, jnp.ones(5)) + self.assertEqual((1,), output.shape) + + l1 = bnn_tree.list_of_all( + jnp.linspace(0.0, 100.0, 100), + 1, + 50, + [20.0, 40.0], + parent_is_multiply=True, + ) + # With 2 periods and parent_is_multiply, there are only WeightedSum top + # nodes, with 7*8/2 = 28 trees. + self.assertLen(l1, 28) + + l2 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 2) + # With no periods, there should be + # 15*16/2 = 120 trees with a Multiply top node, + # 55*56/2 = 1540 trees with a WeightedSum top node, and + # 55*55 = 3025 trees with a LearnableChangePoint top node. + self.assertLen(l2, 4685) + + @parameterized.parameters(0, 1) # depth=2 segfaults on my desktop :( + def test_weighted_sum_of_all(self, depth): + soa = bnn_tree.weighted_sum_of_all( + jnp.linspace(0.0, 1.0, 100), jnp.ones(100), depth=depth + ) + params = soa.init(jax.random.PRNGKey(0), jnp.zeros(5)) + lp = soa.log_prior(params) + self.assertLess(lp, 0.0) + output = soa.apply(params, jnp.ones(5)) + self.assertEqual((1,), output.shape) + + def test_random_tree(self): + r0 = bnn_tree.random_tree( + jax.random.PRNGKey(0), depth=0, width=50, period=7 + ) + self.assertIsInstance(r0, kernels.OneLayerBNN) + params = r0.init(jax.random.PRNGKey(1), jnp.zeros(5)) + lp = r0.log_prior(params) + self.assertLess(lp, 0.0) + output = r0.apply(params, jnp.ones(5)) + self.assertEqual((1,), output.shape) + + r1 = bnn_tree.random_tree( + jax.random.PRNGKey(0), depth=1, width=50, period=24 + ) + self.assertIsInstance(r1, nn.Module) + params = r1.init(jax.random.PRNGKey(1), jnp.zeros(5)) + lp = r1.log_prior(params) + self.assertLess(lp, 0.0) + output = r1.apply(params, jnp.ones(5)) + self.assertEqual((1,), output.shape) + + r2 = bnn_tree.random_tree( + jax.random.PRNGKey(0), depth=2, width=50, period=52 + ) + self.assertIsInstance(r2, nn.Module) + params = r2.init(jax.random.PRNGKey(1), jnp.zeros(5)) + lp = r2.log_prior(params) + self.assertLess(lp, 0.0) + output = r2.apply(params, jnp.ones(5)) + self.assertEqual((1,), output.shape) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_probability/python/experimental/autobnn/kernels.py b/tensorflow_probability/python/experimental/autobnn/kernels.py new file mode 100644 index 0000000000..b02ffb7316 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/kernels.py @@ -0,0 +1,324 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""`Leaf` BNNs, most of which correspond to some known GP kernel.""" + +from flax import linen as nn +from flax.linen import initializers +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib +from tensorflow_probability.substrates.jax.distributions import normal as normal_lib +from tensorflow_probability.substrates.jax.distributions import student_t as student_t_lib +from tensorflow_probability.substrates.jax.distributions import uniform as uniform_lib + + +Array = jnp.ndarray + + +SQRT_TWO = 1.41421356237309504880168872420969807856967187537694807317667 + + +class MultipliableBNN(bnn.BNN): + """Abstract base class for BNN's that can be multiplied.""" + width: int = 50 + going_to_be_multiplied: bool = False + + def penultimate(self, inputs): + raise NotImplementedError('Subclasses of MultipliableBNN must define this.') + + +class IdentityBNN(MultipliableBNN): + """A BNN that always predicts 1.""" + + def penultimate(self, inputs): + return jnp.ones(shape=inputs.shape[:-1] + (self.width,)) + + def __call__(self, inputs, deterministic=True): + out_shape = inputs.shape[:-1] + (self.likelihood_model.num_outputs(),) + return jnp.ones(shape=out_shape) + + +class OneLayerBNN(MultipliableBNN): + """A BNN with one hidden layer.""" + + # Period is currently only used by the PeriodicBNN class, but we declare it + # here so it can be passed to a "generic" OneLayerBNN instance. + period: float = 0.0 + + bias_scale: float = 1.0 + + def setup(self): + if not hasattr(self, 'input_warping'): + self.input_warping = lambda x: x + if not hasattr(self, 'activation_function'): + self.activation_function = nn.relu + if not hasattr(self, 'kernel_init'): + self.kernel_init = initializers.lecun_normal() + if not hasattr(self, 'bias_init'): + self.bias_init = initializers.zeros_init() + self.dense1 = nn.Dense(self.width, + kernel_init=self.kernel_init, + bias_init=self.bias_init) + if not self.going_to_be_multiplied: + self.dense2 = nn.Dense( + self.likelihood_model.num_outputs(), + kernel_init=nn.initializers.normal(1. / jnp.sqrt(self.width)), + bias_init=nn.initializers.zeros) + else: + def fake_dense2(x): + out_shape = x.shape[:-1] + (self.likelihood_model.num_outputs(),) + return jnp.ones(out_shape) + self.dense2 = fake_dense2 + super().setup() + + def distributions(self): + # Strictly speaking, these distributions don't exactly correspond to + # the initializations used in setup(). lecun_normal uses a truncated + # normal, for example, and the zeros_init used for the bias certainly + # isn't a sample from a normal. + d = { + 'dense1': { + 'kernel': normal_lib.Normal( + loc=0, scale=1.0 / jnp.sqrt(self.width) + ), + 'bias': normal_lib.Normal(loc=0, scale=self.bias_scale), + } + } + if not self.going_to_be_multiplied: + d['dense2'] = { + 'kernel': normal_lib.Normal(loc=0, scale=1.0 / jnp.sqrt(self.width)), + 'bias': normal_lib.Normal(loc=0, scale=self.bias_scale), + } + return super().distributions() | d + + def penultimate(self, inputs): + y = self.input_warping(inputs) + return self.activation_function(self.dense1(y)) + + def __call__(self, inputs, deterministic=True): + return self.dense2(self.penultimate(inputs)) + + +class ExponentiatedQuadraticBNN(OneLayerBNN): + """A BNN corresponding to the Radial Basis Function kernel.""" + amplitude_scale: float = 1.0 + length_scale_scale: float = 1.0 + + def setup(self): + if not hasattr(self, 'activation_function'): + self.activation_function = lambda x: SQRT_TWO * jnp.sin(x) + if not hasattr(self, 'input_warping'): + self.input_warping = lambda x: x / self.length_scale + self.kernel_init = nn.initializers.normal(1.0) + def uniform_init(seed, shape, dtype): + return nn.initializers.uniform(scale=2.0 * jnp.pi)( + seed, shape, dtype=dtype) - jnp.pi + self.bias_init = uniform_init + super().setup() + + def distributions(self): + d = super().distributions() + return d | { + 'amplitude': lognormal_lib.LogNormal(loc=0, scale=self.amplitude_scale), + 'length_scale': lognormal_lib.LogNormal( + loc=0, scale=self.length_scale_scale + ), + 'dense1': { + 'kernel': normal_lib.Normal(loc=0, scale=1.0), + 'bias': uniform_lib.Uniform(low=-jnp.pi, high=jnp.pi), + }, + } + + def __call__(self, inputs, deterministic=True): + return self.amplitude * self.dense2(self.penultimate(inputs)) + + def shortname(self) -> str: + sn = super().shortname() + return 'RBF' if sn == 'ExponentiatedQuadratic' else sn + + +class MaternBNN(ExponentiatedQuadraticBNN): + """A BNN corresponding to the Matern kernel.""" + degrees_of_freedom: float = 2.5 + + def setup(self): + def kernel_init(seed, shape, unused_dtype): + return student_t_lib.StudentT( + df=2.0 * self.degrees_of_freedom, loc=0.0, scale=1.0 + ).sample(shape, seed=seed) + self.kernel_init = kernel_init + super().setup() + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + return f'{self.shortname()}({self.degrees_of_freedom})' + + +class PolynomialBNN(OneLayerBNN): + """A BNN where samples are polynomial functions.""" + degree: int = 2 + shift_mean: float = 0.0 + shift_scale: float = 1.0 + amplitude_scale: float = 1.0 + bias_init_amplitude: float = 0.0 + + def distributions(self): + d = super().distributions() + del d['dense1'] + for i in range(self.degree): + # Do not scale these layers by 1/sqrt(width), because we also + # multiply these weights by the learned `amplitude` parameter. + d[f'hiddens_{i}'] = { + 'kernel': normal_lib.Normal(loc=0, scale=1.0), + 'bias': normal_lib.Normal(loc=0, scale=self.bias_scale), + } + return d | { + 'shift': normal_lib.Normal(loc=self.shift_mean, scale=self.shift_scale), + 'amplitude': lognormal_lib.LogNormal(loc=0, scale=self.amplitude_scale), + } + + def setup(self): + kernel_init = nn.initializers.normal(1.0) + def bias_init(seed, shape, dtype=jnp.float32): + return self.bias_init_amplitude * jax.random.normal( + seed, shape, dtype=dtype) + self.hiddens = [ + nn.Dense(self.width, kernel_init=kernel_init, bias_init=bias_init) + for _ in range(self.degree)] + super().setup() + + def penultimate(self, inputs): + x = inputs - self.shift + ys = jnp.stack([h(x) for h in self.hiddens], axis=-1) + return self.amplitude * jnp.prod(ys, axis=-1) + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + return f'{self.shortname()}(degree={self.degree})' + + +class LinearBNN(PolynomialBNN): + """A BNN where samples are lines.""" + degree: int = 1 + + def summarize(self, params=None, full: bool = False) -> str: + return self.shortname() + + +class QuadraticBNN(PolynomialBNN): + """A BNN where samples are parabolas.""" + + degree: int = 2 + + def summarize(self, params=None, full: bool = False) -> str: + return self.shortname() + + +def make_periodic_input_warping(period, periodic_index, include_original): + """Return an input warping function that adds Fourier features. + + Args: + period: The added features will repeat this many time steps. + periodic_index: Look for the time feature in input[..., periodic_index]. + include_original: If true, don't replace the time feature with the + new Fourier features. + + Returns: + A function that takes an input tensor of shape [..., n] and returns a + tensor of shape [..., n+2] if include_original is True and of shape + [..., n+1] if include_original is False. + """ + def input_warping(x): + time = x[..., periodic_index] + y = 2.0 * jnp.pi * time / period + features = [jnp.cos(y), jnp.sin(y)] + if include_original: + features.append(time) + if jnp.ndim(x) == 1: + features = jnp.array(features).T + else: + features = jnp.vstack(features).T + return jnp.concatenate( + [ + x[..., :periodic_index], + features, + x[..., periodic_index + 1:], + ], + -1, + ) + + return input_warping + + +class PeriodicBNN(ExponentiatedQuadraticBNN): + """A BNN corresponding to a periodic kernel.""" + periodic_index: int = 0 + + def setup(self): + # TODO(colcarroll): Figure out how to assert that self.period is positive. + + self.input_warping = make_periodic_input_warping( + self.period, self.periodic_index, include_original=False + ) + super().setup() + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + return f'{self.shortname()}(period={self.period:.2f})' + + +class MultiLayerBNN(OneLayerBNN): + """Multi-layer BNN that also has access to periodic features.""" + num_layers: int = 3 + periodic_index: int = 0 + + def setup(self): + if not hasattr(self, 'kernel_init'): + self.kernel_init = initializers.lecun_normal() + if not hasattr(self, 'bias_init'): + self.bias_init = initializers.zeros_init() + self.input_warping = make_periodic_input_warping( + self.period, self.periodic_index, include_original=True + ) + self.dense = [ + nn.Dense( + self.width, kernel_init=self.kernel_init, bias_init=self.bias_init + ) + for _ in range(self.num_layers) + ] + super().setup() + + def distributions(self): + d = super().distributions() + del d['dense1'] + for i in range(self.num_layers): + d[f'dense_{i}'] = { + 'kernel': normal_lib.Normal(loc=0, scale=1.0 / jnp.sqrt(self.width)), + 'bias': normal_lib.Normal(loc=0, scale=self.bias_scale), + } + return d + + def penultimate(self, inputs): + y = self.input_warping(inputs) + for i in range(self.num_layers): + y = self.activation_function(self.dense[i](y)) + return y + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + return ( + f'{self.shortname()}(num_layers={self.num_layers},period={self.period})' + ) diff --git a/tensorflow_probability/python/experimental/autobnn/kernels_test.py b/tensorflow_probability/python/experimental/autobnn/kernels_test.py new file mode 100644 index 0000000000..67e574d517 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/kernels_test.py @@ -0,0 +1,242 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for kernels.py.""" + +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import util +from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib + +from absl.testing import absltest + + +KERNELS = [ + kernels.IdentityBNN, + kernels.OneLayerBNN, + kernels.ExponentiatedQuadraticBNN, + kernels.MaternBNN, + kernels.PeriodicBNN, + kernels.PolynomialBNN, + kernels.LinearBNN, + kernels.MultiLayerBNN, +] + + +class ReproduceExperimentTest(absltest.TestCase): + + def get_bnn_and_params(self): + x_train, y_train = util.load_fake_dataset() + linear_bnn = kernels.OneLayerBNN(width=50) + seed = jax.random.PRNGKey(0) + init_params = linear_bnn.init(seed, x_train) + constant_params = jax.tree_map( + lambda x: jnp.full(x.shape, 0.1), init_params) + constant_params['params']['noise_scale'] = jnp.array([0.005 ** 0.5]) + return linear_bnn, constant_params, x_train, y_train + + # This now uses a Logistic noise model, not Normal as in Pearce + @absltest.expectedFailure + def test_log_prior_matches(self): + # Pearce has a set `noise_scale` of 0.005 ** 0.5 that we must account for. + linear_bnn, constant_params, _, _ = self.get_bnn_and_params() + diff = lognormal_lib.LogNormal( + linear_bnn.noise_min, linear_bnn.log_noise_scale + ).log_prob(0.005**0.5) + self.assertAlmostEqual( + linear_bnn.log_prior(constant_params) - diff, + 31.59, # Hardcoded from reference implementation. + places=2) + + def test_log_likelihood_matches(self): + linear_bnn, constant_params, x_train, y_train = self.get_bnn_and_params() + self.assertAlmostEqual( + linear_bnn.log_likelihood(constant_params, x_train, y_train), + -7808.4434, + places=2) + + # This now uses a Logistic noise model, not Normal as in Pearce + @absltest.expectedFailure + def test_log_prob_matches(self): + # Pearce has a set `noise_scale` of 0.005 ** 0.5 that we must account for. + linear_bnn, constant_params, x_train, y_train = self.get_bnn_and_params() + diff = lognormal_lib.LogNormal( + linear_bnn.noise_min, linear_bnn.log_noise_scale + ).log_prob(0.005**0.5) + self.assertAlmostEqual( + linear_bnn.log_prob(constant_params, x_train, y_train) - diff, + -14505.76, # Hardcoded from reference implementation. + places=2) + + +class KernelsTest(parameterized.TestCase): + + @parameterized.product( + shape=[(5,), (5, 1), (5, 5)], + kernel=KERNELS, + ) + def test_default_kernels(self, shape, kernel): + if kernel in [kernels.PeriodicBNN, kernels.MultiLayerBNN]: + bnn = kernel(period=0.1, periodic_index=shape[-1]//2) + else: + bnn = kernel() + if isinstance(bnn, kernels.PolynomialBNN): + self.assertIn('shift', bnn.distributions()) + elif isinstance(bnn, kernels.MultiLayerBNN): + self.assertIn('dense_1', bnn.distributions()) + elif isinstance(bnn, kernels.IdentityBNN): + pass + else: + self.assertIn('dense1', bnn.distributions()) + if not isinstance(bnn, kernels.IdentityBNN): + self.assertIn('dense2', bnn.distributions()) + params = bnn.init(jax.random.PRNGKey(0), jnp.zeros(shape)) + lprior = bnn.log_prior(params) + params2 = params + if 'params' in params2: + params2 = params2['params'] + params2['noise_scale'] = params2['noise_scale'] + 100.0 + lprior2 = bnn.log_prior(params2) + self.assertLess(lprior2, lprior) + output = bnn.apply(params, jnp.ones(shape)) + self.assertEqual(shape[:-1] + (1,), output.shape) + + @parameterized.parameters(KERNELS) + def test_likelihood(self, kernel): + if kernel in [kernels.PeriodicBNN, kernels.MultiLayerBNN]: + bnn = kernel(period=0.1) + else: + bnn = kernel() + params = bnn.init(jax.random.PRNGKey(1), jnp.zeros(1)) + data = jnp.array([[0], [1], [2], [3], [4], [5]], dtype=jnp.float32) + obs = jnp.array([1, 0, 1, 0, 1, 0], dtype=jnp.float32) + ll = bnn.log_likelihood(params, data, obs) + lp = bnn.log_prob(params, data, obs) + # We are mostly just testing that ll and lp are both float-ish numbers + # than can be compared. In general, there is no reason to expect that + # lp < ll because there is no reason to expect in general that the + # log_prior will be negative. + if kernel == kernels.MultiLayerBNN: + self.assertLess(ll, lp) + else: + self.assertLess(lp, ll) + + @parameterized.parameters( + (kernels.OneLayerBNN(width=10), 'OneLayer'), + (kernels.ExponentiatedQuadraticBNN(width=5), 'RBF'), + (kernels.MaternBNN(width=5), 'Matern(2.5)'), + (kernels.PeriodicBNN(period=10, width=10), 'Periodic(period=10.00)'), + (kernels.PolynomialBNN(degree=3, width=2), 'Polynomial(degree=3)'), + (kernels.LinearBNN(width=5), 'Linear'), + (kernels.QuadraticBNN(width=5), 'Quadratic'), + ( + kernels.MultiLayerBNN(width=10, num_layers=3, period=20), + 'MultiLayer(num_layers=3,period=20)', + ), + ) + def test_summarize(self, bnn, expected): + self.assertEqual(expected, bnn.summarize()) + + @parameterized.parameters(KERNELS) + def test_penultimate(self, kernel): + if kernel in [kernels.PeriodicBNN, kernels.MultiLayerBNN]: + bnn = kernel(period=0.1, going_to_be_multiplied=True) + else: + bnn = kernel(going_to_be_multiplied=True) + self.assertNotIn('dense2', bnn.distributions()) + params = bnn.init(jax.random.PRNGKey(0), jnp.zeros(5)) + lprior = bnn.log_prior(params) + if kernel != kernels.MultiLayerBNN: + self.assertLess(lprior, 0.0) + h = bnn.apply(params, jnp.ones(5), method=bnn.penultimate) + self.assertEqual((50,), h.shape) + + def test_polynomial_is_almost_a_polynomial(self): + poly_bnn = kernels.PolynomialBNN(degree=3) + init_params = poly_bnn.init(jax.random.PRNGKey(0), jnp.ones((10, 1))) + + # compute power series + func = lambda x: poly_bnn.apply(init_params, x)[0] + params = [func(0.)] + for _ in range(4): + func = jax.grad(func) + params.append(func(0.)) + + # Last 4th degree coefficient should be around 0. + self.assertAlmostEqual(params[-1], 0.) + + # Check that the random initialization is approximately a polynomial by + # evaluating far away from the expansion. + x = 17.0 + self.assertAlmostEqual( + poly_bnn.apply(init_params, x)[0], + params[0] + x * params[1] + x**2 * params[2] / 2 + x**3 * params[3] / 6, + places=3) + + def test_make_periodic_input_warping_onedim(self): + iw = kernels.make_periodic_input_warping(4, 0, True) + np.testing.assert_allclose( + jnp.array([0, 1, 1, 2, 3, 4, 5]), + iw(jnp.array([1, 2, 3, 4, 5])), + atol=1e-6 + ) + iw = kernels.make_periodic_input_warping(4, 0, False) + np.testing.assert_allclose( + jnp.array([0, 1, 2, 3, 4, 5]), + iw(jnp.array([1, 2, 3, 4, 5])), + atol=1e-6 + ) + + def test_make_periodic_input_warping_onedim_features(self): + iw = kernels.make_periodic_input_warping(4, 0, True) + np.testing.assert_allclose( + jnp.array([[1, 0, 0], [0, 1, 1], [-1, 0, 2], [0, -1, 3], [1, 0, 4]]), + iw(jnp.array([[0], [1], [2], [3], [4]])), + atol=1e-6 + ) + iw = kernels.make_periodic_input_warping(4, 0, False) + np.testing.assert_allclose( + jnp.array([[1, 0], [0, 1], [-1, 0], [0, -1], [1, 0]]), + iw(jnp.array([[0], [1], [2], [3], [4]])), + atol=1e-6 + ) + + def test_make_periodic_input_warping_twodim(self): + iw = kernels.make_periodic_input_warping(2, 0, True) + np.testing.assert_allclose( + jnp.array([[1, 0, 0, 0], [-1, 0, 1, 1], [1, 0, 2, 4], [-1, 0, 3, 9], + [1, 0, 4, 16]]), + iw(jnp.array([[0, 0], [1, 1], [2, 4], [3, 9], [4, 16]])), + atol=1e-6 + ) + iw = kernels.make_periodic_input_warping(4, 1, True) + np.testing.assert_allclose( + jnp.array([[0, 1, 0, 0], [1, 0, 1, 1], [2, 1, 0, 4], [3, 0, 1, 9], + [4, 1, 0, 16]]), + iw(jnp.array([[0, 0], [1, 1], [2, 4], [3, 9], [4, 16]])), + atol=1e-6 + ) + iw = kernels.make_periodic_input_warping(2, 0, False) + np.testing.assert_allclose( + jnp.array([[1, 0, 0], [-1, 0, 1], [1, 0, 4], [-1, 0, 9], [1, 0, 16]]), + iw(jnp.array([[0, 0], [1, 1], [2, 4], [3, 9], [4, 16]])), + atol=1e-6 + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_probability/python/experimental/autobnn/likelihoods.py b/tensorflow_probability/python/experimental/autobnn/likelihoods.py new file mode 100644 index 0000000000..384d9c6735 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/likelihoods.py @@ -0,0 +1,173 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Likelihood models for Bayesian Neural Networks.""" + +import dataclasses +from typing import Any +import jax +from tensorflow_probability.substrates.jax.bijectors import softplus as softplus_lib +from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib +from tensorflow_probability.substrates.jax.distributions import inflated as inflated_lib +from tensorflow_probability.substrates.jax.distributions import logistic as logistic_lib +from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib +from tensorflow_probability.substrates.jax.distributions import negative_binomial as negative_binomial_lib +from tensorflow_probability.substrates.jax.distributions import normal as normal_lib +from tensorflow_probability.substrates.jax.distributions import transformed_distribution as transformed_distribution_lib + + +@dataclasses.dataclass +class LikelihoodModel: + """A class that knows how to compute the likelihood of some data.""" + + def dist(self, params, nn_out) -> distribution_lib.Distribution: + """Return the distribution underlying the likelihood.""" + raise NotImplementedError() + + def sample(self, params, nn_out, seed, sample_shape=None) -> jax.Array: + """Sample from the likelihood.""" + return self.dist(params, nn_out).sample( + seed=seed, sample_shape=sample_shape + ) + + def num_outputs(self): + """The number of outputs from the neural network the model needs.""" + return 1 + + def distributions(self): + """Like BayesianModule::distributions but for the model's parameters.""" + return {} + + def log_likelihood( + self, params, nn_out: jax.Array, observations: jax.Array + ) -> jax.Array: + return self.dist(params, nn_out).log_prob(observations) + + +@dataclasses.dataclass +class DummyLikelihoodModel(LikelihoodModel): + """A likelihood model that only knows how many outputs it has.""" + num_outs: int + + def num_outputs(self): + return self.num_outs + + +class NormalLikelihoodFixedNoise(LikelihoodModel): + """Abstract base class for observations = N(nn_out, noise_scale).""" + + def dist(self, params, nn_out): + return normal_lib.Normal(loc=nn_out, scale=params['noise_scale']) + + +@dataclasses.dataclass +class NormalLikelihoodLogisticNoise(NormalLikelihoodFixedNoise): + noise_min: float = 0.0 + log_noise_scale: float = 1.0 + + def distributions(self): + noise_scale = transformed_distribution_lib.TransformedDistribution( + logistic_lib.Logistic(0.0, self.log_noise_scale), + softplus_lib.Softplus(low=self.noise_min), + ) + return {'noise_scale': noise_scale} + + +@dataclasses.dataclass +class BoundedNormalLikelihoodLogisticNoise(NormalLikelihoodLogisticNoise): + lower_bound: float = 0.0 + + def dist(self, params, nn_out): + return softplus_lib.Softplus(low=self.lower_bound)( + normal_lib.Normal(loc=nn_out, scale=params['noise_scale']) + ) + + +@dataclasses.dataclass +class NormalLikelihoodLogNormalNoise(NormalLikelihoodFixedNoise): + log_noise_mean: float = -2.0 + log_noise_scale: float = 1.0 + + def distributions(self): + return { + 'noise_scale': lognormal_lib.LogNormal( + loc=self.log_noise_mean, scale=self.log_noise_scale + ) + } + + +class NormalLikelihoodVaryingNoise(LikelihoodModel): + + def num_outputs(self): + return 2 + + def dist(self, params, nn_out): + # TODO(colcarroll): Add a prior to constrain the scale (`nn_out[..., [1]]`) + # separately before it goes into the likelihood. + return normal_lib.Normal( + loc=nn_out[..., [0]], scale=jax.nn.softplus(nn_out[..., [1]]) + ) + + +class NegativeBinomial(LikelihoodModel): + """observations = NB(total_count = nn_out[0], logits = nn_out[1]).""" + + def num_outputs(self): + return 2 + + def dist(self, params, nn_out): + return negative_binomial_lib.NegativeBinomial( + total_count=nn_out[..., [0]], + logits=nn_out[..., [1]], + require_integer_total_count=False, + ) + + +class ZeroInflatedNegativeBinomial(LikelihoodModel): + """observations = NB(total_count = nn_out[0], logits = nn_out[1]).""" + + def num_outputs(self): + return 3 + + def dist(self, params, nn_out): + return inflated_lib.ZeroInflatedNegativeBinomial( + total_count=nn_out[..., [0]], + logits=nn_out[..., [1]], + inflated_loc_logits=nn_out[..., [2]], + require_integer_total_count=False, + ) + + +NAME_TO_LIKELIHOOD_MODEL = { + 'normal_likelihood_logistic_noise': NormalLikelihoodLogisticNoise, + 'bounded_normal_likelihood_logistic_noise': ( + BoundedNormalLikelihoodLogisticNoise + ), + 'normal_likelihood_lognormal_noise': NormalLikelihoodLogNormalNoise, + 'normal_likelihood_varying_noise': NormalLikelihoodVaryingNoise, + 'negative_binomial': NegativeBinomial, + 'zero_inflated_negative_binomial': ZeroInflatedNegativeBinomial, +} + + +def get_likelihood_model( + likelihood_model: str, likelihood_parameters: dict[str, Any] +) -> Any: + # Actually returns a Likelihood model, but pytype thinks it returns a + # Union[NegativeBinomial, ...]. + m = NAME_TO_LIKELIHOOD_MODEL[likelihood_model]() + for k, v in likelihood_parameters.items(): + if hasattr(m, k): + setattr(m, k, v) + return m diff --git a/tensorflow_probability/python/experimental/autobnn/likelihoods_test.py b/tensorflow_probability/python/experimental/autobnn/likelihoods_test.py new file mode 100644 index 0000000000..4776f951a3 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/likelihoods_test.py @@ -0,0 +1,63 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for bnn.py.""" + +from absl.testing import parameterized +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import likelihoods +from absl.testing import absltest + + +class LikelihoodTests(parameterized.TestCase): + + @parameterized.parameters( + likelihoods.NormalLikelihoodLogisticNoise(), + likelihoods.NormalLikelihoodLogNormalNoise(), + likelihoods.NormalLikelihoodVaryingNoise(), + likelihoods.NegativeBinomial(), + likelihoods.ZeroInflatedNegativeBinomial(), + ) + def test_likelihoods(self, likelihood_model): + lp = likelihood_model.log_likelihood( + params={'noise_scale': 0.4}, + nn_out=jnp.ones(shape=(10, likelihood_model.num_outputs())), + observations=jnp.zeros(shape=(10, 1)), + ) + self.assertEqual(lp.shape, (10, 1)) + + @parameterized.parameters(list(likelihoods.NAME_TO_LIKELIHOOD_MODEL.keys())) + def test_get_likelihood_model(self, likelihood_model): + m = likelihoods.get_likelihood_model(likelihood_model, {}) + lp = m.log_likelihood( + params={'noise_scale': 0.4}, + nn_out=jnp.ones(shape=(10, m.num_outputs())), + observations=jnp.zeros(shape=(10, 1)), + ) + self.assertEqual(lp.shape, (10, 1)) + + m2 = likelihoods.get_likelihood_model( + likelihood_model, + {'noise_min': 0.1, 'log_noise_scale': 0.5, 'log_noise_mean': -1.0}, + ) + lp2 = m2.log_likelihood( + params={'noise_scale': 0.4}, + nn_out=jnp.ones(shape=(10, m2.num_outputs())), + observations=jnp.zeros(shape=(10, 1)), + ) + self.assertEqual(lp2.shape, (10, 1)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_probability/python/experimental/autobnn/models.py b/tensorflow_probability/python/experimental/autobnn/models.py new file mode 100644 index 0000000000..63c7165988 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/models.py @@ -0,0 +1,309 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""BNF models. + +The "combo" model is a simple sum of linear and periodic components. The sum of +products is the smallest example of a sum of two products over two leaves each, +where each leaf is a continuous relaxiation (using WeightedSum) of periodic and +linear components. +""" +import functools +from typing import Sequence +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.python.experimental.autobnn import bnn_tree +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import likelihoods +from tensorflow_probability.python.experimental.autobnn import operators + + +Array = jnp.ndarray + + +def make_sum_of_operators_of_relaxed_leaves( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + use_mul: bool = True, + num_outputs: int = 1, +) -> bnn.BNN: + """Returns BNN model consisting of a sum of products or changeponts of leaves. + + Each leaf is a continuous relaxation over base kernels. + + Args: + time_series_xs: The x-values of the training data. + width: Width of the leaf BNNs. + periods: Periods for the PeriodicBNN kernel. + use_mul: If true, use Multiply as the depth 1 operator. If false, use + a LearnableChangepoint instead. + num_outputs: Number of outputs on the BNN. + """ + del num_outputs + def _make_continuous_relaxation( + width: int, + periods: Sequence[float], + include_eq_and_poly: bool) -> bnn.BNN: + leaves = [kernels.PeriodicBNN( + width=width, period=p, going_to_be_multiplied=use_mul) for p in periods] + leaves.append(kernels.LinearBNN( + width=width, going_to_be_multiplied=use_mul)) + if include_eq_and_poly: + leaves.extend([ + kernels.ExponentiatedQuadraticBNN( + width=width, going_to_be_multiplied=use_mul), + kernels.PolynomialBNN(width=width, going_to_be_multiplied=use_mul), + kernels.IdentityBNN(width=width, going_to_be_multiplied=use_mul), + ]) + return operators.WeightedSum(bnns=tuple(leaves), num_outputs=1, + going_to_be_multiplied=use_mul) + + leaf1 = _make_continuous_relaxation(width, periods, include_eq_and_poly=False) + leaf2 = _make_continuous_relaxation(width, periods, include_eq_and_poly=False) + + if use_mul: + op = operators.Multiply + else: + op = functools.partial(operators.LearnableChangePoint, + time_series_xs=time_series_xs) + + bnn1 = op(bnns=(leaf1, leaf2)) + + leaf3 = _make_continuous_relaxation(width, periods, include_eq_and_poly=True) + leaf4 = _make_continuous_relaxation(width, periods, include_eq_and_poly=True) + bnn2 = op(bnns=(leaf3, leaf4)) + + net = operators.Add(bnns=(bnn1, bnn2)) + return net + + +def make_sum_of_products( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + return make_sum_of_operators_of_relaxed_leaves( + time_series_xs, width, periods, use_mul=True, num_outputs=num_outputs) + + +def make_sum_of_changepoints( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + return make_sum_of_operators_of_relaxed_leaves( + time_series_xs, width, periods, use_mul=False, num_outputs=num_outputs) + + +def make_linear_plus_periodic( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Returns Combo model, consisting of linear and periodic leafs. + + Args: + time_series_xs: The x-values of the training data. + width: Width of the leaf BNNs. + periods: Periods for the PeriodicBNN kernel. + num_outputs: Number of outputs on the BNN. + """ + del num_outputs + del time_series_xs + leaves = [kernels.PeriodicBNN(width=width, period=p) for p in periods] + leaves.append(kernels.LinearBNN(width=width)) + return operators.Add(bnns=tuple(leaves)) + + +def make_sum_of_stumps( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Return a sum of depth 0 trees.""" + stumps = bnn_tree.list_of_all(time_series_xs, 0, width, periods=periods) + + return operators.WeightedSum(bnns=tuple(stumps), num_outputs=num_outputs) + + +def make_sum_of_stumps_and_products( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Return a sum of depth 0 and depth 1 product-only trees.""" + stumps = bnn_tree.list_of_all(time_series_xs, 0, width, periods=periods) + products = bnn_tree.list_of_all( + time_series_xs, + 1, + width, + periods=periods, + include_sums=False, + include_changepoints=False, + ) + + return operators.WeightedSum( + bnns=tuple(stumps + products), num_outputs=num_outputs) + + +def make_sum_of_shallow( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Return a sum of depth 0 and 1 trees.""" + stumps = bnn_tree.list_of_all(time_series_xs, 0, width, periods=periods) + depth1 = bnn_tree.list_of_all( + time_series_xs, 1, width, periods=periods, include_sums=False + ) + + return operators.WeightedSum( + bnns=tuple(stumps + depth1), num_outputs=num_outputs) + + +def make_sum_of_safe_shallow( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Return a sum of depth 0 and 1 trees, but not unsafe products.""" + stumps = bnn_tree.list_of_all(time_series_xs, 0, width, periods=periods) + depth1 = bnn_tree.list_of_all( + time_series_xs, + 1, + width, + periods=periods, + include_sums=False, + only_safe_products=True, + ) + + return operators.WeightedSum( + bnns=tuple(stumps + depth1), num_outputs=num_outputs) + + +def make_changepoint_of_safe_products( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, +) -> bnn.BNN: + """Return a changepoint over two Multiply(Linear, WeightedSum(kernels))'s.""" + # By varying the weights inside the WeightedSum (and by relying on the + # identity Changepoint(A, A) = A), this model can express + # * all base kernels, + # * all "safe" multiplies over two base kernels (i.e., one of the terms + # has a very low effective parameter count to avoid overfitting noise), and + # * all single changepoints over two of the above. + + all_kernels = [ + kernels.PeriodicBNN(width=width, period=p, going_to_be_multiplied=True) + for p in periods + ] + all_kernels.extend( + [ + k(width=width, going_to_be_multiplied=True) + for k in [ + kernels.ExponentiatedQuadraticBNN, + kernels.MaternBNN, + kernels.LinearBNN, + kernels.QuadraticBNN, + ] + ] + ) + + safe_product = operators.Multiply( + bnns=( + operators.WeightedSum( + num_outputs=num_outputs, + bnns=( + kernels.IdentityBNN(width=width, going_to_be_multiplied=True), + kernels.LinearBNN(width=width, going_to_be_multiplied=True), + kernels.QuadraticBNN( + width=width, going_to_be_multiplied=True), + ), + going_to_be_multiplied=True + ), + operators.WeightedSum(bnns=tuple(all_kernels), + going_to_be_multiplied=True, + num_outputs=num_outputs), + ), + ) + + return operators.LearnableChangePoint( + time_series_xs=time_series_xs, + bnns=(safe_product, safe_product.clone(_deep_clone=True)), + ) + + +def make_mlp(num_layers: int): + """Return a make function for the MultiLayerBNN of the given depth.""" + + def make_multilayer( + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), + num_outputs: int = 1, + ): + del num_outputs + del time_series_xs + assert len(periods) == 1 + return kernels.MultiLayerBNN( + num_layers=num_layers, + width=width, + period=periods[0], + ) + + return make_multilayer + + +MODEL_NAME_TO_MAKE_FUNCTION = { + 'sum_of_products': make_sum_of_products, + 'sum_of_changepoints': make_sum_of_changepoints, + 'linear_plus_periodic': make_linear_plus_periodic, + 'sum_of_stumps': make_sum_of_stumps, + 'sum_of_stumps_and_products': make_sum_of_stumps_and_products, + 'sum_of_shallow': make_sum_of_shallow, + 'sum_of_safe_shallow': make_sum_of_safe_shallow, + 'changepoint_of_safe_products': make_changepoint_of_safe_products, + 'mlp_depth2': make_mlp(2), + 'mlp_depth3': make_mlp(3), + 'mlp_depth4': make_mlp(4), + 'mlp_depth5': make_mlp(5), +} + + +def make_model( + model_name: str, + likelihood_model: likelihoods.LikelihoodModel, + time_series_xs: Array, + width: int = 5, + periods: Sequence[float] = (0.1,), +) -> bnn.BNN: + """Create a BNN model by name.""" + m = MODEL_NAME_TO_MAKE_FUNCTION[model_name]( + time_series_xs=time_series_xs, + width=width, + periods=periods, + num_outputs=likelihood_model.num_outputs(), + ) + m.set_likelihood_model(likelihood_model) + return m diff --git a/tensorflow_probability/python/experimental/autobnn/models_test.py b/tensorflow_probability/python/experimental/autobnn/models_test.py new file mode 100644 index 0000000000..ed9d44ca6d --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/models_test.py @@ -0,0 +1,67 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for models.py.""" + +from absl.testing import parameterized +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import likelihoods +from tensorflow_probability.python.experimental.autobnn import models +from absl.testing import absltest + + +MODELS = list(models.MODEL_NAME_TO_MAKE_FUNCTION.keys()) + + +class ModelsTest(parameterized.TestCase): + + @parameterized.parameters(MODELS) + def test_make_model(self, model_name): + m = models.make_model( + model_name, + likelihoods.NormalLikelihoodLogisticNoise(), + time_series_xs=jnp.linspace(0.0, 1.0, 50), + width=5, + periods=[0.2], + ) + params = m.init(jax.random.PRNGKey(0), jnp.zeros(5)) + lp = m.log_prior(params) + self.assertTrue((lp < 0.0) or (lp > 0.0)) + + @parameterized.product( + model_name=MODELS, + # It takes too long to test all of the likelihoods, so just test a + # couple to make sure each model correctly handles num_outputs > 1. + likelihood_name=[ + 'normal_likelihood_varying_noise', + 'zero_inflated_negative_binomial', + ], + ) + def test_make_model_and_likelihood(self, model_name, likelihood_name): + ll = likelihoods.get_likelihood_model(likelihood_name, {}) + m = models.make_model( + model_name, + ll, + time_series_xs=jnp.linspace(0.0, 1.0, 50), + width=5, + periods=[0.2], + ) + params = m.init(jax.random.PRNGKey(0), jnp.zeros(5)) + lp = m.log_prior(params) + self.assertTrue((lp < 0.0) or (lp > 0.0)) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_probability/python/experimental/autobnn/operators.py b/tensorflow_probability/python/experimental/autobnn/operators.py new file mode 100644 index 0000000000..6772ab56d7 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/operators.py @@ -0,0 +1,292 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Flax.linen modules for combining BNNs.""" + +from typing import Optional +from flax import linen as nn +import jax.numpy as jnp +from tensorflow_probability.python.experimental.autobnn import bnn +from tensorflow_probability.python.experimental.autobnn import likelihoods +from tensorflow_probability.substrates.jax.bijectors import chain as chain_lib +from tensorflow_probability.substrates.jax.bijectors import scale as scale_lib +from tensorflow_probability.substrates.jax.bijectors import shift as shift_lib +from tensorflow_probability.substrates.jax.distributions import beta as beta_lib +from tensorflow_probability.substrates.jax.distributions import dirichlet as dirichlet_lib +from tensorflow_probability.substrates.jax.distributions import half_normal as half_normal_lib +from tensorflow_probability.substrates.jax.distributions import normal as normal_lib +from tensorflow_probability.substrates.jax.distributions import transformed_distribution as transformed_distribution_lib + + +Array = jnp.ndarray + + +class BnnOperator(bnn.BNN): + """Base class for BNNs that are made from other BNNs.""" + bnns: tuple[bnn.BNN, ...] = tuple() + + def setup(self): + assert self.bnns, 'Forgot to pass `bnns` keyword argument?' + super().setup() + + def set_likelihood_model(self, likelihood_model: likelihoods.LikelihoodModel): + super().set_likelihood_model(likelihood_model) + # We need to set the likelihood models on the component + # bnns so that they will know how many outputs they are + # supposed to have. BUT: we also don't want to accidentally + # create any additional variables, distributions or parameters + # in them. So we set them all to having a dummy likelihood + # model that only knows how many outputs it has. + dummy_ll_model = likelihoods.DummyLikelihoodModel( + num_outs=likelihood_model.num_outputs() + ) + for b in self.bnns: + b.set_likelihood_model(dummy_ll_model) + + def log_prior(self, params): + if 'params' in params: + params = params['params'] + # params for bnns[i] are stored in params['bnns_{i}']. + lp = bnn.log_prior_of_parameters(params, self.distributions()) + for i, b in enumerate(self.bnns): + params_field = f'bnns_{i}' + if params_field in params: + lp += b.log_prior(params[params_field]) + return lp + + def get_all_distributions(self): + distributions = self.distributions() + for idx, sub_bnn in enumerate(self.bnns): + d = sub_bnn.get_all_distributions() + if d: + distributions[f'bnns_{idx}'] = d + return distributions + + def summary_join_string(self, params) -> str: + """String to use when joining the component summaries.""" + raise NotImplementedError() + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + params = params or {} + if 'params' in params: + params = params['params'] + + names = [ + b.summarize(params.get(f'bnns_{i}'), full) + for i, b in enumerate(self.bnns) + ] + + return f'({self.summary_join_string(params).join(names)})' + + +class MultipliableBnnOperator(BnnOperator): + """Abstract base class for a BnnOperator that can be multiplied.""" + # Ideally, this would just inherit from both BnnOperator and + # kernels.MultipliableBNN, but pytype gets really confused by that. + going_to_be_multiplied: bool = False + + def setup(self): + if self.going_to_be_multiplied: + for b in self.bnns: + assert b.going_to_be_multiplied + else: + for b in self.bnns: + assert not getattr(b, 'going_to_be_multiplied', False) + super().setup() + + def penultimate(self, inputs): + raise NotImplementedError( + 'Subclasses of MultipliableBnnOperator must define this.') + + +class Add(MultipliableBnnOperator): + """Add two or more BNNs.""" + + def penultimate(self, inputs): + penultimates = [b.penultimate(inputs) for b in self.bnns] + return jnp.sum(jnp.stack(penultimates, axis=-1), axis=-1) + + def __call__(self, inputs, deterministic=True): + return jnp.sum( + jnp.stack([b(inputs) for b in self.bnns], axis=-1), + axis=-1) + + def summary_join_string(self, params) -> str: + return '#' + + +class WeightedSum(MultipliableBnnOperator): + """Add two or more BNNs, with weights taken from a Dirichlet prior.""" + + # `alpha=1` is a uniform prior on mixing weights, higher values will favor + # weights like `1/n`, and lower weights will favor sparsity. + alpha: float = 1.0 + num_outputs: int = 1 + + def distributions(self): + bnn_concentrations = [1.0 if isinstance(b, BnnOperator) else 1.5 + for b in self.bnns] + if self.going_to_be_multiplied: + concentration = self.alpha * jnp.array(bnn_concentrations) + else: + concentration = self.alpha * jnp.array( + [bnn_concentrations for _ in range(self.num_outputs)]) + return super().distributions() | { + 'bnn_weights': dirichlet_lib.Dirichlet(concentration=concentration) + } + + def penultimate(self, inputs): + penultimates = [ + b.penultimate(inputs) * self.bnn_weights[0, i] + for i, b in enumerate(self.bnns) + ] + return jnp.sum(jnp.stack(penultimates, axis=-1), axis=-1) + + def __call__(self, inputs, deterministic=True): + return jnp.sum( + jnp.stack( + [ + b(inputs) * self.bnn_weights[0, :, i] + for i, b in enumerate(self.bnns) + ], + axis=-1, + ), + axis=-1, + ) + + def summarize(self, params=None, full: bool = False) -> str: + """Return a string summarizing the structure of the BNN.""" + params = params or {} + if 'params' in params: + params = params['params'] + + names = [ + b.summarize(params.get(f'bnns_{i}'), full) + for i, b in enumerate(self.bnns) + ] + + def pretty_print(w): + try: + s = f'{jnp.array_str(jnp.array(w), precision=3)}' + except Exception: # pylint: disable=broad-exception-caught + try: + s = f'{w:.3f}' + except Exception: # pylint: disable=broad-exception-caught + s = f'{w}' + return s.replace('\n', ' ') + + weights = params.get('bnn_weights') + if weights is not None: + weights = jnp.array(weights)[0].T.squeeze() + names = [ + f'{pretty_print(w)} {n}' + for w, n in zip(weights, names) + if full or jnp.max(w) > 0.04 + ] + + return f'({"+".join(names)})' + + +class Multiply(BnnOperator): + """Multiply two or more BNNs.""" + + def setup(self): + self.dense = nn.Dense(self.likelihood_model.num_outputs()) + for b in self.bnns: + assert hasattr(b, 'penultimate') + assert b.going_to_be_multiplied, 'Forgot to set going_to_be_multiplied?' + super().setup() + + def distributions(self): + return super().distributions() | { + 'dense': { + 'kernel': normal_lib.Normal(loc=0, scale=1.0), + 'bias': normal_lib.Normal(loc=0, scale=1.0), + } + } + + def __call__(self, inputs, deterministic=True): + penultimates = [b.penultimate(inputs) for b in self.bnns] + return self.dense(jnp.prod(jnp.stack(penultimates, axis=-1), axis=-1)) + + def summary_join_string(self, params) -> str: + return '*' + + +class ChangePoint(BnnOperator): + """Switch from one BNN to another based on a time point.""" + change_point: float = 0.0 + slope: float = 1.0 + change_index: int = 0 + + def setup(self): + assert len(self.bnns) == 2 + super().setup() + + def __call__(self, inputs, deterministic=True): + time = inputs[..., self.change_index, jnp.newaxis] + y = (time - self.change_point) / self.slope + return nn.sigmoid(y) * self.bnns[1](inputs) + nn.sigmoid( + -y) * self.bnns[0](inputs) + + def summary_join_string(self, params) -> str: + return f'<[{self.change_point}]' + + +class LearnableChangePoint(BnnOperator): + """Switch from one BNN to another based on a time point.""" + time_series_xs: Optional[Array] = None + change_index: int = 0 + + def distributions(self): + assert self.time_series_xs is not None + lo = jnp.min(self.time_series_xs) + hi = jnp.max(self.time_series_xs) + # We want change_slope_scale to be the average value of + # time_series_xs[i+1] - time_series_xs[i] + change_slope_scale = (hi - lo) / self.time_series_xs.size + + # this distribution puts a lower density at the endpoints, and a reasonably + # flat distribution near the middle of the timeseries. + bij = chain_lib.Chain([shift_lib.Shift(lo), scale_lib.Scale(hi - lo)]) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=beta_lib.Beta(1.5, 1.5), bijector=bij + ) + return super().distributions() | { + 'change_point': dist, + 'change_slope': half_normal_lib.HalfNormal(scale=change_slope_scale), + } + + def setup(self): + assert len(self.bnns) == 2 + assert len(self.time_series_xs) >= 2 + super().setup() + + def __call__(self, inputs, deterministic=True): + time = inputs[..., self.change_index, jnp.newaxis] + y = (time - self.change_point) / self.change_slope + return nn.sigmoid(y) * self.bnns[1](inputs) + nn.sigmoid(-y) * self.bnns[0]( + inputs + ) + + def summary_join_string(self, params) -> str: + params = params or {} + if 'params' in params: + params = params['params'] + change_point = params.get('change_point') + cp_str = '' + if change_point is not None: + cp_str = f'[{jnp.array_str(change_point, precision=2)}]' + return f'<{cp_str}' diff --git a/tensorflow_probability/python/experimental/autobnn/operators_test.py b/tensorflow_probability/python/experimental/autobnn/operators_test.py new file mode 100644 index 0000000000..63f978f003 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/operators_test.py @@ -0,0 +1,218 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for operators.py.""" + +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import operators +from tensorflow_probability.python.experimental.autobnn import util +from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib +from absl.testing import absltest + + +KERNELS = [ + operators.Add( + bnns=(kernels.OneLayerBNN(width=50), kernels.OneLayerBNN(width=50)) + ), + operators.Add( + bnns=(kernels.OneLayerBNN(width=50), kernels.OneLayerBNN(width=100)) + ), + operators.Add( + bnns=( + kernels.PeriodicBNN(width=50, period=0.1), + kernels.OneLayerBNN(width=50), + ) + ), + operators.WeightedSum( + bnns=(kernels.OneLayerBNN(width=50), kernels.OneLayerBNN(width=50)) + ), + operators.WeightedSum( + bnns=( + kernels.PeriodicBNN(width=50, period=0.1), + kernels.OneLayerBNN(width=50), + ), + alpha=2.0, + ), + operators.WeightedSum( + bnns=( + kernels.ExponentiatedQuadraticBNN(width=50), + kernels.ExponentiatedQuadraticBNN(width=50), + ) + ), + operators.Multiply( + bnns=( + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + ), + ), + operators.Multiply( + bnns=( + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + ) + ), + operators.Multiply( + bnns=( + operators.Add( + bnns=( + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + ), + going_to_be_multiplied=True + ), + operators.Add( + bnns=( + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + kernels.OneLayerBNN(width=50, going_to_be_multiplied=True), + ), + going_to_be_multiplied=True + ), + ) + ), + operators.ChangePoint( + bnns=(kernels.OneLayerBNN(width=50), kernels.OneLayerBNN(width=50)), + change_point=5.0, + slope=1.0, + ), + operators.LearnableChangePoint( + bnns=(kernels.OneLayerBNN(width=50), kernels.OneLayerBNN(width=50)), + time_series_xs=np.linspace(0., 5., 100), + ), +] + + +NAMES = [ + "(OneLayer#OneLayer)", + "(OneLayer#OneLayer)", + "(Periodic(period=0.10)#OneLayer)", + "(OneLayer+OneLayer)", + "(Periodic(period=0.10)+OneLayer)", + "(RBF+RBF)", + "(OneLayer*OneLayer)", + "(OneLayer*OneLayer*OneLayer)", + "((OneLayer#OneLayer)*(OneLayer#OneLayer))", + "(OneLayer<[5.0]OneLayer)", + "(OneLayer Tuple[Callable[..., Any], Callable[..., Any], Callable[..., Any]]: + """Returns unconstraining bijectors for all variables in the BNN.""" + jb = jax.tree_map( + lambda x: x.experimental_default_event_space_bijector(), + net.get_all_distributions(), + is_leaf=lambda x: isinstance(x, distribution_lib.Distribution), + ) + + def transform(params): + return {'params': jax.tree_map(lambda p, b: b(p), params['params'], jb)} + + def inverse_transform(params): + return { + 'params': jax.tree_map(lambda p, b: b.inverse(p), params['params'], jb) + } + + def inverse_log_det_jacobian(params): + return jax.tree_util.tree_reduce( + lambda a, b: a + b, + jax.tree_map( + lambda p, b: jnp.sum(b.inverse_log_det_jacobian(p)), + params['params'], + jb, + ), + initializer=0.0, + ) + + return transform, inverse_transform, inverse_log_det_jacobian + + +def suggest_periods(ys) -> List[float]: + """Suggest a few periods for the time series.""" + f, pxx = scipy.signal.periodogram(ys) + + top5_powers, top5_indices = jax.lax.top_k(pxx, 5) + top5_power = jnp.sum(top5_powers) + best_indices = [i for i in top5_indices if pxx[i] > 0.05 * top5_power] + # Sort in descending order so the best periods are first. + best_indices.sort(reverse=True, key=lambda i: pxx[i]) + return [1.0 / f[i] for i in best_indices if 1.0 / f[i] < 0.6 * len(ys)] + + +def load_fake_dataset(): + """Return some fake data for testing purposes.""" + x_train = jnp.arange(0.0, 120.0) / 120.0 + y_train = x_train + jnp.sin(x_train * 10.0) + x_train * x_train + x_train = x_train[..., jnp.newaxis] + return x_train, y_train[..., jnp.newaxis] diff --git a/tensorflow_probability/python/experimental/autobnn/util_test.py b/tensorflow_probability/python/experimental/autobnn/util_test.py new file mode 100644 index 0000000000..dbaa3866c9 --- /dev/null +++ b/tensorflow_probability/python/experimental/autobnn/util_test.py @@ -0,0 +1,72 @@ +# Copyright 2023 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for util.py.""" + +import jax +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.autobnn import kernels +from tensorflow_probability.python.experimental.autobnn import util +from tensorflow_probability.python.internal import test_util + + +class UtilTest(test_util.TestCase): + + def test_suggest_periods(self): + self.assertListEqual([], util.suggest_periods([1 for _ in range(20)])) + self.assertListEqual( + [2.0], util.suggest_periods([i % 2 for i in range(20)]) + ) + np.testing.assert_allclose( + [20.0], + util.suggest_periods( + [jnp.sin(2.0 * jnp.pi * i / 20.0) for i in range(100)] + ), + ) + # suggest_periods is robust against small linear trends ... + np.testing.assert_allclose( + [20.0], + util.suggest_periods( + [0.01 * i + jnp.sin(2.0 * jnp.pi * i / 20.0) for i in range(100)] + ), + ) + # but sort of falls apart currently for large linear trends. + np.testing.assert_allclose( + [50.0, 100.0 / 3.0], + util.suggest_periods( + [i + jnp.sin(2.0 * jnp.pi * i / 20.0) for i in range(100)] + ), + ) + + def test_transform(self): + seed = jax.random.PRNGKey(20231018) + bnn = kernels.LinearBNN(width=5) + bnn.likelihood_model.noise_min = 0.2 + transform, _, _ = util.make_transforms(bnn) + p = bnn.init(seed, jnp.ones((1, 10), dtype=jnp.float32)) + + # Softplus(low=0.2) bijector + self.assertEqual(0.2 + jax.nn.softplus(p['params']['noise_scale']), + transform(p)['params']['noise_scale']) + self.assertEqual(jnp.exp(p['params']['amplitude']), + transform(p)['params']['amplitude']) + + # Identity bijector + self.assertAllEqual(p['params']['dense2']['kernel'], + transform(p)['params']['dense2']['kernel']) + + +if __name__ == '__main__': + test_util.main() diff --git a/testing/dependency_install_lib.sh b/testing/dependency_install_lib.sh index 801d7a3361..261cc1665b 100644 --- a/testing/dependency_install_lib.sh +++ b/testing/dependency_install_lib.sh @@ -93,9 +93,11 @@ install_test_only_packages() { # The following unofficial dependencies are used only by tests. PIP_FLAGS=${1-} python -m pip install $PIP_FLAGS \ + flax \ hypothesis \ jax \ jaxlib \ + jaxtyping \ optax \ matplotlib \ mock \