Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ book_schema = {
# (You will need to login via the Hugging Face CLI and have access to the model.)
llm = PromptedLLM.from_name(
"meta-llama/Llama-3.2-1B-Instruct",
eos_tokens=[b"<|eom_id|>", b"<|eot_id|>"],
eos_byte_strings=[b"<|eom_id|>", b"<|eot_id|>"],
temperature=0.8
)

Expand Down
4 changes: 2 additions & 2 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ First, let's look at basic language model sampling using a [`PromptedLLM`][genlm
from genlm.control import PromptedLLM, direct_token_sampler

# Load gpt2 (or any other Hugging Face model)
mtl_llm = PromptedLLM.from_name("gpt2", temperature=0.5, eos_tokens=[b'.'])
mtl_llm = PromptedLLM.from_name("gpt2", temperature=0.5, eos_byte_strings=[b'.'])

# Set the fixed prompt prefix for the language model
# All language model predictions will be conditioned on this prompt
Expand All @@ -30,7 +30,7 @@ sequences.posterior
sequences.decoded_posterior
```

Note: Sequences are lists of `bytes` objects because each token in the language model's vocabulary is represented as a bytes object.
Note: Sequences are lists of `Token` objects. Each `Token` carries a `token_id` and a `byte_string`, and subclasses `bytes` for backwards compatibility (so `b"".join(sequence)` still works).

## Prompt Intersection

Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ book_schema = {
# (You will need to login via the Hugging Face CLI and have access to the model.)
llm = PromptedLLM.from_name(
"meta-llama/Llama-3.2-1B-Instruct",
eos_tokens=[b"<|eom_id|>", b"<|eot_id|>"],
eos_byte_strings=[b"<|eom_id|>", b"<|eot_id|>"],
temperature=0.8
)

Expand Down
4 changes: 2 additions & 2 deletions docs/potentials.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Potentials guide text generation by:

### Vocabulary

Each potential has a **vocabulary** which defines the set of tokens it operates on. Most built-in potentials operate on vocabularies whose tokens are `bytes` or `int` objects (the latter often representing individual bytes).
Each potential has a **vocabulary** which defines the set of tokens it operates on. Language model potentials (`PromptedLLM`) use `Token` objects (which carry both a `token_id` and `byte_string`). Constraint potentials (FSAs, CFGs) typically operate on `int` objects representing individual bytes.

### Weight assignment

Expand Down Expand Up @@ -60,7 +60,7 @@ llm = PromptedLLM.from_name("gpt2", temperature=0.5)
llm.set_prompt_from_str("Montreal is")
```

`PromptedLLM`s have a vocabulary of `bytes` tokens, obtained from the language model's tokenizer.
`PromptedLLM`s have a vocabulary of `Token` objects, obtained from the language model's tokenizer. Each `Token` carries a `token_id` and a `byte_string`, and subclasses `bytes` for backwards compatibility. Note that multiple tokens can share the same byte string.

### Finite-state automata

Expand Down
9 changes: 4 additions & 5 deletions genlm/control/potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def __init__(self, vocabulary, token_type=None, eos=None):
vocabulary (list): List of tokens that make up the vocabulary.
token_type (TokenType, optional): Optional TokenType of all elements of the vocabulary.
If None, will be inferred from vocabulary.
eos (EndOfSequence, optional): Special token to use as end-of-sequence. Defaults to `EOS`.
In general, this should not be set by users.
eos (EndOfSequence, optional): Special token to use as end-of-sequence. Defaults to `EOS` sentinel.

Raises:
ValueError: If vocabulary is empty.
Expand All @@ -62,10 +61,10 @@ def __init__(self, vocabulary, token_type=None, eos=None):
if not all(token_type.check(x) for x in vocabulary):
raise TypeError(f"Tokens in vocabulary must be of type {token_type}.")

if eos and not isinstance(eos, EndOfSequence):
raise ValueError(f"EOS must be an instance of EndOfSequence, got {eos!r}.")
if eos is not None and not isinstance(eos, EndOfSequence):
raise ValueError("EOS must be an instance of EndOfSequence")

self.eos = eos or EOS
self.eos = eos if eos is not None else EOS

self.token_type = token_type
self.vocab = vocabulary
Expand Down
5 changes: 3 additions & 2 deletions genlm/control/potential/built_in/bytellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ class ByteLLM(Potential):
Args:
llm: The language model to use (from `genlm.backend`).
beam_params (BeamParams): Configuration for beam search, including beam width `K`,
`eos_tokens`, and healing parameters (`heal`, `heal_max_backoff`, `heal_max_splits`).
`eos_byte_strings` (list of EOS byte sequences), and healing parameters
(`heal`, `heal_max_backoff`, `heal_max_splits`).
cache_size (int): Maximum number of beam states to cache. Defaults to 1024.

Example:
```python
from genlm.bytes import BeamParams
from genlm.control import ByteLLM

beam_params = BeamParams(K=5, eos_tokens={b"<|endoftext|>"}, heal=True)
beam_params = BeamParams(K=5, eos_byte_strings=[b"<|endoftext|>"], heal=True)
async with ByteLLM.from_name("gpt2", beam_params) as byte_llm:
byte_llm.set_prompt_from_str("Hello")
logp = await byte_llm.prefix([b" ", b"w", b"o", b"r", b"l", b"d"])
Expand Down
40 changes: 23 additions & 17 deletions genlm/control/potential/built_in/canonical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import scipy.sparse as sp
from collections import defaultdict
from genlm.control.potential.base import Potential
from genlm.backend.tokenization import decode_vocab
from genlm.backend.tokenization import decode_vocab, Token
from genlm.control.potential.built_in.llm import PromptedLLM

VERYLARGE = 10000000
Expand Down Expand Up @@ -127,17 +127,14 @@ def __call__(self, context):
if context == ():
mask = np.ones(self.V, dtype=bool)
else:
(_, last_token) = context
try:
left_id = self._encode[last_token] # Get the ID of the last token
except KeyError as e:
raise KeyError(
f"Last token {last_token!r} not found in encode map."
) from e

mask = self._vectorized_conflicting_next_tokens(
left_id
) # Get base mask from BPE rules
_, last_token = context
last_token_bytes = Token.as_bytes(last_token)

left_id = self._encode.get(last_token_bytes)
if left_id is None:
raise KeyError(f"Last token {last_token_bytes!r} not found in encode map.")

mask = self._vectorized_conflicting_next_tokens(left_id)

# Apply overrides: Ensure overridden tokens are allowed (True)
if left_id in self.overrides:
Expand Down Expand Up @@ -234,16 +231,22 @@ def _vectorized_conflicting_next_tokens(self, left: int):

@classmethod
def from_tokenizer(cls, tokenizer, eos_token_ids=None):
_decode, _ = decode_vocab(tokenizer)
if len(_decode) != len(set(_decode)):
_decode_tokens, _ = decode_vocab(tokenizer) # Returns (List[Token], List[str])

# Extract byte strings and check for duplicates
byte_strings = [token.byte_string for token in _decode_tokens]
if len(byte_strings) != len(set(byte_strings)):
raise ValueError(
"Duplicate byte sequences found in vocabulary. Cannot create unique byte->ID mapping (_encode)."
)

_merges = _extract_bpe_merges(tokenizer)

# Build _encode (bytes -> token_id map) from _decode
_encode = {b: i for i, b in enumerate(_decode) if b is not None}
# Build _encode (bytes -> token_id map) from Token objects
_encode = {token.byte_string: token.token_id for token in _decode_tokens}

# For _decode, we keep Token objects to maintain token_id information
_decode = _decode_tokens

# Build _encode_byte (single byte -> token_id map)
_encode_byte = [None] * 256
Expand Down Expand Up @@ -278,6 +281,7 @@ def __init__(self, canonicality_filter):

# IMPORTANT: In the base Potential class, EOS will be added to vocab automatically
# So we should NOT add it ourselves to the vocabulary we pass to super().__init__
# Use Token objects directly as vocabulary to maintain token_id information
vocabulary = self.canonicality_filter._decode
super().__init__(vocabulary)

Expand Down Expand Up @@ -401,7 +405,9 @@ def _check_canonicality(self, context):
# print("percent of mask: ", np.sum(mask)*100 / len(mask))

# Find token_id in the canonicality filter's vocabulary
token_id = self.canonicality_filter._encode[current_token]
current_token_bytes = Token.as_bytes(current_token)

token_id = self.canonicality_filter._encode[current_token_bytes]
if not mask[token_id]:
return False

Expand Down
Loading
Loading