Skip to content

Commit

Permalink
marged origin
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Nov 14, 2023
1 parent 62b7701 commit 0368dff
Show file tree
Hide file tree
Showing 241 changed files with 3,160 additions and 3,125 deletions.
4 changes: 2 additions & 2 deletions STYLE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ they supersede all previous conventions.
1. Submodule names should be singular, except where they overlap to TF.
Justification: Having plural looks strange in user code, ie,
tf.optimizer.Foo reads nicer than tf.optimizers.Foo since submodules are
only used to access a single, specific thing (at a time).
tf.optimizer.Foo reads nicer than tf_keras.optimizers.Foo since submodules
are only used to access a single, specific thing (at a time).
1. Use `tf.newaxis` rather than `None` to `tf.expand_dims`.
Expand Down
6 changes: 3 additions & 3 deletions SUBSTRATES.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ vmap, etc.), we will special-case using an `if JAX_MODE:` block.
tests, TFP impl, etc), with `tfp.math.value_and_gradient` or similar. Then,
we can special-case `JAX_MODE` inside the body of `value_and_gradient`.

* __`tf.Variable`, `tf.optimizers.Optimizer`__
* __`tf.Variable`, `tf_keras.optimizers.Optimizer`__

TF provides a `Variable` abstraction so that graph functions may modify
state, including using the TF `Optimizer` subclasses like `Adam`. JAX, in
contrast, operates only on pure functions. In general, TFP is fairly
state, including using the Keras `Optimizer` subclasses like `Adam`. JAX,
in contrast, operates only on pure functions. In general, TFP is fairly
functional (e.g. `tfp.optimizer.lbfgs_minimize`), but in some cases (e.g.
`tfp.vi.fit_surrogate_posterior`,
`tfp.optimizer.StochasticGradientLangevinDynamics`) we have felt the
Expand Down
18 changes: 9 additions & 9 deletions discussion/adaptive_malt/adaptive_malt.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def adaptive_mcmc_step(
target_log_prob_fn: fun_mc.PotentialFn,
num_mala_steps: int,
num_adaptation_steps: int,
seed: jax.random.KeyArray,
seed: jax.Array,
method: str = 'hmc',
damping: Optional[jnp.ndarray] = None,
scalar_step_size: Optional[jnp.ndarray] = None,
Expand Down Expand Up @@ -778,7 +778,7 @@ def adaptive_nuts_step(
target_log_prob_fn: fun_mc.PotentialFn,
num_mala_steps: int,
num_adaptation_steps: int,
seed: jax.random.KeyArray,
seed: jax.Array,
scalar_step_size: Optional[jnp.ndarray] = None,
vector_step_size: Optional[jnp.ndarray] = None,
rvar_factor: int = 8,
Expand Down Expand Up @@ -1040,7 +1040,7 @@ class MeadsExtra(NamedTuple):


def meads_init(state: jnp.ndarray, target_log_prob_fn: fun_mc.PotentialFn,
num_folds: int, seed: jax.random.KeyArray):
num_folds: int, seed: jax.Array):
"""Initializes MEADS."""
num_dimensions = state.shape[-1]
num_chains = state.shape[0]
Expand All @@ -1062,7 +1062,7 @@ def meads_init(state: jnp.ndarray, target_log_prob_fn: fun_mc.PotentialFn,

def meads_step(meads_state: MeadsState,
target_log_prob_fn: fun_mc.PotentialFn,
seed: jax.random.KeyArray,
seed: jax.Array,
vector_step_size: Optional[jnp.ndarray] = None,
damping: Optional[jnp.ndarray] = None,
step_size_multiplier: float = 0.5,
Expand Down Expand Up @@ -1221,7 +1221,7 @@ def run_adaptive_mcmc_on_target(
init_step_size: jnp.ndarray,
num_adaptation_steps: int,
num_results: int,
seed: jax.random.KeyArray,
seed: jax.Array,
num_mala_steps: int = 100,
rvar_smoothing: int = 0,
trajectory_opt_kwargs: Mapping[str, Any] = immutabledict.immutabledict({
Expand Down Expand Up @@ -1358,7 +1358,7 @@ def run_adaptive_nuts_on_target(
init_step_size: jnp.ndarray,
num_adaptation_steps: int,
num_results: int,
seed: jax.random.KeyArray,
seed: jax.Array,
num_mala_steps: int = 100,
rvar_smoothing: int = 0,
num_chains: Optional[int] = None,
Expand Down Expand Up @@ -1478,7 +1478,7 @@ def run_meads_on_target(
num_adaptation_steps: int,
num_results: int,
thinning: int,
seed: jax.random.KeyArray,
seed: jax.Array,
num_folds: int,
num_chains: Optional[int] = None,
init_x: Optional[jnp.ndarray] = None,
Expand Down Expand Up @@ -1596,7 +1596,7 @@ def run_fixed_mcmc_on_target(
target: gym.targets.Model,
init_x: jnp.ndarray,
method: str,
seed: jax.random.KeyArray,
seed: jax.Array,
num_warmup_steps: int,
num_results: int,
scalar_step_size: jnp.ndarray,
Expand Down Expand Up @@ -1706,7 +1706,7 @@ def run_vi_on_target(
init_x: jnp.ndarray,
num_steps: int,
learning_rate: float,
seed: jax.random.KeyArray,
seed: jax.Array,
):
"""Run VI on a target.
Expand Down
1 change: 0 additions & 1 deletion required_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
'cloudpickle>=1.3',
'gast>=0.3.2', # For autobatching
'dm-tree', # For NumPy/JAX backends (hence, also for prefer_static)
'typing-extensions<4.6.0', # TODO(b/284106340): Remove this pin
]

if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def has_ext_modules(self):
url='http://github.com/tensorflow/probability',
license='Apache 2.0',
packages=find_packages(),
python_requires='>=3.8',
python_requires='>=3.9',
install_requires=REQUIRED_PACKAGES,
# Add in any packaged data.
include_package_data=True,
Expand All @@ -88,7 +88,6 @@ def has_ext_modules(self):
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
Expand Down
4 changes: 3 additions & 1 deletion spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def make_tensor_seed(seed):
"""Converts a seed to a `Tensor` seed."""
if seed is None:
raise ValueError('seed must not be None when using JAX')
if isinstance(seed, jax.random.PRNGKeyArray):
if hasattr(seed, 'dtype') and jax.dtypes.issubdtype(
seed.dtype, jax.dtypes.prng_key
):
return seed
return jnp.asarray(seed, jnp.uint32)

Expand Down
50 changes: 31 additions & 19 deletions spinoffs/fun_mc/fun_mc/fun_mc_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,9 +732,9 @@ def maybe_broadcast_structure(from_structure: Any,
to_structure: Any) -> Any:
"""Maybe broadcasts `from_structure` to `to_structure`.
If `from_structure` is a singleton, it is tiled to match the structure of
`to_structure`. Note that the elements in `from_structure` are not copied if
this tiling occurs.
This assumes that `from_structure` is a shallow version of `to_structure`.
Subtrees of `to_structure` are set to the leaf values of `from_structure` that
those subtrees correspond to.
Args:
from_structure: A structure.
Expand All @@ -743,11 +743,12 @@ def maybe_broadcast_structure(from_structure: Any,
Returns:
new_from_structure: Same structure as `to_structure`.
"""
flat_from = util.flatten_tree(from_structure)
flat_to = util.flatten_tree(to_structure)
if len(flat_from) == 1:
flat_from *= len(flat_to)
return util.unflatten_tree(to_structure, flat_from)
def _broadcast_leaf(from_val, to_subtree):
return util.map_tree(lambda _: from_val, to_subtree)

return util.map_tree_up_to(
from_structure, _broadcast_leaf, from_structure, to_structure
)


def reparameterize_potential_fn(
Expand Down Expand Up @@ -3420,8 +3421,11 @@ def _default_log_weight_fn(old_state, new_state, stage, transition_extra):

@util.named_call
def systematic_resample(
particles: State, log_weights: FloatTensor,
seed: Any) -> (tuple[tuple[State, FloatTensor], IntTensor]):
particles: State,
log_weights: FloatTensor,
seed: Any,
do_resample: Optional[BooleanTensor] = None,
) -> tuple[tuple[State, FloatTensor], IntTensor]:
"""Systematically resamples particles in proportion to their weights.
This uses the algorithm from [1].
Expand All @@ -3430,6 +3434,8 @@ def systematic_resample(
particles: The particles.
log_weights: Un-normalized weights.
seed: PRNG seed.
do_resample: Whether to perform the resample. If None, resampling is
performed unconditionally.
Returns:
particles_and_weights: tuple of resampled particles and weights.
Expand All @@ -3453,30 +3459,36 @@ def systematic_resample(
repeats = tf.cast(util.diff(tf.floor(pie), prepend=0), tf.int32)
parent_idxs = util.repeat(
tf.range(num_particles), repeats, total_repeat_length=num_particles)
if do_resample is not None:
parent_idxs = tf.where(do_resample, parent_idxs, tf.range(num_particles))
new_particles = util.map_tree(lambda x: tf.gather(x, parent_idxs), particles)
new_log_weights = tf.fill(log_weights.shape,
tfp.math.reduce_logmeanexp(log_weights))
if do_resample is not None:
new_log_weights = tf.where(do_resample, new_log_weights, log_weights)
return (new_particles, new_log_weights), parent_idxs


@util.named_call
def annealed_importance_sampling_resample(
ais_state: AnnealedImportanceSamplingState,
resample_fn: Callable[
[State, FloatTensor, Any], tuple[tuple[State, tf.Tensor], ResampleExtra]
[State, FloatTensor, Any, BooleanTensor],
tuple[tuple[State, tf.Tensor], ResampleExtra],
] = systematic_resample,
min_ess_threshold: FloatTensor = 0.5,
seed: Any = None,
) -> tuple[AnnealedImportanceSamplingState, ResampleExtra]:
"""Resamples the particles in AnnealedImportanceSamplingState."""

(state, log_weight), extra = resample_fn(ais_state.state,
ais_state.log_weight, seed)
state, log_weight = choose(
ais_state.ess() <
tf.cast(log_weight.shape[0], log_weight.dtype) * min_ess_threshold,
(state, log_weight),
(ais_state.state, ais_state.log_weight),
log_weight = tf.convert_to_tensor(ais_state.log_weight)
do_resample = (
ais_state.ess()
< tf.cast(log_weight.shape[0], log_weight.dtype)
* min_ess_threshold
)
(state, log_weight), extra = resample_fn(
ais_state.state, ais_state.log_weight, seed, do_resample
)
return ais_state._replace(state=state, log_weight=log_weight), extra

Expand All @@ -3500,7 +3512,7 @@ def geometric_annealing_path(
initial_target_log_prob_fn: PotentialFn,
final_target_log_prob_fn: PotentialFn,
fraction_fn: Optional[Callable[[FloatTensor], tf.Tensor]] = None,
) -> Callable[[Stage], PotentialFn]:
) -> PotentialFn:
"""Returns a geometrically interpolated target density function.
This interpolates between `initial_target_log_prob_fn` and
Expand Down
35 changes: 34 additions & 1 deletion spinoffs/fun_mc/fun_mc/fun_mc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from absl.testing import parameterized
import jax
from jax.config import config as jax_config
from jax import config as jax_config
import numpy as np
import scipy.stats
import tensorflow.compat.v2 as real_tf
Expand Down Expand Up @@ -361,6 +361,9 @@ def testBroadcastStructure(self):
struct = fun_mc.maybe_broadcast_structure([3, 4], [1, 2])
self.assertEqual([3, 4], struct)

struct = fun_mc.maybe_broadcast_structure([1, 2], [[0, 0], [0, 0, 0]])
self.assertEqual([[1, 1], [2, 2, 2]], struct)

def testCallPotentialFn(self):

def potential(x):
Expand Down Expand Up @@ -1885,6 +1888,36 @@ def body(seed):
new_log_weights,
tf.fill(probs.shape, tfp.math.reduce_logmeanexp(log_weights)))

def testSystematicResampleAncestors(self):
log_weights = self._constant([-float('inf'), 0.])
particles = tf.range(log_weights.shape[0])
seed = self._make_seed(_test_seed())

(new_particles, new_log_weights), ancestors = fun_mc.systematic_resample(
particles, log_weights, seed=seed
)
self.assertAllEqual(new_particles, tf.ones_like(particles))
self.assertAllEqual(
new_log_weights, tf.math.log(self._constant([0.5, 0.5]))
)
self.assertAllEqual(ancestors, tf.ones_like(particles))

(new_particles, new_log_weights), ancestors = fun_mc.systematic_resample(
particles, log_weights, do_resample=True, seed=seed
)
self.assertAllEqual(new_particles, tf.ones_like(particles))
self.assertAllEqual(
new_log_weights, tf.math.log(self._constant([0.5, 0.5]))
)
self.assertAllEqual(ancestors, tf.ones_like(particles))

(new_particles, new_log_weights), ancestors = fun_mc.systematic_resample(
particles, log_weights, do_resample=False, seed=seed
)
self.assertAllEqual(new_particles, particles)
self.assertAllEqual(new_log_weights, log_weights)
self.assertAllEqual(ancestors, particles)

def testAIS(self):

def tlp_1(x):
Expand Down
2 changes: 1 addition & 1 deletion spinoffs/fun_mc/fun_mc/malt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Dependency imports

import jax
from jax.config import config as jax_config
from jax import config as jax_config
import numpy as np
import tensorflow.compat.v2 as real_tf

Expand Down
2 changes: 1 addition & 1 deletion spinoffs/fun_mc/fun_mc/prefab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Dependency imports

import jax
from jax.config import config as jax_config
from jax import config as jax_config
import numpy as np
import tensorflow.compat.v2 as real_tf

Expand Down
2 changes: 1 addition & 1 deletion spinoffs/fun_mc/fun_mc/sga_hmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from absl.testing import parameterized
import jax
from jax.config import config as jax_config
from jax import config as jax_config
import tensorflow.compat.v2 as real_tf

from tensorflow_probability.python.internal import test_util as tfp_test_util
Expand Down
2 changes: 1 addition & 1 deletion spinoffs/fun_mc/fun_mc/util_tfp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Dependency imports

from absl.testing import parameterized
from jax.config import config as jax_config
from jax import config as jax_config
import numpy as np
import tensorflow.compat.v2 as real_tf

Expand Down
3 changes: 0 additions & 3 deletions spinoffs/inference_gym/inference_gym/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

# Placeholder: py_library
# [internal] load pytype.bzl (pytype_strict_library)
# [internal] load dummy dependency

package(
# default_applicable_licenses
Expand Down Expand Up @@ -98,5 +97,3 @@ py_library(
name = "backend_tensorflow",
srcs = ["dynamic/backend_tensorflow/__init__.py"],
)

# third_party_dependency(package = "py/inference_gym") # DisableOnExport
1 change: 1 addition & 0 deletions tensorflow_probability/examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ py_library(
# six dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:tf_keras",
],
)

Expand Down
Loading

0 comments on commit 0368dff

Please sign in to comment.