From 711fe527d45d5fc8a38c28283c0f8ea74472cd31 Mon Sep 17 00:00:00 2001 From: samuki Date: Fri, 30 Jan 2026 17:34:23 +0100 Subject: [PATCH 1/6] Start adding ensemble --- genlm/control/potential/__init__.py | 6 + genlm/control/potential/built_in/__init__.py | 4 + genlm/control/potential/built_in/ensemble.py | 515 +++++++++++++++++++ 3 files changed, 525 insertions(+) create mode 100644 genlm/control/potential/built_in/ensemble.py diff --git a/genlm/control/potential/__init__.py b/genlm/control/potential/__init__.py index b35572c..3e0c805 100644 --- a/genlm/control/potential/__init__.py +++ b/genlm/control/potential/__init__.py @@ -14,6 +14,9 @@ BoolFSA, JsonSchema, CanonicalTokenization, + Ensemble, + convert_to_weighted_logop, + ByteEnsemble, ) __all__ = [ @@ -28,6 +31,9 @@ "WFSA", "BoolFSA", "CanonicalTokenization", + "Ensemble", + "convert_to_weighted_logop", + "ByteEnsemble", "AutoBatchedPotential", "MultiProcPotential", "Coerced", diff --git a/genlm/control/potential/built_in/__init__.py b/genlm/control/potential/built_in/__init__.py index 8656469..0cde80a 100644 --- a/genlm/control/potential/built_in/__init__.py +++ b/genlm/control/potential/built_in/__init__.py @@ -4,6 +4,7 @@ from .json import JsonSchema from .canonical import CanonicalTokenization from .bytellm import ByteLLM +from .ensemble import Ensemble, convert_to_weighted_logop, ByteEnsemble __all__ = [ "PromptedLLM", @@ -14,4 +15,7 @@ "WFSA", "BoolFSA", "CanonicalTokenization", + "Ensemble", + "convert_to_weighted_logop", + "ByteEnsemble", ] diff --git a/genlm/control/potential/built_in/ensemble.py b/genlm/control/potential/built_in/ensemble.py new file mode 100644 index 0000000..97221ee --- /dev/null +++ b/genlm/control/potential/built_in/ensemble.py @@ -0,0 +1,515 @@ +import asyncio +import warnings +import numpy as np +from typing import Callable, List, Literal, Union +from collections import defaultdict + +from arsenal.maths import logsumexp + +from genlm.control.potential.base import Potential + + +class Ensemble(Potential): + """An ensemble potential combining two language models using a weighted operation. + + The Ensemble class creates a potential that combines log-probabilities from two + base potentials (typically language models) using a specified weighted operation + (e.g., weighted geometric mean, arithmetic mean, min, max, etc.). + + Args: + p1: First potential (language model) + p2: Second potential (language model) + op: Operation name (e.g., "sum", "prod", "min", "max", "harmonic", or power means) + a: Weighting parameter between 0 and 1 (default 0.5 for equal weighting). + When a=0.5, models are weighted equally. For a != 0.5, the combination + is weighted: a * model1 + (1-a) * model2 + + Attributes: + p1: First potential + p2: Second potential + op: Weighted log operation function + p1_vocab_idxs: Indices mapping unified vocabulary to p1's vocabulary + p2_vocab_idxs: Indices mapping unified vocabulary to p2's vocabulary + + Example: + ```python + from genlm.control import PromptedLLM, Ensemble + + # Create two language model potentials + p1 = PromptedLLM.from_name("gpt2") + p2 = PromptedLLM.from_name("gpt2") + + # Create an ensemble with weighted geometric mean (a=0.5) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + + # Use ensemble in sampling + logw = await ensemble.prefix(context) + ``` + + Note: + The Ensemble class handles vocabulary alignment automatically. Both potentials + must have compatible vocabularies (typically the same tokenizer). The logw_next + method is not implemented for Ensemble; instead, use logws_next to get separate + log weights from each model, or use batch_logw_next for batched operations. + """ + + def __init__(self, p1, p2, op, a=0.5): + self.p1 = p1 + self.p2 = p2 + self.op = convert_to_weighted_logop(op, a) + vocab = list(set(p1.vocab + p2.vocab)) + super().__init__(vocabulary=vocab) + + self.p1_vocab_idxs = [self.p1.lookup[x] for x in self.vocab_eos] + self.p2_vocab_idxs = [self.p2.lookup[x] for x in self.vocab_eos] + assert self.p1_vocab_idxs == self.p2_vocab_idxs + + async def prefix(self, context): + """Compute log weights for the prefix using both potentials. + + Args: + context: The context tokens + + Returns: + Combined log weight from both potentials using the ensemble operation + """ + p1_logw, p2_logw = await asyncio.gather( + self.p1.prefix(context), self.p2.prefix(context) + ) + return self.op(p1_logw, p2_logw) + + async def complete(self, context): + """Compute completion log weights using both potentials. + + Args: + context: The context tokens + + Returns: + Combined completion log weight from both potentials + """ + p1_logw, p2_logw = await asyncio.gather( + self.p1.complete(context), self.p2.complete(context) + ) + return self.op(p1_logw, p2_logw) + + async def logws_next(self, context): + """Get log weights from both potentials separately. + + This method returns the log weights from both underlying potentials + without combining them. Useful for custom combination logic. + + Args: + context: The context tokens + + Returns: + Tuple of (p1_logw_next, p2_logw_next) + """ + return await asyncio.gather( + self.p1.logw_next(context), self.p2.logw_next(context) + ) + + async def logw_next(self, context): + """Not implemented for Ensemble class. + + Raises: + NotImplementedError: Always raised. Use logws_next or batch_logw_next instead. + """ + raise NotImplementedError("logw_next is not implemented for Ensemble class.") + + async def batch_logw_next(self, contexts): + """Batched version of logw_next for Ensemble. + + This enables batching when multiple particles need to be extended during SMC, + which can significantly improve performance when using PromptedLLM with + batch_logw_next support. + + Args: + contexts: List of context token sequences + + Returns: + List of LazyWeights objects, one per context, containing the combined + log weights from both potentials + + Note: + This method is only used if the Ensemble is wrapped in AutoBatchedPotential or + called directly with multiple contexts. EnsembleTokenSampler calls p1.logw_next() + and p2.logw_next() directly, so for batching in EnsembleTokenSampler, wrap p1 + and p2 in AutoBatchedPotential before creating the Ensemble. + """ + # Get batched log weights from both potentials + Ws1, Ws2 = await asyncio.gather( + self.p1.batch_logw_next(contexts), self.p2.batch_logw_next(contexts) + ) + # Combine using the ensemble operation + return [ + self.make_lazy_weights( + self.op( + Ws1[n].weights[self.p1_vocab_idxs], + Ws2[n].weights[self.p2_vocab_idxs], + ) + ) + for n in range(len(contexts)) + ] + + +def split_with_atomic_tokens( + data: bytes, atomic_tokens: list[bytes] +) -> list[Union[int, bytes]]: + """ + Splits a bytestring into a list of either individual bytes (as integers) or atomic tokens (as bytes), + depending on whether the current position matches an atomic token. + + Args: + data (bytes): The input byte string to split. + atomic_tokens (list[bytes]): A list of byte substrings that are treated as indivisible atomic tokens. + + Returns: + list[Union[int, bytes]]: A list where each element is either: + - an atomic token (as bytes) if a match is found at that position, + - or a single byte (as an int) if no atomic token matches. + + Notes: + - Matching is greedy but only left-to-right: at each position, the function checks for atomic token matches + starting from length 1 up to the maximum token length. + - Only the first match (shortest prefix match) is used; longer overlapping tokens may be missed if a shorter + prefix matches first. + - If atomic tokens overlap (e.g., b"A" and b"AB"), a warning is raised and only the shortest prefix match + will be used. + + Example: + >>> split_with_atomic_tokens(b"ABC", [b"A", b"AB"]) + [b'A', 66, 67] # b"AB" is not matched because b"A" matched first + """ + # Detect overlapping atomic tokens + for i, token1 in enumerate(atomic_tokens): + for j, token2 in enumerate(atomic_tokens): + if i != j and (token1.startswith(token2) or token2.startswith(token1)): + warnings.warn( + f"Overlapping atomic tokens detected: {token1!r} and {token2!r}. " + "Only the shortest matching prefix will be used." + ) + break # One warning is enough + + result = [] + i = 0 + token_set = set(atomic_tokens) + max_len = max(len(t) for t in atomic_tokens) if atomic_tokens else 0 + + while i < len(data): + matched = False + for length in range(1, max_len + 1): + fragment = data[i : i + length] + if fragment in token_set: + result.append(fragment) + i += length + matched = True + break + if not matched: + result.append(data[i]) + i += 1 + + return result + + +class ByteEnsemble(Potential): + """ + An ensemble potential combining two language models at the byte level using beam search. + + ByteEnsemble manages synchronized beam states for two language models, enabling efficient + byte-level ensemble sampling. Unlike the standard Ensemble class that works with any + Potential, ByteEnsemble provides direct access to beam states for specialized sampling + strategies like ByteEnsembleTokenSampler. + + Attributes: + p1, p2: The base LM objects (not Potentials, but raw model objects). + op: A function to combine log-probabilities. + data_dict_1, data_dict_2: Beam state caches keyed by context (bytes). + vocabulary: Byte-level vocabulary (list of integers 0-255). + eos_tokens: List of EOS tokens from both models. + + Note: + ByteEnsemble is designed to work with ByteEnsembleTokenSampler for specialized + byte-level ensemble sampling. The prefix() and complete() methods are not fully + implemented as this class is meant to be used with custom sampling strategies + that directly access beam states via get_beam_states(). + + Example: + ```python + from genlm.backend import load_model_by_name + from genlm.bytes import BeamParams + from genlm.control.potential.built_in import ByteEnsemble + + llm1 = load_model_by_name("gpt2") + llm2 = load_model_by_name("gpt2") + + beam_params = BeamParams(K=5, prune_threshold=0.0) + ensemble = await ByteEnsemble.create( + llm1, llm2, + op="prod", + prompt1=b"Hello ", + prompt2=b"Hello ", + a=0.5 + ) + + # Use with ByteEnsembleTokenSampler for sampling + ``` + """ + + def __init__( + self, p1, p2, op: Callable, data_dict_1, data_dict_2, vocab, eos_tokens + ): + self.p1 = p1 + self.p2 = p2 + self.op = op + self.data_dict_1 = data_dict_1 + self.data_dict_2 = data_dict_2 + self.eos_tokens = eos_tokens + super().__init__(vocabulary=vocab) + + @classmethod + async def create( + cls, llm1, llm2, op: str, prompt1: bytes, prompt2: bytes, a: float = 0.5 + ): + """Factory method to initialize beam states from prompts and return a ByteEnsemble instance. + + Args: + llm1: First language model (from genlm.backend) + llm2: Second language model (from genlm.backend) + op: Operation name ('sum', 'prod', 'min', 'max', 'harmonic', or power means) + prompt1: Prompt bytes for first model + prompt2: Prompt bytes for second model + a: Weighting parameter between 0 and 1 (default 0.5 for equal weighting) + + Returns: + ByteEnsemble: Initialized ensemble with beam states ready for sampling + + Raises: + RuntimeError: If beam states become empty after prefill + """ + from genlm.bytes import ByteBeamState, BeamParams + + # Use reasonable beam parameters - K=5 with moderate pruning + # K=1 was too aggressive and caused empty beams + beam_params = BeamParams(K=5, prune_threshold=0.0, verbose=False) + data_dict_1 = defaultdict() + data_dict_2 = defaultdict() + + async def setup(): + # Initialize beams sequentially to avoid overwhelming vLLM with concurrent requests + # This can help prevent "Background loop has errored" errors + beam1 = await ByteBeamState.initial(llm1, beam_params) + beam2 = await ByteBeamState.initial(llm2, beam_params) + # Prefill sequentially as well to reduce concurrent load + beam_state_1 = await beam1.prefill(prompt1) + beam_state_2 = await beam2.prefill(prompt2) + return beam_state_1, beam_state_2 + + beam_state_1, beam_state_2 = await setup() + + # Check if beams are empty after initialization + if len(beam_state_1) == 0: + raise RuntimeError( + f"Beam1 is empty after prefill with prompt of length {len(prompt1)} bytes" + ) + if len(beam_state_2) == 0: + raise RuntimeError( + f"Beam2 is empty after prefill with prompt of length {len(prompt2)} bytes" + ) + + data_dict_1[b""] = beam_state_1 + data_dict_2[b""] = beam_state_2 + + eos_tokens = [ + llm1.byte_vocab[llm1.tokenizer.eos_token_id], + llm2.byte_vocab[llm2.tokenizer.eos_token_id], + ] + + return cls( + llm1, + llm2, + convert_to_weighted_logop(op, a), + data_dict_1, + data_dict_2, + vocab=list(range(256)), + eos_tokens=eos_tokens, + ) + + async def _cleanup_cache(self): + """Remove old entries to avoid cache bloat.""" + max_len = max( + ( + len(split_with_atomic_tokens(k, self.eos_tokens)) + for k in self.data_dict_1 + ), + default=0, + ) + min_len = max_len - 2 + for d in [self.data_dict_1, self.data_dict_2]: + for k in list(d.keys()): + if len(k) < min_len: + del d[k] + + async def get_beam_states(self, context: List[int]): + """Fetch beam states for the current context. + + This method provides direct access to the underlying beam states, which + is used by ByteEnsembleTokenSampler for synchronized beam advancement. + + Args: + context (List[int]): Context as list of byte values + + Returns: + Tuple[ByteBeamState, ByteBeamState]: Beam states from both models + + Raises: + KeyError: If context not found in cache (beam states must be populated + by ByteEnsembleTokenSampler during sampling) + """ + ctx_bytes = bytes(context) + await self._cleanup_cache() + beam1 = self.data_dict_1[ctx_bytes] + beam2 = self.data_dict_2[ctx_bytes] + return beam1, beam2 + + async def prefix(self, context: List[int]): + """Compute prefix weight (not fully implemented). + + ByteEnsemble is designed to be used with ByteEnsembleTokenSampler which + manages weights separately. This method is a stub to satisfy the Potential interface. + + Returns: + None + """ + return None + + async def complete(self, context: List[int]): + """Compute completion weight (not fully implemented). + + ByteEnsemble is designed to be used with ByteEnsembleTokenSampler which + manages weights separately. This method is a stub to satisfy the Potential interface. + + Returns: + None + """ + return None + + +def _power_mean(p: float, a: float): + """Create a weighted power mean operator in log space. + + M_p(x, y; a) = (a * exp(p*x) + (1-a) * exp(p*y))^(1/p) + In log space: (1/p) * logsumexp([log(a) + p*x, log(1-a) + p*y]) + """ + log_a, log_1_minus_a = np.log(a), np.log(1 - a) + return lambda x, y: (1.0 / p) * logsumexp( + [log_a + p * x, log_1_minus_a + p * y], axis=0 + ) + + +def _weighted_extremum(func, a: float): + """Create a weighted min/max operator.""" + + def extremum(x, y, a): + if a <= 0.5: + return (1 - 2 * a) * x + 2 * a * func(x, y) + else: + return (2 * a - 1) * y + 2 * (1 - a) * func(x, y) + + return lambda x, y: extremum(x, y, a) + + +# Map operation names to their power values +_POWER_MEANS = { + "pm5": -5.0, + "pm2.5": -2.5, + "p-2": -2.0, + "pm1.5": -1.5, + "pm0.5": -0.5, + "pm0.25": -0.25, + "p0.25": 0.25, + "p0.5": 0.5, + "p1.5": 1.5, + "p2": 2.0, + "p2.5": 2.5, + "p3": 3.0, + "p5": 5.0, +} + + +def convert_to_weighted_logop( + op: Literal[ + "sum", + "prod", + "min", + "max", + "harmonic", + "pm5", + "pm2.5", + "p-2", + "pm1.5", + "pm0.5", + "pm0.25", + "p0.25", + "p0.5", + "p1.5", + "p2", + "p2.5", + "p3", + "p5", + ], + a: float = 0.5, +): + """Convert a string operation to its weighted log-space equivalent. + + This function takes an operation name and a weighting parameter and returns + a function that combines two log-probability arrays using the specified + weighted operation. + + Args: + op: Operation name. Supported operations include: + - Means: "sum" (arithmetic), "prod" (geometric), "harmonic" + - Extrema: "min", "max" + - Power means: "pm5", "pm2.5", "p-2", "pm1.5", "pm0.5", "pm0.25", + "p0.25", "p0.5", "p1.5", "p2", "p2.5", "p3", "p5" + a: Weighting parameter between 0 and 1. When a=0.5, equal weighting. + For weighted operations: a * model1 + (1-a) * model2 + + Returns: + A function that takes two log-probability arrays and returns their + weighted combination in log space. + + Raises: + ValueError: If a is not between 0 and 1, or if op is not recognized. + + Examples: + >>> op_func = convert_to_weighted_logop("sum", a=0.5) + >>> x = np.log(np.array([0.3, 0.7])) + >>> y = np.log(np.array([0.6, 0.4])) + >>> result = op_func(x, y) # Weighted arithmetic mean in log space + """ + if not 0 < a < 1: + raise ValueError("variable a should be between 0 and 1") + + log_a, log_1_minus_a = np.log(a), np.log(1 - a) + + # Power means - all follow the same pattern + if op in _POWER_MEANS: + return _power_mean(_POWER_MEANS[op], a) + + # Basic operations + operations = { + "sum": lambda x, y: logsumexp([x + log_a, y + log_1_minus_a], axis=0), + "prod": lambda x, y: a * x + (1 - a) * y, + "harmonic": lambda x, y: -logsumexp([-x + log_a, -y + log_1_minus_a], axis=0), + "min": _weighted_extremum(np.minimum, a), + "max": _weighted_extremum(np.maximum, a), + } + + if op in operations: + return operations[op] + + # If we get here, operation is invalid + valid_ops = list(operations.keys()) + list(_POWER_MEANS.keys()) + raise ValueError( + f"Invalid operation: {op}. Must be one of {', '.join(repr(o) for o in valid_ops)}." + ) From 1e865e1ebd40534e81abba4ffbff26d93a5dcfd1 Mon Sep 17 00:00:00 2001 From: samuki Date: Sun, 1 Feb 2026 01:36:19 +0100 Subject: [PATCH 2/6] Add byte level ensembling classes --- genlm/control/__init__.py | 14 ++ genlm/control/potential/built_in/ensemble.py | 5 +- genlm/control/sampler/__init__.py | 7 +- genlm/control/sampler/byte_ensemble.py | 220 +++++++++++++++++++ genlm/control/sampler/sequence.py | 122 ++++++++++ pyproject.toml | 1 + 6 files changed, 364 insertions(+), 5 deletions(-) create mode 100644 genlm/control/sampler/byte_ensemble.py diff --git a/genlm/control/__init__.py b/genlm/control/__init__.py index 2af3f83..6b323cc 100644 --- a/genlm/control/__init__.py +++ b/genlm/control/__init__.py @@ -9,13 +9,20 @@ WCFG, JsonSchema, CanonicalTokenization, + Ensemble, + convert_to_weighted_logop, + ByteEnsemble, ) from .sampler import ( SMC, + EnsembleSMC, + Sequences, + SequencesExt, direct_token_sampler, eager_token_sampler, topk_token_sampler, AWRS, + ByteEnsembleTokenSampler, ) from .viz import InferenceVisualizer @@ -23,6 +30,9 @@ "EOS", "EOT", "SMC", + "EnsembleSMC", + "Sequences", + "SequencesExt", "Potential", "PromptedLLM", "ByteLLM", @@ -37,4 +47,8 @@ "eager_token_sampler", "topk_token_sampler", "InferenceVisualizer", + "Ensemble", + "convert_to_weighted_logop", + "ByteEnsemble", + "ByteEnsembleTokenSampler", ] diff --git a/genlm/control/potential/built_in/ensemble.py b/genlm/control/potential/built_in/ensemble.py index 97221ee..1c64142 100644 --- a/genlm/control/potential/built_in/ensemble.py +++ b/genlm/control/potential/built_in/ensemble.py @@ -418,7 +418,6 @@ def extremum(x, y, a): return lambda x, y: extremum(x, y, a) -# Map operation names to their power values _POWER_MEANS = { "pm5": -5.0, "pm2.5": -2.5, @@ -492,11 +491,10 @@ def convert_to_weighted_logop( log_a, log_1_minus_a = np.log(a), np.log(1 - a) - # Power means - all follow the same pattern + # Ensemble operations if op in _POWER_MEANS: return _power_mean(_POWER_MEANS[op], a) - # Basic operations operations = { "sum": lambda x, y: logsumexp([x + log_a, y + log_1_minus_a], axis=0), "prod": lambda x, y: a * x + (1 - a) * y, @@ -508,7 +506,6 @@ def convert_to_weighted_logop( if op in operations: return operations[op] - # If we get here, operation is invalid valid_ops = list(operations.keys()) + list(_POWER_MEANS.keys()) raise ValueError( f"Invalid operation: {op}. Must be one of {', '.join(repr(o) for o in valid_ops)}." diff --git a/genlm/control/sampler/__init__.py b/genlm/control/sampler/__init__.py index ed5bbea..31bddcb 100644 --- a/genlm/control/sampler/__init__.py +++ b/genlm/control/sampler/__init__.py @@ -1,6 +1,6 @@ from .token import DirectTokenSampler, SetTokenSampler, AWRS, TokenSampler from .set import EagerSetSampler, TopKSetSampler -from .sequence import SMC, SequenceModel +from .sequence import SMC, EnsembleSMC, Sequences, SequencesExt, SequenceModel from .unit import ( MultiTokenUnitSampler, BoundaryPredicate, @@ -9,6 +9,7 @@ CFGBoundary, flatten_units, ) +from .byte_ensemble import ByteEnsembleTokenSampler from genlm.control.potential import Potential @@ -73,6 +74,9 @@ def topk_token_sampler(iter_potential, item_potential, K): "TokenSampler", "Importance", "SMC", + "EnsembleSMC", + "Sequences", + "SequencesExt", "SequenceModel", "MultiTokenUnitSampler", "BoundaryPredicate", @@ -80,4 +84,5 @@ def topk_token_sampler(iter_potential, item_potential, K): "FixedLengthBoundary", "CFGBoundary", "flatten_units", + "ByteEnsembleTokenSampler", ] diff --git a/genlm/control/sampler/byte_ensemble.py b/genlm/control/sampler/byte_ensemble.py new file mode 100644 index 0000000..be28a33 --- /dev/null +++ b/genlm/control/sampler/byte_ensemble.py @@ -0,0 +1,220 @@ +from typing import List, Literal, Tuple +from collections import defaultdict + +from cachetools import LRUCache +from arsenal.maths import logsumexp + +from genlm.control.sampler.token import TokenSampler +from genlm.control.util import fast_sample_logprobs +from genlm.control.constant import EOS + + +class ByteEnsembleTokenSampler(TokenSampler): + """ + Token sampler for byte-level ensemble using synchronized beam search. + + This sampler draws from an ensemble of two language models by advancing both + beam states synchronously with the same sampled token. This enables efficient + exploration with proper importance weighting for SMC. + + Unlike standard token samplers, ByteEnsembleTokenSampler: + - Directly accesses and manipulates beam states from ByteEnsemble + - Advances both beams with the same token (synchronized exploration) + - Tracks separate log probabilities for each model + - Uses shaping weights for proper SMC proposals + + Args: + potential (ByteEnsemble): The target byte-level ensemble potential. + proposal (Literal["linear", "abs", "square", "soft n"]): Proposal strategy. + Currently only "linear" is implemented. + n_particles (int): Number of particles for SMC sampling. Defaults to 10. + eos_tokens (List[int]): List of end-of-sequence tokens (as byte values). + max_tokens (int, optional): Maximum number of tokens to generate. + models_equal (bool): Flag indicating whether the two models are identical. + Defaults to False. + + Example: + ```python + from genlm.backend import load_model_by_name + from genlm.bytes import BeamParams + from genlm.control.potential.built_in import ByteEnsemble + from genlm.control.sampler.byte_ensemble import ByteEnsembleTokenSampler + + # Load models + llm1 = load_model_by_name("gpt2") + llm2 = load_model_by_name("gpt2") + + # Create ensemble + ensemble = await ByteEnsemble.create( + llm1, llm2, + op="prod", + prompt1=b"Hello ", + prompt2=b"Hello ", + a=0.5 + ) + + # Create sampler + eos_tokens = [llm1.byte_vocab[llm1.tokenizer.eos_token_id]] + sampler = ByteEnsembleTokenSampler( + ensemble, + max_tokens=100, + eos_tokens=eos_tokens, + n_particles=10 + ) + + # Run SMC sampling + result = await sampler.smc( + n_particles=10, + ess_threshold=0.5, + max_tokens=100 + ) + ``` + """ + + def __init__( + self, + potential, + proposal: Literal["linear", "abs", "square", "soft n"] = "linear", + n_particles: int = 10, + eos_tokens: List[int] = None, + max_tokens: int = None, + models_equal: bool = False, + ): + super().__init__(target=potential) + self.potential = potential + self.proposal = proposal + self.n_particles = n_particles + self.eos_tokens = eos_tokens or [] + self.max_tokens = max_tokens + self.models_equal = models_equal + + # LRU caches for prefix weights + self.prefix_cache_1 = LRUCache(maxsize=3 * n_particles) + self.prefix_cache_2 = LRUCache(maxsize=3 * n_particles) + + # Track final particle probabilities + self.particle_prefix_log_prob_1 = defaultdict(lambda: float("-inf")) + self.particle_prefix_log_prob_2 = defaultdict(lambda: float("-inf")) + + # Init empty context weights + self.prefix_cache_1[()] = 0.0 + self.prefix_cache_2[()] = 0.0 + + async def start_weight(self) -> float: + """Compute the weight of the empty sequence.""" + return 0.0 + + async def sample(self, context: List[int], draw=None) -> Tuple[int, float, float]: + """Sample one token from the ensemble distribution. + + This method: + 1. Fetches beam states for both models at the current context + 2. Gets next-token distributions from both beams + 3. Combines distributions using the ensemble operation + 4. Samples a token using the combined distribution + 5. Advances both beams synchronously with the sampled token + 6. Updates caches with new beam states and weights + + Args: + context (List[int]): Current context as list of byte values + draw (callable, optional): Drawing function (not used, for compatibility) + + Returns: + Tuple[int, float, float]: (token, log_weight, log_prob) + - token: Sampled byte value (or EOS) + - log_weight: Log importance weight for SMC + - log_prob: Log probability under proposal distribution + """ + # Get beam states + beam1, beam2 = await self.potential.get_beam_states(context) + logp_1, logp_2 = await beam1.logp_next(), await beam2.logp_next() + + # Get cached prefix weights + ctx_tuple = tuple(context) + log_context_weight_1 = self.prefix_cache_1[ctx_tuple] + log_context_weight_2 = self.prefix_cache_2[ctx_tuple] + + # Compute next-token weights + logws1 = log_context_weight_1 + logp_1.ps + logws2 = log_context_weight_2 + logp_2.ps + + # Compute shaping weight from previous context + log_shaping_weight_prev = ( + 0 + if not context + else self.potential.op(log_context_weight_1, log_context_weight_2) + ) + + # Combine weights using ensemble operation and compute proposal + proposal_weights = self.potential.op(logws1, logws2) - log_shaping_weight_prev + logps = proposal_weights - logsumexp(proposal_weights) + + # Sample token from proposal distribution + token_idx = fast_sample_logprobs(logps)[0] + + # Decode token from trie + token = beam1.states[0].trie.trie.decode[token_idx] + assert ( + token == beam2.states[0].trie.trie.decode[token_idx] + ), "Models must have aligned vocabularies" + + # Advance both beams synchronously with sampled token + next_context = ( + bytes(context + [token]) + if isinstance(token, int) + else bytes(context) + token + ) + self.potential.data_dict_1[next_context] = await (beam1.prune() << token) + self.potential.data_dict_2[next_context] = await (beam2.prune() << token) + + # Update prefix caches + new_ctx_tuple = ctx_tuple + (token,) + self.prefix_cache_1[new_ctx_tuple] = logws1[token_idx] + self.prefix_cache_2[new_ctx_tuple] = logws2[token_idx] + + # Handle EOS tokens + if token in self.eos_tokens: + token = EOS + + # Store final particle weights if sequence is complete + if token == EOS or (self.max_tokens and len(ctx_tuple) + 1 == self.max_tokens): + self.particle_prefix_log_prob_1[ctx_tuple + (token,)] = logws1[token_idx] + self.particle_prefix_log_prob_2[ctx_tuple + (token,)] = logws2[token_idx] + + # Return token, importance weight, and proposal log prob + return token, proposal_weights[token_idx] - logps[token_idx], logps[token_idx] + + async def smc( + self, + n_particles: int, + ess_threshold: float, + max_tokens: int, + critic=None, + **kwargs, + ): + """Run Sequential Monte Carlo inference with byte-level ensemble. + + This method requires EnsembleSMC to be available in the sampler.sequence module. + If not available, falls back to standard SMC. + + Args: + n_particles (int): Number of particles to maintain + ess_threshold (float): ESS threshold for resampling (0-1) + max_tokens (int): Maximum tokens per sequence + critic (Potential, optional): Critic potential for guided sampling + **kwargs: Additional arguments passed to SMC + + Returns: + Sequences or SequencesExt: Generated sequences with weights + + Raises: + ImportError: If required SMC components are not available + """ + from genlm.control.sampler.sequence import EnsembleSMC + + return await EnsembleSMC(self, critic)( + n_particles=n_particles, + ess_threshold=ess_threshold, + max_tokens=max_tokens, + **kwargs, + ) diff --git a/genlm/control/sampler/sequence.py b/genlm/control/sampler/sequence.py index fff7dc6..a777928 100644 --- a/genlm/control/sampler/sequence.py +++ b/genlm/control/sampler/sequence.py @@ -346,3 +346,125 @@ def _unpack_particles(particles): ), ) return contexts, logws + + +class EnsembleSMC(SMC): + """EnsembleSMC is a specialized version of SMC for ensemble sampling. + + This class extends SMC to track individual model weights when using + ByteEnsemble with ByteEnsembleTokenSampler. It returns SequencesExt + which includes log_prefix_weights_1 and log_prefix_weights_2. + + Args: + unit_sampler (TokenSampler): The sampler that generates tokens. + critic (Potential, optional): A potential function that guides the generation process. + """ + + async def __call__( + self, + n_particles, + ess_threshold, + max_tokens, + verbosity=0, + json_path=None, + **kwargs, + ): + """Generate sequences using SMC with ensemble weight tracking. + + Args: + n_particles (int): Number of particles to maintain. + ess_threshold (float): ESS threshold for resampling (0-1). + max_tokens (int): Maximum tokens to generate. + verbosity (int, optional): Verbosity level (0=silent, 1=verbose). + json_path (str, optional): Path to save inference visualization data. + **kwargs: Additional arguments for smc_standard. + + Returns: + (SequencesExt): Sequences with individual model weights. + """ + try: + original_max_tokens = self.model.max_tokens + original_verbosity = self.model.verbosity + original_twist_with_critic = self.model.twist_with_critic + self.model.max_tokens = max_tokens + self.model.verbosity = verbosity + self.model.twist_with_critic = ess_threshold > 0 + + particles = await smc_standard( + model=self.model, + n_particles=n_particles, + ess_threshold=ess_threshold, + json_file=json_path, + **kwargs, + ) + finally: + self.model.max_tokens = original_max_tokens + self.model.verbosity = original_verbosity + self.model.twist_with_critic = original_twist_with_critic + + # Extract individual model weights if available + log_prefix_weights_1 = [] + log_prefix_weights_2 = [] + + if hasattr(self.unit_sampler, "particle_prefix_log_prob_1"): + for p in particles: + ctx_tuple = tuple(p.token_ctx) + log_prefix_weights_1.append( + self.unit_sampler.particle_prefix_log_prob_1.get( + ctx_tuple, float("-inf") + ) + ) + + if hasattr(self.unit_sampler, "particle_prefix_log_prob_2"): + for p in particles: + ctx_tuple = tuple(p.token_ctx) + log_prefix_weights_2.append( + self.unit_sampler.particle_prefix_log_prob_2.get( + ctx_tuple, float("-inf") + ) + ) + + contexts, logws = map( + list, + zip( + *[ + (p.token_ctx, float("-inf") if np.isnan(p.weight) else p.weight) + for p in particles + ] + ), + ) + + return SequencesExt( + contexts, + logws, + log_prefix_weights_1, + log_prefix_weights_2, + ) + + +@dataclass +class SequencesExt(Sequences): + """Extended Sequences container for ensemble sampling. + + This class extends Sequences to track individual model weights + when using ByteEnsemble with ByteEnsembleTokenSampler. + + Args: + contexts (list): List of token sequences. + log_weights (list): Log importance weights for each sequence. + log_prefix_weights_1 (list): Log weights from first model. + log_prefix_weights_2 (list): Log weights from second model. + """ + + log_prefix_weights_1: list = None + log_prefix_weights_2: list = None + + def __post_init__(self): + super().__post_init__() + if self.log_prefix_weights_1 is not None: + if not isinstance(self.log_prefix_weights_1, np.ndarray): + self.log_prefix_weights_1 = np.array(self.log_prefix_weights_1) + + if self.log_prefix_weights_2 is not None: + if not isinstance(self.log_prefix_weights_2, np.ndarray): + self.log_prefix_weights_2 = np.array(self.log_prefix_weights_2) diff --git a/pyproject.toml b/pyproject.toml index 76fe16d..9a8f914 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "torch", "json-stream", "jsonschema[format-nongpl]", + "cachetools", ] [project.optional-dependencies] From f9b81163ccbb73a5bb8dd35ccee239a0698ff8ae Mon Sep 17 00:00:00 2001 From: samuki Date: Sun, 1 Feb 2026 18:23:54 +0100 Subject: [PATCH 3/6] Adjust parameters and add tests --- genlm/control/potential/built_in/ensemble.py | 33 +- genlm/control/sampler/byte_ensemble.py | 3 +- tests/potential/test_ensemble.py | 694 +++++++++++++++++++ 3 files changed, 721 insertions(+), 9 deletions(-) create mode 100644 tests/potential/test_ensemble.py diff --git a/genlm/control/potential/built_in/ensemble.py b/genlm/control/potential/built_in/ensemble.py index 1c64142..b637937 100644 --- a/genlm/control/potential/built_in/ensemble.py +++ b/genlm/control/potential/built_in/ensemble.py @@ -7,6 +7,7 @@ from arsenal.maths import logsumexp from genlm.control.potential.base import Potential +from genlm.bytes import ByteBeamState, BeamParams class Ensemble(Potential): @@ -57,6 +58,16 @@ def __init__(self, p1, p2, op, a=0.5): self.p1 = p1 self.p2 = p2 self.op = convert_to_weighted_logop(op, a) + + # Warn if potentials have different vocabularies + if set(p1.vocab) != set(p2.vocab): + warnings.warn( + "Ensemble is being used with potentials that have different vocabularies. " + "Consider using ByteEnsemble instead.", + UserWarning, + stacklevel=2, + ) + vocab = list(set(p1.vocab + p2.vocab)) super().__init__(vocabulary=vocab) @@ -188,7 +199,7 @@ def split_with_atomic_tokens( f"Overlapping atomic tokens detected: {token1!r} and {token2!r}. " "Only the shortest matching prefix will be used." ) - break # One warning is enough + break result = [] i = 0 @@ -268,7 +279,16 @@ def __init__( @classmethod async def create( - cls, llm1, llm2, op: str, prompt1: bytes, prompt2: bytes, a: float = 0.5 + cls, + llm1, + llm2, + op: str, + prompt1: bytes, + prompt2: bytes, + a: float = 0.5, + K: int = 5, + prune_threshold: float = 0.0, + verbose: bool = False, ): """Factory method to initialize beam states from prompts and return a ByteEnsemble instance. @@ -279,6 +299,9 @@ async def create( prompt1: Prompt bytes for first model prompt2: Prompt bytes for second model a: Weighting parameter between 0 and 1 (default 0.5 for equal weighting) + K: Beam width for beam search (default 5) + prune_threshold: Threshold for pruning low-probability beams (default 0.0) + verbose: Whether to print verbose beam search output (default False) Returns: ByteEnsemble: Initialized ensemble with beam states ready for sampling @@ -286,17 +309,13 @@ async def create( Raises: RuntimeError: If beam states become empty after prefill """ - from genlm.bytes import ByteBeamState, BeamParams - # Use reasonable beam parameters - K=5 with moderate pruning - # K=1 was too aggressive and caused empty beams - beam_params = BeamParams(K=5, prune_threshold=0.0, verbose=False) + beam_params = BeamParams(K=K, prune_threshold=prune_threshold, verbose=verbose) data_dict_1 = defaultdict() data_dict_2 = defaultdict() async def setup(): # Initialize beams sequentially to avoid overwhelming vLLM with concurrent requests - # This can help prevent "Background loop has errored" errors beam1 = await ByteBeamState.initial(llm1, beam_params) beam2 = await ByteBeamState.initial(llm2, beam_params) # Prefill sequentially as well to reduce concurrent load diff --git a/genlm/control/sampler/byte_ensemble.py b/genlm/control/sampler/byte_ensemble.py index be28a33..c267964 100644 --- a/genlm/control/sampler/byte_ensemble.py +++ b/genlm/control/sampler/byte_ensemble.py @@ -7,6 +7,7 @@ from genlm.control.sampler.token import TokenSampler from genlm.control.util import fast_sample_logprobs from genlm.control.constant import EOS +from genlm.control.sampler.sequence import EnsembleSMC class ByteEnsembleTokenSampler(TokenSampler): @@ -210,8 +211,6 @@ async def smc( Raises: ImportError: If required SMC components are not available """ - from genlm.control.sampler.sequence import EnsembleSMC - return await EnsembleSMC(self, critic)( n_particles=n_particles, ess_threshold=ess_threshold, diff --git a/tests/potential/test_ensemble.py b/tests/potential/test_ensemble.py new file mode 100644 index 0000000..e8dc807 --- /dev/null +++ b/tests/potential/test_ensemble.py @@ -0,0 +1,694 @@ +import pytest +import numpy as np +from genlm.backend import load_model_by_name +from genlm.control import ( + Ensemble, + ByteEnsemble, + ByteEnsembleTokenSampler, + Potential, + PromptedLLM, + convert_to_weighted_logop, +) +from conftest import MockPotential + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def mock_vocab(): + """Simple vocabulary for testing.""" + return ["a", "b", "c", "d"] + + +@pytest.fixture +def mock_potential_1(mock_vocab): + """Create a mock potential with predefined probabilities.""" + logws = np.log([0.4, 0.3, 0.2, 0.1, 0.001]) + return MockPotential(vocab=mock_vocab, next_token_logws=logws) + + +@pytest.fixture +def mock_potential_2(mock_vocab): + """Create a second mock potential with different probabilities.""" + logws = np.log([0.1, 0.2, 0.3, 0.4, 0.001]) + return MockPotential(vocab=mock_vocab, next_token_logws=logws) + + +# ============================================================================ +# Test Basic Initialization & API +# ============================================================================ + + +@pytest.mark.asyncio +async def test_ensemble_initialization(mock_potential_1, mock_potential_2): + """Test that Ensemble initializes correctly.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="prod", a=0.5) + assert isinstance(ensemble, Potential) + assert ensemble.p1 is mock_potential_1 + assert ensemble.p2 is mock_potential_2 + assert len(ensemble.vocab) == 4 + + +@pytest.mark.asyncio +async def test_ensemble_logws_next(mock_potential_1, mock_potential_2): + """Test that logws_next returns separate log weights from both potentials.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="prod", a=0.5) + p1_logw, p2_logw = await ensemble.logws_next([]) + assert hasattr(p1_logw, "weights") + assert hasattr(p2_logw, "weights") + assert p1_logw.weights.shape == p2_logw.weights.shape + + +@pytest.mark.asyncio +async def test_ensemble_logw_next_raises(mock_potential_1, mock_potential_2): + """Test that logw_next raises NotImplementedError.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="prod", a=0.5) + with pytest.raises(NotImplementedError): + await ensemble.logw_next([]) + + +@pytest.mark.asyncio +async def test_ensemble_batch_logw_next(mock_potential_1, mock_potential_2): + """Test batch_logw_next combines weights from both potentials.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="prod", a=0.5) + results = await ensemble.batch_logw_next([[], ["a"]]) + assert len(results) == 2 + for result in results: + assert hasattr(result, "weights") + assert len(result.weights) == len(ensemble.vocab_eos) + + +@pytest.mark.asyncio +async def test_ensemble_prefix_geometric_mean(mock_potential_1, mock_potential_2): + """Test ensemble prefix with product.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="prod", a=0.5) + logw = await ensemble.prefix([]) + # For product with a=0.5: result = 0.5 * log(p1) + 0.5 * log(p2) + p1_logw = await mock_potential_1.prefix([]) + p2_logw = await mock_potential_2.prefix([]) + expected = 0.5 * p1_logw + 0.5 * p2_logw + np.testing.assert_allclose(logw, expected, rtol=1e-5) + + +@pytest.mark.asyncio +async def test_ensemble_prefix_arithmetic_mean(mock_potential_1, mock_potential_2): + """Test ensemble prefix with sum.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="sum", a=0.5) + logw = await ensemble.prefix([]) + # For sum with a=0.5: result = log(0.5 * exp(p1) + 0.5 * exp(p2)) + p1_logw = await mock_potential_1.prefix([]) + p2_logw = await mock_potential_2.prefix([]) + expected = np.logaddexp(np.log(0.5) + p1_logw, np.log(0.5) + p2_logw) + np.testing.assert_allclose(logw, expected, rtol=1e-5) + + +@pytest.mark.asyncio +async def test_ensemble_with_context(): + """Test ensemble operations with non-empty context.""" + mock_vocab = ["a", "b", "c", "d"] + logws1 = np.log([0.4, 0.3, 0.2, 0.1, 0.001]) + logws2 = np.log([0.1, 0.2, 0.3, 0.4, 0.001]) + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws2) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + context = ["a", "b"] + logw = await ensemble.prefix(context) + assert isinstance(logw, (int, float, np.number)) + assert np.isfinite(logw) + complete_logw = await ensemble.complete(context) + assert isinstance(complete_logw, (int, float, np.number)) + assert np.isfinite(complete_logw) + + +@pytest.mark.asyncio +async def test_ensemble_consistency(): + """Test that ensemble computations are consistent across multiple calls.""" + mock_vocab = ["x", "y"] + logws = np.log([0.6, 0.4, 0.001]) + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + results = [await ensemble.prefix(["x"]) for _ in range(3)] + assert all(np.isclose(results[0], r) for r in results) + + +# ============================================================================ +# Test Ensemble Operations +# ============================================================================ + + +@pytest.mark.asyncio +async def test_token_ensemble_different_operations(): + """Test different ensemble operations with mock models.""" + mock_vocab = ["a", "b", "c"] + logws1 = np.array([np.log(0.6), np.log(0.3), np.log(0.1), -100.0]) + logws2 = np.array([np.log(0.2), np.log(0.5), np.log(0.3), -100.0]) + + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws2) + + # Prod: 0.5 * log(p1) + 0.5 * log(p2) + ensemble_prod = Ensemble(p1, p2, op="prod", a=0.5) + result_prod = await ensemble_prod.batch_logw_next([[]]) + for tok in mock_vocab: + idx_ens = ensemble_prod.lookup[tok] + idx_p1 = p1.lookup[tok] + expected_val = 0.5 * logws1[idx_p1] + 0.5 * logws2[idx_p1] + assert result_prod[0].weights[idx_ens] == pytest.approx(expected_val, abs=1e-6) + + # Sum: log(0.5 * exp(log(p1)) + 0.5 * exp(log(p2))) + ensemble_sum = Ensemble(p1, p2, op="sum", a=0.5) + result_sum = await ensemble_sum.batch_logw_next([[]]) + + for tok in mock_vocab: + idx_ens = ensemble_sum.lookup[tok] + idx_p1 = p1.lookup[tok] + expected_val = np.logaddexp( + np.log(0.5) + logws1[idx_p1], np.log(0.5) + logws2[idx_p1] + ) + assert result_sum[0].weights[idx_ens] == pytest.approx(expected_val, abs=1e-6) + + # Min: For a=0.5, should be close to actual minimum + ensemble_min = Ensemble(p1, p2, op="min", a=0.5) + result_min = await ensemble_min.batch_logw_next([[]]) + + for tok in mock_vocab: + idx_ens = ensemble_min.lookup[tok] + idx_p1 = p1.lookup[tok] + expected_val = np.minimum(logws1[idx_p1], logws2[idx_p1]) + assert result_min[0].weights[idx_ens] == pytest.approx(expected_val, abs=0.5) + + # Max: For a=0.5, should be close to actual maximum + ensemble_max = Ensemble(p1, p2, op="max", a=0.5) + result_max = await ensemble_max.batch_logw_next([[]]) + + for tok in mock_vocab: + idx_ens = ensemble_max.lookup[tok] + idx_p1 = p1.lookup[tok] + expected_val = np.maximum(logws1[idx_p1], logws2[idx_p1]) + assert result_max[0].weights[idx_ens] == pytest.approx(expected_val, abs=0.5) + + +@pytest.mark.asyncio +async def test_ensemble_all_power_means(): + """Test all supported power mean operations.""" + mock_vocab = ["a", "b"] + logws1 = np.log([0.7, 0.3, 0.001]) # Model 1 prefers 'a' + logws2 = np.log([0.3, 0.7, 0.001]) # Model 2 prefers 'b' + + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws2) + + power_means = [ + "pm5", + "pm2.5", + "p-2", + "pm1.5", + "pm0.5", + "pm0.25", + "p0.25", + "p0.5", + "p1.5", + "p2", + "p2.5", + "p3", + "p5", + ] + for op in power_means: + ensemble = Ensemble(p1, p2, op=op, a=0.5) + logw = await ensemble.prefix([]) + assert isinstance(logw, (int, float, np.number)) + assert np.isfinite(logw) + + +# ============================================================================ +# Test Weighting & Parameters +# ============================================================================ + + +@pytest.mark.asyncio +async def test_ensemble_weighting_affects_output(): + """Verify that changing the weighting parameter affects the output.""" + mock_vocab = ["a", "b"] + logws1 = np.array([0.0, -5.0, -100.0]) # Model 1 prefers 'a' + logws2 = np.array([-5.0, 0.0, -100.0]) # Model 2 prefers 'b' + + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws2) + + ensemble_50 = Ensemble(p1, p2, op="prod", a=0.5) + result_50 = await ensemble_50.batch_logw_next([[]]) + ensemble_80 = Ensemble(p1, p2, op="prod", a=0.8) # Weight (0.8) on model 1 + result_80 = await ensemble_80.batch_logw_next([[]]) + ensemble_20 = Ensemble(p1, p2, op="prod", a=0.2) # Weight (0.2) on model 2 + result_20 = await ensemble_20.batch_logw_next([[]]) + + logws_50 = result_50[0].weights + logws_80 = result_80[0].weights + logws_20 = result_20[0].weights + assert not np.allclose(logws_50, logws_80, rtol=1e-5) + assert not np.allclose(logws_50, logws_20, rtol=1e-5) + assert not np.allclose(logws_80, logws_20, rtol=1e-5) + + a_idx_50 = ensemble_50.lookup["a"] + b_idx_50 = ensemble_50.lookup["b"] + a_idx_80 = ensemble_80.lookup["a"] + b_idx_20 = ensemble_20.lookup["b"] + assert logws_80[a_idx_80] > logws_50[a_idx_50] + assert logws_20[b_idx_20] > logws_50[b_idx_50] + diff_50 = abs(logws_50[a_idx_50] - logws_50[b_idx_50]) + diff_80 = logws_80[a_idx_80] - logws_80[a_idx_80 if a_idx_80 == 0 else 1 - a_idx_80] + diff_20 = logws_20[b_idx_20] - logws_20[b_idx_20 if b_idx_20 == 0 else 1 - b_idx_20] + assert abs(diff_80) > diff_50 or abs(diff_20) > diff_50 + + +@pytest.mark.asyncio +async def test_ensemble_with_differently_conditioned_models(): + """Test ensemble with different weighting simulating different prompt strategies.""" + vocab = ["SELECT", "FROM", "WHERE", "JOIN"] + logws1 = np.array([0.0, -1.0, -2.0, -3.0, -100.0]) + logws2 = np.array([-3.0, -2.0, -1.0, 0.0, -100.0]) + p1 = MockPotential(vocab=vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=vocab, next_token_logws=logws2) + ensemble_balanced = Ensemble( + p1, p2, op="prod", a=0.5 + ) # Ensemble with different weights + ensemble_favor_p1 = Ensemble(p1, p2, op="prod", a=0.7) + + result_balanced = await ensemble_balanced.batch_logw_next([[]]) + result_favor_p1 = await ensemble_favor_p1.batch_logw_next([[]]) + combined_balanced = result_balanced[0].weights + combined_favor_p1 = result_favor_p1[0].weights + + select_idx_bal = ensemble_balanced.lookup[ + "SELECT" + ] # When favoring p1, SELECT is more likely + select_idx_fav = ensemble_favor_p1.lookup["SELECT"] + join_idx_bal = ensemble_balanced.lookup[ + "JOIN" + ] # When favoring p1, JOIN is less likely + join_idx_fav = ensemble_favor_p1.lookup["JOIN"] + + assert combined_favor_p1[select_idx_fav] > combined_balanced[select_idx_bal] + assert combined_favor_p1[join_idx_fav] < combined_balanced[join_idx_bal] + + +# ============================================================================ +# Test Vocabulary Handling +# ============================================================================ + + +@pytest.mark.asyncio +async def test_ensemble_warns_on_different_vocabularies(): + """Test Ensemble warns when using potentials with different vocabularies.""" + vocab1 = ["a", "b", "c", "d"] + vocab2 = ["a", "b", "x", "y"] + logws1 = np.log([0.25, 0.25, 0.25, 0.25, 0.001]) + logws2 = np.log([0.25, 0.25, 0.25, 0.25, 0.001]) + p1 = MockPotential(vocab=vocab1, next_token_logws=logws1) + p2 = MockPotential(vocab=vocab2, next_token_logws=logws2) + with pytest.warns(UserWarning, match="different vocabularies"): + with pytest.raises((KeyError, AssertionError)): + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + + +@pytest.mark.asyncio +async def test_ensemble_vocab_alignment(mock_vocab): + """Test that ensemble handles vocabulary alignment correctly.""" + logws = np.log([0.25, 0.25, 0.25, 0.25, 0.001]) + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws) + + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + + # Vocabulary indices should be correctly aligned + assert len(ensemble.p1_vocab_idxs) == len(ensemble.vocab_eos) + assert len(ensemble.p2_vocab_idxs) == len(ensemble.vocab_eos) + assert ensemble.p1_vocab_idxs == ensemble.p2_vocab_idxs + + +@pytest.mark.asyncio +async def test_ensemble_respects_vocab_alignment(): + """Verify ensemble correctly handles vocabulary alignment with reordering.""" + vocab = ["x", "y", "z"] + logws1 = np.array([0.0, -1.0, -2.0, -100.0]) + logws2 = np.array([-1.0, 0.0, -2.0, -100.0]) + + p1 = MockPotential(vocab=vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=vocab, next_token_logws=logws2) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + result = await ensemble.batch_logw_next([[]]) + combined = result[0].weights + + # Each token should get correct combined weight + for tok in ["x", "y", "z"]: + ensemble_idx = ensemble.lookup[tok] + p1_idx = p1.lookup[tok] + p2_idx = p2.lookup[tok] + expected = 0.5 * logws1[p1_idx] + 0.5 * logws2[p2_idx] + actual = combined[ensemble_idx] + assert actual == pytest.approx(expected, abs=1e-5), f"Token {tok} mismatch" + + +# ============================================================================ +# Test Integration with Real Models (GPT-2) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_token_ensemble_with_different_prompts(): + """Test token-level Ensemble with different prompts - basic functionality check.""" + llm1 = PromptedLLM.from_name("gpt2") + llm2 = PromptedLLM.from_name("gpt2") + llm1.set_prompt_from_str("Write a SQL query: ") + llm2.set_prompt_from_str("SQL code: ") + + ensemble = Ensemble(llm1, llm2, op="prod", a=0.5) + assert ensemble.p1 is llm1 + assert ensemble.p2 is llm2 + + ensemble_result = await ensemble.batch_logw_next([[]]) + ensemble_logws = ensemble_result[0].weights + + assert len(ensemble_logws) > 0 + assert np.all(np.isfinite(ensemble_logws)) + assert len(ensemble_logws) == len( + ensemble.vocab_eos + ) # Should have same length as vocab + + +@pytest.mark.asyncio +async def test_token_ensemble_complementary_prompts(): + """Test token-level Ensemble combining complementary prompting strategies.""" + llm1 = PromptedLLM.from_name("gpt2") + llm2 = PromptedLLM.from_name("gpt2") + + llm1.set_prompt_from_str("Task: Generate structured SQL.\n") + llm2.set_prompt_from_str("Task: Generate correct SQL.\n") + + ensemble = Ensemble(llm1, llm2, op="prod", a=0.5) + p1_result = await llm1.batch_logw_next([[]]) + p2_result = await llm2.batch_logw_next([[]]) + ensemble_result = await ensemble.batch_logw_next([[]]) + + p1_logws = p1_result[0].weights + p2_logws = p2_result[0].weights + ensemble_logws = ensemble_result[0].weights + + assert not np.allclose(ensemble_logws, p1_logws, rtol=0.1) + assert not np.allclose(ensemble_logws, p2_logws, rtol=0.1) + assert np.all(np.isfinite(ensemble_logws)) + + +# ============================================================================ +# Test ByteEnsemble +# ============================================================================ + + +@pytest.mark.asyncio +async def test_byte_ensemble_creation(): + """Test ByteEnsemble creation with identical prompts.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + prompt = b"Test" + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=prompt, prompt2=prompt, a=0.5 + ) + assert ensemble.p1 is llm1 + assert ensemble.p2 is llm2 + assert len(ensemble.vocab) == 256 + assert isinstance(ensemble.vocab, list) + assert all(isinstance(v, int) and 0 <= v < 256 for v in ensemble.vocab) + assert b"" in ensemble.data_dict_1 + assert b"" in ensemble.data_dict_2 + + +@pytest.mark.asyncio +async def test_byte_ensemble_different_prompts(): + """Test ByteEnsemble with different prompts - the key use case for ensembling.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + + prompt1 = b"Write a SQL query to find all users: " + prompt2 = b"SQL: Find all users in the database: " + + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=prompt1, prompt2=prompt2, a=0.5 + ) + assert ensemble.p1 is llm1 + assert ensemble.p2 is llm2 + assert len(ensemble.vocab) == 256 + assert b"" in ensemble.data_dict_1 + assert b"" in ensemble.data_dict_2 + beam1, beam2 = await ensemble.get_beam_states([]) + assert beam1 is not None + assert beam2 is not None + + +@pytest.mark.asyncio +async def test_byte_ensemble_get_beam_states(): + """Test that ByteEnsemble.get_beam_states() provides access to beams.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=b"Hi", prompt2=b"Hello", a=0.5 + ) + beam1, beam2 = await ensemble.get_beam_states([]) + assert beam1 is not None + assert beam2 is not None + assert hasattr(beam1, "states") + assert hasattr(beam2, "states") + assert len(beam1) > 0 + assert len(beam2) > 0 + + +@pytest.mark.asyncio +async def test_byte_ensemble_token_sampler_initialization(): + """Test ByteEnsembleTokenSampler initialization with different prompts.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=b"Answer: ", prompt2=b"Response: ", a=0.5 + ) + eos_tokens = [llm1.byte_vocab[llm1.tokenizer.eos_token_id]] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=50, eos_tokens=eos_tokens, n_particles=5 + ) + assert sampler.potential is ensemble + assert sampler.max_tokens == 50 + assert sampler.eos_tokens == eos_tokens + assert sampler.n_particles == 5 + # check caches + assert () in sampler.prefix_cache_1 + assert () in sampler.prefix_cache_2 + assert sampler.prefix_cache_1[()] == 0.0 + assert sampler.prefix_cache_2[()] == 0.0 + + +@pytest.mark.asyncio +async def test_byte_ensemble_sampler_sample(): + """Test ByteEnsembleTokenSampler samples with different prompts.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=b"The cat is ", prompt2=b"A cat is ", a=0.5 + ) + eos_tokens = [llm1.byte_vocab[llm1.tokenizer.eos_token_id]] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=10, eos_tokens=eos_tokens, n_particles=3 + ) + token, logw, logp = await sampler.sample([]) + assert isinstance(token, (int, bytes)) + assert isinstance(logw, (int, float, np.number)) + assert isinstance(logp, (int, float, np.number)) + assert np.isfinite(logw) + assert np.isfinite(logp) + if isinstance(token, int): + next_context_bytes = bytes([token]) + else: + next_context_bytes = token + + assert next_context_bytes in ensemble.data_dict_1 + assert next_context_bytes in ensemble.data_dict_2 + + +@pytest.mark.asyncio +async def test_byte_ensemble_weighted_different_prompts(): + """Test ByteEnsemble with unequal weights on different prompts.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + + prompt1 = b"Correct approach: " + prompt2 = b"Alternative: " + + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=prompt1, prompt2=prompt2, a=0.7 + ) + + assert ensemble.p1 is llm1 + assert ensemble.p2 is llm2 + + eos_tokens = [llm1.byte_vocab[llm1.tokenizer.eos_token_id]] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=5, eos_tokens=eos_tokens, n_particles=2 + ) + token, logw, logp = await sampler.sample([]) + assert isinstance(token, (int, bytes)) + assert np.isfinite(logw) + + +# ============================================================================ +# Test Realistic Ensemble Applications +# ============================================================================ + + +@pytest.mark.asyncio +async def test_ensemble_with_different_model_preferences(): + """Test ensemble where models have same vocab but different preferences.""" + vocab = ["a", "b", "c", "d"] + logws1 = np.array([0.0, -0.5, -2.0, -3.0, -100.0]) # prefers 'a' > 'b' > 'c' > 'd' + logws2 = np.array([-3.0, -2.0, -0.5, 0.0, -100.0]) # prefers 'd' > 'c' > 'b' > 'a' + p1 = MockPotential(vocab=vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=vocab, next_token_logws=logws2) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + + for tok in vocab: + assert tok in ensemble.vocab + result = await ensemble.batch_logw_next([[]]) + combined = result[0].weights + + for tok in vocab: + ensemble_idx = ensemble.lookup[tok] + p1_idx = p1.lookup[tok] + p2_idx = p2.lookup[tok] + expected = 0.5 * logws1[p1_idx] + 0.5 * logws2[p2_idx] + actual = combined[ensemble_idx] + assert actual == pytest.approx(expected, abs=1e-5), f"Token {tok} mismatch" + + b_idx = ensemble.lookup["b"] + c_idx = ensemble.lookup["c"] + a_idx = ensemble.lookup["a"] + d_idx = ensemble.lookup["d"] + assert combined[b_idx] > min(combined[a_idx], combined[d_idx]) + assert combined[c_idx] > min(combined[a_idx], combined[d_idx]) + + +@pytest.mark.asyncio +async def test_ensemble_with_complementary_knowledge(): + """Test ensemble where models show different performance on different tokens.""" + vocab1 = ["SELECT", "FROM", "WHERE", "LIMIT"] + logws1 = np.array( + [ + np.log(0.4), # SELECT: confident + np.log(0.3), # FROM: confident + np.log(0.2), # WHERE: confident + np.log(0.1), # LIMIT: less confident + -100.0, + ] + ) + vocab2 = ["SELECT", "FROM", "WHERE", "LIMIT"] + logws2 = np.array( + [ + np.log(0.1), # SELECT: not confident + np.log(0.2), # FROM: not confident + np.log(0.3), # WHERE: somewhat confident + np.log(0.4), # LIMIT: very confident + -100.0, + ] + ) + + p1 = MockPotential(vocab=vocab1, next_token_logws=logws1) + p2 = MockPotential(vocab=vocab2, next_token_logws=logws2) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + result = await ensemble.batch_logw_next([[]]) + combined = result[0].weights + + for tok in vocab1: + idx = ensemble.lookup[tok] + prob = np.exp(combined[idx]) + assert prob > 0.05, f"{tok} should have reasonable probability in ensemble" + assert prob < 0.95, f"{tok} shouldn't dominate in balanced ensemble" + + p1_result = await p1.batch_logw_next([[]]) + p2_result = await p2.batch_logw_next([[]]) + + assert not np.allclose(combined, p1_result[0].weights, rtol=0.1) + assert not np.allclose(combined, p2_result[0].weights, rtol=0.1) + + +@pytest.mark.asyncio +async def test_ensemble_helps_uncertain_model(): + """Test that ensembling helps when one model is uncertain but the other is confident.""" + mock_vocab = ["correct", "wrong1", "wrong2"] + logws1 = np.array( + [np.log(0.33), np.log(0.33), np.log(0.34), -100.0] + ) # Model 1 is uncertain + logws2 = np.array( + [np.log(0.9), np.log(0.05), np.log(0.05), -100.0] + ) # Model 2 is confident + + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws2) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + result = await ensemble.batch_logw_next([[]]) + combined = result[0].weights + + correct_idx = ensemble.lookup["correct"] + wrong1_idx = ensemble.lookup["wrong1"] + # ensemble should favor 'correct' more than model 1 alone + p1_result = await p1.batch_logw_next([[]]) + p1_logws = p1_result[0].weights + p1_correct = p1_logws[p1.lookup["correct"]] + p1_wrong1 = p1_logws[p1.lookup["wrong1"]] + ensemble_correct = combined[correct_idx] + ensemble_wrong1 = combined[wrong1_idx] + + # Ensemble should have stronger preference for 'correct' than uncertain model 1 + p1_gap = p1_correct - p1_wrong1 + ensemble_gap = ensemble_correct - ensemble_wrong1 + assert ( + ensemble_gap > p1_gap + ), "Ensemble should be more confident than uncertain model" + + # But less confident than model 2 alone + p2_result = await p2.batch_logw_next([[]]) + p2_logws = p2_result[0].weights + p2_gap = p2_logws[p2.lookup["correct"]] - p2_logws[p2.lookup["wrong1"]] + assert ( + ensemble_gap < p2_gap + ), "Ensemble should be less confident than very confident model" + + +# ============================================================================ +# Test Utility Functions +# ============================================================================ + + +def test_convert_to_weighted_logop_invalid_a(): + """Test that invalid 'a' parameter raises ValueError.""" + with pytest.raises(ValueError, match="variable a should be between 0 and 1"): + convert_to_weighted_logop("prod", a=1.5) + + with pytest.raises(ValueError, match="variable a should be between 0 and 1"): + convert_to_weighted_logop("prod", a=-0.1) + + +def test_convert_to_weighted_logop_invalid_op(): + """Test that invalid operation raises ValueError.""" + with pytest.raises(ValueError, match="Invalid operation"): + convert_to_weighted_logop("invalid_op", a=0.5) + + +def test_convert_to_weighted_logop_operations(): + """Test convert_to_weighted_logop returns correct operation for prod.""" + x = np.log(np.array([0.3, 0.7])) + y = np.log(np.array([0.6, 0.4])) + + # Test prod with analytical verification + op_prod = convert_to_weighted_logop("prod", a=0.5) + result_prod = op_prod(x, y) + expected_prod = 0.5 * x + 0.5 * y + np.testing.assert_allclose(result_prod, expected_prod, rtol=1e-5) From 6ec539d5c2c11d1b9c2a4abeaceffbf41b7e2253 Mon Sep 17 00:00:00 2001 From: samuki Date: Sun, 1 Feb 2026 22:27:10 +0100 Subject: [PATCH 4/6] Update docstrings and types --- genlm/control/potential/built_in/ensemble.py | 130 ++++++++++++------- genlm/control/sampler/byte_ensemble.py | 12 +- genlm/control/sampler/sequence.py | 18 +-- tests/potential/test_ensemble.py | 2 +- 4 files changed, 101 insertions(+), 61 deletions(-) diff --git a/genlm/control/potential/built_in/ensemble.py b/genlm/control/potential/built_in/ensemble.py index b637937..c4d05b5 100644 --- a/genlm/control/potential/built_in/ensemble.py +++ b/genlm/control/potential/built_in/ensemble.py @@ -1,7 +1,7 @@ import asyncio import warnings import numpy as np -from typing import Callable, List, Literal, Union +from typing import Any, Callable, List, Literal, Tuple, Union from collections import defaultdict from arsenal.maths import logsumexp @@ -18,12 +18,12 @@ class Ensemble(Potential): (e.g., weighted geometric mean, arithmetic mean, min, max, etc.). Args: - p1: First potential (language model) - p2: Second potential (language model) - op: Operation name (e.g., "sum", "prod", "min", "max", "harmonic", or power means) - a: Weighting parameter between 0 and 1 (default 0.5 for equal weighting). - When a=0.5, models are weighted equally. For a != 0.5, the combination - is weighted: a * model1 + (1-a) * model2 + p1 (Potential): First potential (language model) + p2 (Potential): Second potential (language model) + op (str): Operation name (e.g., "sum", "prod", "min", "max", "harmonic", or power means) + a (float): Weighting parameter between 0 and 1 (default 0.5 for equal weighting). + When a=0.5, models are weighted equally. For a != 0.5, the combination + is weighted: a * model1 + (1-a) * model2 Attributes: p1: First potential @@ -54,7 +54,13 @@ class Ensemble(Potential): log weights from each model, or use batch_logw_next for batched operations. """ - def __init__(self, p1, p2, op, a=0.5): + def __init__( + self, + p1: Potential, + p2: Potential, + op: str, + a: float = 0.5, + ): self.p1 = p1 self.p2 = p2 self.op = convert_to_weighted_logop(op, a) @@ -75,51 +81,51 @@ def __init__(self, p1, p2, op, a=0.5): self.p2_vocab_idxs = [self.p2.lookup[x] for x in self.vocab_eos] assert self.p1_vocab_idxs == self.p2_vocab_idxs - async def prefix(self, context): + async def prefix(self, context: List[str]) -> float: """Compute log weights for the prefix using both potentials. Args: - context: The context tokens + context (List[str]): The context tokens Returns: - Combined log weight from both potentials using the ensemble operation + float: Combined log weight from both potentials using the ensemble operation """ p1_logw, p2_logw = await asyncio.gather( self.p1.prefix(context), self.p2.prefix(context) ) return self.op(p1_logw, p2_logw) - async def complete(self, context): + async def complete(self, context: List[str]) -> float: """Compute completion log weights using both potentials. Args: - context: The context tokens + context (List[str]): The context tokens Returns: - Combined completion log weight from both potentials + float: Combined completion log weight from both potentials """ p1_logw, p2_logw = await asyncio.gather( self.p1.complete(context), self.p2.complete(context) ) return self.op(p1_logw, p2_logw) - async def logws_next(self, context): + async def logws_next(self, context: List[str]) -> Tuple[Any, Any]: """Get log weights from both potentials separately. This method returns the log weights from both underlying potentials without combining them. Useful for custom combination logic. Args: - context: The context tokens + context (List[str]): The context tokens Returns: - Tuple of (p1_logw_next, p2_logw_next) + Tuple[Any, Any]: Tuple of (p1_logw_next, p2_logw_next) """ return await asyncio.gather( self.p1.logw_next(context), self.p2.logw_next(context) ) - async def logw_next(self, context): + async def logw_next(self, context: List[str]): """Not implemented for Ensemble class. Raises: @@ -127,7 +133,7 @@ async def logw_next(self, context): """ raise NotImplementedError("logw_next is not implemented for Ensemble class.") - async def batch_logw_next(self, contexts): + async def batch_logw_next(self, contexts: List[List[str]]) -> List[Any]: """Batched version of logw_next for Ensemble. This enables batching when multiple particles need to be extended during SMC, @@ -135,11 +141,11 @@ async def batch_logw_next(self, contexts): batch_logw_next support. Args: - contexts: List of context token sequences + contexts (List[List[str]]): List of context token sequences Returns: - List of LazyWeights objects, one per context, containing the combined - log weights from both potentials + List[Any]: List of LazyWeights objects, one per context, containing the combined + log weights from both potentials Note: This method is only used if the Ensemble is wrapped in AutoBatchedPotential or @@ -267,7 +273,14 @@ class ByteEnsemble(Potential): """ def __init__( - self, p1, p2, op: Callable, data_dict_1, data_dict_2, vocab, eos_tokens + self, + p1: Any, + p2: Any, + op: Callable, + data_dict_1: dict, + data_dict_2: dict, + vocab: List[int], + eos_tokens: List[bytes], ): self.p1 = p1 self.p2 = p2 @@ -280,8 +293,8 @@ def __init__( @classmethod async def create( cls, - llm1, - llm2, + llm1: Any, + llm2: Any, op: str, prompt1: bytes, prompt2: bytes, @@ -289,19 +302,19 @@ async def create( K: int = 5, prune_threshold: float = 0.0, verbose: bool = False, - ): + ) -> "ByteEnsemble": """Factory method to initialize beam states from prompts and return a ByteEnsemble instance. Args: - llm1: First language model (from genlm.backend) - llm2: Second language model (from genlm.backend) - op: Operation name ('sum', 'prod', 'min', 'max', 'harmonic', or power means) - prompt1: Prompt bytes for first model - prompt2: Prompt bytes for second model - a: Weighting parameter between 0 and 1 (default 0.5 for equal weighting) - K: Beam width for beam search (default 5) - prune_threshold: Threshold for pruning low-probability beams (default 0.0) - verbose: Whether to print verbose beam search output (default False) + llm1 (Any): First language model (from genlm.backend) + llm2 (Any): Second language model (from genlm.backend) + op (str): Operation name ('sum', 'prod', 'min', 'max', 'harmonic', or power means) + prompt1 (bytes): Prompt bytes for first model + prompt2 (bytes): Prompt bytes for second model + a (float): Weighting parameter between 0 and 1 (default 0.5 for equal weighting) + K (int): Beam width for beam search (default 5) + prune_threshold (float): Threshold for pruning low-probability beams (default 0.0) + verbose (bool): Whether to print verbose beam search output (default False) Returns: ByteEnsemble: Initialized ensemble with beam states ready for sampling @@ -368,7 +381,9 @@ async def _cleanup_cache(self): if len(k) < min_len: del d[k] - async def get_beam_states(self, context: List[int]): + async def get_beam_states( + self, context: List[int] + ) -> Tuple["ByteBeamState", "ByteBeamState"]: """Fetch beam states for the current context. This method provides direct access to the underlying beam states, which @@ -382,7 +397,7 @@ async def get_beam_states(self, context: List[int]): Raises: KeyError: If context not found in cache (beam states must be populated - by ByteEnsembleTokenSampler during sampling) + by ByteEnsembleTokenSampler during sampling) """ ctx_bytes = bytes(context) await self._cleanup_cache() @@ -390,23 +405,29 @@ async def get_beam_states(self, context: List[int]): beam2 = self.data_dict_2[ctx_bytes] return beam1, beam2 - async def prefix(self, context: List[int]): + async def prefix(self, context: List[int]) -> None: """Compute prefix weight (not fully implemented). ByteEnsemble is designed to be used with ByteEnsembleTokenSampler which manages weights separately. This method is a stub to satisfy the Potential interface. + Args: + context (List[int]): The context as list of byte values + Returns: None """ return None - async def complete(self, context: List[int]): + async def complete(self, context: List[int]) -> None: """Compute completion weight (not fully implemented). ByteEnsemble is designed to be used with ByteEnsembleTokenSampler which manages weights separately. This method is a stub to satisfy the Potential interface. + Args: + context (List[int]): The context as list of byte values + Returns: None """ @@ -418,6 +439,13 @@ def _power_mean(p: float, a: float): M_p(x, y; a) = (a * exp(p*x) + (1-a) * exp(p*y))^(1/p) In log space: (1/p) * logsumexp([log(a) + p*x, log(1-a) + p*y]) + + Args: + p (float): Power parameter for the power mean + a (float): Weighting parameter between 0 and 1 + + Returns: + Callable: Function that computes weighted power mean in log space """ log_a, log_1_minus_a = np.log(a), np.log(1 - a) return lambda x, y: (1.0 / p) * logsumexp( @@ -426,7 +454,15 @@ def _power_mean(p: float, a: float): def _weighted_extremum(func, a: float): - """Create a weighted min/max operator.""" + """Create a weighted min/max operator. + + Args: + func (Callable): The extremum function (np.minimum or np.maximum) + a (float): Weighting parameter between 0 and 1 + + Returns: + Callable: Function that computes weighted extremum + """ def extremum(x, y, a): if a <= 0.5: @@ -476,7 +512,7 @@ def convert_to_weighted_logop( "p5", ], a: float = 0.5, -): +) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: """Convert a string operation to its weighted log-space equivalent. This function takes an operation name and a weighting parameter and returns @@ -484,17 +520,17 @@ def convert_to_weighted_logop( weighted operation. Args: - op: Operation name. Supported operations include: + op (str): Operation name. Supported operations include: - Means: "sum" (arithmetic), "prod" (geometric), "harmonic" - Extrema: "min", "max" - Power means: "pm5", "pm2.5", "p-2", "pm1.5", "pm0.5", "pm0.25", - "p0.25", "p0.5", "p1.5", "p2", "p2.5", "p3", "p5" - a: Weighting parameter between 0 and 1. When a=0.5, equal weighting. - For weighted operations: a * model1 + (1-a) * model2 + "p0.25", "p0.5", "p1.5", "p2", "p2.5", "p3", "p5" + a (float): Weighting parameter between 0 and 1. When a=0.5, equal weighting. + For weighted operations: a * model1 + (1-a) * model2 Returns: - A function that takes two log-probability arrays and returns their - weighted combination in log space. + Callable[[np.ndarray, np.ndarray], np.ndarray]: A function that takes two + log-probability arrays and returns their weighted combination in log space. Raises: ValueError: If a is not between 0 and 1, or if op is not recognized. diff --git a/genlm/control/sampler/byte_ensemble.py b/genlm/control/sampler/byte_ensemble.py index c267964..d0062ec 100644 --- a/genlm/control/sampler/byte_ensemble.py +++ b/genlm/control/sampler/byte_ensemble.py @@ -102,7 +102,11 @@ def __init__( self.prefix_cache_2[()] = 0.0 async def start_weight(self) -> float: - """Compute the weight of the empty sequence.""" + """Compute the weight of the empty sequence. + + Returns: + float: Log weight of the empty sequence (always 0.0) + """ return 0.0 async def sample(self, context: List[int], draw=None) -> Tuple[int, float, float]: @@ -192,7 +196,7 @@ async def smc( max_tokens: int, critic=None, **kwargs, - ): + ) -> "SequencesExt": """Run Sequential Monte Carlo inference with byte-level ensemble. This method requires EnsembleSMC to be available in the sampler.sequence module. @@ -202,11 +206,11 @@ async def smc( n_particles (int): Number of particles to maintain ess_threshold (float): ESS threshold for resampling (0-1) max_tokens (int): Maximum tokens per sequence - critic (Potential, optional): Critic potential for guided sampling + critic (Potential): Critic potential for guided sampling **kwargs: Additional arguments passed to SMC Returns: - Sequences or SequencesExt: Generated sequences with weights + SequencesExt: Generated sequences with weights Raises: ImportError: If required SMC components are not available diff --git a/genlm/control/sampler/sequence.py b/genlm/control/sampler/sequence.py index a777928..4740efe 100644 --- a/genlm/control/sampler/sequence.py +++ b/genlm/control/sampler/sequence.py @@ -362,25 +362,25 @@ class EnsembleSMC(SMC): async def __call__( self, - n_particles, - ess_threshold, - max_tokens, - verbosity=0, - json_path=None, + n_particles: int, + ess_threshold: float, + max_tokens: int, + verbosity: int = 0, + json_path: str = None, **kwargs, - ): + ) -> "SequencesExt": """Generate sequences using SMC with ensemble weight tracking. Args: n_particles (int): Number of particles to maintain. ess_threshold (float): ESS threshold for resampling (0-1). max_tokens (int): Maximum tokens to generate. - verbosity (int, optional): Verbosity level (0=silent, 1=verbose). - json_path (str, optional): Path to save inference visualization data. + verbosity (int): Verbosity level (0=silent, 1=verbose). + json_path (str): Path to save inference visualization data. **kwargs: Additional arguments for smc_standard. Returns: - (SequencesExt): Sequences with individual model weights. + SequencesExt: Sequences with individual model weights. """ try: original_max_tokens = self.model.max_tokens diff --git a/tests/potential/test_ensemble.py b/tests/potential/test_ensemble.py index e8dc807..b26c36d 100644 --- a/tests/potential/test_ensemble.py +++ b/tests/potential/test_ensemble.py @@ -311,7 +311,7 @@ async def test_ensemble_warns_on_different_vocabularies(): p2 = MockPotential(vocab=vocab2, next_token_logws=logws2) with pytest.warns(UserWarning, match="different vocabularies"): with pytest.raises((KeyError, AssertionError)): - ensemble = Ensemble(p1, p2, op="prod", a=0.5) + _ = Ensemble(p1, p2, op="prod", a=0.5) @pytest.mark.asyncio From c3ae593833fe1724c7e3d1e1f9e22ae58e5da559 Mon Sep 17 00:00:00 2001 From: samuki Date: Sun, 1 Feb 2026 23:10:59 +0100 Subject: [PATCH 5/6] Update coverage, simplify EnsembleSMC --- genlm/control/sampler/byte_ensemble.py | 2 +- genlm/control/sampler/sequence.py | 50 ++++------- tests/potential/test_ensemble.py | 110 +++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 37 deletions(-) diff --git a/genlm/control/sampler/byte_ensemble.py b/genlm/control/sampler/byte_ensemble.py index d0062ec..f1fffbe 100644 --- a/genlm/control/sampler/byte_ensemble.py +++ b/genlm/control/sampler/byte_ensemble.py @@ -196,7 +196,7 @@ async def smc( max_tokens: int, critic=None, **kwargs, - ) -> "SequencesExt": + ): """Run Sequential Monte Carlo inference with byte-level ensemble. This method requires EnsembleSMC to be available in the sampler.sequence module. diff --git a/genlm/control/sampler/sequence.py b/genlm/control/sampler/sequence.py index 4740efe..ff8bdde 100644 --- a/genlm/control/sampler/sequence.py +++ b/genlm/control/sampler/sequence.py @@ -382,33 +382,21 @@ async def __call__( Returns: SequencesExt: Sequences with individual model weights. """ - try: - original_max_tokens = self.model.max_tokens - original_verbosity = self.model.verbosity - original_twist_with_critic = self.model.twist_with_critic - self.model.max_tokens = max_tokens - self.model.verbosity = verbosity - self.model.twist_with_critic = ess_threshold > 0 - - particles = await smc_standard( - model=self.model, - n_particles=n_particles, - ess_threshold=ess_threshold, - json_file=json_path, - **kwargs, - ) - finally: - self.model.max_tokens = original_max_tokens - self.model.verbosity = original_verbosity - self.model.twist_with_critic = original_twist_with_critic + sequences = await super().__call__( + n_particles=n_particles, + ess_threshold=ess_threshold, + max_tokens=max_tokens, + verbosity=verbosity, + json_path=json_path, + **kwargs, + ) - # Extract individual model weights if available log_prefix_weights_1 = [] log_prefix_weights_2 = [] if hasattr(self.unit_sampler, "particle_prefix_log_prob_1"): - for p in particles: - ctx_tuple = tuple(p.token_ctx) + for ctx in sequences.contexts: + ctx_tuple = tuple(ctx) log_prefix_weights_1.append( self.unit_sampler.particle_prefix_log_prob_1.get( ctx_tuple, float("-inf") @@ -416,27 +404,17 @@ async def __call__( ) if hasattr(self.unit_sampler, "particle_prefix_log_prob_2"): - for p in particles: - ctx_tuple = tuple(p.token_ctx) + for ctx in sequences.contexts: + ctx_tuple = tuple(ctx) log_prefix_weights_2.append( self.unit_sampler.particle_prefix_log_prob_2.get( ctx_tuple, float("-inf") ) ) - contexts, logws = map( - list, - zip( - *[ - (p.token_ctx, float("-inf") if np.isnan(p.weight) else p.weight) - for p in particles - ] - ), - ) - return SequencesExt( - contexts, - logws, + sequences.contexts, + sequences.log_weights, log_prefix_weights_1, log_prefix_weights_2, ) diff --git a/tests/potential/test_ensemble.py b/tests/potential/test_ensemble.py index b26c36d..2264e74 100644 --- a/tests/potential/test_ensemble.py +++ b/tests/potential/test_ensemble.py @@ -1,5 +1,6 @@ import pytest import numpy as np +from unittest.mock import AsyncMock, MagicMock, patch from genlm.backend import load_model_by_name from genlm.control import ( Ensemble, @@ -8,6 +9,12 @@ Potential, PromptedLLM, convert_to_weighted_logop, + EOS, +) +from genlm.control.sampler.sequence import EnsembleSMC, SequencesExt +from genlm.control.potential.built_in.ensemble import ( + split_with_atomic_tokens, + _weighted_extremum, ) from conftest import MockPotential @@ -692,3 +699,106 @@ def test_convert_to_weighted_logop_operations(): result_prod = op_prod(x, y) expected_prod = 0.5 * x + 0.5 * y np.testing.assert_allclose(result_prod, expected_prod, rtol=1e-5) + + +@pytest.mark.asyncio +async def test_byte_ensemble_token_sampler_start_weight(): + """Test ByteEnsembleTokenSampler.start_weight() returns 0.0.""" + llm = load_model_by_name("gpt2", backend="hf") + ensemble = await ByteEnsemble.create( + llm, llm, op="prod", prompt1=b"Hi", prompt2=b"Hi", a=0.5 + ) + eos_tokens = [llm.byte_vocab[llm.tokenizer.eos_token_id]] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=10, eos_tokens=eos_tokens, n_particles=5 + ) + start_weight = await sampler.start_weight() + assert start_weight == 0.0 + + +@pytest.mark.asyncio +async def test_byte_ensemble_sampler_eos_handling(): + """Test ByteEnsembleTokenSampler properly handles EOS tokens and max_tokens.""" + llm = load_model_by_name("gpt2", backend="hf") + ensemble = await ByteEnsemble.create( + llm, llm, op="prod", prompt1=b"Hi", prompt2=b"Hi", a=0.5 + ) + eos_byte = llm.byte_vocab[llm.tokenizer.eos_token_id] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=5, eos_tokens=[eos_byte], n_particles=5 + ) + _, _, _ = await sampler.sample([]) + if len(sampler.particle_prefix_log_prob_1) > 0: + assert len(sampler.particle_prefix_log_prob_1) >= 0 + assert len(sampler.particle_prefix_log_prob_2) >= 0 + + +@pytest.mark.asyncio +async def test_byte_ensemble_empty_beam_error(): + """Test ByteEnsemble raises RuntimeError when beam is empty after prefill.""" + mock_llm = MagicMock() + mock_llm.byte_vocab = {0: b"a"} + mock_llm.tokenizer.eos_token_id = 0 + empty_beam = MagicMock() + empty_beam.prefill = AsyncMock(return_value=[]) + with patch( + "genlm.control.potential.built_in.ensemble.ByteBeamState.initial", + AsyncMock(return_value=empty_beam), + ): + with pytest.raises(RuntimeError, match="Beam1 is empty after prefill"): + await ByteEnsemble.create( + mock_llm, mock_llm, op="prod", prompt1=b"test", prompt2=b"test", a=0.5 + ) + + +def test_split_with_atomic_tokens_overlapping(): + """Test split_with_atomic_tokens with overlapping tokens.""" + with pytest.warns(UserWarning, match="Overlapping atomic tokens detected"): + result = split_with_atomic_tokens(b"ABC", [b"A", b"AB"]) + assert result == [b"A", 66, 67] + + +def test_split_with_atomic_tokens_no_match(): + """Test split_with_atomic_tokens when no atomic tokens match.""" + result = split_with_atomic_tokens(b"XYZ", [b"A", b"B"]) + assert result == [88, 89, 90] + + +def test_weighted_extremum_different_weights(): + """Test _weighted_extremum with different weight values.""" + x = np.array([-1.0, -2.0, -3.0]) + y = np.array([-2.0, -1.5, -3.5]) + max_op_favoring_y = _weighted_extremum(np.maximum, a=0.7) + result = max_op_favoring_y(x, y) + expected = (2 * 0.7 - 1) * y + 2 * (1 - 0.7) * np.maximum(x, y) + np.testing.assert_allclose(result, expected, rtol=1e-5) + min_op_favoring_x = _weighted_extremum(np.minimum, a=0.3) + result2 = min_op_favoring_x(x, y) + expected2 = (1 - 2 * 0.3) * x + 2 * 0.3 * np.minimum(x, y) + np.testing.assert_allclose(result2, expected2, rtol=1e-5) + + +def test_sequences_ext_post_init(): + """Test SequencesExt.__post_init__ converts lists to numpy arrays.""" + seq = SequencesExt( + contexts=[["a", "b"], ["c", "d"]], + log_weights=[0.1, 0.2], + log_prefix_weights_1=[0.15, 0.25], + log_prefix_weights_2=[0.12, 0.22], + ) + assert isinstance(seq.log_prefix_weights_1, np.ndarray) + assert isinstance(seq.log_prefix_weights_2, np.ndarray) + seq2 = SequencesExt(contexts=[["a"]], log_weights=[0.1], log_prefix_weights_1=None) + assert seq2.log_prefix_weights_1 is None + + +def test_sequences_ext_post_init_with_none(): + """Test SequencesExt.__post_init__ handles None values correctly.""" + seq = SequencesExt( + contexts=[["a", "b"]], + log_weights=[0.1], + log_prefix_weights_1=None, + log_prefix_weights_2=None, + ) + assert seq.log_prefix_weights_1 is None + assert seq.log_prefix_weights_2 is None From 3f0a0b01d107e12b4af3bc976520c8605f4850bc Mon Sep 17 00:00:00 2001 From: samuki Date: Mon, 2 Feb 2026 00:19:47 +0100 Subject: [PATCH 6/6] Coverage updates --- genlm/control/potential/built_in/ensemble.py | 11 +- genlm/control/sampler/byte_ensemble.py | 13 +- genlm/control/sampler/sequence.py | 3 +- tests/potential/test_ensemble.py | 153 ++++++++++++++++++- 4 files changed, 164 insertions(+), 16 deletions(-) diff --git a/genlm/control/potential/built_in/ensemble.py b/genlm/control/potential/built_in/ensemble.py index c4d05b5..cc2a56b 100644 --- a/genlm/control/potential/built_in/ensemble.py +++ b/genlm/control/potential/built_in/ensemble.py @@ -399,7 +399,12 @@ async def get_beam_states( KeyError: If context not found in cache (beam states must be populated by ByteEnsembleTokenSampler during sampling) """ - ctx_bytes = bytes(context) + # Convert context to bytes + if context and isinstance(context[0], bytes): + ctx_bytes = b"".join(context) + else: + ctx_bytes = bytes(context) + await self._cleanup_cache() beam1 = self.data_dict_1[ctx_bytes] beam2 = self.data_dict_2[ctx_bytes] @@ -417,7 +422,7 @@ async def prefix(self, context: List[int]) -> None: Returns: None """ - return None + return None # pragma: no cover async def complete(self, context: List[int]) -> None: """Compute completion weight (not fully implemented). @@ -431,7 +436,7 @@ async def complete(self, context: List[int]) -> None: Returns: None """ - return None + return None # pragma: no cover def _power_mean(p: float, a: float): diff --git a/genlm/control/sampler/byte_ensemble.py b/genlm/control/sampler/byte_ensemble.py index f1fffbe..6142582 100644 --- a/genlm/control/sampler/byte_ensemble.py +++ b/genlm/control/sampler/byte_ensemble.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Tuple +from typing import Any, List, Literal, Tuple from collections import defaultdict from cachetools import LRUCache @@ -195,25 +195,16 @@ async def smc( ess_threshold: float, max_tokens: int, critic=None, - **kwargs, + **kwargs: Any, ): """Run Sequential Monte Carlo inference with byte-level ensemble. - This method requires EnsembleSMC to be available in the sampler.sequence module. - If not available, falls back to standard SMC. - Args: n_particles (int): Number of particles to maintain ess_threshold (float): ESS threshold for resampling (0-1) max_tokens (int): Maximum tokens per sequence critic (Potential): Critic potential for guided sampling **kwargs: Additional arguments passed to SMC - - Returns: - SequencesExt: Generated sequences with weights - - Raises: - ImportError: If required SMC components are not available """ return await EnsembleSMC(self, critic)( n_particles=n_particles, diff --git a/genlm/control/sampler/sequence.py b/genlm/control/sampler/sequence.py index ff8bdde..4fa6af4 100644 --- a/genlm/control/sampler/sequence.py +++ b/genlm/control/sampler/sequence.py @@ -1,4 +1,5 @@ import numpy as np +from typing import Any from genlm.grammar import Float from arsenal.maths import logsumexp from functools import cached_property @@ -367,7 +368,7 @@ async def __call__( max_tokens: int, verbosity: int = 0, json_path: str = None, - **kwargs, + **kwargs: Any, ) -> "SequencesExt": """Generate sequences using SMC with ensemble weight tracking. diff --git a/tests/potential/test_ensemble.py b/tests/potential/test_ensemble.py index 2264e74..82843d0 100644 --- a/tests/potential/test_ensemble.py +++ b/tests/potential/test_ensemble.py @@ -11,7 +11,7 @@ convert_to_weighted_logop, EOS, ) -from genlm.control.sampler.sequence import EnsembleSMC, SequencesExt +from genlm.control.sampler.sequence import EnsembleSMC, SequencesExt, Sequences from genlm.control.potential.built_in.ensemble import ( split_with_atomic_tokens, _weighted_extremum, @@ -778,6 +778,53 @@ def test_weighted_extremum_different_weights(): np.testing.assert_allclose(result2, expected2, rtol=1e-5) +@pytest.mark.asyncio +async def test_ensemble_smc_weight_extraction(): + """Test EnsembleSMC extracts individual model weights correctly.""" + from genlm.control.sampler.token import TokenSampler + + class MockTokenSampler(TokenSampler): + def __init__(self): + self.particle_prefix_log_prob_1 = { + ("a",): -1.0, + ("b",): -2.0, + } + self.particle_prefix_log_prob_2 = { + ("a",): -1.5, + ("b",): -2.5, + } + + async def start_weight(self): + return 0.0 + + async def sample(self, context, draw=None): + return EOS, 0.0, 0.0 + + mock_sampler = MockTokenSampler() + smc = EnsembleSMC(mock_sampler, None) + mock_sequences = Sequences( + contexts=[["a"], ["b"]], + log_weights=[-0.5, -0.7], + ) + + with patch.object( + EnsembleSMC.__bases__[0], + "__call__", + AsyncMock(return_value=mock_sequences), + ): + result = await smc(n_particles=2, ess_threshold=0.5, max_tokens=10) + + assert isinstance(result, SequencesExt) + assert hasattr(result, "log_prefix_weights_1") + assert hasattr(result, "log_prefix_weights_2") + assert len(result.log_prefix_weights_1) == 2 + assert len(result.log_prefix_weights_2) == 2 + assert result.log_prefix_weights_1[0] == -1.0 + assert result.log_prefix_weights_1[1] == -2.0 + assert result.log_prefix_weights_2[0] == -1.5 + assert result.log_prefix_weights_2[1] == -2.5 + + def test_sequences_ext_post_init(): """Test SequencesExt.__post_init__ converts lists to numpy arrays.""" seq = SequencesExt( @@ -802,3 +849,107 @@ def test_sequences_ext_post_init_with_none(): ) assert seq.log_prefix_weights_1 is None assert seq.log_prefix_weights_2 is None + + +@pytest.mark.asyncio +async def test_byte_ensemble_cleanup_cache_deletes_short_keys(): + """Test ByteEnsemble._cleanup_cache() deletes short keys.""" + gpt2 = load_model_by_name("gpt2") + prompt1 = b"The capital of France is" + prompt2 = b"Paris, the capital city of France, is" + ensemble = await ByteEnsemble.create( + gpt2, gpt2, "sum", prompt1, prompt2, a=0.5, K=3 + ) + ensemble.data_dict_1 = { + (1,): "short1", + (1, 2): "short2", + (1, 2, 3): "keep3", + (1, 2, 3, 4): "keep4", + (1, 2, 3, 4, 5): "keep5", + (1, 2, 3, 4, 5, 6): "keep6", + } + ensemble.data_dict_2 = { + (10,): "short1", + (10, 20): "short2", + (10, 20, 30): "keep3", + (10, 20, 30, 40): "keep4", + (10, 20, 30, 40, 50): "keep5", + (10, 20, 30, 40, 50, 60): "keep6", + } + await ensemble._cleanup_cache() + for d in [ensemble.data_dict_1, ensemble.data_dict_2]: + for k in d.keys(): + assert len(k) >= 4, f"Key {k} should have been deleted" + assert len(ensemble.data_dict_1) == 3 + assert len(ensemble.data_dict_2) == 3 + + +@pytest.mark.asyncio +async def test_byte_ensemble_empty_beam_error_covered(): + gpt2 = load_model_by_name("gpt2") + prompt1 = b"\xff\xfe\xfd" # Invalid UTF-8 bytes + prompt2 = b"Test" + with pytest.raises(RuntimeError, match="is empty after prefill"): + await ByteEnsemble.create( + gpt2, gpt2, "sum", prompt1, prompt2, a=0.5, K=1, prune_threshold=100.0 + ) + + +@pytest.mark.asyncio +async def test_byte_ensemble_sampler_stores_particle_weights(): + """Test ByteEnsembleTokenSampler stores particle weights at EOS.""" + gpt2 = load_model_by_name("gpt2") + prompt1 = b"Hi" + prompt2 = b"Hi" + + ensemble = await ByteEnsemble.create( + gpt2, gpt2, "sum", prompt1, prompt2, a=0.5, K=3 + ) + sampler = ByteEnsembleTokenSampler(ensemble, max_tokens=1) + context = [] + token, _, _ = await sampler.sample(context) + new_ctx_tuple = (token,) + assert new_ctx_tuple in sampler.particle_prefix_log_prob_1 + assert new_ctx_tuple in sampler.particle_prefix_log_prob_2 + + +@pytest.mark.asyncio +async def test_byte_ensemble_sampler_eos_conversion(): + """Test ByteEnsembleTokenSampler EOS conversion path.""" + gpt2 = load_model_by_name("gpt2") + prompt1 = b"Hi" + prompt2 = b"Hi" + ensemble = await ByteEnsemble.create( + gpt2, gpt2, "sum", prompt1, prompt2, a=0.5, K=3 + ) + sampler = ByteEnsembleTokenSampler(ensemble) + context = [] + token, _, _ = await sampler.sample(context) + assert token is not None + + +@pytest.mark.asyncio +async def test_byte_ensemble_sampler_smc_calls_ensemble_smc(): + """Test ByteEnsembleTokenSampler.smc() method invokes EnsembleSMC.""" + gpt2 = load_model_by_name("gpt2") + prompt1 = b"Hi" + prompt2 = b"Hi" + ensemble = await ByteEnsemble.create( + gpt2, gpt2, "sum", prompt1, prompt2, a=0.5, K=3 + ) + sampler = ByteEnsembleTokenSampler(ensemble) + assert hasattr(sampler, "smc") + assert callable(sampler.smc) + try: + result = await sampler.smc( + n_particles=1, + ess_threshold=0.5, + max_tokens=1, + critic=None, + ) + assert isinstance(result, SequencesExt) + except (AssertionError, KeyError) as e: + if "Beam is empty" in str(e) or "not found in cache" in str(e): + pass + else: + raise