34
34
35
35
from blackjax .base import SamplingAlgorithm
36
36
from blackjax .types import Array , ArrayLikeTree , ArrayTree , PRNGKey
37
+ from blackjax .mcmc .proposal import static_binomial_sampling
37
38
38
39
__all__ = [
39
40
"SliceState" ,
@@ -71,6 +72,8 @@ class SliceInfo(NamedTuple):
71
72
72
73
Attributes
73
74
----------
75
+ is_accepted
76
+ A boolean indicating whether the proposed sample was accepted.
74
77
constraint
75
78
The constraint values at the final accepted position.
76
79
l_steps
@@ -82,15 +85,12 @@ class SliceInfo(NamedTuple):
82
85
s_steps
83
86
The number of steps taken during the "shrinking" phase to find an
84
87
acceptable sample.
85
- evals
86
- The total number of log-density evaluations performed during the step.
87
88
"""
88
89
90
+ is_accepted : bool
89
91
constraint : Array = jnp .array ([])
90
- l_steps : int = 0
91
- r_steps : int = 0
92
- s_steps : int = 0
93
- evals : int = 0
92
+ num_steps : int = 0
93
+ num_shrink : int = 0
94
94
95
95
96
96
def init (position : ArrayTree , logdensity_fn : Callable ) -> SliceState :
@@ -113,6 +113,8 @@ def init(position: ArrayTree, logdensity_fn: Callable) -> SliceState:
113
113
114
114
def build_kernel (
115
115
stepper_fn : Callable ,
116
+ max_steps : int = 10 ,
117
+ max_shrinkage : int = 100 ,
116
118
) -> Callable :
117
119
"""Build a Slice Sampling kernel.
118
120
@@ -149,8 +151,8 @@ def kernel(
149
151
strict : Array ,
150
152
) -> tuple [SliceState , SliceInfo ]:
151
153
rng_key , vs_key , hs_key = jax .random .split (rng_key , 3 )
152
- intermediate_state , vs_info = vertical_slice (vs_key , state )
153
- new_state , hs_info = horizontal_slice (
154
+ intermediate_state , v_info = vertical_slice (vs_key , state )
155
+ new_state , info = horizontal_slice (
154
156
hs_key ,
155
157
intermediate_state ,
156
158
d ,
@@ -159,22 +161,30 @@ def kernel(
159
161
constraint_fn ,
160
162
constraint ,
161
163
strict ,
164
+ max_steps ,
165
+ max_shrinkage ,
162
166
)
163
-
164
- info = SliceInfo (
165
- constraint = hs_info .constraint ,
166
- l_steps = hs_info .l_steps ,
167
- r_steps = hs_info .r_steps ,
168
- s_steps = hs_info .s_steps ,
169
- evals = vs_info .evals + hs_info .evals ,
167
+ info = info ._replace (is_accepted = v_info .is_accepted & info .is_accepted )
168
+ new_state = jax .lax .cond (
169
+ info .is_accepted ,
170
+ lambda _ : new_state ,
171
+ lambda _ : state ,
172
+ operand = None ,
170
173
)
174
+ # info = SliceInfo(
175
+ # constraint=hs_info.constraint,
176
+ # expanding_steps=hs_info.expanding_steps,
177
+ # slice_steps=hs_info.slice_steps,
178
+ # shrink_steps=hs_info.shrink_steps,
179
+
180
+ # )
171
181
172
182
return new_state , info
173
183
174
184
return kernel
175
185
176
186
177
- def vertical_slice (rng_key : PRNGKey , state : SliceState ) -> tuple [ SliceState , SliceInfo ] :
187
+ def vertical_slice (rng_key : PRNGKey , state : SliceState ) -> SliceState :
178
188
"""Define the vertical slice for the Slice Sampling algorithm.
179
189
180
190
This function determines the height `y` for the horizontal slice by sampling
@@ -197,7 +207,8 @@ def vertical_slice(rng_key: PRNGKey, state: SliceState) -> tuple[SliceState, Sli
197
207
"""
198
208
logslice = state .logdensity + jnp .log (jax .random .uniform (rng_key ))
199
209
new_state = state ._replace (logslice = logslice )
200
- info = SliceInfo ()
210
+ is_accepted = logslice < state .logdensity
211
+ info = SliceInfo (is_accepted = is_accepted )
201
212
return new_state , info
202
213
203
214
@@ -210,6 +221,8 @@ def horizontal_slice(
210
221
constraint_fn : Callable ,
211
222
constraint : Array ,
212
223
strict : Array ,
224
+ m : int ,
225
+ max_shrinkage : int ,
213
226
) -> tuple [SliceState , SliceInfo ]:
214
227
"""Propose a new sample using the stepping-out and shrinking procedures.
215
228
@@ -224,8 +237,8 @@ def horizontal_slice(
224
237
----------
225
238
rng_key
226
239
A JAX PRNG key.
227
- x0
228
- The current position (PyTree) .
240
+ state
241
+ The current slice sampling state .
229
242
d
230
243
The direction (PyTree) for proposing moves.
231
244
stepper_fn
@@ -248,6 +261,11 @@ def horizontal_slice(
248
261
An array of boolean flags indicating whether each constraint should be
249
262
strict (constraint_fn(x) > constraint) or non-strict
250
263
(constraint_fn(x) >= constraint).
264
+ m
265
+ The maximum number of steps to take when expanding the interval in
266
+ each direction during the stepping-out phase.
267
+ max_shrinkage
268
+ The maximum number of shrinking steps to perform to avoid infinite loops.
251
269
252
270
Returns
253
271
-------
@@ -258,11 +276,13 @@ def horizontal_slice(
258
276
"""
259
277
# Initial bounds
260
278
rng_key , subkey = jax .random .split (rng_key )
261
- u = jax .random .uniform (subkey )
279
+ u , v = jax .random .uniform (subkey , 2 )
280
+ j = jnp .floor (m * v ).astype (int )
281
+ k = (m - 1 ) - j
262
282
x0 = state .position
263
283
264
- def body_fun (carry ):
265
- _ , s , t , n = carry
284
+ def step_body_fun (carry ):
285
+ _ , s , t , i = carry
266
286
t += s
267
287
x = stepper_fn (x0 , d , t )
268
288
logdensity_x = logdensity_fn (x )
@@ -272,21 +292,22 @@ def body_fun(carry):
272
292
)
273
293
constraints = jnp .append (constraints , logdensity_x >= state .logslice )
274
294
within = jnp .all (constraints )
275
- n + = 1
276
- return within , s , t , n
295
+ i - = 1
296
+ return within , s , t , i
277
297
278
- def cond_fun (carry ):
298
+ def step_cond_fun (carry ):
279
299
within = carry [0 ]
280
- return within
300
+ i = carry [- 1 ]
301
+ return within & (i > 0 )
281
302
282
303
# Expand
283
- _ , _ , l , l_steps = jax .lax .while_loop (cond_fun , body_fun , (True , - 1 , - u , 0 ))
284
- _ , _ , r , r_steps = jax .lax .while_loop (cond_fun , body_fun , (True , + 1 , 1 - u , 0 ))
304
+ _ , _ , l , j = jax .lax .while_loop (step_cond_fun , step_body_fun , (True , - 1 , - u , j ))
305
+ _ , _ , r , k = jax .lax .while_loop (step_cond_fun , step_body_fun , (True , + 1 , 1 - u , k ))
285
306
286
307
# Shrink
287
308
def shrink_body_fun (carry ):
288
- _ , l , r , _ , _ , _ , rng_key , s_steps = carry
289
- s_steps += 1
309
+ _ , l , r , _ , _ , _ , rng_key , s = carry
310
+ s += 1
290
311
291
312
rng_key , subkey = jax .random .split (rng_key )
292
313
u = jax .random .uniform (subkey , minval = l , maxval = r )
@@ -303,24 +324,30 @@ def shrink_body_fun(carry):
303
324
l = jnp .where (u < 0 , u , l )
304
325
r = jnp .where (u > 0 , u , r )
305
326
306
- return within , l , r , x , logdensity_x , constraint_x , rng_key , s_steps
327
+ return within , l , r , x , logdensity_x , constraint_x , rng_key , s
307
328
308
329
def shrink_cond_fun (carry ):
309
330
within = carry [0 ]
310
- return ~ within
331
+ s = carry [- 1 ]
332
+ return ~ within & (s < max_shrinkage + 1 )
311
333
312
334
carry = (False , l , r , x0 , - jnp .inf , constraint , rng_key , 0 )
313
335
carry = jax .lax .while_loop (shrink_cond_fun , shrink_body_fun , carry )
314
- _ , l , r , x , logdensity_x , constraint_x , rng_key , s_steps = carry
315
- slice_state = SliceState (x , logdensity_x )
316
- evals = l_steps + r_steps + s_steps
317
- slice_info = SliceInfo (constraint_x , l_steps , r_steps , s_steps , evals )
336
+ _ , l , r , x , logdensity_x , constraint_x , rng_key , s = carry
337
+
338
+ end_state = SliceState (x , logdensity_x )
339
+ slice_state , (is_accepted , _ , _ ) = static_binomial_sampling (
340
+ rng_key , jnp .log (s < max_shrinkage + 1 ), state , end_state
341
+ )
342
+
343
+ slice_info = SliceInfo (is_accepted , constraint_x , j + k , s )
318
344
return slice_state , slice_info
319
345
320
346
321
347
def build_hrss_kernel (
322
348
generate_slice_direction_fn : Callable ,
323
349
stepper_fn : Callable ,
350
+ max_steps : int = 10 ,
324
351
) -> Callable :
325
352
"""Build a Hit-and-Run Slice Sampling kernel.
326
353
@@ -347,7 +374,7 @@ def build_hrss_kernel(
347
374
A kernel function that takes a PRNG key, the current `SliceState`, and
348
375
the log-density function, and returns a new `SliceState` and `SliceInfo`.
349
376
"""
350
- slice_kernel = build_kernel (stepper_fn )
377
+ slice_kernel = build_kernel (stepper_fn , max_steps )
351
378
352
379
def kernel (
353
380
rng_key : PRNGKey , state : SliceState , logdensity_fn : Callable
0 commit comments