Skip to content

Commit 4044b05

Browse files
authored
Merge pull request #35 from handley-lab/slice_acceptance
Implement maximum step limits and acceptance handling for slice sampling
2 parents 171ff14 + 5956a83 commit 4044b05

File tree

3 files changed

+86
-40
lines changed

3 files changed

+86
-40
lines changed

blackjax/mcmc/ss.py

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from blackjax.base import SamplingAlgorithm
3636
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
37+
from blackjax.mcmc.proposal import static_binomial_sampling
3738

3839
__all__ = [
3940
"SliceState",
@@ -71,6 +72,8 @@ class SliceInfo(NamedTuple):
7172
7273
Attributes
7374
----------
75+
is_accepted
76+
A boolean indicating whether the proposed sample was accepted.
7477
constraint
7578
The constraint values at the final accepted position.
7679
l_steps
@@ -82,15 +85,12 @@ class SliceInfo(NamedTuple):
8285
s_steps
8386
The number of steps taken during the "shrinking" phase to find an
8487
acceptable sample.
85-
evals
86-
The total number of log-density evaluations performed during the step.
8788
"""
8889

90+
is_accepted: bool
8991
constraint: Array = jnp.array([])
90-
l_steps: int = 0
91-
r_steps: int = 0
92-
s_steps: int = 0
93-
evals: int = 0
92+
num_steps: int = 0
93+
num_shrink: int = 0
9494

9595

9696
def init(position: ArrayTree, logdensity_fn: Callable) -> SliceState:
@@ -113,6 +113,8 @@ def init(position: ArrayTree, logdensity_fn: Callable) -> SliceState:
113113

114114
def build_kernel(
115115
stepper_fn: Callable,
116+
max_steps: int = 10,
117+
max_shrinkage: int = 100,
116118
) -> Callable:
117119
"""Build a Slice Sampling kernel.
118120
@@ -149,8 +151,8 @@ def kernel(
149151
strict: Array,
150152
) -> tuple[SliceState, SliceInfo]:
151153
rng_key, vs_key, hs_key = jax.random.split(rng_key, 3)
152-
intermediate_state, vs_info = vertical_slice(vs_key, state)
153-
new_state, hs_info = horizontal_slice(
154+
intermediate_state, v_info = vertical_slice(vs_key, state)
155+
new_state, info = horizontal_slice(
154156
hs_key,
155157
intermediate_state,
156158
d,
@@ -159,22 +161,30 @@ def kernel(
159161
constraint_fn,
160162
constraint,
161163
strict,
164+
max_steps,
165+
max_shrinkage,
162166
)
163-
164-
info = SliceInfo(
165-
constraint=hs_info.constraint,
166-
l_steps=hs_info.l_steps,
167-
r_steps=hs_info.r_steps,
168-
s_steps=hs_info.s_steps,
169-
evals=vs_info.evals + hs_info.evals,
167+
info = info._replace(is_accepted=v_info.is_accepted & info.is_accepted)
168+
new_state = jax.lax.cond(
169+
info.is_accepted,
170+
lambda _: new_state,
171+
lambda _: state,
172+
operand=None,
170173
)
174+
# info = SliceInfo(
175+
# constraint=hs_info.constraint,
176+
# expanding_steps=hs_info.expanding_steps,
177+
# slice_steps=hs_info.slice_steps,
178+
# shrink_steps=hs_info.shrink_steps,
179+
180+
# )
171181

172182
return new_state, info
173183

174184
return kernel
175185

176186

177-
def vertical_slice(rng_key: PRNGKey, state: SliceState) -> tuple[SliceState, SliceInfo]:
187+
def vertical_slice(rng_key: PRNGKey, state: SliceState) -> SliceState:
178188
"""Define the vertical slice for the Slice Sampling algorithm.
179189
180190
This function determines the height `y` for the horizontal slice by sampling
@@ -197,7 +207,8 @@ def vertical_slice(rng_key: PRNGKey, state: SliceState) -> tuple[SliceState, Sli
197207
"""
198208
logslice = state.logdensity + jnp.log(jax.random.uniform(rng_key))
199209
new_state = state._replace(logslice=logslice)
200-
info = SliceInfo()
210+
is_accepted = logslice < state.logdensity
211+
info = SliceInfo(is_accepted=is_accepted)
201212
return new_state, info
202213

203214

@@ -210,6 +221,8 @@ def horizontal_slice(
210221
constraint_fn: Callable,
211222
constraint: Array,
212223
strict: Array,
224+
m: int,
225+
max_shrinkage: int,
213226
) -> tuple[SliceState, SliceInfo]:
214227
"""Propose a new sample using the stepping-out and shrinking procedures.
215228
@@ -224,8 +237,8 @@ def horizontal_slice(
224237
----------
225238
rng_key
226239
A JAX PRNG key.
227-
x0
228-
The current position (PyTree).
240+
state
241+
The current slice sampling state.
229242
d
230243
The direction (PyTree) for proposing moves.
231244
stepper_fn
@@ -248,6 +261,11 @@ def horizontal_slice(
248261
An array of boolean flags indicating whether each constraint should be
249262
strict (constraint_fn(x) > constraint) or non-strict
250263
(constraint_fn(x) >= constraint).
264+
m
265+
The maximum number of steps to take when expanding the interval in
266+
each direction during the stepping-out phase.
267+
max_shrinkage
268+
The maximum number of shrinking steps to perform to avoid infinite loops.
251269
252270
Returns
253271
-------
@@ -258,11 +276,13 @@ def horizontal_slice(
258276
"""
259277
# Initial bounds
260278
rng_key, subkey = jax.random.split(rng_key)
261-
u = jax.random.uniform(subkey)
279+
u, v = jax.random.uniform(subkey, 2)
280+
j = jnp.floor(m * v).astype(int)
281+
k = (m - 1) - j
262282
x0 = state.position
263283

264-
def body_fun(carry):
265-
_, s, t, n = carry
284+
def step_body_fun(carry):
285+
_, s, t, i = carry
266286
t += s
267287
x = stepper_fn(x0, d, t)
268288
logdensity_x = logdensity_fn(x)
@@ -272,21 +292,22 @@ def body_fun(carry):
272292
)
273293
constraints = jnp.append(constraints, logdensity_x >= state.logslice)
274294
within = jnp.all(constraints)
275-
n += 1
276-
return within, s, t, n
295+
i -= 1
296+
return within, s, t, i
277297

278-
def cond_fun(carry):
298+
def step_cond_fun(carry):
279299
within = carry[0]
280-
return within
300+
i = carry[-1]
301+
return within & (i > 0)
281302

282303
# Expand
283-
_, _, l, l_steps = jax.lax.while_loop(cond_fun, body_fun, (True, -1, -u, 0))
284-
_, _, r, r_steps = jax.lax.while_loop(cond_fun, body_fun, (True, +1, 1 - u, 0))
304+
_, _, l, j = jax.lax.while_loop(step_cond_fun, step_body_fun, (True, -1, -u, j))
305+
_, _, r, k = jax.lax.while_loop(step_cond_fun, step_body_fun, (True, +1, 1 - u, k))
285306

286307
# Shrink
287308
def shrink_body_fun(carry):
288-
_, l, r, _, _, _, rng_key, s_steps = carry
289-
s_steps += 1
309+
_, l, r, _, _, _, rng_key, s = carry
310+
s += 1
290311

291312
rng_key, subkey = jax.random.split(rng_key)
292313
u = jax.random.uniform(subkey, minval=l, maxval=r)
@@ -303,24 +324,30 @@ def shrink_body_fun(carry):
303324
l = jnp.where(u < 0, u, l)
304325
r = jnp.where(u > 0, u, r)
305326

306-
return within, l, r, x, logdensity_x, constraint_x, rng_key, s_steps
327+
return within, l, r, x, logdensity_x, constraint_x, rng_key, s
307328

308329
def shrink_cond_fun(carry):
309330
within = carry[0]
310-
return ~within
331+
s = carry[-1]
332+
return ~within & (s < max_shrinkage + 1)
311333

312334
carry = (False, l, r, x0, -jnp.inf, constraint, rng_key, 0)
313335
carry = jax.lax.while_loop(shrink_cond_fun, shrink_body_fun, carry)
314-
_, l, r, x, logdensity_x, constraint_x, rng_key, s_steps = carry
315-
slice_state = SliceState(x, logdensity_x)
316-
evals = l_steps + r_steps + s_steps
317-
slice_info = SliceInfo(constraint_x, l_steps, r_steps, s_steps, evals)
336+
_, l, r, x, logdensity_x, constraint_x, rng_key, s = carry
337+
338+
end_state = SliceState(x, logdensity_x)
339+
slice_state, (is_accepted, _, _) = static_binomial_sampling(
340+
rng_key, jnp.log(s < max_shrinkage + 1), state, end_state
341+
)
342+
343+
slice_info = SliceInfo(is_accepted, constraint_x, j+k, s)
318344
return slice_state, slice_info
319345

320346

321347
def build_hrss_kernel(
322348
generate_slice_direction_fn: Callable,
323349
stepper_fn: Callable,
350+
max_steps: int = 10,
324351
) -> Callable:
325352
"""Build a Hit-and-Run Slice Sampling kernel.
326353
@@ -347,7 +374,7 @@ def build_hrss_kernel(
347374
A kernel function that takes a PRNG key, the current `SliceState`, and
348375
the log-density function, and returns a new `SliceState` and `SliceInfo`.
349376
"""
350-
slice_kernel = build_kernel(stepper_fn)
377+
slice_kernel = build_kernel(stepper_fn, max_steps)
351378

352379
def kernel(
353380
rng_key: PRNGKey, state: SliceState, logdensity_fn: Callable

blackjax/ns/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,9 @@ def delete_fn(
414414
"""
415415
loglikelihood = state.loglikelihood
416416
neg_dead_loglikelihood, dead_idx = jax.lax.top_k(-loglikelihood, num_delete)
417-
constraint = loglikelihood > -neg_dead_loglikelihood.min()
418-
weights = jnp.array(constraint, dtype=jnp.float32)
417+
constraint_loglikelihood = loglikelihood > -neg_dead_loglikelihood.min()
418+
weights = jnp.array(constraint_loglikelihood, dtype=jnp.float32)
419+
weights = jnp.where(weights.sum() > 0., weights, jnp.ones_like(weights))
419420
start_idx = jax.random.choice(
420421
rng_key,
421422
len(weights),

blackjax/ns/nss.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def build_kernel(
148148
stepper_fn: Callable = default_stepper_fn,
149149
adapt_direction_params_fn: Callable = compute_covariance_from_particles,
150150
generate_slice_direction_fn: Callable = sample_direction_from_covariance,
151+
max_steps: int = 10,
152+
max_shrinkage: int = 100,
151153
) -> Callable:
152154
"""Builds the Nested Slice Sampling kernel.
153155
@@ -179,6 +181,12 @@ def build_kernel(
179181
A function `(rng_key, **params) -> direction_pytree` that generates a
180182
normalized direction for HRSS, using parameters from `adapt_direction_params_fn`.
181183
Defaults to `sample_direction_from_covariance`.
184+
max_steps
185+
The maximum number of steps to take when expanding the interval in
186+
each direction during the stepping-out phase. Defaults to 10.
187+
max_shrinkage
188+
The maximum number of shrinking steps to perform to avoid infinite loops.
189+
Defaults to 100.
182190
183191
Returns
184192
-------
@@ -188,7 +196,7 @@ def build_kernel(
188196
the `NSInfo` for the step.
189197
"""
190198

191-
slice_kernel = build_slice_kernel(stepper_fn)
199+
slice_kernel = build_slice_kernel(stepper_fn, max_steps, max_shrinkage)
192200

193201
@repeat_kernel(num_inner_steps)
194202
def inner_kernel(
@@ -235,6 +243,8 @@ def as_top_level_api(
235243
stepper_fn: Callable = default_stepper_fn,
236244
adapt_direction_params_fn: Callable = compute_covariance_from_particles,
237245
generate_slice_direction_fn: Callable = sample_direction_from_covariance,
246+
max_steps: int = 10,
247+
max_shrinkage: int = 100,
238248
) -> SamplingAlgorithm:
239249
"""Creates an adaptive Nested Slice Sampling (NSS) algorithm.
240250
@@ -266,6 +276,12 @@ def as_top_level_api(
266276
A function `(rng_key, **params) -> direction_pytree` that generates a
267277
normalized direction for HRSS, using parameters from `adapt_direction_params_fn`.
268278
Defaults to `sample_direction_from_covariance`.
279+
max_steps
280+
The maximum number of steps to take when expanding the interval in
281+
each direction during the stepping-out phase. Defaults to 10.
282+
max_shrinkage
283+
The maximum number of shrinking steps to perform to avoid infinite loops.
284+
Defaults to 100.
269285
270286
Returns
271287
-------
@@ -283,6 +299,8 @@ def as_top_level_api(
283299
stepper_fn=stepper_fn,
284300
adapt_direction_params_fn=adapt_direction_params_fn,
285301
generate_slice_direction_fn=generate_slice_direction_fn,
302+
max_steps=max_steps,
303+
max_shrinkage=max_shrinkage,
286304
)
287305
init_fn = partial(
288306
init,

0 commit comments

Comments
 (0)