Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ dependencies = [
"fastapi>=0.125.0",
"httpx>=0.28.1",
"numpy<2.0.0",
"tinker>=0.7.0",
"tinker>=0.18.2",
"protobuf>=4.21.0",
"typer>=0.20.1",
"uvicorn[standard]>=0.38.0",
"omegaconf>=2.3.0",
Expand Down Expand Up @@ -73,6 +74,7 @@ line-length = 100
target-version = "py311"
exclude = [
"tinker",
"thirdparty",
]

[tool.ruff.lint]
Expand All @@ -95,12 +97,15 @@ pythonVersion = "3.11"
typeCheckingMode = "standard"
reportUnusedImport = false
reportMissingImports = false
reportPrivateImportUsage = false
reportOptionalMemberAccess = false
exclude = [
"**/.venv",
"**/__pycache__",
".pytest_cache",
".ruff_cache",
"tinker",
"thirdparty",
]

[tool.pytest.ini_options]
Expand Down
8 changes: 4 additions & 4 deletions src/tuft/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def sample(
generated = self._generate_tokens(prompt_tokens, max_tokens)
seq = types.SampledSequence(
stop_reason="length",
tokens=generated,
logprobs=[-0.3 for _ in generated],
_tokens_list=generated,
_logprobs_list=[-0.3 for _ in generated],
)
sequences.append(seq)
prompt_logprobs = None
Expand All @@ -187,8 +187,8 @@ def sample(
]
return types.SampleResponse(
sequences=sequences,
prompt_logprobs=prompt_logprobs,
topk_prompt_logprobs=topk_prompt,
_prompt_logprobs_list=prompt_logprobs,
_topk_prompt_logprobs_list=topk_prompt,
)

# ------------------------------------------------------------------
Expand Down
218 changes: 208 additions & 10 deletions src/tuft/backends/sampling_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
from logging import getLogger
from pathlib import Path
from typing import Optional
from typing import Any, Optional

from opentelemetry.trace import StatusCode
from tinker import types
Expand All @@ -24,6 +24,154 @@
logger = getLogger(__name__)


def _build_sample_response(
req_output: Any,
include_prompt_logprobs: bool = False,
topk_prompt_logprobs: int = 0,
) -> types.SampleResponse:
"""Build a tinker 0.18.2 SampleResponse from vLLM RequestOutput.

WHY THIS EXISTS:
trinity-rft 0.5.1 declares ``tinker>=0.10.0`` as a dependency but its
``vLLMRolloutModel.sample()`` constructs ``SampledSequence`` and
``SampleResponse`` using the old tinker 0.7 keyword arguments
(``tokens=``, ``logprobs=``, ``prompt_logprobs=``,
``topk_prompt_logprobs=``). In tinker 0.18.2 these types are frozen
dataclasses and the old names are no longer valid constructor parameters,
causing ``TypeError: SampledSequence.__init__() got an unexpected keyword
argument 'tokens'``.

Because the trinity model runs as a **Ray remote actor in a separate
process**, we cannot monkey-patch tinker's constructors from our main
process. The only self-contained workaround is to bypass trinity's
``sample()`` entirely, call its lower-level ``_generate_internal()``
(which returns the raw vLLM ``RequestOutput`` without touching tinker
types), and build the ``SampleResponse`` ourselves here using the new
tinker 0.18.2 constructor API.

HOW TO REVERT WHEN TRINITY IS FIXED:
The recommended first step is to upgrade trinity-rft to its latest
version (``pip install --upgrade trinity-rft``). If the new version
constructs ``SampledSequence`` / ``SampleResponse`` with tinker
0.18.2-compatible keyword arguments (``_tokens_list=``,
``_logprobs_list=``, etc.), then do the following:

1. In ``VLLMSamplingBackend.sample()``, replace the call to
``engine._generate_internal.remote()`` + ``_build_sample_response()``
with a direct call to ``engine.sample.remote()``.
2. Delete this ``_build_sample_response()`` function.
3. Optionally delete ``_normalize_sample_response()`` if no longer needed.
4. Remove the ``skip_reading_prefix_cache`` workaround in
``VLLMSamplingBackend.sample()`` (trinity handles it internally).

The logic below mirrors trinity's ``vllm_model.py::sample()`` but uses
the new constructor API (``_tokens_list=``, ``_logprobs_list=``, etc.).
"""
sequences: list[types.SampledSequence] = []
topk_prompt_logprobs_list: list[list[tuple[int, float]] | None] = [None]
prompt_logprobs: list[float | None] = [None]

# collect prompt logprobs
if include_prompt_logprobs:
for logprob_dict in req_output.prompt_logprobs[1:]:
prompt_logprobs.append(next(iter(logprob_dict.values())).logprob)
if topk_prompt_logprobs > 0:
logprob_items = sorted(logprob_dict.items(), key=lambda x: x[1].rank)
topk = logprob_items[:topk_prompt_logprobs]
topk_prompt_logprobs_list.append(
[(token_id, logprob.logprob) for token_id, logprob in topk]
)

# collect response sequences
for seq_output in req_output.outputs:
seq = types.SampledSequence(
stop_reason="length" if seq_output.finish_reason == "length" else "stop",
_tokens_list=seq_output.token_ids,
_logprobs_list=[
next(iter(logprob_dict.values())).logprob for logprob_dict in seq_output.logprobs
],
)
sequences.append(seq)

return types.SampleResponse(
sequences=sequences,
_prompt_logprobs_list=prompt_logprobs if include_prompt_logprobs else None,
_topk_prompt_logprobs_list=(
topk_prompt_logprobs_list
if include_prompt_logprobs and topk_prompt_logprobs > 0
else None
),
)


def _normalize_sample_response(raw: Any) -> types.SampleResponse:
"""Normalize engine sample response to tinker 0.18.2 SampleResponse dataclass.

Handles responses from engines that may use older tinker versions:
- If already a SampleResponse dataclass: pass through
- If dict (JSON-like): construct from dict fields
- If Pydantic-like object with .sequences attribute: extract and convert
"""
if isinstance(raw, types.SampleResponse):
return raw

# Handle dict response (e.g., from JSON serialization)
if isinstance(raw, dict):
sequences = []
for seq_data in raw.get("sequences", []):
if isinstance(seq_data, dict):
sequences.append(
types.SampledSequence(
stop_reason=seq_data["stop_reason"],
_tokens_list=seq_data.get("tokens", []),
_logprobs_list=seq_data.get("logprobs"),
)
)
else:
# Already a SampledSequence-like object
sequences.append(
types.SampledSequence(
stop_reason=seq_data.stop_reason,
_tokens_list=list(seq_data.tokens) if hasattr(seq_data, "tokens") else [],
_logprobs_list=list(seq_data.logprobs)
if hasattr(seq_data, "logprobs") and seq_data.logprobs is not None
else None,
)
)
return types.SampleResponse(
sequences=sequences,
_prompt_logprobs_list=raw.get("prompt_logprobs"),
_topk_prompt_logprobs_list=raw.get("topk_prompt_logprobs"),
)

# Handle old Pydantic-like object (has .sequences attribute)
if hasattr(raw, "sequences"):
sequences = []
for seq in raw.sequences:
tokens = list(seq.tokens) if hasattr(seq, "tokens") else []
logprobs = (
list(seq.logprobs)
if hasattr(seq, "logprobs") and seq.logprobs is not None
else None
)
sequences.append(
types.SampledSequence(
stop_reason=seq.stop_reason,
_tokens_list=tokens,
_logprobs_list=logprobs,
)
)
prompt_lp = getattr(raw, "prompt_logprobs", None)
topk_lp = getattr(raw, "topk_prompt_logprobs", None)
return types.SampleResponse(
sequences=sequences,
_prompt_logprobs_list=prompt_lp,
_topk_prompt_logprobs_list=topk_lp,
)

raise TypeError(f"Cannot normalize sample response of type {type(raw)}")


class VLLMSamplingBackend(BaseSamplingBackend):
"""A sampling backend using vLLM.

Expand Down Expand Up @@ -222,14 +370,64 @@ async def sample(
if lora_id is not None and lora_id not in self.lora_adapters:
raise ValueError(f"LoRA adapter {lora_id} not found in backend.")
lora_request = self.lora_adapters[lora_id] if lora_id is not None else None

# -----------------------------------------------------------------
# WORKAROUND: bypass trinity's engine.sample.remote()
#
# trinity-rft 0.5.1 uses old tinker 0.7 constructor keywords
# (tokens=, logprobs=) inside its sample() method, which crash
# with tinker 0.18.2 frozen dataclasses. The actor runs in a
# separate Ray worker process, so monkey-patching from here
# won't help.
#
# Instead we call the lower-level _generate_internal() which
# returns raw vLLM RequestOutput, then build SampleResponse
# ourselves via _build_sample_response() using the new API.
#
# TODO(trinity): First try upgrading trinity-rft to latest
# (pip install --upgrade trinity-rft). If the new version is
# compatible with tinker 0.18.2, replace this block with:
# raw_response = await self.engine.sample.remote(
# prompt=prompt,
# num_samples=num_samples,
# sampling_params=sampling_params,
# include_prompt_logprobs=include_prompt_logprobs,
# topk_prompt_logprobs=topk_prompt_logprobs,
# lora_request=lora_request,
# )
# return _normalize_sample_response(raw_response)
# -----------------------------------------------------------------
prompt_token_ids = prompt.to_ints()
params = {
"max_tokens": (
sampling_params.max_tokens if sampling_params.max_tokens is not None else 16
),
"seed": sampling_params.seed,
"top_k": sampling_params.top_k,
"top_p": sampling_params.top_p,
"temperature": sampling_params.temperature,
"n": num_samples,
"prompt_logprobs": (topk_prompt_logprobs if include_prompt_logprobs else None),
"logprobs": 0,
}
# Avoid prefix cache corruption when computing prompt logprobs.
# Trinity sets this for vLLM >= 0.12.0 to prevent OverflowError
# in vLLM's _update_prompt_logprobs when prefix cache is active.
if include_prompt_logprobs:
params["skip_reading_prefix_cache"] = True
if sampling_params.stop is not None:
params["stop"] = sampling_params.stop

# Ray @ray.remote decorator adds .remote() method dynamically
return await self.engine.sample.remote( # type: ignore[attr-defined]
prompt=prompt,
num_samples=num_samples,
sampling_params=sampling_params,
req_output = await self.engine._generate_internal.remote( # type: ignore[attr-defined]
prompt={"prompt_token_ids": prompt_token_ids},
lora_request=lora_request,
**params,
)
return _build_sample_response(
req_output=req_output,
include_prompt_logprobs=include_prompt_logprobs,
topk_prompt_logprobs=topk_prompt_logprobs,
lora_request=lora_request,
)
except Exception as e:
span.record_exception(e)
Expand Down Expand Up @@ -536,8 +734,8 @@ async def sample(
generated = self._generate_tokens(prompt_tokens, max_tokens)
seq = types.SampledSequence(
stop_reason="length",
tokens=generated,
logprobs=[-0.3 for _ in generated],
_tokens_list=generated,
_logprobs_list=[-0.3 for _ in generated],
)
sequences.append(seq)
prompt_logprobs = None
Expand All @@ -556,8 +754,8 @@ async def sample(
]
return types.SampleResponse(
sequences=sequences,
prompt_logprobs=prompt_logprobs,
topk_prompt_logprobs=topk_prompt,
_prompt_logprobs_list=prompt_logprobs,
_topk_prompt_logprobs_list=topk_prompt,
)

def _generate_tokens(self, prompt_tokens: list[int], max_tokens: int) -> list[int]:
Expand Down
Loading
Loading