Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion genlm/backend/llm/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def batch_evaluate_queries(self):
else:
pasts = None

pasts = DynamicCache.from_legacy_cache(pasts)
pasts = DynamicCache(pasts) if pasts is not None else None

results = self.model(
input_ids,
Expand Down
128 changes: 65 additions & 63 deletions genlm/backend/llm/vllm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import torch
import logging
import warnings
Expand All @@ -7,7 +8,7 @@

try:
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
from vllm.utils import Counter
from vllm.utils.counter import Counter
from vllm.inputs import TokensPrompt

from vllm.distributed.parallel_state import (
Expand Down Expand Up @@ -40,27 +41,7 @@ def from_name(cls, *args, **kwargs): # pragma: no cover
else:
logging.getLogger("vllm.engine.async_llm_engine").setLevel(logging.WARNING)

class PassThroughLogitsProcessor:
"""A logits processor that stores the logprobs and passes the logits through."""

def __init__(self):
self.log_probs = None

def __call__(self, past_token_ids, logits):
assert self.log_probs is None, (
"Log probs already set. This should never happen."
)
self.log_probs = torch.log_softmax(logits, dim=-1, dtype=logits.dtype)
return logits

class AsyncVirtualLM(AsyncLM):
default_params = {
"max_tokens": 1,
"n": 1,
"detokenize": False,
"stop": None,
"ignore_eos": True,
}

def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
"""Initialize an `AsyncVirtualLM` instance.
Expand All @@ -74,26 +55,37 @@ def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
"""
self.async_llm_engine = async_llm_engine
self.tokenizer = async_llm_engine.engine.get_tokenizer()
self.tokenizer = async_llm_engine.tokenizer
self.request_counter = Counter()
self.cache = (
OutputCache(maxsize=cache_size, **cache_opts)
if cache_size > 0
else None
)

async_llm_engine.engine.log_stats = False

super().__init__(tokenizer=self.tokenizer)

# Store vocab size for logprobs requests
self._vocab_size = len(self.tokenizer)

# Default sampling params for logprobs - request full vocab logprobs
self._logprobs_params = SamplingParams(
max_tokens=1,
n=1,
detokenize=False,
stop=None,
ignore_eos=True,
logprobs=self._vocab_size,
)

@classmethod
def from_name(cls, model_name, engine_opts=None, **kwargs):
"""Create a `AsyncVirtualLM` instance from a model name.

Args:
model_name (str): Name of the model to load.
engine_opts (dict): Additional options to pass to the `AsyncLLMEngine`. The engine will be
configured with prefix caching enabled and async output processing disabled by default.
configured with prefix caching enabled by default.
**kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

Returns:
Expand All @@ -111,10 +103,15 @@ def from_name(cls, model_name, engine_opts=None, **kwargs):
"custom sampling functionality."
)

# Get vocab size to set max_logprobs - we need to load tokenizer first
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
vocab_size = len(tokenizer)

engine_opts = {
"enable_prefix_caching": True,
"disable_log_requests": True,
"disable_async_output_proc": True, # This parameter forces vLLM to use v0, which is currently what we want to do.
"max_logprobs": vocab_size, # Enable full vocab logprobs
**(engine_opts or {}),
}

Expand All @@ -126,7 +123,10 @@ def from_name(cls, model_name, engine_opts=None, **kwargs):

@property
def underlying_model(self):
return self.async_llm_engine.engine.model_executor.driver_worker.model_runner.model
raise NotImplementedError(
"underlying_model is not available with vLLM V1 engine. "
"The V1 engine does not expose direct model access."
)

async def next_token_logprobs(self, token_ids):
"""Request log probabilities of next token asynchronously with output caching.
Expand Down Expand Up @@ -165,22 +165,28 @@ async def _next_token_logprobs(self, token_ids):
req_id = str(next(self.request_counter))
prompt = TokensPrompt(prompt_token_ids=token_ids)

outputs = []
processor = PassThroughLogitsProcessor()
async for output in self.async_llm_engine.generate(
prompt=prompt,
sampling_params=SamplingParams(
**self.default_params, logits_processors=[processor]
),
sampling_params=self._logprobs_params,
request_id=req_id,
):
if output.finished:
outputs.append(output)
# Extract logprobs from output
# output.outputs[0].logprobs is a list of dicts (one per generated token)
# Each dict maps token_id -> LogProb object with .logprob attribute
logprobs_dict = output.outputs[0].logprobs[0]

# Convert to tensor - logprobs_dict contains LogProb objects
# We need to create a full vocab tensor
log_probs = torch.full(
(self._vocab_size,), float("-inf"), dtype=torch.float32
)
for token_id, logprob_obj in logprobs_dict.items():
log_probs[token_id] = logprob_obj.logprob

assert processor.log_probs is not None, (
"Log probs should be set by the logits processor."
)
return processor.log_probs
return log_probs

raise RuntimeError("No output received from vLLM engine")

def next_token_logprobs_sync(self, token_ids):
"""Request log probabilities of next token synchronously.
Expand All @@ -203,32 +209,28 @@ def batch_next_token_logprobs_sync(self, token_ids_list):
Returns:
(torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
"""
req_ids = []
req_id2processors = {}
for token_ids in token_ids_list:
req_id = str(next(self.request_counter))
req_ids.append(req_id)
processor = PassThroughLogitsProcessor()
req_id2processors[req_id] = processor
self.async_llm_engine.engine.add_request(
prompt=TokensPrompt(prompt_token_ids=token_ids),
params=SamplingParams(
**self.default_params, logits_processors=[processor]
),
request_id=req_id,
)

while self.async_llm_engine.engine.has_unfinished_requests():
output = self.async_llm_engine.engine.step()
for out in output:
if out.finished:
assert out.request_id in req_id2processors, (
f"{out.request_id} not in requested IDs"
)
async def _batch_async():
results = []
for token_ids in token_ids_list:
result = await self._next_token_logprobs(tuple(token_ids))
results.append(result)
return torch.stack(results)

return torch.stack(
[req_id2processors[req_id].log_probs for req_id in req_ids]
)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop is not None:
# We're already in an async context, create a new task
import concurrent.futures

with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, _batch_async())
return future.result()
else:
return asyncio.run(_batch_async())

def clear_cache(self):
"""Clear output cache."""
Expand All @@ -242,7 +244,7 @@ def __del__(self):
def _cleanup_engine(self):
"""Clean up the vLLM engine and associated resources."""
if async_engine := getattr(self, "async_llm_engine", None):
async_engine.shutdown_background_loop()
async_engine.shutdown()
destroy_model_parallel()
destroy_distributed_environment()

Expand Down
12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ authors = [
]
dependencies = [
"torch",
"transformers>=4.36.0,<5.0.0.rc0",
"transformers>=4.57.1,<5.0.0.rc0",
"sentencepiece",
"protobuf",
"accelerate",
"bitsandbytes; sys_platform == 'linux'",
"numba",
"vllm>=0.6.6,<=0.10.0; sys_platform == 'linux'",
"vllm>=0.11.0; sys_platform == 'linux'",
"triton>=3.2.0; sys_platform == 'linux'",
]

Expand All @@ -41,4 +41,12 @@ exclude = ["benchmark*", "tests*"]
requires = ["setuptools>=64.0", "setuptools-scm>=8"]
build-backend = "setuptools.build_meta"

[dependency-groups]
dev = [
"arsenal>=3.1",
"hypothesis>=6.150.2",
"pytest>=9.0.2",
"pytest-asyncio>=1.3.0",
]

[tool.setuptools_scm]
23 changes: 14 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,9 @@ class ReferenceVirtualLM:

def __init__(self, llm):
self.llm = llm
self.tokenizer = llm.llm_engine.get_tokenizer()
self.tokenizer = llm.get_tokenizer()
self.byte_vocab, self.str_vocab = decode_vocab(self.tokenizer)
self.vocab_length = len(self.byte_vocab)
self.llm.llm_engine.get_model_config().max_logprobs = self.vocab_length
self.DEFAULT_SAMPLING_PARAMS = SamplingParams(
max_tokens=1,
n=1,
Expand All @@ -143,16 +142,22 @@ def __init__(self, llm):
ignore_eos=True,
)

self.llm.llm_engine.log_stats = False

@classmethod
def from_name(cls, model_name, llm_opts=None):
if not HAS_VLLM:
raise ImportError("vLLM not installed.")

# Get vocab size to set max_logprobs
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
vocab_size = len(tokenizer)

llm_opts = {
"enable_prefix_caching": True,
"disable_log_stats": True,
"dtype": "float16",
"max_logprobs": vocab_size, # Enable full vocab logprobs
**(llm_opts or {}),
}
llm = LLM(model=model_name, tokenizer=model_name, **llm_opts)
Expand Down Expand Up @@ -198,8 +203,8 @@ async def batch_next_token_logprobs(self, token_ids_list):
return logprobs

def __del__(self):
if llm_engine := getattr(self.llm, "llm_engine"):
if executor := getattr(llm_engine, "model_executor"):
destroy_model_parallel()
destroy_distributed_environment()
del executor
try:
destroy_model_parallel()
destroy_distributed_environment()
except Exception:
pass