Skip to content

[Feature] Support model vocab size being less than tokenizer #237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ PYBIND11_MODULE(xgrammar_bindings, m) {
.def_property_readonly("stop_token_ids", &TokenizerInfo::GetStopTokenIds)
.def_property_readonly("special_token_ids", &TokenizerInfo::GetSpecialTokenIds)
.def("dump_metadata", &TokenizerInfo::DumpMetadata)
.def_static("from_huggingface", &TokenizerInfo::FromHuggingFace)
.def_static("from_vocab_and_metadata", &TokenizerInfo::FromVocabAndMetadata);
.def_static("from_vocab_and_metadata", &TokenizerInfo::FromVocabAndMetadata)
.def_static("_detect_metadata_from_hf", &TokenizerInfo::DetectMetadataFromHF);

auto pyGrammar = py::class_<Grammar>(m, "Grammar");
pyGrammar.def("to_string", &Grammar::ToString)
Expand Down
21 changes: 10 additions & 11 deletions cpp/pybind/python_methods.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,24 @@ namespace xgrammar {

TokenizerInfo TokenizerInfo_Init(
const std::vector<std::string>& encoded_vocab,
std::string vocab_type,
int vocab_type,
std::optional<int> vocab_size,
std::optional<std::vector<int32_t>> stop_token_ids,
bool add_prefix_space
) {
const std::unordered_map<std::string, VocabType> VOCAB_TYPE_MAP = {
{"RAW", VocabType::RAW},
{"BYTE_FALLBACK", VocabType::BYTE_FALLBACK},
{"BYTE_LEVEL", VocabType::BYTE_LEVEL},
};
XGRAMMAR_CHECK(VOCAB_TYPE_MAP.count(vocab_type)) << "Invalid vocab type: " << vocab_type;
XGRAMMAR_CHECK(vocab_type == 0 || vocab_type == 1 || vocab_type == 2)
<< "Invalid vocab type: " << vocab_type;
return TokenizerInfo(
encoded_vocab, VOCAB_TYPE_MAP.at(vocab_type), vocab_size, stop_token_ids, add_prefix_space
encoded_vocab,
static_cast<VocabType>(vocab_type),
vocab_size,
stop_token_ids,
add_prefix_space
);
}

std::string TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer) {
const std::string VOCAB_TYPE_NAMES[] = {"RAW", "BYTE_FALLBACK", "BYTE_LEVEL"};
return VOCAB_TYPE_NAMES[static_cast<int>(tokenizer.GetVocabType())];
int TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer) {
return static_cast<int>(tokenizer.GetVocabType());
}

std::vector<pybind11::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/pybind/python_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ namespace xgrammar {

TokenizerInfo TokenizerInfo_Init(
const std::vector<std::string>& encoded_vocab,
std::string vocab_type,
int vocab_type,
std::optional<int> vocab_size,
std::optional<std::vector<int32_t>> stop_token_ids,
bool add_prefix_space
);

std::string TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer);
int TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer);

std::vector<pybind11::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer);

Expand Down
58 changes: 19 additions & 39 deletions cpp/tokenizer_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,14 @@ class TokenizerInfo::Impl {
const std::vector<std::pair<int32_t, std::string>>& GetSortedDecodedVocab() const {
return sorted_decoded_vocab_;
}

std::string DumpMetadata() const;

static std::shared_ptr<Impl> FromVocabAndMetadata(
const std::vector<std::string>& encoded_vocab, const std::string& metadata
);
static std::shared_ptr<Impl> FromHuggingFace(
const std::vector<std::string>& encoded_vocab,
const std::string& backend_str,
std::optional<int> vocab_size,
std::optional<std::vector<int32_t>> stop_token_ids
);

static std::string DetectMetadataFromHF(const std::string& backend_str);

private:
static bool IsSpecialToken(const std::string& decoded_token);
Expand Down Expand Up @@ -369,9 +366,8 @@ TokenizerInfo::Impl::Impl(
}

std::string TokenizerInfo::Impl::DumpMetadata() const {
static const std::string VOCAB_TYPE_NAMES[] = {"RAW", "BYTE_FALLBACK", "BYTE_LEVEL"};
picojson::object obj;
obj["vocab_type"] = picojson::value(VOCAB_TYPE_NAMES[static_cast<int>(vocab_type_)]);
obj["vocab_type"] = picojson::value(static_cast<int64_t>(vocab_type_));
obj["vocab_size"] = picojson::value(static_cast<int64_t>(vocab_size_));
obj["add_prefix_space"] = picojson::value(add_prefix_space_);
picojson::array stop_token_ids_array;
Expand All @@ -386,26 +382,18 @@ std::string TokenizerInfo::Impl::DumpMetadata() const {
std::shared_ptr<TokenizerInfo::Impl> TokenizerInfo::Impl::FromVocabAndMetadata(
const std::vector<std::string>& encoded_vocab, const std::string& metadata
) {
static const std::unordered_map<std::string, VocabType> VOCAB_TYPE_MAP = {
{"RAW", VocabType::RAW},
{"BYTE_FALLBACK", VocabType::BYTE_FALLBACK},
{"BYTE_LEVEL", VocabType::BYTE_LEVEL},
};

picojson::value v;
std::string err = picojson::parse(v, metadata);
XGRAMMAR_CHECK(err.empty()) << "Failed to parse metadata: " << err;

const picojson::object& obj = v.get<picojson::object>();
XGRAMMAR_CHECK(obj.count("vocab_type") && obj["vocab_type"].is<std::string>())

XGRAMMAR_CHECK(obj.count("vocab_type") && obj["vocab_type"].is<std::int64_t>())
<< "Missing or invalid 'vocab_type' in metadata";
std::string vocab_type_str = obj["vocab_type"].get<std::string>();
VocabType vocab_type;
if (auto it = VOCAB_TYPE_MAP.find(vocab_type_str); it != VOCAB_TYPE_MAP.end()) {
vocab_type = it->second;
} else {
XGRAMMAR_CHECK(false) << "Invalid vocab_type in metadata: " << vocab_type_str;
}
int vocab_type_int = static_cast<int>(obj["vocab_type"].get<int64_t>());
XGRAMMAR_CHECK(vocab_type_int == 0 || vocab_type_int == 1 || vocab_type_int == 2)
<< "Invalid vocab_type in metadata: " << vocab_type_int;
VocabType vocab_type = static_cast<VocabType>(vocab_type_int);

XGRAMMAR_CHECK(obj.count("vocab_size") && obj["vocab_size"].is<int64_t>())
<< "Missing or invalid 'vocab_size' in metadata";
Expand All @@ -427,21 +415,19 @@ std::shared_ptr<TokenizerInfo::Impl> TokenizerInfo::Impl::FromVocabAndMetadata(
);
}

std::shared_ptr<TokenizerInfo::Impl> TokenizerInfo::Impl::FromHuggingFace(
const std::vector<std::string>& encoded_vocab,
const std::string& backend_str,
std::optional<int> vocab_size,
std::optional<std::vector<int32_t>> stop_token_ids
) {
std::string TokenizerInfo::Impl::DetectMetadataFromHF(const std::string& backend_str) {
picojson::value v;
std::string err = picojson::parse(v, backend_str);
XGRAMMAR_CHECK(err.empty() && v.is<picojson::object>()) << "Failed to parse JSON object: " << err;
const picojson::object& obj = v.get<picojson::object>();
VocabType vocab_type = HFTokenizerAnalyzer::DetectVocabType(obj);
bool add_prefix_space = HFTokenizerAnalyzer::DetectAddPrefixSpace(obj);
return std::make_shared<Impl>(
encoded_vocab, vocab_type, vocab_size, stop_token_ids, add_prefix_space
);

// Serialize the metadata
picojson::object metadata_obj;
metadata_obj["vocab_type"] = picojson::value(static_cast<int64_t>(vocab_type));
metadata_obj["add_prefix_space"] = picojson::value(add_prefix_space);
return picojson::value(metadata_obj).serialize(false);
}

/************* TokenizerInfo *************/
Expand Down Expand Up @@ -481,14 +467,8 @@ TokenizerInfo TokenizerInfo::FromVocabAndMetadata(
return TokenizerInfo(Impl::FromVocabAndMetadata(encoded_vocab, metadata));
}

TokenizerInfo TokenizerInfo::FromHuggingFace(
const std::vector<std::string>& encoded_vocab,
const std::string& backend_str,
std::optional<int> vocab_size,
std::optional<std::vector<int32_t>> stop_token_ids
) {
return TokenizerInfo(Impl::FromHuggingFace(encoded_vocab, backend_str, vocab_size, stop_token_ids)
);
std::string TokenizerInfo::DetectMetadataFromHF(const std::string& backend_str) {
return Impl::DetectMetadataFromHF(backend_str);
}

} // namespace xgrammar
8 changes: 2 additions & 6 deletions include/xgrammar/tokenizer_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,8 @@ class TokenizerInfo {
static TokenizerInfo FromVocabAndMetadata(
const std::vector<std::string>& encoded_vocab, const std::string& metadata
);
static TokenizerInfo FromHuggingFace(
const std::vector<std::string>& encoded_vocab,
const std::string& backend_str,
std::optional<int> vocab_size = std::nullopt,
std::optional<std::vector<int32_t>> stop_token_ids = std::nullopt
);

static std::string DetectMetadataFromHF(const std::string& backend_str);

XGRAMMAR_DEFINE_PIMPL_METHODS(TokenizerInfo);
};
Expand Down
93 changes: 62 additions & 31 deletions python/xgrammar/tokenizer_info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""This module provides the tokenizer info class to handle the tokenizer information."""

import json
from enum import Enum
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import sentencepiece
import tiktoken
Expand Down Expand Up @@ -37,9 +38,9 @@ class VocabType(Enum):
meta-llama/Meta-Llama-3.1-8B-Instruct, etc.
"""

RAW = "RAW"
BYTE_FALLBACK = "BYTE_FALLBACK"
BYTE_LEVEL = "BYTE_LEVEL"
RAW = 0
BYTE_FALLBACK = 1
BYTE_LEVEL = 2


class TokenizerInfo(XGRObject):
Expand Down Expand Up @@ -100,7 +101,8 @@ def _is_tiktoken_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:
)

filename_pattern = (
"vocab_file" in tokenizer.vocab_files_names
hasattr(tokenizer, "vocab_files_names")
and "vocab_file" in tokenizer.vocab_files_names
and "tiktoken" in tokenizer.vocab_files_names["vocab_file"]
)

Expand Down Expand Up @@ -132,34 +134,43 @@ def from_huggingface(
various tokenizer backends, including the huggingface fast tokenizer and tiktoken tokenizer.
Necessary information is automatically detected from the tokenizer.

Note that some models (e.g. Phi-3 and Deepseek-V2) may pad the vocabulary to a multiple
of 32. In this case, the model's vocab_size is larger than the tokenizer's vocabulary
size. Please pass the model's vocab_size (this should be defined in the model config)
to the vocab_size parameter in the constructor, because this information is used to
determine the size of the token mask.
The vocab_size parameter is introduced to handle the misalignment between the model's
vocab_size and the tokenizer's vocabulary size. User should pass the model's vocab_size
(could be defined in the model config) here. See docs of vocab_size for more details.

Some models can have more than one stop token ids, and auto detection may not find all
of them. In this case, you can specify the stop token ids manually.
The stop token ids is by default the eos_token_id of the tokenizer. If there are other
stop tokens, you can specify them manually.

Parameters
----------
tokenizer : PreTrainedTokenizerBase
The huggingface tokenizer.

vocab_size : Optional[int], default: None
The size of the vocabulary. If not provided, the vocabulary size will be
len(encoded_vocab).
The vocabulary size **defined by the model** (**not the tokenizer**). This equals to the
vocab dimention of the model's lm_head. This is the size of the token mask.

It can be:
1. the same as the tokenizer's vocabulary size. This is the most common case.
2. larger than the tokenizer's vocabulary size. This happens when the model has padding
to lm_head, possibly due to aligning lm_head to the power of 2.
E.g. Phi-3 and Deepseek-V2.
3. smaller than the tokenizer's vocabulary size. This happens when the tokenizer has
some added tokens that will not supported by the model. E.g.
Llama-3.2 Vision and Molmo-72B-0924 has padded <|image|> tokens, but they will not
be considered in lm_head or generated by the model.

model_vocab_size need to be provided for case 2 and 3. If not provided, it will be
set to the tokenizer's vocabulary size.

stop_token_ids : Optional[List[int]], default: None
The stop token ids. If not provided, the stop token ids will be auto detected
(but may not be correct).
The stop token ids. If not provided, the eos_token_id of the tokenizer will be used.

Returns
-------
tokenizer_info : TokenizerInfo
The tokenizer info.
"""

if isinstance(stop_token_ids, int):
stop_token_ids = [stop_token_ids]
if isinstance(stop_token_ids, list) and len(stop_token_ids) == 0:
Expand All @@ -174,19 +185,18 @@ def from_huggingface(
)
raise ValueError(msg) from e

max_id = max(vocab_dict.values()) if vocab_dict else -1
detected_vocab_size = max(len(vocab_dict), max_id + 1)
if vocab_size is None:
vocab_size = detected_vocab_size
else:
if vocab_size < detected_vocab_size:
msg = f"Input vocab_size less than minimum viable vocab size for tokenizer {type(tokenizer)}."
raise ValueError(msg)
# Some tokenizer don't have token id 0 or 1 or 2. So the max_id could be larger than the
# number of tokens.
max_id = max(vocab_dict.values())
tokenizer_vocab_size = max(len(vocab_dict), max_id + 1)

vocab_size = vocab_size or tokenizer_vocab_size

# maintain tokenizer's indexing
encoded_vocab = ["" for _ in range(vocab_size)]
encoded_vocab = [""] * vocab_size
for token, idx in vocab_dict.items():
encoded_vocab[idx] = token
if idx < vocab_size:
encoded_vocab[idx] = token

if isinstance(tokenizer, PreTrainedTokenizerFast):
# huggingface fast tokenizer
Expand All @@ -207,11 +217,15 @@ def from_huggingface(
"stop_token_ids is neither provided by user nor found from the tokenizer. "
"It will be automatically detected."
)
return TokenizerInfo._create_from_handle(
_core.TokenizerInfo.from_huggingface(
encoded_vocab, backend_str, vocab_size, stop_token_ids
)
metadata = TokenizerInfo._detect_metadata_from_hf(backend_str)
return TokenizerInfo(
encoded_vocab,
vocab_type=metadata["vocab_type"],
vocab_size=vocab_size,
stop_token_ids=stop_token_ids,
add_prefix_space=metadata["add_prefix_space"],
)

elif TokenizerInfo._is_tiktoken_tokenizer(tokenizer):
# tiktoken tokenizer
# e.g. Phi-3-small-8k-instruct, Qwen-7B-Chat, stablelm-2-12b-chat (previously)
Expand All @@ -231,6 +245,7 @@ def from_huggingface(
stop_token_ids=stop_token_ids,
add_prefix_space=False,
)

elif TokenizerInfo._is_sentencepiece_tokenizer(tokenizer):
# sentencepiece tokenizer
# e.g. Chatglm3-6b
Expand Down Expand Up @@ -265,6 +280,7 @@ def from_huggingface(
stop_token_ids=stop_token_ids,
add_prefix_space=True,
)

else:
# TODO(yixin): unsupported tokenizer
raise ValueError(f"Unsupported tokenizer type: {type(tokenizer)}")
Expand Down Expand Up @@ -337,3 +353,18 @@ def from_vocab_and_metadata(
return TokenizerInfo._create_from_handle(
_core.TokenizerInfo.from_vocab_and_metadata(encoded_vocab, metadata)
)

@staticmethod
def _detect_metadata_from_hf(backend_str: str) -> Dict[str, Any]:
"""Detect the metadata from the huggingface tokenizer backend string. For implementation
use only.

It returns {"vocab_type": VocabType, "add_prefix_space": bool}.
"""
# the metadata_str should in the format of {"vocab_type": int, "add_prefix_space": bool}
metadata_str = _core.TokenizerInfo._detect_metadata_from_hf(backend_str)
metadata = json.loads(metadata_str)
return {
"vocab_type": VocabType(metadata["vocab_type"]),
"add_prefix_space": metadata["add_prefix_space"],
}
Loading
Loading