diff --git a/genlm/bytes/byte_lm/beam.py b/genlm/bytes/byte_lm/beam.py index 7d0a0e3..8a6acc3 100644 --- a/genlm/bytes/byte_lm/beam.py +++ b/genlm/bytes/byte_lm/beam.py @@ -1,10 +1,10 @@ import asyncio +import warnings import numpy as np from arsenal import colors from dataclasses import dataclass from scipy.special import logsumexp as scipy_logsumexp from functools import cached_property -from genlm.backend.tokenization.bytes import get_byte_vocab from ..util import logsumexp, LazyByteProbs from ..trie import AsyncTokenByteTrie @@ -22,7 +22,7 @@ class BeamParams: prune_threshold (float, optional): Probability threshold for pruning candidates. Candidates with probability below this are removed. Defaults to 0.0 verbose (bool, optional): Whether to print the beam state at each step. Defaults to False - eos_tokens (list[bytes], optional): List of tokens that should be treated as EOS. When configured, + eos_byte_strings (list[bytes], optional): List of tokens that should be treated as EOS. When configured, EOS tokens will terminate generation when sampled. Defaults to None heal (bool, optional): Whether to enable adaptive token healing. Defaults to True heal_max_backoff (int, optional): Maximum number of bytes to back off when healing. Defaults to None @@ -32,15 +32,29 @@ class BeamParams: K: int prune_threshold: float = 0.0 verbose: bool = False - eos_tokens: list[bytes] = None + eos_byte_strings: list[bytes] = None heal: bool = True heal_max_backoff: int | None = None # Optional cap on how many intra-partial commits are allowed during a # single healing attempt. None means unlimited. Set to 0 to disable # multi-split behavior (i.e., single-split only). heal_max_splits: int | None = None + # Deprecated alias for eos_byte_strings + eos_tokens: list[bytes] = None def __post_init__(self): + if self.eos_tokens is not None: + if self.eos_byte_strings is not None: + raise TypeError( + "Cannot specify both 'eos_byte_strings' and the deprecated 'eos_tokens'." + ) + warnings.warn( + "'eos_tokens' is deprecated, use 'eos_byte_strings' instead.", + DeprecationWarning, + stacklevel=2, + ) + self.eos_byte_strings = self.eos_tokens + self.eos_tokens = None if self.prune_threshold < 0: raise ValueError( f"prune_threshold must be non-negative, got {self.prune_threshold}" @@ -48,7 +62,7 @@ def __post_init__(self): self.log_prune_threshold = ( np.log(self.prune_threshold) if self.prune_threshold > 0 else -np.inf ) - self.eos_tokens = set(self.eos_tokens) if self.eos_tokens else set() + self.eos_byte_strings = set(self.eos_byte_strings) if self.eos_byte_strings else set() class ByteBeamState(StatefulByteLM): @@ -71,7 +85,7 @@ async def initial(cls, llm, params, trie_opts=None): """Creates initial beam state. Args: - llm (StatefulTokenizedLM): Token-level language model to use. + llm (genlm.backend.AsyncLM): Token-level language model to use. params (BeamParams): Beam search parameters. trie_opts (dict, optional): Additional keyword arguments passed to AsyncTokenByteTrie.from_vocab. For example, {"max_batch_size": 100}. @@ -81,10 +95,11 @@ async def initial(cls, llm, params, trie_opts=None): """ # Handle EOS tokens trie_opts = trie_opts or {} - trie_opts["eos_tokens"] = params.eos_tokens + trie_opts["eos_byte_strings"] = params.eos_byte_strings + # Use llm.byte_vocab which contains Token objects (supports duplicate byte strings) async_trie = AsyncTokenByteTrie.from_vocab( - get_byte_vocab(llm.tokenizer), **trie_opts + llm.byte_vocab, **trie_opts ) state = LazyTrieState.initial(llm, async_trie, mode=TrieMode.WITH_EOS) return cls([await state.materialize()], params) @@ -162,18 +177,22 @@ async def logp_next(self): async def extend(self, logZ): """Attempts to advance each candidate in the beam by a token (EOT). - For each candididate with EOT available, this ends the current token and + For each candidate with EOT available, this ends the current token and starts a new one in preparation for the next byte. + With duplicate tokens (multiple token IDs mapping to the same byte string), + a single state can have multiple extensions - one for each possible token. + Args: - logZ (float): Current estimated of the partition function for pruning + logZ (float): Current estimate of the partition function for pruning Returns: (list[LazyTrieState]): New candidate states after extension """ extends = [] for state in self: - if new_state := state.extend(): + # extend_all() returns all possible extensions (one per token at this position) + for new_state in state.extend_all(): logZ = np.logaddexp(logZ, new_state.weight) extends.append(new_state) diff --git a/genlm/bytes/byte_lm/heal.py b/genlm/bytes/byte_lm/heal.py index a9bfe28..951d6af 100644 --- a/genlm/bytes/byte_lm/heal.py +++ b/genlm/bytes/byte_lm/heal.py @@ -2,13 +2,26 @@ from ..util import format_byte +def _find_all_eot_edges(children, eot_sentinel): + """Find all EOT edges in children dict. Returns list of (node, token_id) tuples. + + EOT edges are stored as tuple keys: (eot_sentinel, token_id). + With duplicate tokens, multiple token IDs can map to the same byte string. + """ + results = [] + for key, node in children.items(): + if isinstance(key, tuple) and key[0] == eot_sentinel: + results.append((node, key[1])) + return results + + class TokenHealer: """Handles adaptive token healing for ByteBeamState. Token healing finds alternative tokenizations when the current tokenization cannot consume the next byte. It works by: 1. Trying different "backoff" positions k (commit partial[:k] as a token) 2. Replaying the remaining bytes (partial[k:]) from fresh root - 3. Using extend() when stuck to commit intermediate tokens + 3. Using extend_all() when stuck to commit intermediate tokens 4. Finally consuming the target next_byte Args: @@ -69,6 +82,9 @@ async def try_heal(self, state, next_byte: int): async def _try_at_k(self, state, trie, base_weight: float, k: int, next_byte: int): """Try healing by committing partial[:k], replaying partial[k:], then consuming next_byte. + With duplicate tokens, there can be multiple EOT edges at position k. + This method tries all of them until one succeeds. + Args: state: The original state to heal from trie: The trie structure (state.trie.trie) @@ -89,18 +105,49 @@ async def _try_at_k(self, state, trie, base_weight: float, k: int, next_byte: in if node_at_k is None: return None # Path doesn't exist - # Check if there's an EOT at position k - eot_node = children[node_at_k].get(trie.eot_token) - if eot_node is None: + # Find all EOT edges at position k + # With duplicate tokens, multiple token IDs can map to the same byte string + eot_edges = _find_all_eot_edges(children[node_at_k], trie.eot_sentinel) + if not eot_edges: if self.verbose: print(f"[heal] k={k}: no EOT at {bytes(partial[:k])!r}") return None - # Commit at position k + # Try each possible EOT edge + for eot_node, eot_token_id in eot_edges: + result = await self._try_eot_at_k( + state, trie, base_weight, k, next_byte, eot_node, eot_token_id + ) + if result is not None: + return result + + return None + + async def _try_eot_at_k( + self, state, trie, base_weight: float, k: int, next_byte: int, + eot_node: int, eot_token_id: int + ): + """Try healing with a specific EOT edge at position k. + + Args: + state: The original state to heal from + trie: The trie structure (state.trie.trie) + base_weight: Precomputed weight after undoing current path + k: Backoff position + next_byte: The byte we want to consume + eot_node: The EOT node to commit + eot_token_id: The token ID for this EOT + + Returns: + LazyTrieState if successful, None otherwise + """ + partial = state.partial + + # Commit at position k with this specific token weight_after_commit = base_weight + ( state.mass[eot_node] - state.mass[trie.root] ) - token_id = int(trie.leaf2token_id[eot_node]) + token_id = int(eot_token_id) current = LazyTrieState( lm_state=(state.lm_state << token_id), @@ -114,8 +161,10 @@ async def _try_at_k(self, state, trie, base_weight: float, k: int, next_byte: in current = await current.materialize() if self.verbose: + # trie.decode contains Token objects, get byte_string for display + token_bytes = trie.decode[token_id].byte_string print( - f"[heal] k={k}: commit {trie.decode[token_id]!r}, w={weight_after_commit:.2f}" + f"[heal] k={k}: commit {token_bytes!r} (token_id={token_id}), w={weight_after_commit:.2f}" ) # Replay suffix bytes then consume next_byte @@ -134,26 +183,32 @@ async def _try_at_k(self, state, trie, base_weight: float, k: int, next_byte: in print(f"[heal] k={k}: hit max_splits={self.max_splits}") return None - extended = current.extend() - if extended is None: + # extend_all() returns list of all possible extensions + extensions = current.extend_all() + if not extensions: if self.verbose: print(f"[heal] k={k}: can't extend at {bytes(current.partial)!r}") return None - current = await extended.materialize() + # Try each possible extension (duplicates = same split, different token_id) splits_used += 1 - if self.verbose: - print(f"[heal] k={k}: split #{splits_used}, w={current.weight:.2f}") - - # Retry consuming the byte after extend - next_state = current << b - if next_state is None: + for extended in extensions: + materialized = await extended.materialize() + if self.verbose: + print(f"[heal] k={k}: split #{splits_used}, w={materialized.weight:.2f}") + + # Retry consuming the byte after extend + next_state = materialized << b + if next_state is not None: + current = next_state + break # Found a valid extension + else: + # None of the extensions worked if self.verbose: print( f"[heal] k={k}: couldn't consume {format_byte(b)} even after extend" ) return None - current = next_state if self.verbose: print(f"[heal] SUCCESS at k={k}: w={current.weight:.2f}") diff --git a/genlm/bytes/byte_lm/lm_state.py b/genlm/bytes/byte_lm/lm_state.py index a059221..69d69cf 100644 --- a/genlm/bytes/byte_lm/lm_state.py +++ b/genlm/bytes/byte_lm/lm_state.py @@ -67,8 +67,9 @@ async def logp_next(self): return await self.model.next_token_logprobs(self.context) def __repr__(self): + # byte_vocab contains Token objects, so we access .byte_string return colors.purple % ( - "|".join([escape(self.model.byte_vocab[x]) for x in self.context]) + "|".join([escape(self.model.byte_vocab[x].byte_string) for x in self.context]) ) diff --git a/genlm/bytes/byte_lm/trie_state.py b/genlm/bytes/byte_lm/trie_state.py index 8d7375d..10feeca 100644 --- a/genlm/bytes/byte_lm/trie_state.py +++ b/genlm/bytes/byte_lm/trie_state.py @@ -103,9 +103,24 @@ def actions(self): """Returns possible byte transitions from current node.""" return self.children[self.node] - def get_EOT(self): - """Returns the end-of-token node if available from current position in the trie.""" - return self.children[self.node].get(self.trie.trie.eot_token) + def get_all_EOT(self): + """Returns all EOT edges from the current position in the trie. + + With duplicate tokens, multiple token IDs can map to the same byte string, + resulting in multiple EOT edges at the same node. + + Returns: + list[tuple[int, int]]: List of (eot_node, token_id) tuples for each EOT edge. + Empty list if no EOT edges exist. + """ + eot_sentinel = self.trie.trie.eot_sentinel + results = [] + for key, node in self.children[self.node].items(): + # EOT edges are tuples: (eot_sentinel, token_id) + if isinstance(key, tuple) and key[0] == eot_sentinel: + token_id = key[1] + results.append((node, token_id)) + return results def __lshift__(self, b): """Transitions to a new state by consuming a byte. @@ -131,23 +146,29 @@ def __lshift__(self, b): terminated=b == EOS, ) - def extend(self): - """Extends current state by consuming an end-of-token if possible. + def extend_all(self): + """Extends current state by consuming an end-of-token, returning all possible extensions. + + With duplicate tokens (multiple token IDs with the same byte string), there can be + multiple valid extensions at the same position. Each extension corresponds to a + different token being committed, which affects future LM predictions. Returns: - (LazyTrieState|None): New state after consuming EOT, or None if not possible + list[LazyTrieState]: List of new states after consuming EOT, one per possible token. + Empty list if no EOT edges exist. """ if self._extend is None: - if (eot_node := self.get_EOT()) is not None: - mass = self.mass - self._extend = LazyTrieState( - lm_state=self.lm_state - << int(self.trie.trie.leaf2token_id[eot_node]), + extensions = [] + mass = self.mass + for eot_node, token_id in self.get_all_EOT(): + extensions.append(LazyTrieState( + lm_state=self.lm_state << int(token_id), trie=self.trie, node=self.root, weight=self.weight + mass[eot_node] - mass[self.node], mode=self.mode, - ) + )) + self._extend = extensions return self._extend @cached_property @@ -161,8 +182,25 @@ def logp_next(self): mass = self.mass logZ = mass[self.node] - for byte, node in self.actions().items(): - logps[byte if byte is not None else 256] = mass[node] - logZ + for key, node in self.actions().items(): + # Handle different edge types: + # - tuple: (eot_sentinel, token_id) for EOT edges to leaves + # - int 0-255: byte transitions + # - int 257: EOS transition + if isinstance(key, tuple): + # EOT edge - use index 256 + # For duplicates, sum their masses using logaddexp + if logps[256] == -np.inf: + logps[256] = mass[node] - logZ + else: + logps[256] = np.logaddexp(logps[256], mass[node] - logZ) + elif isinstance(key, int): + logps[key] = mass[node] - logZ + else: # pragma: no cover + raise ValueError( + f"Unexpected edge key type: {type(key).__name__} (value: {key!r}). " + f"Expected tuple (EOT edge) or int (byte/EOS transition)." + ) return LazyByteProbs(logps) @@ -181,7 +219,7 @@ async def materialize(self): self._mass = mass.cpu().numpy() return self - def __repr__(self): + def __repr__(self): # pragma: no cover context = colors.green % ("|" + escape(bytes(self.partial))) if self.terminated: context += colors.yellow % "" diff --git a/genlm/bytes/trie.py b/genlm/bytes/trie.py index 842d150..90e5de8 100644 --- a/genlm/bytes/trie.py +++ b/genlm/bytes/trie.py @@ -1,10 +1,13 @@ import torch import asyncio import logging +import warnings import numpy as np from enum import Enum from collections import defaultdict +from genlm.backend.tokenization import Token + EOS = 257 logger = logging.getLogger(__name__) @@ -17,41 +20,64 @@ class TrieMode(Enum): class TokenByteTrie: - """A trie data structure for efficient token-to-byte mapping.""" + """A trie data structure for efficient token-to-byte mapping. + + Requires Token objects (from genlm.backend.tokenization) which allow handling + models with duplicate byte strings (multiple token IDs mapping to the same bytes). + """ def __init__( self, decode, device=None, - atomic_tokens=None, - eot_token=None, - eos_tokens=None, + atomic_byte_strings=None, + eot_sentinel=None, + eos_byte_strings=None, max_batch_size=64, ): """Initialize a `TokenByteTrie`. Args: - decode (list[bytes]): List representing the token vocabulary. + decode (list[Token]): List of Token objects representing the token vocabulary. + Each Token must have both token_id and byte_string attributes. device (str, optional): Device to use for weight sum and max computations ('cpu' or 'cuda'). - atomic_tokens (list[bytes], optional): List of tokens that should be treated as atomic units rather than being split into bytes. - eot_token (bytes|None, optional): End-of-token token. Default is None, which represents EOT as None. - eos_tokens (set[bytes], optional): Set of tokens that should be treated as EOS (End of Sequence). + atomic_byte_strings (list[bytes], optional): List of byte strings that should be treated as atomic units rather than being split into individual bytes. + eot_sentinel (bytes|None, optional): End-of-token sentinel value. Default is None, which represents EOT as None. + eos_byte_strings (set[bytes], optional): Set of tokens that should be treated as EOS (End of Sequence). max_batch_size (int, optional): Maximum batch size for weight sum sparse matrix multiplication. """ + if not decode: + raise ValueError("decode cannot be empty") + if Token.is_plain_bytes(decode[0]): + warnings.warn( + "Passing plain bytes to TokenByteTrie is deprecated. " + "Use Token objects from decode_vocab() instead.", + DeprecationWarning, + stacklevel=2, + ) + decode = [Token(token_id=i, byte_string=b) for i, b in enumerate(decode)] + elif not isinstance(decode[0], Token): + raise TypeError( + f"decode must contain Token objects, got {type(decode[0]).__name__}. " + f"Use genlm.backend.tokenization.decode_vocab() to get Token objects." + ) + self.decode = decode + self._byte_decode = [t.byte_string for t in decode] self.max_batch_size = max_batch_size self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") if self.device not in ["cpu", "cuda"]: raise ValueError(f"Invalid device: {device}. Must be 'cpu', 'cuda' or None") - self.eot_token = eot_token - self.eos_tokens = set(eos_tokens or []) + self.eot_sentinel = eot_sentinel + self.eos_byte_strings = set(eos_byte_strings or []) self.eos_token_ids = [ - i for i, token in enumerate(decode) if token in self.eos_tokens + token.token_id for token in self.decode + if token.byte_string in self.eos_byte_strings ] - self._build_trie(atomic_tokens or []) + self._build_trie(atomic_byte_strings or []) self._renumber() self._build_node2prefix() self._build_reachability_matrix() @@ -59,19 +85,25 @@ def __init__( self.token_id_to_leaf[:, 0], dtype=torch.long, device=self.device ) - def _build_trie(self, atomic_tokens): + def _build_trie(self, atomic_byte_strings): """Builds a trie data structure from the vocabulary. + Handles duplicate byte strings by using (byte_string, token_id) as keys. + Each token gets its own leaf node, even if multiple tokens share the same bytes. + Returns: (dict): A dictionary where keys are token IDs and values are lists of characters. """ - for token in atomic_tokens: - if token not in self.decode: - raise ValueError(f"Atomic token {token} not in vocabulary") + # Check atomic_byte_strings against byte representations + byte_set = set(self._byte_decode) + for bs in atomic_byte_strings: + if bs not in byte_set: + raise ValueError(f"Atomic byte string {bs!r} not in vocabulary") - for token in self.eos_tokens: - if token not in self.decode: - raise ValueError(f"EOS token {token} not in vocabulary") + # Check eos_byte_strings against byte representations + for bs in self.eos_byte_strings: + if bs not in byte_set: + raise ValueError(f"EOS byte string {bs!r} not in vocabulary") self.word2leaf = {} self.children = [{}] # First node is root @@ -79,24 +111,34 @@ def _build_trie(self, atomic_tokens): self.token_id_to_leaf = [] self.lookup = {} - for token_id, word in enumerate(self.decode): - if word in self.lookup: - raise ValueError(f"Duplicate word in vocabulary: {word}") - self.lookup[word] = token_id + for token in self.decode: + token_id = token.token_id + word = token.byte_string + + # Use (word, token_id) as lookup key to allow duplicates + lookup_key = (word, token_id) + if lookup_key in self.lookup: # pragma: no cover + # This should never happen since Token objects have unique token_ids + raise ValueError(f"Duplicate token in vocabulary: {token}, lookup_key: {lookup_key}") + self.lookup[lookup_key] = token_id # Build ALL tokens in trie (including EOS tokens for conditioning mode) curr = self.root - letters = [word] if word in atomic_tokens else word + letters = [word] if word in atomic_byte_strings else word for letter in letters: if letter not in self.children[curr]: self.children[curr][letter] = len(self.children) self.children.append({}) curr = self.children[curr][letter] - self.children[curr][self.eot_token] = last = len(self.children) + # Each token gets its own leaf, using (eot_sentinel, token_id) as edge key + # This allows multiple tokens with the same byte_string to have separate leaves + leaf_edge_key = (self.eot_sentinel, token_id) + self.children[curr][leaf_edge_key] = last = len(self.children) self.children.append({}) - assert word not in self.word2leaf - self.word2leaf[word] = last + + # Use (word, token_id) as key in word2leaf to handle duplicates + self.word2leaf[(word, token_id)] = last self.token_id_to_leaf.append((token_id, last)) self.eos_node = len(self.children) @@ -130,7 +172,11 @@ def _order(self, node): int: Node indices in topological order """ for a in self.children[node]: - if a is not None: + # Skip leaf edges (tuples like (eot_sentinel, token_id)) from ordering + # but include all other edges including EOS (257) + if isinstance(a, tuple): + pass # Skip leaf edges in ordering + else: yield from self._order(self.children[node][a]) yield node @@ -189,11 +235,14 @@ def _build_node2prefix(self): node2prefix = {self.root: []} for x in reversed(range(len(self.children))): for letter, y in self.children[x].items(): - if letter is None: + # Handle leaf edges: (eot_sentinel, token_id) tuples + if isinstance(letter, tuple): + # This is a leaf edge, prefix stays the same node2prefix[y] = node2prefix[x] elif isinstance(letter, bytes): node2prefix[y] = node2prefix[x] + list(letter) else: + # Regular byte transition (int) node2prefix[y] = node2prefix[x] + [letter] self.node2prefix = node2prefix @@ -229,11 +278,12 @@ def _build_reachability_matrix(self): for i, node in enumerate(leaf_indices): token_id = self.token_id_to_leaf[i, 0] token = self.decode[token_id] + token_bytes = token.byte_string # self-connection rows_no_eos.append(i) cols_no_eos.append(node) - if token not in self.eos_tokens: + if token_bytes not in self.eos_byte_strings: rows_with_eos.append(i) cols_with_eos.append(node) else: @@ -248,7 +298,7 @@ def _build_reachability_matrix(self): ancestor = parent[current] rows_no_eos.append(i) cols_no_eos.append(ancestor) - if token not in self.eos_tokens: + if token_bytes not in self.eos_byte_strings: rows_with_eos.append(i) cols_with_eos.append(ancestor) current = ancestor @@ -464,10 +514,13 @@ def visualize(self, ws=None): for node_id, children in enumerate(self.children): for char, child_id in children.items(): - if char is not None: - edge_label = str(char) + # Handle leaf edges: (eot_sentinel, token_id) tuples + if isinstance(char, tuple): + _, token_id = char + edge_label = f"EOT (ID: {token_id})" else: - edge_label = "End-of-Token" + # Regular byte transition (int) or EOS + edge_label = str(char) dot.edge(str(node_id), str(child_id), label=edge_label) @@ -499,9 +552,10 @@ def from_vocab(cls, vocab, **kwargs): """Creates an `AsyncTokenByteTrie` from a vocabulary. Args: - vocab (list): The vocabulary over which the trie will be defined. + vocab (list[Token]): List of Token objects representing the vocabulary. + Use genlm.backend.tokenization.decode_vocab() to get Token objects from a tokenizer. **kwargs (dict): Additional arguments passed to the trie constructor. - Can include 'eos_tokens' for EOS support. + Can include 'eos_byte_strings' for EOS support. Returns: (AsyncTokenByteTrie): The initialized asynchronous trie instance. @@ -595,7 +649,7 @@ async def _background_loop(self): ) # pragma: no cover # MAX operations don't need mode, so use the original batch_weight_max results = self.trie.batch_weight_max(ws_list) - else: + else: # pragma: no cover raise ValueError(f"Unknown trie operation: {op}") for future, result in zip(futures, results): diff --git a/pyproject.toml b/pyproject.toml index 672d467..838d90b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ readme = "README.md" requires-python = ">=3.11" authors = [ { name = "Ben LeBrun", email = "benlebrun1@gmail.com" }, { name = "Tim Vieira"} ] dependencies = [ - "genlm-backend>=0.1.1", + "genlm-backend>=0.2.0", "arsenal", "IPython", ] diff --git a/tests/test_beam.py b/tests/test_beam.py index e024b6a..863d87e 100644 --- a/tests/test_beam.py +++ b/tests/test_beam.py @@ -6,6 +6,14 @@ from genlm.bytes.trie import EOS +def find_token_id_by_bytes(byte_vocab, target_bytes): + """Find the token ID for a given byte string in a list of Token objects.""" + for token in byte_vocab: + if token.byte_string == target_bytes: + return token.token_id + raise ValueError(f"{target_bytes} is not in byte_vocab") + + @pytest.fixture(scope="module") def llm(): return load_model_by_name("gpt2-medium", backend="hf") @@ -82,17 +90,30 @@ def test_invalid_prune_threshold(): BeamParams(K=1, prune_threshold=-0.1) +def test_beam_params_eos_tokens_deprecation(): + """Test that the deprecated eos_tokens kwarg works and warns.""" + with pytest.warns(DeprecationWarning, match="eos_tokens.*deprecated"): + params = BeamParams(K=3, eos_tokens=[b".", b"!"]) + assert params.eos_byte_strings == {b".", b"!"} + + +def test_beam_params_eos_tokens_and_byte_strings_conflict(): + """Test that specifying both eos_tokens and eos_byte_strings raises.""" + with pytest.raises(TypeError, match="Cannot specify both"): + BeamParams(K=3, eos_byte_strings=[b"."], eos_tokens=[b"!"]) + + # EOS-specific tests @pytest.mark.asyncio async def test_eos_manual_configuration(llm): """Test manual EOS token configuration.""" manual_eos = [b".", b"!", b"?"] - params = BeamParams(K=3, eos_tokens=manual_eos) + params = BeamParams(K=3, eos_byte_strings=manual_eos) state = await ByteBeamState.initial(llm, params) try: for state in state.states: - assert state.trie.trie.eos_tokens == set(manual_eos) + assert state.trie.trie.eos_byte_strings == set(manual_eos) assert state.trie.trie.eos_node is not None finally: @@ -102,12 +123,12 @@ async def test_eos_manual_configuration(llm): @pytest.mark.asyncio async def test_eos_disabled(llm): """Test EOS functionality disabled.""" - params = BeamParams(K=3, eos_tokens=set()) # Empty set = no EOS + params = BeamParams(K=3, eos_byte_strings=set()) # Empty set = no EOS state = await ByteBeamState.initial(llm, params) try: # Check that no EOS tokens were configured - assert not any(state.trie.trie.eos_tokens for state in state.states) + assert not any(state.trie.trie.eos_byte_strings for state in state.states) # check that EOS isn't available logp_next = await state.logp_next() @@ -122,14 +143,14 @@ async def test_eos_disabled(llm): @pytest.mark.asyncio async def test_eos_termination(llm): """Test that EOS byte terminates sequences properly.""" - params = BeamParams(K=3, eos_tokens=[b"!"]) + params = BeamParams(K=3, eos_byte_strings=[b"!"]) state = await ByteBeamState.initial(llm, params) try: new_state = await (state << EOS) assert all(state.terminated for state in new_state.states) - eos_token_id = llm.byte_vocab.index(b"!") + eos_token_id = find_token_id_by_bytes(llm.byte_vocab, b"!") lm_context = [llm.tokenizer.eos_token_id] target_weight = (await llm.next_token_logprobs(lm_context))[eos_token_id] @@ -143,14 +164,14 @@ async def test_eos_termination(llm): @pytest.mark.asyncio async def test_can_generate_with_eos_in_prompt(llm): - params = BeamParams(K=10, eos_tokens=[b"\n", b"\n\n"]) + params = BeamParams(K=10, eos_byte_strings=[b"\n", b"\n\n"]) state = await ByteBeamState.initial(llm, params) try: for trie_state in state.states: trie = trie_state.trie.trie - assert b"\n" in trie.eos_tokens - assert b"\n\n" in trie.eos_tokens + assert b"\n" in trie.eos_byte_strings + assert b"\n\n" in trie.eos_byte_strings # Test prefill with model EOS token (conditioning mode) context_with_eos = b"Hello world" + b"\n" + b" This continues." @@ -189,14 +210,14 @@ async def test_can_generate_with_eos_in_prompt(llm): async def test_eos_logp_next_probability_sum(llm): """Test that EOS probability in logp_next equals sum of specified EOS token probabilities.""" - eos_tokens = [b".", b"\n", b"\n\n"] - params = BeamParams(K=5, eos_tokens=eos_tokens) + eos_byte_strings = [b".", b"\n", b"\n\n"] + params = BeamParams(K=5, eos_byte_strings=eos_byte_strings) beam = await ByteBeamState.initial(llm, params) try: first_state = beam.states[0] logps = await first_state.lm_state.logp_next() - eos_token_ids = [llm.byte_vocab.index(t) for t in eos_tokens] + eos_token_ids = [find_token_id_by_bytes(llm.byte_vocab, t) for t in eos_byte_strings] logps_eos = torch.logsumexp(logps[eos_token_ids], dim=0) logp_next = await beam.logp_next() @@ -205,3 +226,122 @@ async def test_eos_logp_next_probability_sum(llm): np.testing.assert_allclose(eos_logp, logps_eos, rtol=1e-5) finally: await beam.cleanup() + + +@pytest.mark.asyncio +async def test_trie_state_mass_not_materialized(llm): + """Test that accessing mass before materializing raises an error.""" + from genlm.bytes.byte_lm.trie_state import LazyTrieState + from genlm.bytes.trie import AsyncTokenByteTrie + + eos_token = llm.byte_vocab[llm.tokenizer.eos_token_id].byte_string + trie = AsyncTokenByteTrie.from_vocab(llm.byte_vocab, eos_byte_strings={eos_token}) + + try: + # Create a state without materializing + state = LazyTrieState.initial(llm, trie) + + # Accessing mass before materialize should raise + with pytest.raises(ValueError, match="not yet materialized"): + _ = state.mass + finally: + await trie.cleanup() + + +@pytest.mark.asyncio +async def test_trie_state_lshift_terminated(llm): + """Test that lshift on terminated state returns None.""" + eos_token = llm.byte_vocab[llm.tokenizer.eos_token_id].byte_string + params = BeamParams(K=3, eos_byte_strings=[eos_token]) + beam = await ByteBeamState.initial(llm, params) + + try: + # Prefill and get a state + beam = await beam.prefill(b"Hello") + state = beam.states[0] + + # Manually set terminated to True to test the branch + state.terminated = True + + # lshift on terminated state should return None + result = state << ord("a") + assert result is None + finally: + await beam.cleanup() + + +def test_lm_state_max_context_length(llm): + """Test that StatefulTokenizedLM truncates context when max_context_length is reached.""" + from genlm.bytes.byte_lm.lm_state import StatefulTokenizedLM + + # Create a state with max_context_length=3 and context already at limit + # This tests the truncation branch + state = StatefulTokenizedLM.initial(llm, initial_context=[1, 2, 3], max_context_length=3) + assert len(state.context) == 3 + + # Adding a token should trigger truncation: [1, 2, 3] -> [2, 3] -> [2, 3, 4] + new_state = state << 4 + # The truncation happens on the original state before creating new one + # New state context = truncated_context + [new_token] = [2, 3] + [4] = [2, 3, 4] + assert len(new_state.context) == 3 + assert new_state.context == [2, 3, 4] + + +@pytest.mark.asyncio +async def test_logp_next_with_duplicate_eot_edges(): + """Test that logp_next correctly aggregates probabilities for duplicate EOT edges.""" + import numpy as np + from genlm.backend.tokenization import Token + from genlm.bytes.trie import AsyncTokenByteTrie + from genlm.bytes.byte_lm.trie_state import LazyTrieState + + # Create vocab with duplicate byte strings (same prefix leads to multiple EOT edges) + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"a"), # Duplicate - same byte string as token 0 + Token(token_id=2, byte_string=b"b"), + ] + + trie = AsyncTokenByteTrie.from_vocab(vocab) + try: + # Create mock lm_state and mass + class MockLMState: + async def logp_next(self): + # Return log probs for 3 tokens as tensor + return torch.log(torch.tensor([0.3, 0.4, 0.3])) + + lm_state = MockLMState() + + # Create LazyTrieState at root + state = LazyTrieState( + lm_state=lm_state, + trie=trie, + node=trie.trie.root, + weight=0.0, + mode=None, + ) + + # Materialize to get masses + state = await state.materialize() + + # Advance to "a" node where both tokens 0 and 1 have EOT edges + advanced_state = state << ord("a") + assert advanced_state is not None + + # Materialize the advanced state to have masses + advanced_state = await advanced_state.materialize() + + # Access logp_next - this should trigger the logaddexp branch + # because both token 0 and 1 are EOT edges at this position + logps = advanced_state.logp_next + + # The EOT probability (index 256) should be valid (not -inf) + # indicating that both duplicate EOT edges contributed via logaddexp + eot_logp = logps[256] + assert eot_logp > -np.inf, "EOT logp should be valid when duplicate EOT edges exist" + + # Verify we're at a position with multiple EOT edges (the duplicate case) + eot_edges = advanced_state.get_all_EOT() + assert len(eot_edges) == 2, f"Expected 2 EOT edges for duplicates, got {len(eot_edges)}" + finally: + await trie.cleanup() diff --git a/tests/test_eos_logic.py b/tests/test_eos_logic.py index 50e8827..c0e3d93 100644 --- a/tests/test_eos_logic.py +++ b/tests/test_eos_logic.py @@ -4,23 +4,36 @@ from genlm.bytes.trie import TokenByteTrie, EOS from genlm.bytes.byte_lm.trie_state import TrieMode +from genlm.backend.tokenization import Token + + +def find_token_id_by_bytes(decode, target_bytes): + """Find the first token ID for a given byte string in a list of Token objects. + + Note: Returns only the first match if multiple tokens share the same byte string. + This is fine for these tests since the test vocabularies have unique byte strings. + """ + for token in decode: + if token.byte_string == target_bytes: + return token.token_id + raise ValueError(f"{target_bytes} is not in decode") @pytest.fixture(scope="module") def eos_trie(): """Provides a pre-configured trie with multiple EOS tokens.""" vocab = [ - b"hello", # 0 - b"world", # 1 - b"!", # 2 (EOS) - b"!!", # 3 - b".", # 4 (EOS) - b"normal", # 5 - b"end", # 6 - b"", # 7 (EOS) + Token(token_id=0, byte_string=b"hello"), + Token(token_id=1, byte_string=b"world"), + Token(token_id=2, byte_string=b"!"), # EOS + Token(token_id=3, byte_string=b"!!"), + Token(token_id=4, byte_string=b"."), # EOS + Token(token_id=5, byte_string=b"normal"), + Token(token_id=6, byte_string=b"end"), + Token(token_id=7, byte_string=b""), # EOS ] - eos_tokens = [b"!", b".", b""] - return TokenByteTrie(decode=vocab, eos_tokens=eos_tokens) + eos_byte_strings = [b"!", b".", b""] + return TokenByteTrie(decode=vocab, eos_byte_strings=eos_byte_strings) def test_trie_structure(eos_trie: TokenByteTrie): @@ -33,8 +46,10 @@ def test_trie_structure(eos_trie: TokenByteTrie): assert eos_trie.children[eos_trie.root].get(EOS) == eos_trie.eos_node # The original EOS tokens should still exist as leaf nodes in the trie for conditioning - for token in eos_trie.eos_tokens: - assert token in eos_trie.word2leaf + # Now word2leaf uses (bytes, token_id) as keys + for eos_token_bytes in eos_trie.eos_byte_strings: + token_id = find_token_id_by_bytes(eos_trie.decode, eos_token_bytes) + assert (eos_token_bytes, token_id) in eos_trie.word2leaf def test_without_eos_mode_mass_distribution(eos_trie: TokenByteTrie): @@ -49,9 +64,9 @@ def test_without_eos_mode_mass_distribution(eos_trie: TokenByteTrie): node_for_exclamation = eos_trie.children[eos_trie.root][ord("!")] # The mass at this node should be the sum of probabilities of all tokens starting with "!", including "!" itself. - expected_mass = ( - weights[eos_trie.decode.index(b"!")] + weights[eos_trie.decode.index(b"!!")] - ) # P("!") + P("!!") + idx_excl = find_token_id_by_bytes(eos_trie.decode, b"!") + idx_excl2 = find_token_id_by_bytes(eos_trie.decode, b"!!") + expected_mass = weights[idx_excl] + weights[idx_excl2] # P("!") + P("!!") assert np.isclose(masses[node_for_exclamation].item(), expected_mass.item()) # The EOS node should have zero mass in no_eos mode @@ -70,7 +85,8 @@ def test_with_eos_mode_mass_distribution(eos_trie: TokenByteTrie): node_for_exclamation = eos_trie.children[eos_trie.root][ord("!")] # The mass at this node should ONLY be the sum of non-EOS tokens starting with "!" - expected_mass = weights[eos_trie.decode.index(b"!!")] # Only P("!!") + idx_excl2 = find_token_id_by_bytes(eos_trie.decode, b"!!") + expected_mass = weights[idx_excl2] # Only P("!!") assert np.isclose(masses[node_for_exclamation].item(), expected_mass.item()) @@ -83,11 +99,10 @@ def test_with_eos_mode_eos_node_aggregation(eos_trie: TokenByteTrie): masses = eos_trie.weight_sum(weights, mode=TrieMode.WITH_EOS) # The mass of the EOS node should be the sum of all defined EOS tokens - expected_eos_mass = ( - weights[eos_trie.decode.index(b"!")] - + weights[eos_trie.decode.index(b".")] - + weights[eos_trie.decode.index(b"")] - ) + idx_excl = find_token_id_by_bytes(eos_trie.decode, b"!") + idx_dot = find_token_id_by_bytes(eos_trie.decode, b".") + idx_eos = find_token_id_by_bytes(eos_trie.decode, b"") + expected_eos_mass = weights[idx_excl] + weights[idx_dot] + weights[idx_eos] actual_eos_mass = masses[eos_trie.eos_node] assert np.isclose(actual_eos_mass.item(), expected_eos_mass.item()) diff --git a/tests/test_healing.py b/tests/test_healing.py index cfc4590..f760092 100644 --- a/tests/test_healing.py +++ b/tests/test_healing.py @@ -2,7 +2,11 @@ import numpy as np from genlm.backend import load_model_by_name +from genlm.backend.tokenization import Token from genlm.bytes import ByteBeamState, BeamParams +from genlm.bytes.trie import AsyncTokenByteTrie +from genlm.bytes.byte_lm.heal import TokenHealer +from genlm.bytes.byte_lm.trie_state import LazyTrieState TEXT = ". Boulter starred in the 2011 film Mercenaries directed by Paris Leonti ." @@ -25,12 +29,13 @@ async def _advance_bytes( llm, text: str, heal: bool, heal_max_backoff=None, heal_max_splits=None ): """Helper to advance through text bytes and check if healing works.""" - eos_token = llm.byte_vocab[llm.tokenizer.eos_token_id] + # byte_vocab contains Token objects - get the byte_string for eos_byte_strings + eos_token = llm.byte_vocab[llm.tokenizer.eos_token_id].byte_string beam = await ByteBeamState.initial( llm, BeamParams( K=1, - eos_tokens=[eos_token], + eos_byte_strings=[eos_token], heal=heal, heal_max_backoff=heal_max_backoff, heal_max_splits=heal_max_splits, @@ -150,12 +155,12 @@ async def logp_next(self): @pytest.mark.asyncio async def test_healer_with_custom_trie_path_not_found(): """Test healing when partial path doesn't exist""" - from genlm.bytes.trie import AsyncTokenByteTrie - from genlm.bytes.byte_lm.heal import TokenHealer - from genlm.bytes.byte_lm.trie_state import LazyTrieState - # Simple vocab - vocab = [b"a", b"ab", b"x"] + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"ab"), + Token(token_id=2, byte_string=b"x"), + ] async_trie = AsyncTokenByteTrie.from_vocab(vocab, device="cpu") lm_state = MinimalLMState(vocab_size=len(vocab)) @@ -198,12 +203,12 @@ def __init__(self, real_state): @pytest.mark.asyncio async def test_healer_with_custom_trie_cant_extend(): """Test when extend fails - no EOT at current position""" - from genlm.bytes.trie import AsyncTokenByteTrie - from genlm.bytes.byte_lm.heal import TokenHealer - from genlm.bytes.byte_lm.trie_state import LazyTrieState - # Vocab where "ab" exists but NOT "a" - so after consuming 'a' there's no EOT - vocab = [b"ab", b"x", b"y"] + vocab = [ + Token(token_id=0, byte_string=b"ab"), + Token(token_id=1, byte_string=b"x"), + Token(token_id=2, byte_string=b"y"), + ] async_trie = AsyncTokenByteTrie.from_vocab(vocab, device="cpu") lm_state = MinimalLMState(vocab_size=len(vocab)) @@ -252,13 +257,13 @@ def __init__(self, real_state, partial_bytes): @pytest.mark.asyncio async def test_healer_with_custom_trie_cant_consume_after_extend(): """Test when byte can't be consumed even after extend""" - from genlm.bytes.trie import AsyncTokenByteTrie - from genlm.bytes.byte_lm.heal import TokenHealer - from genlm.bytes.byte_lm.trie_state import LazyTrieState - # Vocab: "a", "ab" - 'a' exists so we CAN extend after consuming 'a' # But 'z' isn't in trie at all - vocab = [b"a", b"ab", b"x"] + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"ab"), + Token(token_id=2, byte_string=b"x"), + ] async_trie = AsyncTokenByteTrie.from_vocab(vocab, device="cpu") lm_state = MinimalLMState(vocab_size=len(vocab)) @@ -296,17 +301,148 @@ def __init__(self, real_state, partial_bytes): assert result is None +def find_eot_edge(children, eot_sentinel): + """Find an EOT edge in children dict. Returns (node, token_id) or (None, None).""" + for key, node in children.items(): + if isinstance(key, tuple) and key[0] == eot_sentinel: + return node, key[1] + return None, None + + +def find_all_eot_edges(children, eot_sentinel): + """Find all EOT edges in children dict. Returns list of (node, token_id).""" + results = [] + for key, node in children.items(): + if isinstance(key, tuple) and key[0] == eot_sentinel: + results.append((node, key[1])) + return results + + +@pytest.mark.asyncio +async def test_healer_with_duplicate_tokens(): + """Test healing when there are duplicate tokens (multiple EOT edges at same position). + + This tests the scenario where multiple token IDs decode to the same byte string. + The healer should try all possible EOT edges until one leads to a successful path. + """ + # Vocab with duplicate tokens: both token 0 and token 1 decode to "a" + # Token 2 = "x" for the next byte we want to consume + vocab = [ + Token(token_id=0, byte_string=b"a"), # First "a" + Token(token_id=1, byte_string=b"a"), # Duplicate "a" + Token(token_id=2, byte_string=b"x"), + ] + async_trie = AsyncTokenByteTrie.from_vocab(vocab, device="cpu") + + # Verify the trie has two EOT edges at the node after 'a' + trie = async_trie.trie + node_after_a = trie.children[trie.root].get(ord("a")) + assert node_after_a is not None, "Should have node after 'a'" + + eot_edges = find_all_eot_edges(trie.children[node_after_a], trie.eot_sentinel) + assert len(eot_edges) == 2, f"Expected 2 EOT edges for duplicate 'a', got {len(eot_edges)}" + + lm_state = MinimalLMState(vocab_size=len(vocab)) + state = LazyTrieState( + lm_state=lm_state, + trie=async_trie, + node=async_trie.trie.root, + weight=0.0, + mass=None, + mode="without_eos", + terminated=False, + ) + state = await state.materialize() + + # Consume 'a' to get to a state where we have partial="a" + state_after_a = state << ord("a") + assert state_after_a is not None + + # Try to consume 'x' - should fail normally since 'x' is not a continuation of 'a' + cant_continue = state_after_a << ord("x") + assert cant_continue is None, "Should not be able to consume 'x' after 'a'" + + # Heal to consume 'x' - should succeed by committing one of the "a" tokens + healer = TokenHealer(verbose=True) + healed = await healer.try_heal(state_after_a, next_byte=ord("x")) + assert healed is not None, "Healing should succeed with duplicate tokens" + + # Verify we're now at partial containing 'x' + assert healed.partial == [ord("x")], f"Expected partial [120], got {healed.partial}" + + +@pytest.mark.asyncio +async def test_healer_extend_all_with_duplicates(): + """Test that extend_all is used correctly during healing replay. + + When stuck during replay and extend_all returns multiple extensions, + healing should try all of them. + """ + # Vocab: + # - Token 0 = "ab" (first) + # - Token 1 = "ab" (duplicate) + # - Token 2 = "a" (partial match) + # - Token 3 = "x" + # + # Scenario: partial="aba", trying to consume 'x' + # k=2: commit "ab" at position 2 + # - replay 'a' -> at node after 'a' + # - replay 'x' -> can't continue, need extend + # - extend gives us duplicate "ab" tokens (but we only have "a" partial at this point) + # Actually this is getting complex. Let's simplify. + + # Simpler scenario: + # - Token 0 = "a" (first) + # - Token 1 = "a" (duplicate) + # - Token 2 = "ab" + # - Token 3 = "x" + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"a"), # duplicate + Token(token_id=2, byte_string=b"ab"), + Token(token_id=3, byte_string=b"x"), + ] + async_trie = AsyncTokenByteTrie.from_vocab(vocab, device="cpu") + + lm_state = MinimalLMState(vocab_size=len(vocab)) + state = LazyTrieState( + lm_state=lm_state, + trie=async_trie, + node=async_trie.trie.root, + weight=0.0, + mass=None, + mode="without_eos", + terminated=False, + ) + state = await state.materialize() + + # Consume "ab" to get partial="ab" + state = state << ord("a") + state = state << ord("b") + assert state is not None + + # Try to consume 'x' - should fail + cant_continue = state << ord("x") + assert cant_continue is None + + # Heal should work by committing "ab" (token 2) and then consuming 'x' + healer = TokenHealer(verbose=True) + healed = await healer.try_heal(state, next_byte=ord("x")) + assert healed is not None, "Healing should succeed" + assert healed.partial == [ord("x")] + + @pytest.mark.asyncio async def test_healer_with_custom_trie_final_extend(): """Test final extend path""" - from genlm.bytes.trie import AsyncTokenByteTrie - from genlm.bytes.byte_lm.heal import TokenHealer - from genlm.bytes.byte_lm.trie_state import LazyTrieState - # Vocab: "a", "ab", "x" # After consuming "ab", there's an EOT # next_byte 'x' is at root but NOT after "ab" - vocab = [b"a", b"ab", b"x"] + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"ab"), + Token(token_id=2, byte_string=b"x"), + ] async_trie = AsyncTokenByteTrie.from_vocab(vocab, device="cpu") lm_state = MinimalLMState(vocab_size=len(vocab)) @@ -352,12 +488,12 @@ async def test_healer_weight_calculation(): Verifies the healed state weight matches manually computed expected value. """ - from genlm.bytes.trie import AsyncTokenByteTrie - from genlm.bytes.byte_lm.heal import TokenHealer - from genlm.bytes.byte_lm.trie_state import LazyTrieState - # Vocab: token 0 = "a", token 1 = "ab", token 2 = "x" - vocab = [b"a", b"ab", b"x"] + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"ab"), + Token(token_id=2, byte_string=b"x"), + ] async_trie = AsyncTokenByteTrie.from_vocab(vocab, device="cpu") lm_state = MinimalLMState(vocab_size=len(vocab)) @@ -404,7 +540,7 @@ async def test_healer_weight_calculation(): # Find the EOT node for "a" node_after_a = trie.children[trie.root].get(ord("a")) - eot_node_for_a = trie.children[node_after_a].get(trie.eot_token) + eot_node_for_a, _ = find_eot_edge(trie.children[node_after_a], trie.eot_sentinel) assert eot_node_for_a is not None # base_weight undoes the path from root to node_after_a diff --git a/tests/test_trie.py b/tests/test_trie.py index 7cac103..6023ec2 100644 --- a/tests/test_trie.py +++ b/tests/test_trie.py @@ -5,6 +5,7 @@ from transformers import AutoTokenizer from genlm.backend.llm import MockAsyncLM +from genlm.backend.tokenization import Token from genlm.bytes import TokenByteTrie, AsyncTokenByteTrie from genlm.bytes.byte_lm.trie_state import TrieMode @@ -13,7 +14,12 @@ @pytest.fixture() def decode(): - return [b"a", b"b", b"ab", b""] + return [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"b"), + Token(token_id=2, byte_string=b"ab"), + Token(token_id=3, byte_string=b""), + ] @pytest.fixture(scope="module") @@ -23,20 +29,23 @@ def mock_llm(): @st.composite def tokens_and_weights(draw, n_weights): - vocab = draw( + byte_vocab = draw( st.lists( st.binary(min_size=1, max_size=5), min_size=1, max_size=10, unique=True ) ) # Ensure we have at least two tokens with a shared prefix. - for token in vocab: + for token in byte_vocab: if len(token) > 1: new_token = token[:-1] - if new_token not in vocab: - vocab.append(new_token) + if new_token not in byte_vocab: + byte_vocab.append(new_token) break + # Convert to Token objects + vocab = [Token(token_id=i, byte_string=b) for i, b in enumerate(byte_vocab)] + weights = [] for _ in range(n_weights): weights.append( @@ -57,13 +66,18 @@ def make_wants(trie, weights, op, f): leaf_wants = {} for token, weight in zip(trie.decode, weights): - assert token in trie.word2leaf - leaf_wants[token] = weight + # Token objects: use (byte_string, token_id) as key for word2leaf + token_bytes = token.byte_string + token_id = token.token_id + word2leaf_key = (token_bytes, token_id) + assert word2leaf_key in trie.word2leaf, f"Key {word2leaf_key} not in word2leaf" + leaf_wants[word2leaf_key] = weight internal_wants = {} for token, weight in zip(trie.decode, weights): - for i in range(len(token) + 1): - prefix = token[:i] + token_bytes = token.byte_string + for i in range(len(token_bytes) + 1): + prefix = token_bytes[:i] if prefix not in internal_wants: internal_wants[f(prefix)] = weight else: @@ -86,23 +100,28 @@ def assert_weights_close(trie, leaf_wants, internal_wants, haves, f): want = internal_wants[f(prefix)] assert np.isclose(have, want, rtol=1e-5, atol=1e-8), [have, want, prefix] - for word in trie.decode: - assert word in trie.word2leaf - node = trie.word2leaf[word] + for token in trie.decode: + # Token objects: use (byte_string, token_id) as key for word2leaf + token_bytes = token.byte_string + token_id = token.token_id + word2leaf_key = (token_bytes, token_id) + assert word2leaf_key in trie.word2leaf + node = trie.word2leaf[word2leaf_key] have = haves[node] - want = leaf_wants[word] - assert np.isclose(have, want, rtol=1e-5, atol=1e-8), [have, want, word] + want = leaf_wants[word2leaf_key] + assert np.isclose(have, want, rtol=1e-5, atol=1e-8), [have, want, token_bytes] def test_weight_sum_single(decode): trie = TokenByteTrie(decode=decode) haves = trie.weight_sum(torch.tensor([0.1, 0.2, 0.2, 0.5])) + # leaf_wants now uses (bytes, token_id) keys leaf_wants = { - b"a": 0.1, - b"b": 0.2, - b"ab": 0.2, - b"": 0.5, + (b"a", 0): 0.1, + (b"b", 1): 0.2, + (b"ab", 2): 0.2, + (b"", 3): 0.5, } internal_wants = { b"": 1, @@ -120,14 +139,15 @@ def test_weight_sum_single(decode): def test_weight_sum_single_atomic(decode): - trie = TokenByteTrie(decode=decode, atomic_tokens=[b"ab"]) + trie = TokenByteTrie(decode=decode, atomic_byte_strings=[b"ab"]) haves = trie.weight_sum(torch.tensor([0.1, 0.2, 0.2, 0.5])) + # leaf_wants now uses (bytes, token_id) keys leaf_wants = { - b"a": 0.1, - b"b": 0.2, - b"ab": 0.2, - b"": 0.5, + (b"a", 0): 0.1, + (b"b", 1): 0.2, + (b"ab", 2): 0.2, + (b"", 3): 0.5, } internal_wants = { b"": 1, @@ -282,15 +302,19 @@ def test_visualize(decode): @pytest.mark.asyncio async def test_eos_token_configuration(): """Test EOS token configuration in trie.""" - vocab = [b"hello", b"world", b""] - eos_tokens = [b""] + vocab = [ + Token(token_id=0, byte_string=b"hello"), + Token(token_id=1, byte_string=b"world"), + Token(token_id=2, byte_string=b""), + ] + eos_byte_strings = [b""] # Test trie with EOS tokens - trie = TokenByteTrie(decode=vocab, eos_tokens=eos_tokens) + trie = TokenByteTrie(decode=vocab, eos_byte_strings=eos_byte_strings) - # EOS token should be in the eos_tokens set - assert b"" in trie.eos_tokens - assert len(trie.eos_tokens) == 1 + # EOS token should be in the eos_byte_strings set + assert b"" in trie.eos_byte_strings + assert len(trie.eos_byte_strings) == 1 # EOS token IDs should be populated assert len(trie.eos_token_ids) == 1 @@ -307,10 +331,14 @@ async def test_eos_token_configuration(): @pytest.mark.asyncio async def test_eos_dual_matrix_behavior(): """Test dual matrix behavior for propagate_eos vs no_eos modes.""" - vocab = [b"hello", b"world", b""] - eos_tokens = [b""] - - trie = TokenByteTrie(decode=vocab, eos_tokens=eos_tokens) + vocab = [ + Token(token_id=0, byte_string=b"hello"), + Token(token_id=1, byte_string=b"world"), + Token(token_id=2, byte_string=b""), + ] + eos_byte_strings = [b""] + + trie = TokenByteTrie(decode=vocab, eos_byte_strings=eos_byte_strings) weights = torch.tensor([0.3, 0.4, 0.3]) # hello, world, # Test no_eos mode (no EOS node mass) @@ -338,10 +366,14 @@ async def test_eos_dual_matrix_behavior(): @pytest.mark.asyncio async def test_eos_weight_sum_with_eos(): """Test weight_sum_with_eos method.""" - vocab = [b"hello", b"world", b""] - eos_tokens = [b""] - - trie = TokenByteTrie(decode=vocab, eos_tokens=eos_tokens) + vocab = [ + Token(token_id=0, byte_string=b"hello"), + Token(token_id=1, byte_string=b"world"), + Token(token_id=2, byte_string=b""), + ] + eos_byte_strings = [b""] + + trie = TokenByteTrie(decode=vocab, eos_byte_strings=eos_byte_strings) weights = torch.tensor([0.3, 0.4, 0.1]) # hello, world, # Test with_eos mode @@ -359,15 +391,20 @@ async def test_eos_weight_sum_with_eos(): @pytest.mark.asyncio async def test_eos_multiple_tokens(): """Test with multiple EOS tokens.""" - vocab = [b"hello", b"world", b"dog", b"dogs"] - eos_tokens = [b"dog", b"dogs"] + vocab = [ + Token(token_id=0, byte_string=b"hello"), + Token(token_id=1, byte_string=b"world"), + Token(token_id=2, byte_string=b"dog"), + Token(token_id=3, byte_string=b"dogs"), + ] + eos_byte_strings = [b"dog", b"dogs"] - trie = TokenByteTrie(decode=vocab, eos_tokens=eos_tokens) + trie = TokenByteTrie(decode=vocab, eos_byte_strings=eos_byte_strings) # Should have both EOS tokens - assert len(trie.eos_tokens) == 2 - assert b"dog" in trie.eos_tokens - assert b"dogs" in trie.eos_tokens + assert len(trie.eos_byte_strings) == 2 + assert b"dog" in trie.eos_byte_strings + assert b"dogs" in trie.eos_byte_strings # Should have both EOS token IDs assert len(trie.eos_token_ids) == 2 @@ -385,15 +422,116 @@ async def test_eos_multiple_tokens(): def test_invalid_device(): + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"b"), + Token(token_id=2, byte_string=b"c"), + ] with pytest.raises(ValueError): - TokenByteTrie(decode=["a", "b", "c"], device="invalid") + TokenByteTrie(decode=vocab, device="invalid") + +def test_plain_bytes_decode_deprecation(): + """Test that passing plain bytes to TokenByteTrie warns and converts.""" + with pytest.warns(DeprecationWarning, match="Passing plain bytes to TokenByteTrie is deprecated"): + trie = TokenByteTrie(decode=[b"a", b"b", b"ab"]) + assert all(isinstance(t, Token) for t in trie.decode) + assert trie.decode[0].token_id == 0 + assert trie.decode[0].byte_string == b"a" -def test_invalid_eos_tokens(): + +def test_invalid_eos_byte_strings(): + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"b"), + Token(token_id=2, byte_string=b"c"), + ] with pytest.raises(ValueError): - TokenByteTrie(decode=["a", "b", "c"], eos_tokens=["d"]) + TokenByteTrie(decode=vocab, eos_byte_strings=[b"d"]) -def test_invalid_atomic_tokens(): +def test_invalid_atomic_byte_strings(): + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"b"), + Token(token_id=2, byte_string=b"c"), + ] with pytest.raises(ValueError): - TokenByteTrie(decode=["a", "b", "c"], atomic_tokens=["d"]) + TokenByteTrie(decode=vocab, atomic_byte_strings=[b"d"]) + + +def test_duplicate_byte_strings_with_tokens(): + """Test that trie correctly handles multiple tokens with the same byte string.""" + # Create Token objects with duplicate byte strings + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"hello"), + Token(token_id=2, byte_string=b"hello"), # Duplicate byte string! + Token(token_id=3, byte_string=b"world"), + ] + + trie = TokenByteTrie(decode=vocab) + + # Verify that all tokens got their own leaf nodes + assert len(trie.token_id_to_leaf) == 4 + + # Get the leaf nodes for duplicate tokens + leaf_1 = trie.token_id_to_leaf[1][1] + leaf_2 = trie.token_id_to_leaf[2][1] + + # Should have different leaf nodes + assert leaf_1 != leaf_2, "Tokens with same byte_string should have different leaves" + + # Both should be valid leaf nodes + assert leaf_1 in trie.leaf2word.keys() + assert leaf_2 in trie.leaf2word.keys() + + +def test_duplicate_byte_strings_weight_sum(): + """Test that weight sums work correctly with duplicate byte strings.""" + vocab = [ + Token(token_id=0, byte_string=b"a"), + Token(token_id=1, byte_string=b"hello"), + Token(token_id=2, byte_string=b"hello"), # Duplicate! + Token(token_id=3, byte_string=b"world"), + ] + + trie = TokenByteTrie(decode=vocab) + + # Assign different weights to the duplicate tokens + weights = torch.tensor([0.1, 0.3, 0.5, 0.1]) + + node_weights = trie.weight_sum(weights) + + # Get the leaf weights for the duplicate tokens + leaf_1 = trie.token_id_to_leaf[1][1] + leaf_2 = trie.token_id_to_leaf[2][1] + + # Each leaf should have its own weight + assert np.isclose(node_weights[leaf_1].item(), 0.3, rtol=1e-5) + assert np.isclose(node_weights[leaf_2].item(), 0.5, rtol=1e-5) + + # The parent node (at "hello" prefix) should have the sum of both + # Find the shared parent node (before EOT edges) + hello_prefix = list(b"hello") + for node, prefix in trie.node2prefix.items(): + if prefix == hello_prefix and node not in trie.leaf2word: + # This is the internal node before the EOT edges + expected_sum = 0.3 + 0.5 # Sum of both "hello" tokens + assert np.isclose(node_weights[node].item(), expected_sum, rtol=1e-5) + break + + +def test_requires_token_objects(): + """Test that TokenByteTrie warns for raw bytes and rejects non-bytes types.""" + with pytest.warns(DeprecationWarning, match="Passing plain bytes"): + TokenByteTrie(decode=[b"a", b"b", b"c"]) + + with pytest.raises(TypeError, match="decode must contain Token objects"): + TokenByteTrie(decode=["a", "b", "c"]) + + +def test_empty_decode_raises(): + """Test that empty decode raises ValueError.""" + with pytest.raises(ValueError, match="decode cannot be empty"): + TokenByteTrie(decode=[])