diff --git a/llamppl/inference/resampling.py b/llamppl/inference/resampling.py new file mode 100644 index 0000000..3a45430 --- /dev/null +++ b/llamppl/inference/resampling.py @@ -0,0 +1,156 @@ +import numpy as np + + +def systematic_resample(weights): + """Systematic resampling with a single random offset. + + Generates N equally spaced probe points in [0, 1] with a single + random offset between 0 and 1/N. Indices are derived by mapping + these points through the inverse CDF of the categorical distribution + defined by the weights. Each index i is resampled exactly + floor(N * w_i) or ceil(N * w_i) times. + + Unlike stratified and residual resampling, systematic resampling + is not provably lower-variance than multinomial in all cases; + see Douc et al. (2005), Sec. 3.4: https://arxiv.org/abs/cs/0507025 + + Adapted from FilterPy (R. Labbe): + https://filterpy.readthedocs.io/en/latest/monte_carlo/resampling.html + + Args: + weights (array-like): Normalized probability weights summing to 1. + + Returns: + ndarray: Integer array of ancestor indices. + """ + N = len(weights) + positions = (np.random.random() + np.arange(N)) / N + + indexes = np.zeros(N, "i") + cumulative_sum = np.cumsum(weights) + cumulative_sum[-1] = 1.0 # avoid round-off errors + i, j = 0, 0 + while i < N: + if positions[i] < cumulative_sum[j]: + indexes[i] = j + i += 1 + else: + j += 1 + return indexes + + +def stratified_resample(weights): + """Stratified resampling with one random draw per stratum. + + Divides [0, 1] into N equal strata and draws one uniform point + independently within each, so that consecutive points are between + 0 and 2/N apart. Indices are derived by mapping these points + through the inverse CDF of the categorical distribution defined + by the weights. + + Adapted from FilterPy (R. Labbe): + https://filterpy.readthedocs.io/en/latest/monte_carlo/resampling.html + + Args: + weights (array-like): Normalized probability weights summing to 1. + + Returns: + ndarray: Integer array of ancestor indices. + """ + N = len(weights) + positions = (np.random.random(N) + np.arange(N)) / N + + indexes = np.zeros(N, "i") + cumulative_sum = np.cumsum(weights) + cumulative_sum[-1] = 1.0 + i, j = 0, 0 + while i < N: + if positions[i] < cumulative_sum[j]: + indexes[i] = j + i += 1 + else: + j += 1 + return indexes + + +def residual_resample(weights): + """Residual resampling: deterministic floor copies + multinomial remainder. + + Takes floor(N * w_i) copies of each particle deterministically, + then resamples the remaining slots from the fractional residuals + using multinomial resampling. + + Adapted from FilterPy (R. Labbe): + https://filterpy.readthedocs.io/en/latest/monte_carlo/resampling.html + + Args: + weights (array-like): Normalized probability weights summing to 1. + + Returns: + ndarray: Integer array of ancestor indices. + """ + N = len(weights) + weights = np.asarray(weights, dtype=float) + indexes = np.zeros(N, "i") + + # Deterministic copies + num_copies = np.floor(N * weights).astype(int) + k = 0 + for i in range(N): + for _ in range(num_copies[i]): + indexes[k] = i + k += 1 + + # Multinomial resample on the residual + if k < N: + residual = weights * N - num_copies + residual /= residual.sum() + cumulative_sum = np.cumsum(residual) + cumulative_sum[-1] = 1.0 + indexes[k:N] = np.searchsorted(cumulative_sum, np.random.random(N - k)) + + return indexes + + +def multinomial_resample(weights): + """Multinomial resampling: independent categorical draws. + + Each of the N ancestor indices is drawn independently from the + categorical distribution defined by the weights. + + Args: + weights (array-like): Normalized probability weights summing to 1. + + Returns: + ndarray: Integer array of ancestor indices. + """ + N = len(weights) + return np.random.choice(N, size=N, replace=True, p=weights) + + +RESAMPLING_METHODS = { + "systematic": systematic_resample, + "stratified": stratified_resample, + "residual": residual_resample, + "multinomial": multinomial_resample, +} + + +def get_resampling_fn(method): + """Get a resampling function by name. + + Args: + method (str): One of 'systematic', 'stratified', 'residual', 'multinomial'. + + Returns: + callable: Resampling function that takes weights and returns indices. + + Raises: + ValueError: If method is not recognized. + """ + if method not in RESAMPLING_METHODS: + raise ValueError( + f"Unknown resampling method '{method}'. " + f"Must be one of: {', '.join(RESAMPLING_METHODS.keys())}" + ) + return RESAMPLING_METHODS[method] diff --git a/llamppl/inference/smc_standard.py b/llamppl/inference/smc_standard.py index 99f0623..ff801f2 100644 --- a/llamppl/inference/smc_standard.py +++ b/llamppl/inference/smc_standard.py @@ -6,13 +6,19 @@ from ..util import logsumexp from .smc_record import SMCRecord +from .resampling import get_resampling_fn async def smc_standard( - model, n_particles, ess_threshold=0.5, visualization_dir=None, json_file=None + model, + n_particles, + ess_threshold=0.5, + visualization_dir=None, + json_file=None, + resampling_method="multinomial", ): """ - Standard sequential Monte Carlo algorithm with multinomial resampling. + Standard sequential Monte Carlo algorithm. Args: model (llamppl.modeling.Model): The model to perform inference on. @@ -20,10 +26,13 @@ async def smc_standard( ess_threshold (float): Effective sample size below which resampling is triggered, given as a fraction of `n_particles`. visualization_dir (str): Path to the directory where the visualization server is running. json_file (str): Path to the JSON file to save the record of the inference, relative to `visualization_dir` if provided. + resampling_method (str): One of 'multinomial', 'stratified', 'systematic', or 'residual'. Defaults to 'multinomial'. Returns: particles (list[llamppl.modeling.Model]): The completed particles after inference. """ + resample_fn = get_resampling_fn(resampling_method) + particles = [copy.deepcopy(model) for _ in range(n_particles)] await asyncio.gather(*[p.start() for p in particles]) record = visualization_dir is not None or json_file is not None @@ -48,6 +57,10 @@ async def smc_standard( # Normalize weights W = np.array([p.weight for p in particles]) + if np.all(W == -np.inf): + # All particles dead — skip resampling to avoid NaNs + did_resample = False + continue w_sum = logsumexp(W) normalized_weights = W - w_sum @@ -55,12 +68,8 @@ async def smc_standard( if -logsumexp(normalized_weights * 2) < np.log(ess_threshold) + np.log( n_particles ): - # Alternative implementation uses a multinomial distribution and only makes n-1 copies, reusing existing one, but fine for now probs = np.exp(normalized_weights) - ancestor_indices = [ - np.random.choice(range(len(particles)), p=probs) - for _ in range(n_particles) - ] + ancestor_indices = resample_fn(probs).tolist() if record: # Sort the ancestor indices diff --git a/llamppl/util.py b/llamppl/util.py index 2831945..5021ee2 100644 --- a/llamppl/util.py +++ b/llamppl/util.py @@ -4,6 +4,9 @@ def logsumexp(nums): + nums = np.asarray(nums) + if np.all(nums == -np.inf): + return -np.inf m = np.max(nums) return np.log(np.sum(np.exp(nums - m))) + m diff --git a/tests/test_resampling.py b/tests/test_resampling.py new file mode 100644 index 0000000..d95199d --- /dev/null +++ b/tests/test_resampling.py @@ -0,0 +1,187 @@ +"""Tests for resampling methods.""" + +import numpy as np +import pytest +from llamppl.inference.resampling import ( + multinomial_resample, + stratified_resample, + systematic_resample, + residual_resample, + get_resampling_fn, + RESAMPLING_METHODS, +) + + +ALL_METHODS = list(RESAMPLING_METHODS.keys()) + + +class TestResamplingBasics: + """Basic correctness tests shared across all methods.""" + + @pytest.mark.parametrize("method", ALL_METHODS) + def test_returns_correct_length(self, method): + fn = get_resampling_fn(method) + weights = np.array([0.2, 0.3, 0.1, 0.15, 0.25]) + np.random.seed(42) + indices = fn(weights) + assert len(indices) == len(weights) + + @pytest.mark.parametrize("method", ALL_METHODS) + def test_indices_in_range(self, method): + fn = get_resampling_fn(method) + weights = np.array([0.1, 0.2, 0.3, 0.4]) + np.random.seed(42) + indices = fn(weights) + assert all(0 <= i < len(weights) for i in indices) + + @pytest.mark.parametrize("method", ALL_METHODS) + def test_returns_integer_array(self, method): + fn = get_resampling_fn(method) + weights = np.array([0.5, 0.3, 0.2]) + np.random.seed(42) + indices = fn(weights) + assert indices.dtype in (np.int32, np.int64, np.intp) + + @pytest.mark.parametrize("method", ALL_METHODS) + def test_degenerate_single_particle(self, method): + fn = get_resampling_fn(method) + weights = np.array([1.0]) + indices = fn(weights) + assert len(indices) == 1 + assert indices[0] == 0 + + @pytest.mark.parametrize("method", ALL_METHODS) + def test_degenerate_all_weight_on_one(self, method): + """When one particle has all the weight, all indices should be that particle.""" + fn = get_resampling_fn(method) + weights = np.array([0.0, 0.0, 1.0, 0.0, 0.0]) + np.random.seed(42) + indices = fn(weights) + assert all(i == 2 for i in indices) + + @pytest.mark.parametrize("method", ALL_METHODS) + def test_uniform_weights(self, method): + """With uniform weights, each particle should appear roughly once.""" + fn = get_resampling_fn(method) + N = 100 + weights = np.ones(N) / N + np.random.seed(42) + indices = fn(weights) + counts = np.bincount(indices, minlength=N) + # With uniform weights, no particle should have 0 or >3 copies + # (very unlikely with any reasonable method) + assert counts.min() >= 0 + assert counts.max() <= 5 + + @pytest.mark.parametrize("method", ALL_METHODS) + def test_expected_counts_large_n(self, method): + """With many runs, empirical counts should match weights.""" + fn = get_resampling_fn(method) + weights = np.array([0.5, 0.3, 0.1, 0.1]) + N = len(weights) + n_runs = 5000 + total_counts = np.zeros(N) + for seed in range(n_runs): + np.random.seed(seed) + indices = fn(weights) + total_counts += np.bincount(indices, minlength=N) + empirical = total_counts / (n_runs * N) + # Should be close to the weights + np.testing.assert_allclose(empirical, weights, atol=0.02) + + +class TestLowVarianceMethods: + """Tests specific to low-variance resampling (stratified, systematic, residual).""" + + @pytest.mark.parametrize("method", ["stratified", "systematic", "residual"]) + def test_lower_variance_than_multinomial(self, method): + """Low-variance methods should have less resampling variance.""" + weights = np.array([0.4, 0.3, 0.2, 0.05, 0.05]) + N = len(weights) + n_runs = 2000 + + def variance_of_counts(fn): + all_counts = [] + for seed in range(n_runs): + np.random.seed(seed) + indices = fn(weights) + counts = np.bincount(indices, minlength=N) + all_counts.append(counts) + return np.var(all_counts, axis=0).sum() + + multi_var = variance_of_counts(multinomial_resample) + method_var = variance_of_counts(get_resampling_fn(method)) + assert method_var <= multi_var, ( + f"{method} variance ({method_var:.4f}) should be <= " + f"multinomial ({multi_var:.4f})" + ) + + @pytest.mark.parametrize("method", ["stratified", "systematic"]) + def test_guaranteed_representation(self, method): + """Particles with weight >= 1/N must appear at least once.""" + fn = get_resampling_fn(method) + # Particle 0 has weight 0.5 > 1/5 = 0.2, must always appear + weights = np.array([0.5, 0.2, 0.15, 0.1, 0.05]) + for seed in range(100): + np.random.seed(seed) + indices = fn(weights) + counts = np.bincount(indices, minlength=len(weights)) + assert counts[0] >= 1, f"Particle 0 (w=0.5) missing with seed={seed}" + + +class TestSystematic: + """Tests specific to systematic resampling.""" + + def test_evenly_spaced(self): + """Systematic uses exactly 1/N spacing.""" + weights = np.array([0.25, 0.25, 0.25, 0.25]) + np.random.seed(0) + indices = systematic_resample(weights) + # With uniform weights, should get exactly [0, 1, 2, 3] + np.testing.assert_array_equal(sorted(indices), [0, 1, 2, 3]) + + def test_deterministic_given_seed(self): + weights = np.array([0.4, 0.3, 0.2, 0.1]) + np.random.seed(42) + a = systematic_resample(weights) + np.random.seed(42) + b = systematic_resample(weights) + np.testing.assert_array_equal(a, b) + + +class TestResidual: + """Tests specific to residual resampling.""" + + def test_deterministic_floor(self): + """Floor copies should be deterministic.""" + # w = [0.5, 0.3, 0.2], N=10 + # floor(10*w) = [5, 3, 2] = 10, no residual needed + weights = np.array([0.5, 0.3, 0.2]) + N = 10 + # Simulate what residual does with N=10 by repeating weights + weights_10 = np.repeat(weights, 1) # stays same for N=len(weights) + # For the actual test: with 3 particles + np.random.seed(42) + indices = residual_resample(weights) + counts = np.bincount(indices, minlength=3) + # With N=3: floor(3*[0.5, 0.3, 0.2]) = [1, 0, 0], residual = [0.5, 0.9, 0.6] + # At least 1 copy of particle 0 guaranteed + assert counts[0] >= 1 + + +class TestGetResamplingFn: + """Tests for the get_resampling_fn helper.""" + + def test_valid_methods(self): + for name in ["multinomial", "stratified", "systematic", "residual"]: + fn = get_resampling_fn(name) + assert callable(fn) + + def test_invalid_method(self): + with pytest.raises(ValueError, match="Unknown resampling method"): + get_resampling_fn("invalid") + + def test_all_methods_registered(self): + assert set(RESAMPLING_METHODS.keys()) == { + "multinomial", "stratified", "systematic", "residual" + }