diff --git a/genlm/control/__init__.py b/genlm/control/__init__.py index 2af3f83..6b323cc 100644 --- a/genlm/control/__init__.py +++ b/genlm/control/__init__.py @@ -9,13 +9,20 @@ WCFG, JsonSchema, CanonicalTokenization, + Ensemble, + convert_to_weighted_logop, + ByteEnsemble, ) from .sampler import ( SMC, + EnsembleSMC, + Sequences, + SequencesExt, direct_token_sampler, eager_token_sampler, topk_token_sampler, AWRS, + ByteEnsembleTokenSampler, ) from .viz import InferenceVisualizer @@ -23,6 +30,9 @@ "EOS", "EOT", "SMC", + "EnsembleSMC", + "Sequences", + "SequencesExt", "Potential", "PromptedLLM", "ByteLLM", @@ -37,4 +47,8 @@ "eager_token_sampler", "topk_token_sampler", "InferenceVisualizer", + "Ensemble", + "convert_to_weighted_logop", + "ByteEnsemble", + "ByteEnsembleTokenSampler", ] diff --git a/genlm/control/potential/__init__.py b/genlm/control/potential/__init__.py index b35572c..3e0c805 100644 --- a/genlm/control/potential/__init__.py +++ b/genlm/control/potential/__init__.py @@ -14,6 +14,9 @@ BoolFSA, JsonSchema, CanonicalTokenization, + Ensemble, + convert_to_weighted_logop, + ByteEnsemble, ) __all__ = [ @@ -28,6 +31,9 @@ "WFSA", "BoolFSA", "CanonicalTokenization", + "Ensemble", + "convert_to_weighted_logop", + "ByteEnsemble", "AutoBatchedPotential", "MultiProcPotential", "Coerced", diff --git a/genlm/control/potential/built_in/__init__.py b/genlm/control/potential/built_in/__init__.py index 8656469..0cde80a 100644 --- a/genlm/control/potential/built_in/__init__.py +++ b/genlm/control/potential/built_in/__init__.py @@ -4,6 +4,7 @@ from .json import JsonSchema from .canonical import CanonicalTokenization from .bytellm import ByteLLM +from .ensemble import Ensemble, convert_to_weighted_logop, ByteEnsemble __all__ = [ "PromptedLLM", @@ -14,4 +15,7 @@ "WFSA", "BoolFSA", "CanonicalTokenization", + "Ensemble", + "convert_to_weighted_logop", + "ByteEnsemble", ] diff --git a/genlm/control/potential/built_in/ensemble.py b/genlm/control/potential/built_in/ensemble.py new file mode 100644 index 0000000..cc2a56b --- /dev/null +++ b/genlm/control/potential/built_in/ensemble.py @@ -0,0 +1,572 @@ +import asyncio +import warnings +import numpy as np +from typing import Any, Callable, List, Literal, Tuple, Union +from collections import defaultdict + +from arsenal.maths import logsumexp + +from genlm.control.potential.base import Potential +from genlm.bytes import ByteBeamState, BeamParams + + +class Ensemble(Potential): + """An ensemble potential combining two language models using a weighted operation. + + The Ensemble class creates a potential that combines log-probabilities from two + base potentials (typically language models) using a specified weighted operation + (e.g., weighted geometric mean, arithmetic mean, min, max, etc.). + + Args: + p1 (Potential): First potential (language model) + p2 (Potential): Second potential (language model) + op (str): Operation name (e.g., "sum", "prod", "min", "max", "harmonic", or power means) + a (float): Weighting parameter between 0 and 1 (default 0.5 for equal weighting). + When a=0.5, models are weighted equally. For a != 0.5, the combination + is weighted: a * model1 + (1-a) * model2 + + Attributes: + p1: First potential + p2: Second potential + op: Weighted log operation function + p1_vocab_idxs: Indices mapping unified vocabulary to p1's vocabulary + p2_vocab_idxs: Indices mapping unified vocabulary to p2's vocabulary + + Example: + ```python + from genlm.control import PromptedLLM, Ensemble + + # Create two language model potentials + p1 = PromptedLLM.from_name("gpt2") + p2 = PromptedLLM.from_name("gpt2") + + # Create an ensemble with weighted geometric mean (a=0.5) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + + # Use ensemble in sampling + logw = await ensemble.prefix(context) + ``` + + Note: + The Ensemble class handles vocabulary alignment automatically. Both potentials + must have compatible vocabularies (typically the same tokenizer). The logw_next + method is not implemented for Ensemble; instead, use logws_next to get separate + log weights from each model, or use batch_logw_next for batched operations. + """ + + def __init__( + self, + p1: Potential, + p2: Potential, + op: str, + a: float = 0.5, + ): + self.p1 = p1 + self.p2 = p2 + self.op = convert_to_weighted_logop(op, a) + + # Warn if potentials have different vocabularies + if set(p1.vocab) != set(p2.vocab): + warnings.warn( + "Ensemble is being used with potentials that have different vocabularies. " + "Consider using ByteEnsemble instead.", + UserWarning, + stacklevel=2, + ) + + vocab = list(set(p1.vocab + p2.vocab)) + super().__init__(vocabulary=vocab) + + self.p1_vocab_idxs = [self.p1.lookup[x] for x in self.vocab_eos] + self.p2_vocab_idxs = [self.p2.lookup[x] for x in self.vocab_eos] + assert self.p1_vocab_idxs == self.p2_vocab_idxs + + async def prefix(self, context: List[str]) -> float: + """Compute log weights for the prefix using both potentials. + + Args: + context (List[str]): The context tokens + + Returns: + float: Combined log weight from both potentials using the ensemble operation + """ + p1_logw, p2_logw = await asyncio.gather( + self.p1.prefix(context), self.p2.prefix(context) + ) + return self.op(p1_logw, p2_logw) + + async def complete(self, context: List[str]) -> float: + """Compute completion log weights using both potentials. + + Args: + context (List[str]): The context tokens + + Returns: + float: Combined completion log weight from both potentials + """ + p1_logw, p2_logw = await asyncio.gather( + self.p1.complete(context), self.p2.complete(context) + ) + return self.op(p1_logw, p2_logw) + + async def logws_next(self, context: List[str]) -> Tuple[Any, Any]: + """Get log weights from both potentials separately. + + This method returns the log weights from both underlying potentials + without combining them. Useful for custom combination logic. + + Args: + context (List[str]): The context tokens + + Returns: + Tuple[Any, Any]: Tuple of (p1_logw_next, p2_logw_next) + """ + return await asyncio.gather( + self.p1.logw_next(context), self.p2.logw_next(context) + ) + + async def logw_next(self, context: List[str]): + """Not implemented for Ensemble class. + + Raises: + NotImplementedError: Always raised. Use logws_next or batch_logw_next instead. + """ + raise NotImplementedError("logw_next is not implemented for Ensemble class.") + + async def batch_logw_next(self, contexts: List[List[str]]) -> List[Any]: + """Batched version of logw_next for Ensemble. + + This enables batching when multiple particles need to be extended during SMC, + which can significantly improve performance when using PromptedLLM with + batch_logw_next support. + + Args: + contexts (List[List[str]]): List of context token sequences + + Returns: + List[Any]: List of LazyWeights objects, one per context, containing the combined + log weights from both potentials + + Note: + This method is only used if the Ensemble is wrapped in AutoBatchedPotential or + called directly with multiple contexts. EnsembleTokenSampler calls p1.logw_next() + and p2.logw_next() directly, so for batching in EnsembleTokenSampler, wrap p1 + and p2 in AutoBatchedPotential before creating the Ensemble. + """ + # Get batched log weights from both potentials + Ws1, Ws2 = await asyncio.gather( + self.p1.batch_logw_next(contexts), self.p2.batch_logw_next(contexts) + ) + # Combine using the ensemble operation + return [ + self.make_lazy_weights( + self.op( + Ws1[n].weights[self.p1_vocab_idxs], + Ws2[n].weights[self.p2_vocab_idxs], + ) + ) + for n in range(len(contexts)) + ] + + +def split_with_atomic_tokens( + data: bytes, atomic_tokens: list[bytes] +) -> list[Union[int, bytes]]: + """ + Splits a bytestring into a list of either individual bytes (as integers) or atomic tokens (as bytes), + depending on whether the current position matches an atomic token. + + Args: + data (bytes): The input byte string to split. + atomic_tokens (list[bytes]): A list of byte substrings that are treated as indivisible atomic tokens. + + Returns: + list[Union[int, bytes]]: A list where each element is either: + - an atomic token (as bytes) if a match is found at that position, + - or a single byte (as an int) if no atomic token matches. + + Notes: + - Matching is greedy but only left-to-right: at each position, the function checks for atomic token matches + starting from length 1 up to the maximum token length. + - Only the first match (shortest prefix match) is used; longer overlapping tokens may be missed if a shorter + prefix matches first. + - If atomic tokens overlap (e.g., b"A" and b"AB"), a warning is raised and only the shortest prefix match + will be used. + + Example: + >>> split_with_atomic_tokens(b"ABC", [b"A", b"AB"]) + [b'A', 66, 67] # b"AB" is not matched because b"A" matched first + """ + # Detect overlapping atomic tokens + for i, token1 in enumerate(atomic_tokens): + for j, token2 in enumerate(atomic_tokens): + if i != j and (token1.startswith(token2) or token2.startswith(token1)): + warnings.warn( + f"Overlapping atomic tokens detected: {token1!r} and {token2!r}. " + "Only the shortest matching prefix will be used." + ) + break + + result = [] + i = 0 + token_set = set(atomic_tokens) + max_len = max(len(t) for t in atomic_tokens) if atomic_tokens else 0 + + while i < len(data): + matched = False + for length in range(1, max_len + 1): + fragment = data[i : i + length] + if fragment in token_set: + result.append(fragment) + i += length + matched = True + break + if not matched: + result.append(data[i]) + i += 1 + + return result + + +class ByteEnsemble(Potential): + """ + An ensemble potential combining two language models at the byte level using beam search. + + ByteEnsemble manages synchronized beam states for two language models, enabling efficient + byte-level ensemble sampling. Unlike the standard Ensemble class that works with any + Potential, ByteEnsemble provides direct access to beam states for specialized sampling + strategies like ByteEnsembleTokenSampler. + + Attributes: + p1, p2: The base LM objects (not Potentials, but raw model objects). + op: A function to combine log-probabilities. + data_dict_1, data_dict_2: Beam state caches keyed by context (bytes). + vocabulary: Byte-level vocabulary (list of integers 0-255). + eos_tokens: List of EOS tokens from both models. + + Note: + ByteEnsemble is designed to work with ByteEnsembleTokenSampler for specialized + byte-level ensemble sampling. The prefix() and complete() methods are not fully + implemented as this class is meant to be used with custom sampling strategies + that directly access beam states via get_beam_states(). + + Example: + ```python + from genlm.backend import load_model_by_name + from genlm.bytes import BeamParams + from genlm.control.potential.built_in import ByteEnsemble + + llm1 = load_model_by_name("gpt2") + llm2 = load_model_by_name("gpt2") + + beam_params = BeamParams(K=5, prune_threshold=0.0) + ensemble = await ByteEnsemble.create( + llm1, llm2, + op="prod", + prompt1=b"Hello ", + prompt2=b"Hello ", + a=0.5 + ) + + # Use with ByteEnsembleTokenSampler for sampling + ``` + """ + + def __init__( + self, + p1: Any, + p2: Any, + op: Callable, + data_dict_1: dict, + data_dict_2: dict, + vocab: List[int], + eos_tokens: List[bytes], + ): + self.p1 = p1 + self.p2 = p2 + self.op = op + self.data_dict_1 = data_dict_1 + self.data_dict_2 = data_dict_2 + self.eos_tokens = eos_tokens + super().__init__(vocabulary=vocab) + + @classmethod + async def create( + cls, + llm1: Any, + llm2: Any, + op: str, + prompt1: bytes, + prompt2: bytes, + a: float = 0.5, + K: int = 5, + prune_threshold: float = 0.0, + verbose: bool = False, + ) -> "ByteEnsemble": + """Factory method to initialize beam states from prompts and return a ByteEnsemble instance. + + Args: + llm1 (Any): First language model (from genlm.backend) + llm2 (Any): Second language model (from genlm.backend) + op (str): Operation name ('sum', 'prod', 'min', 'max', 'harmonic', or power means) + prompt1 (bytes): Prompt bytes for first model + prompt2 (bytes): Prompt bytes for second model + a (float): Weighting parameter between 0 and 1 (default 0.5 for equal weighting) + K (int): Beam width for beam search (default 5) + prune_threshold (float): Threshold for pruning low-probability beams (default 0.0) + verbose (bool): Whether to print verbose beam search output (default False) + + Returns: + ByteEnsemble: Initialized ensemble with beam states ready for sampling + + Raises: + RuntimeError: If beam states become empty after prefill + """ + + beam_params = BeamParams(K=K, prune_threshold=prune_threshold, verbose=verbose) + data_dict_1 = defaultdict() + data_dict_2 = defaultdict() + + async def setup(): + # Initialize beams sequentially to avoid overwhelming vLLM with concurrent requests + beam1 = await ByteBeamState.initial(llm1, beam_params) + beam2 = await ByteBeamState.initial(llm2, beam_params) + # Prefill sequentially as well to reduce concurrent load + beam_state_1 = await beam1.prefill(prompt1) + beam_state_2 = await beam2.prefill(prompt2) + return beam_state_1, beam_state_2 + + beam_state_1, beam_state_2 = await setup() + + # Check if beams are empty after initialization + if len(beam_state_1) == 0: + raise RuntimeError( + f"Beam1 is empty after prefill with prompt of length {len(prompt1)} bytes" + ) + if len(beam_state_2) == 0: + raise RuntimeError( + f"Beam2 is empty after prefill with prompt of length {len(prompt2)} bytes" + ) + + data_dict_1[b""] = beam_state_1 + data_dict_2[b""] = beam_state_2 + + eos_tokens = [ + llm1.byte_vocab[llm1.tokenizer.eos_token_id], + llm2.byte_vocab[llm2.tokenizer.eos_token_id], + ] + + return cls( + llm1, + llm2, + convert_to_weighted_logop(op, a), + data_dict_1, + data_dict_2, + vocab=list(range(256)), + eos_tokens=eos_tokens, + ) + + async def _cleanup_cache(self): + """Remove old entries to avoid cache bloat.""" + max_len = max( + ( + len(split_with_atomic_tokens(k, self.eos_tokens)) + for k in self.data_dict_1 + ), + default=0, + ) + min_len = max_len - 2 + for d in [self.data_dict_1, self.data_dict_2]: + for k in list(d.keys()): + if len(k) < min_len: + del d[k] + + async def get_beam_states( + self, context: List[int] + ) -> Tuple["ByteBeamState", "ByteBeamState"]: + """Fetch beam states for the current context. + + This method provides direct access to the underlying beam states, which + is used by ByteEnsembleTokenSampler for synchronized beam advancement. + + Args: + context (List[int]): Context as list of byte values + + Returns: + Tuple[ByteBeamState, ByteBeamState]: Beam states from both models + + Raises: + KeyError: If context not found in cache (beam states must be populated + by ByteEnsembleTokenSampler during sampling) + """ + # Convert context to bytes + if context and isinstance(context[0], bytes): + ctx_bytes = b"".join(context) + else: + ctx_bytes = bytes(context) + + await self._cleanup_cache() + beam1 = self.data_dict_1[ctx_bytes] + beam2 = self.data_dict_2[ctx_bytes] + return beam1, beam2 + + async def prefix(self, context: List[int]) -> None: + """Compute prefix weight (not fully implemented). + + ByteEnsemble is designed to be used with ByteEnsembleTokenSampler which + manages weights separately. This method is a stub to satisfy the Potential interface. + + Args: + context (List[int]): The context as list of byte values + + Returns: + None + """ + return None # pragma: no cover + + async def complete(self, context: List[int]) -> None: + """Compute completion weight (not fully implemented). + + ByteEnsemble is designed to be used with ByteEnsembleTokenSampler which + manages weights separately. This method is a stub to satisfy the Potential interface. + + Args: + context (List[int]): The context as list of byte values + + Returns: + None + """ + return None # pragma: no cover + + +def _power_mean(p: float, a: float): + """Create a weighted power mean operator in log space. + + M_p(x, y; a) = (a * exp(p*x) + (1-a) * exp(p*y))^(1/p) + In log space: (1/p) * logsumexp([log(a) + p*x, log(1-a) + p*y]) + + Args: + p (float): Power parameter for the power mean + a (float): Weighting parameter between 0 and 1 + + Returns: + Callable: Function that computes weighted power mean in log space + """ + log_a, log_1_minus_a = np.log(a), np.log(1 - a) + return lambda x, y: (1.0 / p) * logsumexp( + [log_a + p * x, log_1_minus_a + p * y], axis=0 + ) + + +def _weighted_extremum(func, a: float): + """Create a weighted min/max operator. + + Args: + func (Callable): The extremum function (np.minimum or np.maximum) + a (float): Weighting parameter between 0 and 1 + + Returns: + Callable: Function that computes weighted extremum + """ + + def extremum(x, y, a): + if a <= 0.5: + return (1 - 2 * a) * x + 2 * a * func(x, y) + else: + return (2 * a - 1) * y + 2 * (1 - a) * func(x, y) + + return lambda x, y: extremum(x, y, a) + + +_POWER_MEANS = { + "pm5": -5.0, + "pm2.5": -2.5, + "p-2": -2.0, + "pm1.5": -1.5, + "pm0.5": -0.5, + "pm0.25": -0.25, + "p0.25": 0.25, + "p0.5": 0.5, + "p1.5": 1.5, + "p2": 2.0, + "p2.5": 2.5, + "p3": 3.0, + "p5": 5.0, +} + + +def convert_to_weighted_logop( + op: Literal[ + "sum", + "prod", + "min", + "max", + "harmonic", + "pm5", + "pm2.5", + "p-2", + "pm1.5", + "pm0.5", + "pm0.25", + "p0.25", + "p0.5", + "p1.5", + "p2", + "p2.5", + "p3", + "p5", + ], + a: float = 0.5, +) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: + """Convert a string operation to its weighted log-space equivalent. + + This function takes an operation name and a weighting parameter and returns + a function that combines two log-probability arrays using the specified + weighted operation. + + Args: + op (str): Operation name. Supported operations include: + - Means: "sum" (arithmetic), "prod" (geometric), "harmonic" + - Extrema: "min", "max" + - Power means: "pm5", "pm2.5", "p-2", "pm1.5", "pm0.5", "pm0.25", + "p0.25", "p0.5", "p1.5", "p2", "p2.5", "p3", "p5" + a (float): Weighting parameter between 0 and 1. When a=0.5, equal weighting. + For weighted operations: a * model1 + (1-a) * model2 + + Returns: + Callable[[np.ndarray, np.ndarray], np.ndarray]: A function that takes two + log-probability arrays and returns their weighted combination in log space. + + Raises: + ValueError: If a is not between 0 and 1, or if op is not recognized. + + Examples: + >>> op_func = convert_to_weighted_logop("sum", a=0.5) + >>> x = np.log(np.array([0.3, 0.7])) + >>> y = np.log(np.array([0.6, 0.4])) + >>> result = op_func(x, y) # Weighted arithmetic mean in log space + """ + if not 0 < a < 1: + raise ValueError("variable a should be between 0 and 1") + + log_a, log_1_minus_a = np.log(a), np.log(1 - a) + + # Ensemble operations + if op in _POWER_MEANS: + return _power_mean(_POWER_MEANS[op], a) + + operations = { + "sum": lambda x, y: logsumexp([x + log_a, y + log_1_minus_a], axis=0), + "prod": lambda x, y: a * x + (1 - a) * y, + "harmonic": lambda x, y: -logsumexp([-x + log_a, -y + log_1_minus_a], axis=0), + "min": _weighted_extremum(np.minimum, a), + "max": _weighted_extremum(np.maximum, a), + } + + if op in operations: + return operations[op] + + valid_ops = list(operations.keys()) + list(_POWER_MEANS.keys()) + raise ValueError( + f"Invalid operation: {op}. Must be one of {', '.join(repr(o) for o in valid_ops)}." + ) diff --git a/genlm/control/sampler/__init__.py b/genlm/control/sampler/__init__.py index ed5bbea..31bddcb 100644 --- a/genlm/control/sampler/__init__.py +++ b/genlm/control/sampler/__init__.py @@ -1,6 +1,6 @@ from .token import DirectTokenSampler, SetTokenSampler, AWRS, TokenSampler from .set import EagerSetSampler, TopKSetSampler -from .sequence import SMC, SequenceModel +from .sequence import SMC, EnsembleSMC, Sequences, SequencesExt, SequenceModel from .unit import ( MultiTokenUnitSampler, BoundaryPredicate, @@ -9,6 +9,7 @@ CFGBoundary, flatten_units, ) +from .byte_ensemble import ByteEnsembleTokenSampler from genlm.control.potential import Potential @@ -73,6 +74,9 @@ def topk_token_sampler(iter_potential, item_potential, K): "TokenSampler", "Importance", "SMC", + "EnsembleSMC", + "Sequences", + "SequencesExt", "SequenceModel", "MultiTokenUnitSampler", "BoundaryPredicate", @@ -80,4 +84,5 @@ def topk_token_sampler(iter_potential, item_potential, K): "FixedLengthBoundary", "CFGBoundary", "flatten_units", + "ByteEnsembleTokenSampler", ] diff --git a/genlm/control/sampler/byte_ensemble.py b/genlm/control/sampler/byte_ensemble.py new file mode 100644 index 0000000..6142582 --- /dev/null +++ b/genlm/control/sampler/byte_ensemble.py @@ -0,0 +1,214 @@ +from typing import Any, List, Literal, Tuple +from collections import defaultdict + +from cachetools import LRUCache +from arsenal.maths import logsumexp + +from genlm.control.sampler.token import TokenSampler +from genlm.control.util import fast_sample_logprobs +from genlm.control.constant import EOS +from genlm.control.sampler.sequence import EnsembleSMC + + +class ByteEnsembleTokenSampler(TokenSampler): + """ + Token sampler for byte-level ensemble using synchronized beam search. + + This sampler draws from an ensemble of two language models by advancing both + beam states synchronously with the same sampled token. This enables efficient + exploration with proper importance weighting for SMC. + + Unlike standard token samplers, ByteEnsembleTokenSampler: + - Directly accesses and manipulates beam states from ByteEnsemble + - Advances both beams with the same token (synchronized exploration) + - Tracks separate log probabilities for each model + - Uses shaping weights for proper SMC proposals + + Args: + potential (ByteEnsemble): The target byte-level ensemble potential. + proposal (Literal["linear", "abs", "square", "soft n"]): Proposal strategy. + Currently only "linear" is implemented. + n_particles (int): Number of particles for SMC sampling. Defaults to 10. + eos_tokens (List[int]): List of end-of-sequence tokens (as byte values). + max_tokens (int, optional): Maximum number of tokens to generate. + models_equal (bool): Flag indicating whether the two models are identical. + Defaults to False. + + Example: + ```python + from genlm.backend import load_model_by_name + from genlm.bytes import BeamParams + from genlm.control.potential.built_in import ByteEnsemble + from genlm.control.sampler.byte_ensemble import ByteEnsembleTokenSampler + + # Load models + llm1 = load_model_by_name("gpt2") + llm2 = load_model_by_name("gpt2") + + # Create ensemble + ensemble = await ByteEnsemble.create( + llm1, llm2, + op="prod", + prompt1=b"Hello ", + prompt2=b"Hello ", + a=0.5 + ) + + # Create sampler + eos_tokens = [llm1.byte_vocab[llm1.tokenizer.eos_token_id]] + sampler = ByteEnsembleTokenSampler( + ensemble, + max_tokens=100, + eos_tokens=eos_tokens, + n_particles=10 + ) + + # Run SMC sampling + result = await sampler.smc( + n_particles=10, + ess_threshold=0.5, + max_tokens=100 + ) + ``` + """ + + def __init__( + self, + potential, + proposal: Literal["linear", "abs", "square", "soft n"] = "linear", + n_particles: int = 10, + eos_tokens: List[int] = None, + max_tokens: int = None, + models_equal: bool = False, + ): + super().__init__(target=potential) + self.potential = potential + self.proposal = proposal + self.n_particles = n_particles + self.eos_tokens = eos_tokens or [] + self.max_tokens = max_tokens + self.models_equal = models_equal + + # LRU caches for prefix weights + self.prefix_cache_1 = LRUCache(maxsize=3 * n_particles) + self.prefix_cache_2 = LRUCache(maxsize=3 * n_particles) + + # Track final particle probabilities + self.particle_prefix_log_prob_1 = defaultdict(lambda: float("-inf")) + self.particle_prefix_log_prob_2 = defaultdict(lambda: float("-inf")) + + # Init empty context weights + self.prefix_cache_1[()] = 0.0 + self.prefix_cache_2[()] = 0.0 + + async def start_weight(self) -> float: + """Compute the weight of the empty sequence. + + Returns: + float: Log weight of the empty sequence (always 0.0) + """ + return 0.0 + + async def sample(self, context: List[int], draw=None) -> Tuple[int, float, float]: + """Sample one token from the ensemble distribution. + + This method: + 1. Fetches beam states for both models at the current context + 2. Gets next-token distributions from both beams + 3. Combines distributions using the ensemble operation + 4. Samples a token using the combined distribution + 5. Advances both beams synchronously with the sampled token + 6. Updates caches with new beam states and weights + + Args: + context (List[int]): Current context as list of byte values + draw (callable, optional): Drawing function (not used, for compatibility) + + Returns: + Tuple[int, float, float]: (token, log_weight, log_prob) + - token: Sampled byte value (or EOS) + - log_weight: Log importance weight for SMC + - log_prob: Log probability under proposal distribution + """ + # Get beam states + beam1, beam2 = await self.potential.get_beam_states(context) + logp_1, logp_2 = await beam1.logp_next(), await beam2.logp_next() + + # Get cached prefix weights + ctx_tuple = tuple(context) + log_context_weight_1 = self.prefix_cache_1[ctx_tuple] + log_context_weight_2 = self.prefix_cache_2[ctx_tuple] + + # Compute next-token weights + logws1 = log_context_weight_1 + logp_1.ps + logws2 = log_context_weight_2 + logp_2.ps + + # Compute shaping weight from previous context + log_shaping_weight_prev = ( + 0 + if not context + else self.potential.op(log_context_weight_1, log_context_weight_2) + ) + + # Combine weights using ensemble operation and compute proposal + proposal_weights = self.potential.op(logws1, logws2) - log_shaping_weight_prev + logps = proposal_weights - logsumexp(proposal_weights) + + # Sample token from proposal distribution + token_idx = fast_sample_logprobs(logps)[0] + + # Decode token from trie + token = beam1.states[0].trie.trie.decode[token_idx] + assert ( + token == beam2.states[0].trie.trie.decode[token_idx] + ), "Models must have aligned vocabularies" + + # Advance both beams synchronously with sampled token + next_context = ( + bytes(context + [token]) + if isinstance(token, int) + else bytes(context) + token + ) + self.potential.data_dict_1[next_context] = await (beam1.prune() << token) + self.potential.data_dict_2[next_context] = await (beam2.prune() << token) + + # Update prefix caches + new_ctx_tuple = ctx_tuple + (token,) + self.prefix_cache_1[new_ctx_tuple] = logws1[token_idx] + self.prefix_cache_2[new_ctx_tuple] = logws2[token_idx] + + # Handle EOS tokens + if token in self.eos_tokens: + token = EOS + + # Store final particle weights if sequence is complete + if token == EOS or (self.max_tokens and len(ctx_tuple) + 1 == self.max_tokens): + self.particle_prefix_log_prob_1[ctx_tuple + (token,)] = logws1[token_idx] + self.particle_prefix_log_prob_2[ctx_tuple + (token,)] = logws2[token_idx] + + # Return token, importance weight, and proposal log prob + return token, proposal_weights[token_idx] - logps[token_idx], logps[token_idx] + + async def smc( + self, + n_particles: int, + ess_threshold: float, + max_tokens: int, + critic=None, + **kwargs: Any, + ): + """Run Sequential Monte Carlo inference with byte-level ensemble. + + Args: + n_particles (int): Number of particles to maintain + ess_threshold (float): ESS threshold for resampling (0-1) + max_tokens (int): Maximum tokens per sequence + critic (Potential): Critic potential for guided sampling + **kwargs: Additional arguments passed to SMC + """ + return await EnsembleSMC(self, critic)( + n_particles=n_particles, + ess_threshold=ess_threshold, + max_tokens=max_tokens, + **kwargs, + ) diff --git a/genlm/control/sampler/sequence.py b/genlm/control/sampler/sequence.py index fff7dc6..4fa6af4 100644 --- a/genlm/control/sampler/sequence.py +++ b/genlm/control/sampler/sequence.py @@ -1,4 +1,5 @@ import numpy as np +from typing import Any from genlm.grammar import Float from arsenal.maths import logsumexp from functools import cached_property @@ -346,3 +347,103 @@ def _unpack_particles(particles): ), ) return contexts, logws + + +class EnsembleSMC(SMC): + """EnsembleSMC is a specialized version of SMC for ensemble sampling. + + This class extends SMC to track individual model weights when using + ByteEnsemble with ByteEnsembleTokenSampler. It returns SequencesExt + which includes log_prefix_weights_1 and log_prefix_weights_2. + + Args: + unit_sampler (TokenSampler): The sampler that generates tokens. + critic (Potential, optional): A potential function that guides the generation process. + """ + + async def __call__( + self, + n_particles: int, + ess_threshold: float, + max_tokens: int, + verbosity: int = 0, + json_path: str = None, + **kwargs: Any, + ) -> "SequencesExt": + """Generate sequences using SMC with ensemble weight tracking. + + Args: + n_particles (int): Number of particles to maintain. + ess_threshold (float): ESS threshold for resampling (0-1). + max_tokens (int): Maximum tokens to generate. + verbosity (int): Verbosity level (0=silent, 1=verbose). + json_path (str): Path to save inference visualization data. + **kwargs: Additional arguments for smc_standard. + + Returns: + SequencesExt: Sequences with individual model weights. + """ + sequences = await super().__call__( + n_particles=n_particles, + ess_threshold=ess_threshold, + max_tokens=max_tokens, + verbosity=verbosity, + json_path=json_path, + **kwargs, + ) + + log_prefix_weights_1 = [] + log_prefix_weights_2 = [] + + if hasattr(self.unit_sampler, "particle_prefix_log_prob_1"): + for ctx in sequences.contexts: + ctx_tuple = tuple(ctx) + log_prefix_weights_1.append( + self.unit_sampler.particle_prefix_log_prob_1.get( + ctx_tuple, float("-inf") + ) + ) + + if hasattr(self.unit_sampler, "particle_prefix_log_prob_2"): + for ctx in sequences.contexts: + ctx_tuple = tuple(ctx) + log_prefix_weights_2.append( + self.unit_sampler.particle_prefix_log_prob_2.get( + ctx_tuple, float("-inf") + ) + ) + + return SequencesExt( + sequences.contexts, + sequences.log_weights, + log_prefix_weights_1, + log_prefix_weights_2, + ) + + +@dataclass +class SequencesExt(Sequences): + """Extended Sequences container for ensemble sampling. + + This class extends Sequences to track individual model weights + when using ByteEnsemble with ByteEnsembleTokenSampler. + + Args: + contexts (list): List of token sequences. + log_weights (list): Log importance weights for each sequence. + log_prefix_weights_1 (list): Log weights from first model. + log_prefix_weights_2 (list): Log weights from second model. + """ + + log_prefix_weights_1: list = None + log_prefix_weights_2: list = None + + def __post_init__(self): + super().__post_init__() + if self.log_prefix_weights_1 is not None: + if not isinstance(self.log_prefix_weights_1, np.ndarray): + self.log_prefix_weights_1 = np.array(self.log_prefix_weights_1) + + if self.log_prefix_weights_2 is not None: + if not isinstance(self.log_prefix_weights_2, np.ndarray): + self.log_prefix_weights_2 = np.array(self.log_prefix_weights_2) diff --git a/pyproject.toml b/pyproject.toml index 76fe16d..9a8f914 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "torch", "json-stream", "jsonschema[format-nongpl]", + "cachetools", ] [project.optional-dependencies] diff --git a/tests/potential/test_ensemble.py b/tests/potential/test_ensemble.py new file mode 100644 index 0000000..82843d0 --- /dev/null +++ b/tests/potential/test_ensemble.py @@ -0,0 +1,955 @@ +import pytest +import numpy as np +from unittest.mock import AsyncMock, MagicMock, patch +from genlm.backend import load_model_by_name +from genlm.control import ( + Ensemble, + ByteEnsemble, + ByteEnsembleTokenSampler, + Potential, + PromptedLLM, + convert_to_weighted_logop, + EOS, +) +from genlm.control.sampler.sequence import EnsembleSMC, SequencesExt, Sequences +from genlm.control.potential.built_in.ensemble import ( + split_with_atomic_tokens, + _weighted_extremum, +) +from conftest import MockPotential + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def mock_vocab(): + """Simple vocabulary for testing.""" + return ["a", "b", "c", "d"] + + +@pytest.fixture +def mock_potential_1(mock_vocab): + """Create a mock potential with predefined probabilities.""" + logws = np.log([0.4, 0.3, 0.2, 0.1, 0.001]) + return MockPotential(vocab=mock_vocab, next_token_logws=logws) + + +@pytest.fixture +def mock_potential_2(mock_vocab): + """Create a second mock potential with different probabilities.""" + logws = np.log([0.1, 0.2, 0.3, 0.4, 0.001]) + return MockPotential(vocab=mock_vocab, next_token_logws=logws) + + +# ============================================================================ +# Test Basic Initialization & API +# ============================================================================ + + +@pytest.mark.asyncio +async def test_ensemble_initialization(mock_potential_1, mock_potential_2): + """Test that Ensemble initializes correctly.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="prod", a=0.5) + assert isinstance(ensemble, Potential) + assert ensemble.p1 is mock_potential_1 + assert ensemble.p2 is mock_potential_2 + assert len(ensemble.vocab) == 4 + + +@pytest.mark.asyncio +async def test_ensemble_logws_next(mock_potential_1, mock_potential_2): + """Test that logws_next returns separate log weights from both potentials.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="prod", a=0.5) + p1_logw, p2_logw = await ensemble.logws_next([]) + assert hasattr(p1_logw, "weights") + assert hasattr(p2_logw, "weights") + assert p1_logw.weights.shape == p2_logw.weights.shape + + +@pytest.mark.asyncio +async def test_ensemble_logw_next_raises(mock_potential_1, mock_potential_2): + """Test that logw_next raises NotImplementedError.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="prod", a=0.5) + with pytest.raises(NotImplementedError): + await ensemble.logw_next([]) + + +@pytest.mark.asyncio +async def test_ensemble_batch_logw_next(mock_potential_1, mock_potential_2): + """Test batch_logw_next combines weights from both potentials.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="prod", a=0.5) + results = await ensemble.batch_logw_next([[], ["a"]]) + assert len(results) == 2 + for result in results: + assert hasattr(result, "weights") + assert len(result.weights) == len(ensemble.vocab_eos) + + +@pytest.mark.asyncio +async def test_ensemble_prefix_geometric_mean(mock_potential_1, mock_potential_2): + """Test ensemble prefix with product.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="prod", a=0.5) + logw = await ensemble.prefix([]) + # For product with a=0.5: result = 0.5 * log(p1) + 0.5 * log(p2) + p1_logw = await mock_potential_1.prefix([]) + p2_logw = await mock_potential_2.prefix([]) + expected = 0.5 * p1_logw + 0.5 * p2_logw + np.testing.assert_allclose(logw, expected, rtol=1e-5) + + +@pytest.mark.asyncio +async def test_ensemble_prefix_arithmetic_mean(mock_potential_1, mock_potential_2): + """Test ensemble prefix with sum.""" + ensemble = Ensemble(mock_potential_1, mock_potential_2, op="sum", a=0.5) + logw = await ensemble.prefix([]) + # For sum with a=0.5: result = log(0.5 * exp(p1) + 0.5 * exp(p2)) + p1_logw = await mock_potential_1.prefix([]) + p2_logw = await mock_potential_2.prefix([]) + expected = np.logaddexp(np.log(0.5) + p1_logw, np.log(0.5) + p2_logw) + np.testing.assert_allclose(logw, expected, rtol=1e-5) + + +@pytest.mark.asyncio +async def test_ensemble_with_context(): + """Test ensemble operations with non-empty context.""" + mock_vocab = ["a", "b", "c", "d"] + logws1 = np.log([0.4, 0.3, 0.2, 0.1, 0.001]) + logws2 = np.log([0.1, 0.2, 0.3, 0.4, 0.001]) + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws2) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + context = ["a", "b"] + logw = await ensemble.prefix(context) + assert isinstance(logw, (int, float, np.number)) + assert np.isfinite(logw) + complete_logw = await ensemble.complete(context) + assert isinstance(complete_logw, (int, float, np.number)) + assert np.isfinite(complete_logw) + + +@pytest.mark.asyncio +async def test_ensemble_consistency(): + """Test that ensemble computations are consistent across multiple calls.""" + mock_vocab = ["x", "y"] + logws = np.log([0.6, 0.4, 0.001]) + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + results = [await ensemble.prefix(["x"]) for _ in range(3)] + assert all(np.isclose(results[0], r) for r in results) + + +# ============================================================================ +# Test Ensemble Operations +# ============================================================================ + + +@pytest.mark.asyncio +async def test_token_ensemble_different_operations(): + """Test different ensemble operations with mock models.""" + mock_vocab = ["a", "b", "c"] + logws1 = np.array([np.log(0.6), np.log(0.3), np.log(0.1), -100.0]) + logws2 = np.array([np.log(0.2), np.log(0.5), np.log(0.3), -100.0]) + + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws2) + + # Prod: 0.5 * log(p1) + 0.5 * log(p2) + ensemble_prod = Ensemble(p1, p2, op="prod", a=0.5) + result_prod = await ensemble_prod.batch_logw_next([[]]) + for tok in mock_vocab: + idx_ens = ensemble_prod.lookup[tok] + idx_p1 = p1.lookup[tok] + expected_val = 0.5 * logws1[idx_p1] + 0.5 * logws2[idx_p1] + assert result_prod[0].weights[idx_ens] == pytest.approx(expected_val, abs=1e-6) + + # Sum: log(0.5 * exp(log(p1)) + 0.5 * exp(log(p2))) + ensemble_sum = Ensemble(p1, p2, op="sum", a=0.5) + result_sum = await ensemble_sum.batch_logw_next([[]]) + + for tok in mock_vocab: + idx_ens = ensemble_sum.lookup[tok] + idx_p1 = p1.lookup[tok] + expected_val = np.logaddexp( + np.log(0.5) + logws1[idx_p1], np.log(0.5) + logws2[idx_p1] + ) + assert result_sum[0].weights[idx_ens] == pytest.approx(expected_val, abs=1e-6) + + # Min: For a=0.5, should be close to actual minimum + ensemble_min = Ensemble(p1, p2, op="min", a=0.5) + result_min = await ensemble_min.batch_logw_next([[]]) + + for tok in mock_vocab: + idx_ens = ensemble_min.lookup[tok] + idx_p1 = p1.lookup[tok] + expected_val = np.minimum(logws1[idx_p1], logws2[idx_p1]) + assert result_min[0].weights[idx_ens] == pytest.approx(expected_val, abs=0.5) + + # Max: For a=0.5, should be close to actual maximum + ensemble_max = Ensemble(p1, p2, op="max", a=0.5) + result_max = await ensemble_max.batch_logw_next([[]]) + + for tok in mock_vocab: + idx_ens = ensemble_max.lookup[tok] + idx_p1 = p1.lookup[tok] + expected_val = np.maximum(logws1[idx_p1], logws2[idx_p1]) + assert result_max[0].weights[idx_ens] == pytest.approx(expected_val, abs=0.5) + + +@pytest.mark.asyncio +async def test_ensemble_all_power_means(): + """Test all supported power mean operations.""" + mock_vocab = ["a", "b"] + logws1 = np.log([0.7, 0.3, 0.001]) # Model 1 prefers 'a' + logws2 = np.log([0.3, 0.7, 0.001]) # Model 2 prefers 'b' + + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws2) + + power_means = [ + "pm5", + "pm2.5", + "p-2", + "pm1.5", + "pm0.5", + "pm0.25", + "p0.25", + "p0.5", + "p1.5", + "p2", + "p2.5", + "p3", + "p5", + ] + for op in power_means: + ensemble = Ensemble(p1, p2, op=op, a=0.5) + logw = await ensemble.prefix([]) + assert isinstance(logw, (int, float, np.number)) + assert np.isfinite(logw) + + +# ============================================================================ +# Test Weighting & Parameters +# ============================================================================ + + +@pytest.mark.asyncio +async def test_ensemble_weighting_affects_output(): + """Verify that changing the weighting parameter affects the output.""" + mock_vocab = ["a", "b"] + logws1 = np.array([0.0, -5.0, -100.0]) # Model 1 prefers 'a' + logws2 = np.array([-5.0, 0.0, -100.0]) # Model 2 prefers 'b' + + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws2) + + ensemble_50 = Ensemble(p1, p2, op="prod", a=0.5) + result_50 = await ensemble_50.batch_logw_next([[]]) + ensemble_80 = Ensemble(p1, p2, op="prod", a=0.8) # Weight (0.8) on model 1 + result_80 = await ensemble_80.batch_logw_next([[]]) + ensemble_20 = Ensemble(p1, p2, op="prod", a=0.2) # Weight (0.2) on model 2 + result_20 = await ensemble_20.batch_logw_next([[]]) + + logws_50 = result_50[0].weights + logws_80 = result_80[0].weights + logws_20 = result_20[0].weights + assert not np.allclose(logws_50, logws_80, rtol=1e-5) + assert not np.allclose(logws_50, logws_20, rtol=1e-5) + assert not np.allclose(logws_80, logws_20, rtol=1e-5) + + a_idx_50 = ensemble_50.lookup["a"] + b_idx_50 = ensemble_50.lookup["b"] + a_idx_80 = ensemble_80.lookup["a"] + b_idx_20 = ensemble_20.lookup["b"] + assert logws_80[a_idx_80] > logws_50[a_idx_50] + assert logws_20[b_idx_20] > logws_50[b_idx_50] + diff_50 = abs(logws_50[a_idx_50] - logws_50[b_idx_50]) + diff_80 = logws_80[a_idx_80] - logws_80[a_idx_80 if a_idx_80 == 0 else 1 - a_idx_80] + diff_20 = logws_20[b_idx_20] - logws_20[b_idx_20 if b_idx_20 == 0 else 1 - b_idx_20] + assert abs(diff_80) > diff_50 or abs(diff_20) > diff_50 + + +@pytest.mark.asyncio +async def test_ensemble_with_differently_conditioned_models(): + """Test ensemble with different weighting simulating different prompt strategies.""" + vocab = ["SELECT", "FROM", "WHERE", "JOIN"] + logws1 = np.array([0.0, -1.0, -2.0, -3.0, -100.0]) + logws2 = np.array([-3.0, -2.0, -1.0, 0.0, -100.0]) + p1 = MockPotential(vocab=vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=vocab, next_token_logws=logws2) + ensemble_balanced = Ensemble( + p1, p2, op="prod", a=0.5 + ) # Ensemble with different weights + ensemble_favor_p1 = Ensemble(p1, p2, op="prod", a=0.7) + + result_balanced = await ensemble_balanced.batch_logw_next([[]]) + result_favor_p1 = await ensemble_favor_p1.batch_logw_next([[]]) + combined_balanced = result_balanced[0].weights + combined_favor_p1 = result_favor_p1[0].weights + + select_idx_bal = ensemble_balanced.lookup[ + "SELECT" + ] # When favoring p1, SELECT is more likely + select_idx_fav = ensemble_favor_p1.lookup["SELECT"] + join_idx_bal = ensemble_balanced.lookup[ + "JOIN" + ] # When favoring p1, JOIN is less likely + join_idx_fav = ensemble_favor_p1.lookup["JOIN"] + + assert combined_favor_p1[select_idx_fav] > combined_balanced[select_idx_bal] + assert combined_favor_p1[join_idx_fav] < combined_balanced[join_idx_bal] + + +# ============================================================================ +# Test Vocabulary Handling +# ============================================================================ + + +@pytest.mark.asyncio +async def test_ensemble_warns_on_different_vocabularies(): + """Test Ensemble warns when using potentials with different vocabularies.""" + vocab1 = ["a", "b", "c", "d"] + vocab2 = ["a", "b", "x", "y"] + logws1 = np.log([0.25, 0.25, 0.25, 0.25, 0.001]) + logws2 = np.log([0.25, 0.25, 0.25, 0.25, 0.001]) + p1 = MockPotential(vocab=vocab1, next_token_logws=logws1) + p2 = MockPotential(vocab=vocab2, next_token_logws=logws2) + with pytest.warns(UserWarning, match="different vocabularies"): + with pytest.raises((KeyError, AssertionError)): + _ = Ensemble(p1, p2, op="prod", a=0.5) + + +@pytest.mark.asyncio +async def test_ensemble_vocab_alignment(mock_vocab): + """Test that ensemble handles vocabulary alignment correctly.""" + logws = np.log([0.25, 0.25, 0.25, 0.25, 0.001]) + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws) + + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + + # Vocabulary indices should be correctly aligned + assert len(ensemble.p1_vocab_idxs) == len(ensemble.vocab_eos) + assert len(ensemble.p2_vocab_idxs) == len(ensemble.vocab_eos) + assert ensemble.p1_vocab_idxs == ensemble.p2_vocab_idxs + + +@pytest.mark.asyncio +async def test_ensemble_respects_vocab_alignment(): + """Verify ensemble correctly handles vocabulary alignment with reordering.""" + vocab = ["x", "y", "z"] + logws1 = np.array([0.0, -1.0, -2.0, -100.0]) + logws2 = np.array([-1.0, 0.0, -2.0, -100.0]) + + p1 = MockPotential(vocab=vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=vocab, next_token_logws=logws2) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + result = await ensemble.batch_logw_next([[]]) + combined = result[0].weights + + # Each token should get correct combined weight + for tok in ["x", "y", "z"]: + ensemble_idx = ensemble.lookup[tok] + p1_idx = p1.lookup[tok] + p2_idx = p2.lookup[tok] + expected = 0.5 * logws1[p1_idx] + 0.5 * logws2[p2_idx] + actual = combined[ensemble_idx] + assert actual == pytest.approx(expected, abs=1e-5), f"Token {tok} mismatch" + + +# ============================================================================ +# Test Integration with Real Models (GPT-2) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_token_ensemble_with_different_prompts(): + """Test token-level Ensemble with different prompts - basic functionality check.""" + llm1 = PromptedLLM.from_name("gpt2") + llm2 = PromptedLLM.from_name("gpt2") + llm1.set_prompt_from_str("Write a SQL query: ") + llm2.set_prompt_from_str("SQL code: ") + + ensemble = Ensemble(llm1, llm2, op="prod", a=0.5) + assert ensemble.p1 is llm1 + assert ensemble.p2 is llm2 + + ensemble_result = await ensemble.batch_logw_next([[]]) + ensemble_logws = ensemble_result[0].weights + + assert len(ensemble_logws) > 0 + assert np.all(np.isfinite(ensemble_logws)) + assert len(ensemble_logws) == len( + ensemble.vocab_eos + ) # Should have same length as vocab + + +@pytest.mark.asyncio +async def test_token_ensemble_complementary_prompts(): + """Test token-level Ensemble combining complementary prompting strategies.""" + llm1 = PromptedLLM.from_name("gpt2") + llm2 = PromptedLLM.from_name("gpt2") + + llm1.set_prompt_from_str("Task: Generate structured SQL.\n") + llm2.set_prompt_from_str("Task: Generate correct SQL.\n") + + ensemble = Ensemble(llm1, llm2, op="prod", a=0.5) + p1_result = await llm1.batch_logw_next([[]]) + p2_result = await llm2.batch_logw_next([[]]) + ensemble_result = await ensemble.batch_logw_next([[]]) + + p1_logws = p1_result[0].weights + p2_logws = p2_result[0].weights + ensemble_logws = ensemble_result[0].weights + + assert not np.allclose(ensemble_logws, p1_logws, rtol=0.1) + assert not np.allclose(ensemble_logws, p2_logws, rtol=0.1) + assert np.all(np.isfinite(ensemble_logws)) + + +# ============================================================================ +# Test ByteEnsemble +# ============================================================================ + + +@pytest.mark.asyncio +async def test_byte_ensemble_creation(): + """Test ByteEnsemble creation with identical prompts.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + prompt = b"Test" + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=prompt, prompt2=prompt, a=0.5 + ) + assert ensemble.p1 is llm1 + assert ensemble.p2 is llm2 + assert len(ensemble.vocab) == 256 + assert isinstance(ensemble.vocab, list) + assert all(isinstance(v, int) and 0 <= v < 256 for v in ensemble.vocab) + assert b"" in ensemble.data_dict_1 + assert b"" in ensemble.data_dict_2 + + +@pytest.mark.asyncio +async def test_byte_ensemble_different_prompts(): + """Test ByteEnsemble with different prompts - the key use case for ensembling.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + + prompt1 = b"Write a SQL query to find all users: " + prompt2 = b"SQL: Find all users in the database: " + + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=prompt1, prompt2=prompt2, a=0.5 + ) + assert ensemble.p1 is llm1 + assert ensemble.p2 is llm2 + assert len(ensemble.vocab) == 256 + assert b"" in ensemble.data_dict_1 + assert b"" in ensemble.data_dict_2 + beam1, beam2 = await ensemble.get_beam_states([]) + assert beam1 is not None + assert beam2 is not None + + +@pytest.mark.asyncio +async def test_byte_ensemble_get_beam_states(): + """Test that ByteEnsemble.get_beam_states() provides access to beams.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=b"Hi", prompt2=b"Hello", a=0.5 + ) + beam1, beam2 = await ensemble.get_beam_states([]) + assert beam1 is not None + assert beam2 is not None + assert hasattr(beam1, "states") + assert hasattr(beam2, "states") + assert len(beam1) > 0 + assert len(beam2) > 0 + + +@pytest.mark.asyncio +async def test_byte_ensemble_token_sampler_initialization(): + """Test ByteEnsembleTokenSampler initialization with different prompts.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=b"Answer: ", prompt2=b"Response: ", a=0.5 + ) + eos_tokens = [llm1.byte_vocab[llm1.tokenizer.eos_token_id]] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=50, eos_tokens=eos_tokens, n_particles=5 + ) + assert sampler.potential is ensemble + assert sampler.max_tokens == 50 + assert sampler.eos_tokens == eos_tokens + assert sampler.n_particles == 5 + # check caches + assert () in sampler.prefix_cache_1 + assert () in sampler.prefix_cache_2 + assert sampler.prefix_cache_1[()] == 0.0 + assert sampler.prefix_cache_2[()] == 0.0 + + +@pytest.mark.asyncio +async def test_byte_ensemble_sampler_sample(): + """Test ByteEnsembleTokenSampler samples with different prompts.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=b"The cat is ", prompt2=b"A cat is ", a=0.5 + ) + eos_tokens = [llm1.byte_vocab[llm1.tokenizer.eos_token_id]] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=10, eos_tokens=eos_tokens, n_particles=3 + ) + token, logw, logp = await sampler.sample([]) + assert isinstance(token, (int, bytes)) + assert isinstance(logw, (int, float, np.number)) + assert isinstance(logp, (int, float, np.number)) + assert np.isfinite(logw) + assert np.isfinite(logp) + if isinstance(token, int): + next_context_bytes = bytes([token]) + else: + next_context_bytes = token + + assert next_context_bytes in ensemble.data_dict_1 + assert next_context_bytes in ensemble.data_dict_2 + + +@pytest.mark.asyncio +async def test_byte_ensemble_weighted_different_prompts(): + """Test ByteEnsemble with unequal weights on different prompts.""" + llm1 = load_model_by_name("gpt2", backend="hf") + llm2 = load_model_by_name("gpt2", backend="hf") + + prompt1 = b"Correct approach: " + prompt2 = b"Alternative: " + + ensemble = await ByteEnsemble.create( + llm1, llm2, op="prod", prompt1=prompt1, prompt2=prompt2, a=0.7 + ) + + assert ensemble.p1 is llm1 + assert ensemble.p2 is llm2 + + eos_tokens = [llm1.byte_vocab[llm1.tokenizer.eos_token_id]] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=5, eos_tokens=eos_tokens, n_particles=2 + ) + token, logw, logp = await sampler.sample([]) + assert isinstance(token, (int, bytes)) + assert np.isfinite(logw) + + +# ============================================================================ +# Test Realistic Ensemble Applications +# ============================================================================ + + +@pytest.mark.asyncio +async def test_ensemble_with_different_model_preferences(): + """Test ensemble where models have same vocab but different preferences.""" + vocab = ["a", "b", "c", "d"] + logws1 = np.array([0.0, -0.5, -2.0, -3.0, -100.0]) # prefers 'a' > 'b' > 'c' > 'd' + logws2 = np.array([-3.0, -2.0, -0.5, 0.0, -100.0]) # prefers 'd' > 'c' > 'b' > 'a' + p1 = MockPotential(vocab=vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=vocab, next_token_logws=logws2) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + + for tok in vocab: + assert tok in ensemble.vocab + result = await ensemble.batch_logw_next([[]]) + combined = result[0].weights + + for tok in vocab: + ensemble_idx = ensemble.lookup[tok] + p1_idx = p1.lookup[tok] + p2_idx = p2.lookup[tok] + expected = 0.5 * logws1[p1_idx] + 0.5 * logws2[p2_idx] + actual = combined[ensemble_idx] + assert actual == pytest.approx(expected, abs=1e-5), f"Token {tok} mismatch" + + b_idx = ensemble.lookup["b"] + c_idx = ensemble.lookup["c"] + a_idx = ensemble.lookup["a"] + d_idx = ensemble.lookup["d"] + assert combined[b_idx] > min(combined[a_idx], combined[d_idx]) + assert combined[c_idx] > min(combined[a_idx], combined[d_idx]) + + +@pytest.mark.asyncio +async def test_ensemble_with_complementary_knowledge(): + """Test ensemble where models show different performance on different tokens.""" + vocab1 = ["SELECT", "FROM", "WHERE", "LIMIT"] + logws1 = np.array( + [ + np.log(0.4), # SELECT: confident + np.log(0.3), # FROM: confident + np.log(0.2), # WHERE: confident + np.log(0.1), # LIMIT: less confident + -100.0, + ] + ) + vocab2 = ["SELECT", "FROM", "WHERE", "LIMIT"] + logws2 = np.array( + [ + np.log(0.1), # SELECT: not confident + np.log(0.2), # FROM: not confident + np.log(0.3), # WHERE: somewhat confident + np.log(0.4), # LIMIT: very confident + -100.0, + ] + ) + + p1 = MockPotential(vocab=vocab1, next_token_logws=logws1) + p2 = MockPotential(vocab=vocab2, next_token_logws=logws2) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + result = await ensemble.batch_logw_next([[]]) + combined = result[0].weights + + for tok in vocab1: + idx = ensemble.lookup[tok] + prob = np.exp(combined[idx]) + assert prob > 0.05, f"{tok} should have reasonable probability in ensemble" + assert prob < 0.95, f"{tok} shouldn't dominate in balanced ensemble" + + p1_result = await p1.batch_logw_next([[]]) + p2_result = await p2.batch_logw_next([[]]) + + assert not np.allclose(combined, p1_result[0].weights, rtol=0.1) + assert not np.allclose(combined, p2_result[0].weights, rtol=0.1) + + +@pytest.mark.asyncio +async def test_ensemble_helps_uncertain_model(): + """Test that ensembling helps when one model is uncertain but the other is confident.""" + mock_vocab = ["correct", "wrong1", "wrong2"] + logws1 = np.array( + [np.log(0.33), np.log(0.33), np.log(0.34), -100.0] + ) # Model 1 is uncertain + logws2 = np.array( + [np.log(0.9), np.log(0.05), np.log(0.05), -100.0] + ) # Model 2 is confident + + p1 = MockPotential(vocab=mock_vocab, next_token_logws=logws1) + p2 = MockPotential(vocab=mock_vocab, next_token_logws=logws2) + ensemble = Ensemble(p1, p2, op="prod", a=0.5) + result = await ensemble.batch_logw_next([[]]) + combined = result[0].weights + + correct_idx = ensemble.lookup["correct"] + wrong1_idx = ensemble.lookup["wrong1"] + # ensemble should favor 'correct' more than model 1 alone + p1_result = await p1.batch_logw_next([[]]) + p1_logws = p1_result[0].weights + p1_correct = p1_logws[p1.lookup["correct"]] + p1_wrong1 = p1_logws[p1.lookup["wrong1"]] + ensemble_correct = combined[correct_idx] + ensemble_wrong1 = combined[wrong1_idx] + + # Ensemble should have stronger preference for 'correct' than uncertain model 1 + p1_gap = p1_correct - p1_wrong1 + ensemble_gap = ensemble_correct - ensemble_wrong1 + assert ( + ensemble_gap > p1_gap + ), "Ensemble should be more confident than uncertain model" + + # But less confident than model 2 alone + p2_result = await p2.batch_logw_next([[]]) + p2_logws = p2_result[0].weights + p2_gap = p2_logws[p2.lookup["correct"]] - p2_logws[p2.lookup["wrong1"]] + assert ( + ensemble_gap < p2_gap + ), "Ensemble should be less confident than very confident model" + + +# ============================================================================ +# Test Utility Functions +# ============================================================================ + + +def test_convert_to_weighted_logop_invalid_a(): + """Test that invalid 'a' parameter raises ValueError.""" + with pytest.raises(ValueError, match="variable a should be between 0 and 1"): + convert_to_weighted_logop("prod", a=1.5) + + with pytest.raises(ValueError, match="variable a should be between 0 and 1"): + convert_to_weighted_logop("prod", a=-0.1) + + +def test_convert_to_weighted_logop_invalid_op(): + """Test that invalid operation raises ValueError.""" + with pytest.raises(ValueError, match="Invalid operation"): + convert_to_weighted_logop("invalid_op", a=0.5) + + +def test_convert_to_weighted_logop_operations(): + """Test convert_to_weighted_logop returns correct operation for prod.""" + x = np.log(np.array([0.3, 0.7])) + y = np.log(np.array([0.6, 0.4])) + + # Test prod with analytical verification + op_prod = convert_to_weighted_logop("prod", a=0.5) + result_prod = op_prod(x, y) + expected_prod = 0.5 * x + 0.5 * y + np.testing.assert_allclose(result_prod, expected_prod, rtol=1e-5) + + +@pytest.mark.asyncio +async def test_byte_ensemble_token_sampler_start_weight(): + """Test ByteEnsembleTokenSampler.start_weight() returns 0.0.""" + llm = load_model_by_name("gpt2", backend="hf") + ensemble = await ByteEnsemble.create( + llm, llm, op="prod", prompt1=b"Hi", prompt2=b"Hi", a=0.5 + ) + eos_tokens = [llm.byte_vocab[llm.tokenizer.eos_token_id]] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=10, eos_tokens=eos_tokens, n_particles=5 + ) + start_weight = await sampler.start_weight() + assert start_weight == 0.0 + + +@pytest.mark.asyncio +async def test_byte_ensemble_sampler_eos_handling(): + """Test ByteEnsembleTokenSampler properly handles EOS tokens and max_tokens.""" + llm = load_model_by_name("gpt2", backend="hf") + ensemble = await ByteEnsemble.create( + llm, llm, op="prod", prompt1=b"Hi", prompt2=b"Hi", a=0.5 + ) + eos_byte = llm.byte_vocab[llm.tokenizer.eos_token_id] + sampler = ByteEnsembleTokenSampler( + ensemble, max_tokens=5, eos_tokens=[eos_byte], n_particles=5 + ) + _, _, _ = await sampler.sample([]) + if len(sampler.particle_prefix_log_prob_1) > 0: + assert len(sampler.particle_prefix_log_prob_1) >= 0 + assert len(sampler.particle_prefix_log_prob_2) >= 0 + + +@pytest.mark.asyncio +async def test_byte_ensemble_empty_beam_error(): + """Test ByteEnsemble raises RuntimeError when beam is empty after prefill.""" + mock_llm = MagicMock() + mock_llm.byte_vocab = {0: b"a"} + mock_llm.tokenizer.eos_token_id = 0 + empty_beam = MagicMock() + empty_beam.prefill = AsyncMock(return_value=[]) + with patch( + "genlm.control.potential.built_in.ensemble.ByteBeamState.initial", + AsyncMock(return_value=empty_beam), + ): + with pytest.raises(RuntimeError, match="Beam1 is empty after prefill"): + await ByteEnsemble.create( + mock_llm, mock_llm, op="prod", prompt1=b"test", prompt2=b"test", a=0.5 + ) + + +def test_split_with_atomic_tokens_overlapping(): + """Test split_with_atomic_tokens with overlapping tokens.""" + with pytest.warns(UserWarning, match="Overlapping atomic tokens detected"): + result = split_with_atomic_tokens(b"ABC", [b"A", b"AB"]) + assert result == [b"A", 66, 67] + + +def test_split_with_atomic_tokens_no_match(): + """Test split_with_atomic_tokens when no atomic tokens match.""" + result = split_with_atomic_tokens(b"XYZ", [b"A", b"B"]) + assert result == [88, 89, 90] + + +def test_weighted_extremum_different_weights(): + """Test _weighted_extremum with different weight values.""" + x = np.array([-1.0, -2.0, -3.0]) + y = np.array([-2.0, -1.5, -3.5]) + max_op_favoring_y = _weighted_extremum(np.maximum, a=0.7) + result = max_op_favoring_y(x, y) + expected = (2 * 0.7 - 1) * y + 2 * (1 - 0.7) * np.maximum(x, y) + np.testing.assert_allclose(result, expected, rtol=1e-5) + min_op_favoring_x = _weighted_extremum(np.minimum, a=0.3) + result2 = min_op_favoring_x(x, y) + expected2 = (1 - 2 * 0.3) * x + 2 * 0.3 * np.minimum(x, y) + np.testing.assert_allclose(result2, expected2, rtol=1e-5) + + +@pytest.mark.asyncio +async def test_ensemble_smc_weight_extraction(): + """Test EnsembleSMC extracts individual model weights correctly.""" + from genlm.control.sampler.token import TokenSampler + + class MockTokenSampler(TokenSampler): + def __init__(self): + self.particle_prefix_log_prob_1 = { + ("a",): -1.0, + ("b",): -2.0, + } + self.particle_prefix_log_prob_2 = { + ("a",): -1.5, + ("b",): -2.5, + } + + async def start_weight(self): + return 0.0 + + async def sample(self, context, draw=None): + return EOS, 0.0, 0.0 + + mock_sampler = MockTokenSampler() + smc = EnsembleSMC(mock_sampler, None) + mock_sequences = Sequences( + contexts=[["a"], ["b"]], + log_weights=[-0.5, -0.7], + ) + + with patch.object( + EnsembleSMC.__bases__[0], + "__call__", + AsyncMock(return_value=mock_sequences), + ): + result = await smc(n_particles=2, ess_threshold=0.5, max_tokens=10) + + assert isinstance(result, SequencesExt) + assert hasattr(result, "log_prefix_weights_1") + assert hasattr(result, "log_prefix_weights_2") + assert len(result.log_prefix_weights_1) == 2 + assert len(result.log_prefix_weights_2) == 2 + assert result.log_prefix_weights_1[0] == -1.0 + assert result.log_prefix_weights_1[1] == -2.0 + assert result.log_prefix_weights_2[0] == -1.5 + assert result.log_prefix_weights_2[1] == -2.5 + + +def test_sequences_ext_post_init(): + """Test SequencesExt.__post_init__ converts lists to numpy arrays.""" + seq = SequencesExt( + contexts=[["a", "b"], ["c", "d"]], + log_weights=[0.1, 0.2], + log_prefix_weights_1=[0.15, 0.25], + log_prefix_weights_2=[0.12, 0.22], + ) + assert isinstance(seq.log_prefix_weights_1, np.ndarray) + assert isinstance(seq.log_prefix_weights_2, np.ndarray) + seq2 = SequencesExt(contexts=[["a"]], log_weights=[0.1], log_prefix_weights_1=None) + assert seq2.log_prefix_weights_1 is None + + +def test_sequences_ext_post_init_with_none(): + """Test SequencesExt.__post_init__ handles None values correctly.""" + seq = SequencesExt( + contexts=[["a", "b"]], + log_weights=[0.1], + log_prefix_weights_1=None, + log_prefix_weights_2=None, + ) + assert seq.log_prefix_weights_1 is None + assert seq.log_prefix_weights_2 is None + + +@pytest.mark.asyncio +async def test_byte_ensemble_cleanup_cache_deletes_short_keys(): + """Test ByteEnsemble._cleanup_cache() deletes short keys.""" + gpt2 = load_model_by_name("gpt2") + prompt1 = b"The capital of France is" + prompt2 = b"Paris, the capital city of France, is" + ensemble = await ByteEnsemble.create( + gpt2, gpt2, "sum", prompt1, prompt2, a=0.5, K=3 + ) + ensemble.data_dict_1 = { + (1,): "short1", + (1, 2): "short2", + (1, 2, 3): "keep3", + (1, 2, 3, 4): "keep4", + (1, 2, 3, 4, 5): "keep5", + (1, 2, 3, 4, 5, 6): "keep6", + } + ensemble.data_dict_2 = { + (10,): "short1", + (10, 20): "short2", + (10, 20, 30): "keep3", + (10, 20, 30, 40): "keep4", + (10, 20, 30, 40, 50): "keep5", + (10, 20, 30, 40, 50, 60): "keep6", + } + await ensemble._cleanup_cache() + for d in [ensemble.data_dict_1, ensemble.data_dict_2]: + for k in d.keys(): + assert len(k) >= 4, f"Key {k} should have been deleted" + assert len(ensemble.data_dict_1) == 3 + assert len(ensemble.data_dict_2) == 3 + + +@pytest.mark.asyncio +async def test_byte_ensemble_empty_beam_error_covered(): + gpt2 = load_model_by_name("gpt2") + prompt1 = b"\xff\xfe\xfd" # Invalid UTF-8 bytes + prompt2 = b"Test" + with pytest.raises(RuntimeError, match="is empty after prefill"): + await ByteEnsemble.create( + gpt2, gpt2, "sum", prompt1, prompt2, a=0.5, K=1, prune_threshold=100.0 + ) + + +@pytest.mark.asyncio +async def test_byte_ensemble_sampler_stores_particle_weights(): + """Test ByteEnsembleTokenSampler stores particle weights at EOS.""" + gpt2 = load_model_by_name("gpt2") + prompt1 = b"Hi" + prompt2 = b"Hi" + + ensemble = await ByteEnsemble.create( + gpt2, gpt2, "sum", prompt1, prompt2, a=0.5, K=3 + ) + sampler = ByteEnsembleTokenSampler(ensemble, max_tokens=1) + context = [] + token, _, _ = await sampler.sample(context) + new_ctx_tuple = (token,) + assert new_ctx_tuple in sampler.particle_prefix_log_prob_1 + assert new_ctx_tuple in sampler.particle_prefix_log_prob_2 + + +@pytest.mark.asyncio +async def test_byte_ensemble_sampler_eos_conversion(): + """Test ByteEnsembleTokenSampler EOS conversion path.""" + gpt2 = load_model_by_name("gpt2") + prompt1 = b"Hi" + prompt2 = b"Hi" + ensemble = await ByteEnsemble.create( + gpt2, gpt2, "sum", prompt1, prompt2, a=0.5, K=3 + ) + sampler = ByteEnsembleTokenSampler(ensemble) + context = [] + token, _, _ = await sampler.sample(context) + assert token is not None + + +@pytest.mark.asyncio +async def test_byte_ensemble_sampler_smc_calls_ensemble_smc(): + """Test ByteEnsembleTokenSampler.smc() method invokes EnsembleSMC.""" + gpt2 = load_model_by_name("gpt2") + prompt1 = b"Hi" + prompt2 = b"Hi" + ensemble = await ByteEnsemble.create( + gpt2, gpt2, "sum", prompt1, prompt2, a=0.5, K=3 + ) + sampler = ByteEnsembleTokenSampler(ensemble) + assert hasattr(sampler, "smc") + assert callable(sampler.smc) + try: + result = await sampler.smc( + n_particles=1, + ess_threshold=0.5, + max_tokens=1, + critic=None, + ) + assert isinstance(result, SequencesExt) + except (AssertionError, KeyError) as e: + if "Beam is empty" in str(e) or "not found in cache" in str(e): + pass + else: + raise