Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions llamppl/inference/resampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import numpy as np


def systematic_resample(weights):
"""Systematic resampling with a single random offset.

Separates the sample space into N equal divisions and uses one
random offset for all divisions. Every sample is exactly 1/N apart.
Lowest variance among standard resampling methods.

Adapted from FilterPy (R. Labbe):
https://filterpy.readthedocs.io/en/latest/monte_carlo/resampling.html

Args:
weights (array-like): Normalized probability weights summing to 1.

Returns:
ndarray: Integer array of ancestor indices.
"""
N = len(weights)
positions = (np.random.random() + np.arange(N)) / N

indexes = np.zeros(N, "i")
cumulative_sum = np.cumsum(weights)
cumulative_sum[-1] = 1.0 # avoid round-off errors
i, j = 0, 0
while i < N:
if positions[i] < cumulative_sum[j]:
indexes[i] = j
i += 1
else:
j += 1
return indexes


def stratified_resample(weights):
"""Stratified resampling with one random draw per stratum.

Divides the cumulative sum into N equal strata and draws one
sample uniformly from each. Guarantees samples are between
0 and 2/N apart.

Adapted from FilterPy (R. Labbe):
https://filterpy.readthedocs.io/en/latest/monte_carlo/resampling.html

Args:
weights (array-like): Normalized probability weights summing to 1.

Returns:
ndarray: Integer array of ancestor indices.
"""
N = len(weights)
positions = (np.random.random(N) + np.arange(N)) / N

indexes = np.zeros(N, "i")
cumulative_sum = np.cumsum(weights)
cumulative_sum[-1] = 1.0
i, j = 0, 0
while i < N:
if positions[i] < cumulative_sum[j]:
indexes[i] = j
i += 1
else:
j += 1
return indexes


def residual_resample(weights):
"""Residual resampling: deterministic floor copies + multinomial remainder.

Takes floor(N * w_i) copies of each particle deterministically,
then resamples the remaining slots from the fractional residuals
using multinomial resampling.

Adapted from FilterPy (R. Labbe):
https://filterpy.readthedocs.io/en/latest/monte_carlo/resampling.html

Args:
weights (array-like): Normalized probability weights summing to 1.

Returns:
ndarray: Integer array of ancestor indices.
"""
N = len(weights)
weights = np.asarray(weights, dtype=float)
indexes = np.zeros(N, "i")

# Deterministic copies
num_copies = np.floor(N * weights).astype(int)
k = 0
for i in range(N):
for _ in range(num_copies[i]):
indexes[k] = i
k += 1

# Multinomial resample on the residual
if k < N:
residual = weights * N - num_copies
residual /= residual.sum()
cumulative_sum = np.cumsum(residual)
cumulative_sum[-1] = 1.0
indexes[k:N] = np.searchsorted(cumulative_sum, np.random.random(N - k))

return indexes


def multinomial_resample(weights):
"""Multinomial resampling: independent categorical draws.

Each of the N ancestor indices is drawn independently from the
categorical distribution defined by the weights.

Args:
weights (array-like): Normalized probability weights summing to 1.

Returns:
ndarray: Integer array of ancestor indices.
"""
N = len(weights)
return np.array([
np.random.choice(N, p=weights) for _ in range(N)
], dtype=np.intp)


RESAMPLING_METHODS = {
"systematic": systematic_resample,
"stratified": stratified_resample,
"residual": residual_resample,
"multinomial": multinomial_resample,
}


def get_resampling_fn(method):
"""Get a resampling function by name.

Args:
method (str): One of 'systematic', 'stratified', 'residual', 'multinomial'.

Returns:
callable: Resampling function that takes weights and returns indices.

Raises:
ValueError: If method is not recognized.
"""
if method not in RESAMPLING_METHODS:
raise ValueError(
f"Unknown resampling method '{method}'. "
f"Must be one of: {', '.join(RESAMPLING_METHODS.keys())}"
)
return RESAMPLING_METHODS[method]
23 changes: 16 additions & 7 deletions llamppl/inference/smc_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,33 @@

from ..util import logsumexp
from .smc_record import SMCRecord
from .resampling import get_resampling_fn


async def smc_standard(
model, n_particles, ess_threshold=0.5, visualization_dir=None, json_file=None
model,
n_particles,
ess_threshold=0.5,
visualization_dir=None,
json_file=None,
resampling_method="multinomial",
):
"""
Standard sequential Monte Carlo algorithm with multinomial resampling.
Standard sequential Monte Carlo algorithm.

Args:
model (llamppl.modeling.Model): The model to perform inference on.
n_particles (int): Number of particles to execute concurrently.
ess_threshold (float): Effective sample size below which resampling is triggered, given as a fraction of `n_particles`.
visualization_dir (str): Path to the directory where the visualization server is running.
json_file (str): Path to the JSON file to save the record of the inference, relative to `visualization_dir` if provided.
resampling_method (str): One of 'multinomial', 'stratified', 'systematic', or 'residual'. Defaults to 'multinomial'.

Returns:
particles (list[llamppl.modeling.Model]): The completed particles after inference.
"""
resample_fn = get_resampling_fn(resampling_method)

particles = [copy.deepcopy(model) for _ in range(n_particles)]
await asyncio.gather(*[p.start() for p in particles])
record = visualization_dir is not None or json_file is not None
Expand All @@ -48,19 +57,19 @@ async def smc_standard(

# Normalize weights
W = np.array([p.weight for p in particles])
if np.all(W == -np.inf):
# All particles dead — skip resampling to avoid NaNs
did_resample = False
continue
w_sum = logsumexp(W)
normalized_weights = W - w_sum

# Resample if necessary
if -logsumexp(normalized_weights * 2) < np.log(ess_threshold) + np.log(
n_particles
):
# Alternative implementation uses a multinomial distribution and only makes n-1 copies, reusing existing one, but fine for now
probs = np.exp(normalized_weights)
ancestor_indices = [
np.random.choice(range(len(particles)), p=probs)
for _ in range(n_particles)
]
ancestor_indices = resample_fn(probs).tolist()

if record:
# Sort the ancestor indices
Expand Down
3 changes: 3 additions & 0 deletions llamppl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@


def logsumexp(nums):
nums = np.asarray(nums)
if np.all(nums == -np.inf):
return -np.inf
m = np.max(nums)
return np.log(np.sum(np.exp(nums - m))) + m

Expand Down
Loading
Loading