Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions genlm/bytes/byte_lm/beam.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import warnings
import numpy as np
from arsenal import colors
from dataclasses import dataclass
from scipy.special import logsumexp as scipy_logsumexp
from functools import cached_property
from genlm.backend.tokenization.bytes import get_byte_vocab

from ..util import logsumexp, LazyByteProbs
from ..trie import AsyncTokenByteTrie
Expand All @@ -22,7 +22,7 @@ class BeamParams:
prune_threshold (float, optional): Probability threshold for pruning candidates.
Candidates with probability below this are removed. Defaults to 0.0
verbose (bool, optional): Whether to print the beam state at each step. Defaults to False
eos_tokens (list[bytes], optional): List of tokens that should be treated as EOS. When configured,
eos_byte_strings (list[bytes], optional): List of tokens that should be treated as EOS. When configured,
EOS tokens will terminate generation when sampled. Defaults to None
heal (bool, optional): Whether to enable adaptive token healing. Defaults to True
heal_max_backoff (int, optional): Maximum number of bytes to back off when healing. Defaults to None
Expand All @@ -32,23 +32,37 @@ class BeamParams:
K: int
prune_threshold: float = 0.0
verbose: bool = False
eos_tokens: list[bytes] = None
eos_byte_strings: list[bytes] = None
heal: bool = True
heal_max_backoff: int | None = None
# Optional cap on how many intra-partial commits are allowed during a
# single healing attempt. None means unlimited. Set to 0 to disable
# multi-split behavior (i.e., single-split only).
heal_max_splits: int | None = None
# Deprecated alias for eos_byte_strings
eos_tokens: list[bytes] = None

def __post_init__(self):
if self.eos_tokens is not None:
if self.eos_byte_strings is not None:
raise TypeError(
"Cannot specify both 'eos_byte_strings' and the deprecated 'eos_tokens'."
)
warnings.warn(
"'eos_tokens' is deprecated, use 'eos_byte_strings' instead.",
DeprecationWarning,
stacklevel=2,
)
self.eos_byte_strings = self.eos_tokens
self.eos_tokens = None
if self.prune_threshold < 0:
raise ValueError(
f"prune_threshold must be non-negative, got {self.prune_threshold}"
)
self.log_prune_threshold = (
np.log(self.prune_threshold) if self.prune_threshold > 0 else -np.inf
)
self.eos_tokens = set(self.eos_tokens) if self.eos_tokens else set()
self.eos_byte_strings = set(self.eos_byte_strings) if self.eos_byte_strings else set()


class ByteBeamState(StatefulByteLM):
Expand All @@ -71,7 +85,7 @@ async def initial(cls, llm, params, trie_opts=None):
"""Creates initial beam state.

Args:
llm (StatefulTokenizedLM): Token-level language model to use.
llm (genlm.backend.AsyncLM): Token-level language model to use.
params (BeamParams): Beam search parameters.
trie_opts (dict, optional): Additional keyword arguments passed to
AsyncTokenByteTrie.from_vocab. For example, {"max_batch_size": 100}.
Expand All @@ -81,10 +95,11 @@ async def initial(cls, llm, params, trie_opts=None):
"""
# Handle EOS tokens
trie_opts = trie_opts or {}
trie_opts["eos_tokens"] = params.eos_tokens
trie_opts["eos_byte_strings"] = params.eos_byte_strings

# Use llm.byte_vocab which contains Token objects (supports duplicate byte strings)
async_trie = AsyncTokenByteTrie.from_vocab(
get_byte_vocab(llm.tokenizer), **trie_opts
llm.byte_vocab, **trie_opts
)
state = LazyTrieState.initial(llm, async_trie, mode=TrieMode.WITH_EOS)
return cls([await state.materialize()], params)
Expand Down Expand Up @@ -162,18 +177,22 @@ async def logp_next(self):
async def extend(self, logZ):
"""Attempts to advance each candidate in the beam by a token (EOT).

For each candididate with EOT available, this ends the current token and
For each candidate with EOT available, this ends the current token and
starts a new one in preparation for the next byte.

With duplicate tokens (multiple token IDs mapping to the same byte string),
a single state can have multiple extensions - one for each possible token.

Args:
logZ (float): Current estimated of the partition function for pruning
logZ (float): Current estimate of the partition function for pruning

Returns:
(list[LazyTrieState]): New candidate states after extension
"""
extends = []
for state in self:
if new_state := state.extend():
# extend_all() returns all possible extensions (one per token at this position)
for new_state in state.extend_all():
logZ = np.logaddexp(logZ, new_state.weight)
extends.append(new_state)

Expand Down
91 changes: 73 additions & 18 deletions genlm/bytes/byte_lm/heal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,26 @@
from ..util import format_byte


def _find_all_eot_edges(children, eot_sentinel):
"""Find all EOT edges in children dict. Returns list of (node, token_id) tuples.

EOT edges are stored as tuple keys: (eot_sentinel, token_id).
With duplicate tokens, multiple token IDs can map to the same byte string.
"""
results = []
for key, node in children.items():
if isinstance(key, tuple) and key[0] == eot_sentinel:
results.append((node, key[1]))
return results


class TokenHealer:
"""Handles adaptive token healing for ByteBeamState.
Token healing finds alternative tokenizations when the current tokenization
cannot consume the next byte. It works by:
1. Trying different "backoff" positions k (commit partial[:k] as a token)
2. Replaying the remaining bytes (partial[k:]) from fresh root
3. Using extend() when stuck to commit intermediate tokens
3. Using extend_all() when stuck to commit intermediate tokens
4. Finally consuming the target next_byte

Args:
Expand Down Expand Up @@ -69,6 +82,9 @@ async def try_heal(self, state, next_byte: int):
async def _try_at_k(self, state, trie, base_weight: float, k: int, next_byte: int):
"""Try healing by committing partial[:k], replaying partial[k:], then consuming next_byte.

With duplicate tokens, there can be multiple EOT edges at position k.
This method tries all of them until one succeeds.

Args:
state: The original state to heal from
trie: The trie structure (state.trie.trie)
Expand All @@ -89,18 +105,49 @@ async def _try_at_k(self, state, trie, base_weight: float, k: int, next_byte: in
if node_at_k is None:
return None # Path doesn't exist

# Check if there's an EOT at position k
eot_node = children[node_at_k].get(trie.eot_token)
if eot_node is None:
# Find all EOT edges at position k
# With duplicate tokens, multiple token IDs can map to the same byte string
eot_edges = _find_all_eot_edges(children[node_at_k], trie.eot_sentinel)
if not eot_edges:
if self.verbose:
print(f"[heal] k={k}: no EOT at {bytes(partial[:k])!r}")
return None

# Commit at position k
# Try each possible EOT edge
for eot_node, eot_token_id in eot_edges:
result = await self._try_eot_at_k(
state, trie, base_weight, k, next_byte, eot_node, eot_token_id
)
if result is not None:
return result

return None

async def _try_eot_at_k(
self, state, trie, base_weight: float, k: int, next_byte: int,
eot_node: int, eot_token_id: int
):
"""Try healing with a specific EOT edge at position k.

Args:
state: The original state to heal from
trie: The trie structure (state.trie.trie)
base_weight: Precomputed weight after undoing current path
k: Backoff position
next_byte: The byte we want to consume
eot_node: The EOT node to commit
eot_token_id: The token ID for this EOT

Returns:
LazyTrieState if successful, None otherwise
"""
partial = state.partial

# Commit at position k with this specific token
weight_after_commit = base_weight + (
state.mass[eot_node] - state.mass[trie.root]
)
token_id = int(trie.leaf2token_id[eot_node])
token_id = int(eot_token_id)

current = LazyTrieState(
lm_state=(state.lm_state << token_id),
Expand All @@ -114,8 +161,10 @@ async def _try_at_k(self, state, trie, base_weight: float, k: int, next_byte: in
current = await current.materialize()

if self.verbose:
# trie.decode contains Token objects, get byte_string for display
token_bytes = trie.decode[token_id].byte_string
print(
f"[heal] k={k}: commit {trie.decode[token_id]!r}, w={weight_after_commit:.2f}"
f"[heal] k={k}: commit {token_bytes!r} (token_id={token_id}), w={weight_after_commit:.2f}"
)

# Replay suffix bytes then consume next_byte
Expand All @@ -134,26 +183,32 @@ async def _try_at_k(self, state, trie, base_weight: float, k: int, next_byte: in
print(f"[heal] k={k}: hit max_splits={self.max_splits}")
return None

extended = current.extend()
if extended is None:
# extend_all() returns list of all possible extensions
extensions = current.extend_all()
if not extensions:
if self.verbose:
print(f"[heal] k={k}: can't extend at {bytes(current.partial)!r}")
return None

current = await extended.materialize()
splits_used += 1
if self.verbose:
print(f"[heal] k={k}: split #{splits_used}, w={current.weight:.2f}")

# Retry consuming the byte after extend
next_state = current << b
if next_state is None:
# Try each possible extension
for extended in extensions:
materialized = await extended.materialize()
splits_used += 1
if self.verbose:
print(f"[heal] k={k}: split #{splits_used}, w={materialized.weight:.2f}")

# Retry consuming the byte after extend
next_state = materialized << b
if next_state is not None:
current = next_state
break # Found a valid extension
else:
# None of the extensions worked
if self.verbose:
print(
f"[heal] k={k}: couldn't consume {format_byte(b)} even after extend"
)
return None
current = next_state

if self.verbose:
print(f"[heal] SUCCESS at k={k}: w={current.weight:.2f}")
Expand Down
3 changes: 2 additions & 1 deletion genlm/bytes/byte_lm/lm_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ async def logp_next(self):
return await self.model.next_token_logprobs(self.context)

def __repr__(self):
# byte_vocab contains Token objects, so we access .byte_string
return colors.purple % (
"|".join([escape(self.model.byte_vocab[x]) for x in self.context])
"|".join([escape(self.model.byte_vocab[x].byte_string) for x in self.context])
)


Expand Down
68 changes: 53 additions & 15 deletions genlm/bytes/byte_lm/trie_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,24 @@ def actions(self):
"""Returns possible byte transitions from current node."""
return self.children[self.node]

def get_EOT(self):
"""Returns the end-of-token node if available from current position in the trie."""
return self.children[self.node].get(self.trie.trie.eot_token)
def get_all_EOT(self):
"""Returns all EOT edges from the current position in the trie.

With duplicate tokens, multiple token IDs can map to the same byte string,
resulting in multiple EOT edges at the same node.

Returns:
list[tuple[int, int]]: List of (eot_node, token_id) tuples for each EOT edge.
Empty list if no EOT edges exist.
"""
eot_sentinel = self.trie.trie.eot_sentinel
results = []
for key, node in self.children[self.node].items():
# EOT edges are tuples: (eot_sentinel, token_id)
if isinstance(key, tuple) and key[0] == eot_sentinel:
token_id = key[1]
results.append((node, token_id))
return results

def __lshift__(self, b):
"""Transitions to a new state by consuming a byte.
Expand All @@ -131,23 +146,29 @@ def __lshift__(self, b):
terminated=b == EOS,
)

def extend(self):
"""Extends current state by consuming an end-of-token if possible.
def extend_all(self):
"""Extends current state by consuming an end-of-token, returning all possible extensions.

With duplicate tokens (multiple token IDs with the same byte string), there can be
multiple valid extensions at the same position. Each extension corresponds to a
different token being committed, which affects future LM predictions.

Returns:
(LazyTrieState|None): New state after consuming EOT, or None if not possible
list[LazyTrieState]: List of new states after consuming EOT, one per possible token.
Empty list if no EOT edges exist.
"""
if self._extend is None:
if (eot_node := self.get_EOT()) is not None:
mass = self.mass
self._extend = LazyTrieState(
lm_state=self.lm_state
<< int(self.trie.trie.leaf2token_id[eot_node]),
extensions = []
mass = self.mass
for eot_node, token_id in self.get_all_EOT():
extensions.append(LazyTrieState(
lm_state=self.lm_state << int(token_id),
trie=self.trie,
node=self.root,
weight=self.weight + mass[eot_node] - mass[self.node],
mode=self.mode,
)
))
self._extend = extensions
return self._extend

@cached_property
Expand All @@ -161,8 +182,25 @@ def logp_next(self):
mass = self.mass
logZ = mass[self.node]

for byte, node in self.actions().items():
logps[byte if byte is not None else 256] = mass[node] - logZ
for key, node in self.actions().items():
# Handle different edge types:
# - tuple: (eot_sentinel, token_id) for EOT edges to leaves
# - int 0-255: byte transitions
# - int 257: EOS transition
if isinstance(key, tuple):
# EOT edge - use index 256
# For duplicates, sum their masses using logaddexp
if logps[256] == -np.inf:
logps[256] = mass[node] - logZ
else:
logps[256] = np.logaddexp(logps[256], mass[node] - logZ)
elif isinstance(key, int):
logps[key] = mass[node] - logZ
else: # pragma: no cover
raise ValueError(
f"Unexpected edge key type: {type(key).__name__} (value: {key!r}). "
f"Expected tuple (EOT edge) or int (byte/EOS transition)."
)

return LazyByteProbs(logps)

Expand All @@ -181,7 +219,7 @@ async def materialize(self):
self._mass = mass.cpu().numpy()
return self

def __repr__(self):
def __repr__(self): # pragma: no cover
context = colors.green % ("|" + escape(bytes(self.partial)))
if self.terminated:
context += colors.yellow % "<EOS>"
Expand Down
Loading
Loading