Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions examples/haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ def string_for_serialization(self):


async def run_example(
LLM, poem_title, syllable_pattern=[5, 7, 5], n_particles=20, ess_threshold=0.5
LLM,
poem_title,
syllable_pattern=[5, 7, 5],
n_particles=20,
ess_threshold=0.5,
resampling_method="multinomial",
):
# Construct prompt
prompt = f"""{EXAMPLE_POEMS}
Expand All @@ -132,7 +137,12 @@ async def run_example(

# Run inference
particles = await smc_standard(
haiku_model, n_particles, ess_threshold, "html", "results/haiku.json"
haiku_model,
n_particles,
ess_threshold,
"html",
"results/haiku.json",
resampling_method,
)

return particles
Expand Down
15 changes: 13 additions & 2 deletions examples/hard_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,13 @@ def immutable_properties(self):
3."""


async def run_example(LLM, max_tokens=50, n_particles=20, ess_threshold=0.5):
async def run_example(
LLM,
max_tokens=50,
n_particles=20,
ess_threshold=0.5,
resampling_method="multinomial",
):
# Cache the key value vectors for the prompt.
LLM.cache_kv(LLM.tokenizer.encode(prompt))

Expand All @@ -84,7 +90,12 @@ async def run_example(LLM, max_tokens=50, n_particles=20, ess_threshold=0.5):

# Run inference.
particles = await smc_standard(
constraint_model, n_particles, ess_threshold, "html", "results/output.json"
constraint_model,
n_particles,
ess_threshold,
"html",
"results/output.json",
resampling_method,
)
for p in particles:
print(f"{p.context}")
Expand Down
2 changes: 1 addition & 1 deletion html/results/output.json

Large diffs are not rendered by default.

78 changes: 69 additions & 9 deletions llamppl/inference/smc_standard.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,53 @@
import asyncio
import copy
from datetime import datetime

import warnings
import numpy as np

from ..util import logsumexp
from .smc_record import SMCRecord


def stratified_resample(weights):
# source: https://filterpy.readthedocs.io/en/latest/_modules/filterpy/monte_carlo/resampling.html#stratified_resample
"""Performs the stratified resampling algorithm used by particle filters.

This algorithms aims to make selections relatively uniformly across the
particles. It divides the cumulative sum of the weights into N equal
divisions, and then selects one particle randomly from each division. This
guarantees that each sample is between 0 and 2/N apart.

Args:
weights (list-like of float): List of weights as floats.

Returns:
(ndarray): Array of indexes into the weights defining the resample.
The index of the zeroth resample is indexes[0], etc.
"""

N = len(weights)
# make N subdivisions, and chose a random position within each one
positions = (np.random.random(N) + range(N)) / N
Copy link

Copilot AI May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider using np.arange(N) instead of range(N) for clarity and to ensure consistent behavior with numpy operations.

Suggested change
positions = (np.random.random(N) + range(N)) / N
positions = (np.random.random(N) + np.arange(N)) / N

Copilot uses AI. Check for mistakes.

indexes = np.zeros(N, "i")
cumulative_sum = np.cumsum(weights)
i, j = 0, 0
while i < N:
if positions[i] < cumulative_sum[j]:
indexes[i] = j
i += 1
else:
j += 1
return indexes


async def smc_standard(
model, n_particles, ess_threshold=0.5, visualization_dir=None, json_file=None
model,
n_particles,
ess_threshold=0.5,
visualization_dir=None,
json_file=None,
resampling_method="multinomial",
):
"""
Standard sequential Monte Carlo algorithm with multinomial resampling.
Expand All @@ -20,10 +58,21 @@
ess_threshold (float): Effective sample size below which resampling is triggered, given as a fraction of `n_particles`.
visualization_dir (str): Path to the directory where the visualization server is running.
json_file (str): Path to the JSON file to save the record of the inference, relative to `visualization_dir` if provided.
resampling_method (str): The method to use for resampling. Must be one of 'multinomial' or 'stratified'. Defaults to 'multinomial'.

Returns:
particles (list[llamppl.modeling.Model]): The completed particles after inference.
"""
if not (0 <= ess_threshold <= 1):
raise ValueError(

Check warning on line 67 in llamppl/inference/smc_standard.py

View check run for this annotation

Codecov / codecov/patch

llamppl/inference/smc_standard.py#L67

Added line #L67 was not covered by tests
f"Effective sample size threshold must be between 0 and 1. Got {ess_threshold}."
)

if resampling_method not in ["multinomial", "stratified"]:
raise ValueError(

Check warning on line 72 in llamppl/inference/smc_standard.py

View check run for this annotation

Codecov / codecov/patch

llamppl/inference/smc_standard.py#L72

Added line #L72 was not covered by tests
f"Invalid resampling method: {resampling_method}. Must be one of 'multinomial' or 'stratified'."
)

particles = [copy.deepcopy(model) for _ in range(n_particles)]
await asyncio.gather(*[p.start() for p in particles])
record = visualization_dir is not None or json_file is not None
Expand All @@ -48,19 +97,30 @@

# Normalize weights
W = np.array([p.weight for p in particles])
if np.all(W == -np.inf):
# Avoid a nans in the normalized weights.
# Could just terminate inference here, but keep running for now.
did_resample = False
continue
w_sum = logsumexp(W)
normalized_weights = W - w_sum

# Resample if necessary
if -logsumexp(normalized_weights * 2) < np.log(ess_threshold) + np.log(
n_particles
if ess_threshold > 0 and (
ess_threshold - logsumexp(normalized_weights * 2)
< np.log(ess_threshold) + np.log(n_particles)
):
# Alternative implementation uses a multinomial distribution and only makes n-1 copies, reusing existing one, but fine for now
probs = np.exp(normalized_weights)
ancestor_indices = [
np.random.choice(range(len(particles)), p=probs)
for _ in range(n_particles)
]
if resampling_method == "multinomial":
# Alternative implementation uses a multinomial distribution and only makes n-1 copies, reusing existing one, but fine for now
ancestor_indices = [
np.random.choice(range(len(particles)), p=probs)
for _ in range(n_particles)
]
elif resampling_method == "stratified":
ancestor_indices = stratified_resample(probs)
else:
raise ValueError(f"Invalid resampling method: {resampling_method}.")

Check warning on line 123 in llamppl/inference/smc_standard.py

View check run for this annotation

Codecov / codecov/patch

llamppl/inference/smc_standard.py#L123

Added line #L123 was not covered by tests

if record:
# Sort the ancestor indices
Expand Down
4 changes: 4 additions & 0 deletions llamppl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@


def logsumexp(nums):
nums = np.asarray(nums)
if np.all(nums == -np.inf):
return -np.inf

Check warning on line 9 in llamppl/util.py

View check run for this annotation

Codecov / codecov/patch

llamppl/util.py#L9

Added line #L9 was not covered by tests

m = np.max(nums)
return np.log(np.sum(np.exp(nums - m))) + m

Expand Down
24 changes: 22 additions & 2 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,36 @@ def LLM(backend):


@pytest.mark.parametrize("backend", backends)
def test_hard_constraints(LLM, n_particles=20, max_tokens=25):
def test_hard_constraints(LLM, n_particles=5, max_tokens=25):
particles = asyncio.run(
run_hard_constraints(LLM, max_tokens=max_tokens, n_particles=n_particles)
)
assert len(particles) == n_particles

particles = asyncio.run(
run_hard_constraints(
LLM,
max_tokens=max_tokens,
n_particles=n_particles,
resampling_method="stratified",
)
)
assert len(particles) == n_particles


@pytest.mark.parametrize("backend", backends)
def test_haiku(LLM, n_particles=20):
def test_haiku(LLM, n_particles=5):
particles = asyncio.run(
run_haiku(LLM, poem_title="The beauty of testing", n_particles=n_particles)
)
assert len(particles) == n_particles

particles = asyncio.run(
run_haiku(
LLM,
poem_title="The beauty of testing",
n_particles=n_particles,
resampling_method="stratified",
)
)
assert len(particles) == n_particles