From b7f01df4645304c235e248c16185b9629646ebd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=A0=E9=B8=A3?= Date: Wed, 6 May 2026 17:02:05 +0800 Subject: [PATCH 1/3] feat: upgrade tinker to 0.18.2 with dataclass compat layer --- pyproject.toml | 3 +- src/tuft/backend.py | 8 +- src/tuft/backends/sampling_backend.py | 221 ++++++++++++++++++++++++-- src/tuft/compat.py | 121 ++++++++++++++ src/tuft/server.py | 59 ++++++- 5 files changed, 396 insertions(+), 16 deletions(-) create mode 100644 src/tuft/compat.py diff --git a/pyproject.toml b/pyproject.toml index 7b97acb..bc014cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,8 @@ dependencies = [ "fastapi>=0.125.0", "httpx>=0.28.1", "numpy<2.0.0", - "tinker>=0.7.0", + "tinker>=0.18.2", + "protobuf>=4.21.0", "typer>=0.20.1", "uvicorn[standard]>=0.38.0", "omegaconf>=2.3.0", diff --git a/src/tuft/backend.py b/src/tuft/backend.py index 26f2b26..0be48f8 100644 --- a/src/tuft/backend.py +++ b/src/tuft/backend.py @@ -167,8 +167,8 @@ def sample( generated = self._generate_tokens(prompt_tokens, max_tokens) seq = types.SampledSequence( stop_reason="length", - tokens=generated, - logprobs=[-0.3 for _ in generated], + _tokens_list=generated, + _logprobs_list=[-0.3 for _ in generated], ) sequences.append(seq) prompt_logprobs = None @@ -187,8 +187,8 @@ def sample( ] return types.SampleResponse( sequences=sequences, - prompt_logprobs=prompt_logprobs, - topk_prompt_logprobs=topk_prompt, + _prompt_logprobs_list=prompt_logprobs, + _topk_prompt_logprobs_list=topk_prompt, ) # ------------------------------------------------------------------ diff --git a/src/tuft/backends/sampling_backend.py b/src/tuft/backends/sampling_backend.py index 58a3f1f..ee7d3e9 100644 --- a/src/tuft/backends/sampling_backend.py +++ b/src/tuft/backends/sampling_backend.py @@ -8,7 +8,7 @@ import time from logging import getLogger from pathlib import Path -from typing import Optional +from typing import Any, Optional from opentelemetry.trace import StatusCode from tinker import types @@ -24,6 +24,155 @@ logger = getLogger(__name__) +def _build_sample_response( + req_output: Any, + include_prompt_logprobs: bool = False, + topk_prompt_logprobs: int = 0, +) -> types.SampleResponse: + """Build a tinker 0.18.2 SampleResponse from vLLM RequestOutput. + + WHY THIS EXISTS: + trinity-rft 0.5.1 declares ``tinker>=0.10.0`` as a dependency but its + ``vLLMRolloutModel.sample()`` constructs ``SampledSequence`` and + ``SampleResponse`` using the old tinker 0.7 keyword arguments + (``tokens=``, ``logprobs=``, ``prompt_logprobs=``, + ``topk_prompt_logprobs=``). In tinker 0.18.2 these types are frozen + dataclasses and the old names are no longer valid constructor parameters, + causing ``TypeError: SampledSequence.__init__() got an unexpected keyword + argument 'tokens'``. + + Because the trinity model runs as a **Ray remote actor in a separate + process**, we cannot monkey-patch tinker's constructors from our main + process. The only self-contained workaround is to bypass trinity's + ``sample()`` entirely, call its lower-level ``_generate_internal()`` + (which returns the raw vLLM ``RequestOutput`` without touching tinker + types), and build the ``SampleResponse`` ourselves here using the new + tinker 0.18.2 constructor API. + + HOW TO REVERT WHEN TRINITY IS FIXED: + The recommended first step is to upgrade trinity-rft to its latest + version (``pip install --upgrade trinity-rft``). If the new version + constructs ``SampledSequence`` / ``SampleResponse`` with tinker + 0.18.2-compatible keyword arguments (``_tokens_list=``, + ``_logprobs_list=``, etc.), then do the following: + + 1. In ``VLLMSamplingBackend.sample()``, replace the call to + ``engine._generate_internal.remote()`` + ``_build_sample_response()`` + with a direct call to ``engine.sample.remote()``. + 2. Delete this ``_build_sample_response()`` function. + 3. Optionally delete ``_normalize_sample_response()`` if no longer needed. + 4. Remove the ``skip_reading_prefix_cache`` workaround in + ``VLLMSamplingBackend.sample()`` (trinity handles it internally). + + The logic below mirrors trinity's ``vllm_model.py::sample()`` but uses + the new constructor API (``_tokens_list=``, ``_logprobs_list=``, etc.). + """ + sequences: list[types.SampledSequence] = [] + topk_prompt_logprobs_list: list[list[tuple[int, float]] | None] = [None] + prompt_logprobs: list[float | None] = [None] + + # collect prompt logprobs + if include_prompt_logprobs: + for logprob_dict in req_output.prompt_logprobs[1:]: + prompt_logprobs.append(next(iter(logprob_dict.values())).logprob) + if topk_prompt_logprobs > 0: + logprob_items = sorted(logprob_dict.items(), key=lambda x: x[1].rank) + topk = logprob_items[:topk_prompt_logprobs] + topk_prompt_logprobs_list.append( + [(token_id, logprob.logprob) for token_id, logprob in topk] + ) + + # collect response sequences + for seq_output in req_output.outputs: + seq = types.SampledSequence( + stop_reason="length" if seq_output.finish_reason == "length" else "stop", + _tokens_list=seq_output.token_ids, + _logprobs_list=[ + next(iter(logprob_dict.values())).logprob + for logprob_dict in seq_output.logprobs + ], + ) + sequences.append(seq) + + return types.SampleResponse( + sequences=sequences, + _prompt_logprobs_list=prompt_logprobs if include_prompt_logprobs else None, + _topk_prompt_logprobs_list=( + topk_prompt_logprobs_list + if include_prompt_logprobs and topk_prompt_logprobs > 0 + else None + ), + ) + + +def _normalize_sample_response(raw: Any) -> types.SampleResponse: + """Normalize engine sample response to tinker 0.18.2 SampleResponse dataclass. + + Handles responses from engines that may use older tinker versions: + - If already a SampleResponse dataclass: pass through + - If dict (JSON-like): construct from dict fields + - If Pydantic-like object with .sequences attribute: extract and convert + """ + if isinstance(raw, types.SampleResponse): + return raw + + # Handle dict response (e.g., from JSON serialization) + if isinstance(raw, dict): + sequences = [] + for seq_data in raw.get("sequences", []): + if isinstance(seq_data, dict): + sequences.append( + types.SampledSequence( + stop_reason=seq_data["stop_reason"], + _tokens_list=seq_data.get("tokens", []), + _logprobs_list=seq_data.get("logprobs"), + ) + ) + else: + # Already a SampledSequence-like object + sequences.append( + types.SampledSequence( + stop_reason=seq_data.stop_reason, + _tokens_list=list(seq_data.tokens) if hasattr(seq_data, "tokens") else [], + _logprobs_list=list(seq_data.logprobs) + if hasattr(seq_data, "logprobs") and seq_data.logprobs is not None + else None, + ) + ) + return types.SampleResponse( + sequences=sequences, + _prompt_logprobs_list=raw.get("prompt_logprobs"), + _topk_prompt_logprobs_list=raw.get("topk_prompt_logprobs"), + ) + + # Handle old Pydantic-like object (has .sequences attribute) + if hasattr(raw, "sequences"): + sequences = [] + for seq in raw.sequences: + tokens = list(seq.tokens) if hasattr(seq, "tokens") else [] + logprobs = ( + list(seq.logprobs) + if hasattr(seq, "logprobs") and seq.logprobs is not None + else None + ) + sequences.append( + types.SampledSequence( + stop_reason=seq.stop_reason, + _tokens_list=tokens, + _logprobs_list=logprobs, + ) + ) + prompt_lp = getattr(raw, "prompt_logprobs", None) + topk_lp = getattr(raw, "topk_prompt_logprobs", None) + return types.SampleResponse( + sequences=sequences, + _prompt_logprobs_list=prompt_lp, + _topk_prompt_logprobs_list=topk_lp, + ) + + raise TypeError(f"Cannot normalize sample response of type {type(raw)}") + + class VLLMSamplingBackend(BaseSamplingBackend): """A sampling backend using vLLM. @@ -222,14 +371,66 @@ async def sample( if lora_id is not None and lora_id not in self.lora_adapters: raise ValueError(f"LoRA adapter {lora_id} not found in backend.") lora_request = self.lora_adapters[lora_id] if lora_id is not None else None + + # ----------------------------------------------------------------- + # WORKAROUND: bypass trinity's engine.sample.remote() + # + # trinity-rft 0.5.1 uses old tinker 0.7 constructor keywords + # (tokens=, logprobs=) inside its sample() method, which crash + # with tinker 0.18.2 frozen dataclasses. The actor runs in a + # separate Ray worker process, so monkey-patching from here + # won't help. + # + # Instead we call the lower-level _generate_internal() which + # returns raw vLLM RequestOutput, then build SampleResponse + # ourselves via _build_sample_response() using the new API. + # + # TODO(trinity): First try upgrading trinity-rft to latest + # (pip install --upgrade trinity-rft). If the new version is + # compatible with tinker 0.18.2, replace this block with: + # raw_response = await self.engine.sample.remote( + # prompt=prompt, + # num_samples=num_samples, + # sampling_params=sampling_params, + # include_prompt_logprobs=include_prompt_logprobs, + # topk_prompt_logprobs=topk_prompt_logprobs, + # lora_request=lora_request, + # ) + # return _normalize_sample_response(raw_response) + # ----------------------------------------------------------------- + prompt_token_ids = prompt.to_ints() + params = { + "max_tokens": ( + sampling_params.max_tokens + if sampling_params.max_tokens is not None + else 16 + ), + "seed": sampling_params.seed, + "top_k": sampling_params.top_k, + "top_p": sampling_params.top_p, + "temperature": sampling_params.temperature, + "n": num_samples, + "prompt_logprobs": (topk_prompt_logprobs if include_prompt_logprobs else None), + "logprobs": 0, + } + # Avoid prefix cache corruption when computing prompt logprobs. + # Trinity sets this for vLLM >= 0.12.0 to prevent OverflowError + # in vLLM's _update_prompt_logprobs when prefix cache is active. + if include_prompt_logprobs: + params["skip_reading_prefix_cache"] = True + if sampling_params.stop is not None: + params["stop"] = sampling_params.stop + # Ray @ray.remote decorator adds .remote() method dynamically - return await self.engine.sample.remote( # type: ignore[attr-defined] - prompt=prompt, - num_samples=num_samples, - sampling_params=sampling_params, + req_output = await self.engine._generate_internal.remote( # type: ignore[attr-defined] + prompt={"prompt_token_ids": prompt_token_ids}, + lora_request=lora_request, + **params, + ) + return _build_sample_response( + req_output=req_output, include_prompt_logprobs=include_prompt_logprobs, topk_prompt_logprobs=topk_prompt_logprobs, - lora_request=lora_request, ) except Exception as e: span.record_exception(e) @@ -536,8 +737,8 @@ async def sample( generated = self._generate_tokens(prompt_tokens, max_tokens) seq = types.SampledSequence( stop_reason="length", - tokens=generated, - logprobs=[-0.3 for _ in generated], + _tokens_list=generated, + _logprobs_list=[-0.3 for _ in generated], ) sequences.append(seq) prompt_logprobs = None @@ -556,8 +757,8 @@ async def sample( ] return types.SampleResponse( sequences=sequences, - prompt_logprobs=prompt_logprobs, - topk_prompt_logprobs=topk_prompt, + _prompt_logprobs_list=prompt_logprobs, + _topk_prompt_logprobs_list=topk_prompt, ) def _generate_tokens(self, prompt_tokens: list[int], max_tokens: int) -> list[int]: diff --git a/src/tuft/compat.py b/src/tuft/compat.py new file mode 100644 index 0000000..ec0c1d6 --- /dev/null +++ b/src/tuft/compat.py @@ -0,0 +1,121 @@ +"""Compatibility helpers for tinker 0.7 → 0.18.2 migration. + +Provides serialization of tinker 0.18.2 dataclass types to the JSON wire format +expected by the tinker SDK client. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +from tinker.proto import tinker_public_pb2 as public_pb +from tinker.types.sample_response import SampleResponse + + +def serialize_sample_response(response: SampleResponse) -> dict[str, Any]: + """Serialize a SampleResponse dataclass to the JSON wire format. + + The tinker SDK client expects the old Pydantic field names: + - sequences[].tokens (list[int]) + - sequences[].logprobs (list[float] | None) + - sequences[].stop_reason (str) + - prompt_logprobs (list[float|None] | None) + - topk_prompt_logprobs (list[list[tuple[int,float]]|None] | None) + - type: "sample" + """ + sequences = [] + for seq in response.sequences: + seq_dict: dict[str, Any] = { + "stop_reason": seq.stop_reason, + "tokens": seq.tokens, # uses @cached_property (lazy conversion from np) + } + if seq.logprobs is not None: + seq_dict["logprobs"] = seq.logprobs + else: + seq_dict["logprobs"] = None + sequences.append(seq_dict) + + result: dict[str, Any] = { + "type": "sample", + "sequences": sequences, + "prompt_logprobs": response.prompt_logprobs, + "topk_prompt_logprobs": response.topk_prompt_logprobs, + } + return result + + +def maybe_serialize_payload(payload: Any) -> Any: + """If payload is a SampleResponse dataclass, serialize it to dict. + + Other Pydantic-based types (ForwardBackwardOutput, OptimStepResponse, etc.) + are handled natively by FastAPI and need no conversion. + """ + if isinstance(payload, SampleResponse): + return serialize_sample_response(payload) + return payload + + +# Proto enum mapping: SDK string -> proto enum value +_STOP_REASON_TO_PROTO: dict[str, int] = { + "stop": public_pb.STOP_REASON_STOP, + "length": public_pb.STOP_REASON_LENGTH, +} + + +def serialize_sample_response_proto(response: SampleResponse) -> bytes: + """Serialize a SampleResponse to protobuf wire format. + + Proto schema (from tinker_public_pb2): + - SampledSequence: stop_reason (enum), tokens (bytes=int32[]), logprobs (bytes=float32[]) + - SampleResponse: sequences[], prompt_logprobs (bytes=float32[]), topk_prompt_logprobs + """ + proto = public_pb.SampleResponse() + + for seq in response.sequences: + proto_seq = proto.sequences.add() + proto_seq.stop_reason = _STOP_REASON_TO_PROTO.get( # type: ignore[assignment] + seq.stop_reason, public_pb.STOP_REASON_LENGTH + ) + # Convert tokens to int32 bytes + tokens = seq.tokens # @cached_property, returns list[int] + proto_seq.tokens = np.array(tokens, dtype=np.int32).tobytes() + # Convert logprobs to float32 bytes (optional) + logprobs = seq.logprobs + if logprobs is not None: + proto_seq.logprobs = np.array(logprobs, dtype=np.float32).tobytes() + + # Prompt logprobs: float32 array with NaN for None positions + prompt_lp = response.prompt_logprobs + if prompt_lp is not None: + lp_array = np.array( + [v if v is not None else float("nan") for v in prompt_lp], + dtype=np.float32, + ) + proto.prompt_logprobs = lp_array.tobytes() + + # Top-k prompt logprobs: dense N*K matrices + topk_lp = response.topk_prompt_logprobs + if topk_lp is not None: + # Determine k from first non-None entry + k = 0 + for entry in topk_lp: + if entry is not None: + k = max(k, len(entry)) + break + if k > 0: + n = len(topk_lp) + token_ids = np.zeros((n, k), dtype=np.int32) + logprobs_matrix = np.full((n, k), -99999.0, dtype=np.float32) + for i, entry in enumerate(topk_lp): + if entry is not None: + for j, (tid, lp) in enumerate(entry[:k]): + token_ids[i, j] = tid + logprobs_matrix[i, j] = lp + topk_msg = proto.topk_prompt_logprobs + topk_msg.token_ids = token_ids.tobytes() + topk_msg.logprobs = logprobs_matrix.tobytes() + topk_msg.k = k + topk_msg.prompt_length = n + + return proto.SerializeToString() diff --git a/src/tuft/server.py b/src/tuft/server.py index 9d0fc2f..56b41f8 100644 --- a/src/tuft/server.py +++ b/src/tuft/server.py @@ -17,6 +17,7 @@ from tinker import types from .auth import User +from .compat import maybe_serialize_payload, serialize_sample_response_proto from .config import AppConfig from .exceptions import TuFTException from .oai import create_oai_router @@ -556,6 +557,7 @@ async def asample( @app.post("/api/v1/retrieve_future") async def retrieve_future( request: types.FutureRetrieveRequest, + raw_request: Request, state: ServerState = Depends(_get_state), user: User = Depends(_get_user), ) -> Any: @@ -573,7 +575,22 @@ async def retrieve_future( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to retrieve future: {str(exc)}", ) from exc - return payload # FastAPI will serialize the stored Tinker type + + # Content negotiation: prefer protobuf for SampleResponse if client accepts it + from tinker.types.sample_response import SampleResponse as SampleResponseDataclass + + accept_header = raw_request.headers.get("accept", "") + if ( + isinstance(payload, SampleResponseDataclass) + and "application/x-protobuf" in accept_header + ): + proto_bytes = serialize_sample_response_proto(payload) + return Response( + content=proto_bytes, + media_type="application/x-protobuf", + ) + + return maybe_serialize_payload(payload) @app.get( "/api/v1/training_runs", @@ -763,6 +780,46 @@ async def get_sampler( ) -> types.GetSamplerResponse: return state.get_sampler_info(sampler_id, user.user_id) + @app.post( + "/api/v1/auth/token", + response_model=types.AuthTokenResponse, + ) + async def auth_token( + user: User = Depends(_get_user), + ) -> types.AuthTokenResponse: + # TuFT uses API key auth directly; return a pass-through token + # that the SDK can use for subsequent requests. + import base64 + import json + import time as _time + + header = base64.urlsafe_b64encode(json.dumps({"alg": "none"}).encode()).decode().rstrip("=") + payload_data = { + "sub": user.user_id, + "exp": int(_time.time()) + 3600, + } + payload_b64 = ( + base64.urlsafe_b64encode(json.dumps(payload_data).encode()).decode().rstrip("=") + ) + token = f"{header}.{payload_b64}." + return types.AuthTokenResponse(jwt=token) + + @app.post( + "/api/v1/client/config", + response_model=types.ClientConfigResponse, + ) + async def client_config( + request: types.ClientConfigRequest, + user: User = Depends(_get_user), + ) -> types.ClientConfigResponse: + return types.ClientConfigResponse( + pjwt_auth_enabled=False, + credential_default_source="api_key", + sample_dispatch_bytes_semaphore_size=10 * 1024 * 1024, + inflight_response_bytes_semaphore_size=50 * 1024 * 1024, + parallel_fwdbwd_chunks=False, + ) + for route in app.routes: path = getattr(route, "path", None) or "" # Skip healthz and OAI routes (OAI routes use their own auth via _get_user_oai) From ad380793da5a0a2e657de278a4ee5cc207002a24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=A0=E9=B8=A3?= Date: Thu, 7 May 2026 14:12:50 +0800 Subject: [PATCH 2/3] fix pre-commit --- pyproject.toml | 2 ++ src/tuft/backends/sampling_backend.py | 7 ++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bc014cc..698acbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ line-length = 100 target-version = "py311" exclude = [ "tinker", + "thirdparty", ] [tool.ruff.lint] @@ -102,6 +103,7 @@ exclude = [ ".pytest_cache", ".ruff_cache", "tinker", + "thirdparty", ] [tool.pytest.ini_options] diff --git a/src/tuft/backends/sampling_backend.py b/src/tuft/backends/sampling_backend.py index ee7d3e9..950f75f 100644 --- a/src/tuft/backends/sampling_backend.py +++ b/src/tuft/backends/sampling_backend.py @@ -88,8 +88,7 @@ def _build_sample_response( stop_reason="length" if seq_output.finish_reason == "length" else "stop", _tokens_list=seq_output.token_ids, _logprobs_list=[ - next(iter(logprob_dict.values())).logprob - for logprob_dict in seq_output.logprobs + next(iter(logprob_dict.values())).logprob for logprob_dict in seq_output.logprobs ], ) sequences.append(seq) @@ -401,9 +400,7 @@ async def sample( prompt_token_ids = prompt.to_ints() params = { "max_tokens": ( - sampling_params.max_tokens - if sampling_params.max_tokens is not None - else 16 + sampling_params.max_tokens if sampling_params.max_tokens is not None else 16 ), "seed": sampling_params.seed, "top_k": sampling_params.top_k, From 5cbc8f91548bcd1ab4f23754d2606d5807b00d6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=A0=E9=B8=A3?= Date: Thu, 7 May 2026 14:24:42 +0800 Subject: [PATCH 3/3] fix precommit pyright --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 698acbd..5d4ed89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,8 @@ pythonVersion = "3.11" typeCheckingMode = "standard" reportUnusedImport = false reportMissingImports = false +reportPrivateImportUsage = false +reportOptionalMemberAccess = false exclude = [ "**/.venv", "**/__pycache__",