From 1e00642f472efe1bff58e78cd1fc101c7819a091 Mon Sep 17 00:00:00 2001 From: xuanrui-L Date: Fri, 30 Jan 2026 11:46:16 +0800 Subject: [PATCH 1/5] fix #29: Add config validation on restart to prevent data corruption fix config example typos - Convert AppConfig, ModelConfig, TelemetryConfig from dataclass to Pydantic models for native JSON serialization support - Add ConfigSignature to store complete AppConfig snapshot in Redis - Validate config on startup (read-only), save signature after successful init - Add `check_fields` option in PersistenceConfig to configure which fields to validate (default: ["supported_models"]) - Add --refresh-persistence CLI flag to clear existing data and start fresh - Add --force-refresh-persistence to skip confirmation prompt If config mismatch is detected, server exits with clear error message suggesting to either restore original config or use --refresh-persistence. --- README.md | 44 +++- config/tuft_config.example.yaml | 22 +- scripts/install.sh | 2 +- src/tuft/cli.py | 113 ++++++++ src/tuft/config.py | 113 ++++---- src/tuft/exceptions.py | 56 ++++ src/tuft/persistence/__init__.py | 12 +- src/tuft/persistence/redis_store.py | 360 ++++++++++++++++++++++++-- src/tuft/server.py | 11 +- tests/conftest.py | 2 +- tests/test_persistence.py | 386 +++++++++++++++++++++++++++- 11 files changed, 1018 insertions(+), 103 deletions(-) diff --git a/README.md b/README.md index f3a9178..9cb9114 100644 --- a/README.md +++ b/README.md @@ -331,22 +331,22 @@ TuFT provides three persistence modes: | Mode | Description | Use Case | |------|-------------|----------| -| `disabled` | No persistence, data in-memory only | Development, testing without state recovery | -| `redis_url` | External Redis server | Production, multi-instance deployments | -| `file_redis` | File-backed store | Demos, small-scale testing | +| `DISABLE` | No persistence, data in-memory only | Development, testing without state recovery | +| `REDIS_URL` | External Redis server | Production, multi-instance deployments | +| `FILE_REDIS` | File-backed store | Demos, small-scale testing | ### Configuration Add a `persistence` section to your `tuft_config.yaml` configuration file and choose one of the following modes. -#### Mode 1: Disabled (Default) +#### Mode 1: DISABLE (Default) No configuration needed. All data is stored in memory and lost on restart. ```yaml # tuft_config.yaml persistence: - mode: disabled + mode: DISABLE ``` #### Mode 2: External Redis Server @@ -356,9 +356,9 @@ Use an external Redis server for production deployments: ```yaml # tuft_config.yaml persistence: - mode: redis_url + mode: REDIS_URL redis_url: "redis://localhost:6379/0" - namespace: "tuft" + namespace: "tuft" # Default: "tuft". ``` You can start a local Redis instance using Docker: @@ -374,11 +374,37 @@ Use the file-backed store for demos or small-scale testing: ```yaml # tuft_config.yaml persistence: - mode: file_redis + mode: FILE_REDIS file_path: "~/.cache/tuft/file_redis.json" - namespace: "tuft" + namespace: "tuft" # Default: "tuft" ``` +### Configuration Validation + +When persistence is enabled, TuFT validates the current configuration against the stored signature on restart. This prevents data corruption when configuration changes. By default, only `supported_models` is checked. + +You can configure which fields to validate: + +```yaml +persistence: + mode: REDIS_URL + redis_url: "redis://localhost:6379/0" + check_fields: # Default: ["SUPPORTED_MODELS"] + - SUPPORTED_MODELS # Always checked (mandatory) + - CHECKPOINT_DIR # Optional + - MODEL_OWNER # Optional +``` + +Available check fields: `SUPPORTED_MODELS`, `CHECKPOINT_DIR`, `MODEL_OWNER`, `TOY_BACKEND_SEED`, `AUTHORIZED_USERS`, `TELEMETRY`. + +If a mismatch is detected, use `--refresh-persistence` to clear existing data and start fresh: + +```bash +tuft --config config.yaml --refresh-persistence +``` + +Use `--force-refresh-persistence` to skip the confirmation prompt. + ## Observability (OpenTelemetry) TuFT supports optional OpenTelemetry integration for distributed tracing, metrics, and logging. diff --git a/config/tuft_config.example.yaml b/config/tuft_config.example.yaml index 08f4784..aa0d4a0 100644 --- a/config/tuft_config.example.yaml +++ b/config/tuft_config.example.yaml @@ -81,22 +81,30 @@ authorized_users: # Configure state persistence for recovery after server restart. # # Available modes: -# - disabled: No persistence (default) -# - redis_url: External Redis server -# - file_redis: File-backed store +# - DISABLE: No persistence (default) +# - REDIS_URL: External Redis server +# - FILE_REDIS: File-backed store persistence: - mode: disabled # Options: disabled, redis_url, file_redis + mode: DISABLE # Options: DISABLE, REDIS_URL, FILE_REDIS - # For redis_url mode: + # For REDIS_URL mode: # redis_url: "redis://localhost:6379/0" - # For file_redis mode: + # For FILE_REDIS mode: # file_path: "~/.cache/tuft/file_redis.json" - # Namespace prefix for Redis keys (optional) + # Namespace prefix for Redis keys. (optional, defaults to "tuft".) # namespace: "tuft" + # Fields to validate on server restart for config consistency. + # Defaults to ["SUPPORTED_MODELS"]. SUPPORTED_MODELS is always checked. + # Available fields: SUPPORTED_MODELS, CHECKPOINT_DIR, MODEL_OWNER, + # TOY_BACKEND_SEED, AUTHORIZED_USERS, TELEMETRY. + # check_fields: + # - SUPPORTED_MODELS + # - CHECKPOINT_DIR + # ============================================================================= # Telemetry Configuration (OpenTelemetry) # ============================================================================= diff --git a/scripts/install.sh b/scripts/install.sh index fede6a6..e6668e9 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -409,7 +409,7 @@ authorized_users: # Optional: Persistence configuration # persistence: -# mode: disabled # Options: disabled, redis_url, file_redis +# mode: DISABLE # Options: DISABLE, REDIS_URL, FILE_REDIS # redis_url: "redis://localhost:6379/0" # namespace: "tuft" CONFIG_EOF diff --git a/src/tuft/cli.py b/src/tuft/cli.py index a65ecd6..3ca4e0f 100644 --- a/src/tuft/cli.py +++ b/src/tuft/cli.py @@ -4,12 +4,20 @@ import logging import os +import sys from pathlib import Path import typer import uvicorn from .config import AppConfig, load_yaml_config +from .exceptions import ConfigMismatchError +from .persistence import ( + flush_all_data, + get_current_namespace, + get_redis_store, + validate_config_signature, +) from .server import create_root_app from .telemetry import init_telemetry from .telemetry.metrics import ResourceMetricsCollector @@ -62,6 +70,21 @@ def _resolve_config_path(config_path: Path | None) -> Path: ) +_REFRESH_PERSISTENCE_OPTION = typer.Option( + False, + "--refresh-persistence", + help=( + "Clear all existing persistence data and start fresh. " + "Use when config has changed and you want to discard old data." + ), +) +_FORCE_REFRESH_PERSISTENCE_OPTION = typer.Option( + False, + "--force-refresh-persistence", + help="Skip confirmation prompts when using --refresh-persistence.", +) + + def _build_config( config_path: Path | None, checkpoint_dir: Path | None, @@ -79,6 +102,90 @@ def _build_config( return config +def _handle_refresh_persistence(force_refresh: bool) -> None: + """Handle the --refresh-persistence flag. + + Prompts for confirmation unless --force-refresh is provided, + then clears all persistence data in the current namespace. + """ + namespace = get_current_namespace() + + if not force_refresh: + typer.secho( + "\n🚨🚨🚨 CRITICAL WARNING 🚨🚨🚨\n", + fg=typer.colors.RED, + bold=True, + ) + typer.secho( + "--refresh-persistence will PERMANENTLY DELETE ALL persistence data!\n", + fg=typer.colors.RED, + bold=True, + ) + typer.secho( + f"šŸ“¦ Target namespace: '{namespace}'\n", + fg=typer.colors.YELLOW, + bold=True, + ) + typer.echo( + f"This IRREVERSIBLE action will destroy ALL data in namespace '{namespace}':\n" + " āŒ All saved sessions\n" + " āŒ All training run records and checkpoint metadata (NOT local checkpoint files)\n" + " āŒ All future records\n" + " āŒ All sampling session records\n" + " āŒ Configuration signature\n" + "\n" + "āš ļø The server will start fresh with NO previous state.\n" + "āš ļø This action CANNOT be undone!\n" + "āš ļø Local checkpoint files on disk are NOT affected.\n" + f"āš ļø Only data in namespace '{namespace}' will be affected.\n" + ) + confirmed = typer.confirm( + f"Do you REALLY want to delete all data in namespace '{namespace}'?", + default=False, + ) + if not confirmed: + typer.echo("Aborted. No data was cleared.") + raise typer.Exit(0) + + deleted_count, cleared_namespace = flush_all_data() + typer.secho( + f"āœ… Cleared {deleted_count} keys from namespace '{cleared_namespace}'.", + fg=typer.colors.GREEN, + ) + typer.echo("Server will start with fresh state.\n") + + +def _validate_persistence_config( + config: AppConfig, refresh_persistence: bool, force_refresh_persistence: bool +) -> None: + """Validate that persistence config matches stored config. + + If refresh_persistence is True, clears existing data instead of validating. + If config mismatch is detected, exits with an error message. + """ + if not config.persistence.enabled: + return + + # Configure the Redis store first + store = get_redis_store() + store.configure(config.persistence) + + if refresh_persistence: + _handle_refresh_persistence(force_refresh_persistence) + return + + try: + validate_config_signature(config) + except ConfigMismatchError as e: + typer.secho( + "\n🚫 FATAL ERROR: Configuration Mismatch Detected 🚫", + fg=typer.colors.RED, + bold=True, + ) + typer.echo(f"\n{e}\n") + sys.exit(1) + + def _init_telemetry(config: AppConfig, log_level: str) -> None: """Initialize OpenTelemetry if enabled.""" # Configure root logger level to ensure logs flow to OTel @@ -101,9 +208,15 @@ def launch( reload: bool = _RELOAD_OPTION, config_path: Path | None = _CONFIG_OPTION, checkpoint_dir: Path | None = _CHECKPOINT_DIR_OPTION, + refresh_persistence: bool = _REFRESH_PERSISTENCE_OPTION, + force_refresh_persistence: bool = _FORCE_REFRESH_PERSISTENCE_OPTION, ) -> None: """Launch the TuFT server.""" app_config = _build_config(config_path, checkpoint_dir) + + # Validate persistence configuration before starting + _validate_persistence_config(app_config, refresh_persistence, force_refresh_persistence) + # Initialize telemetry before starting the server _init_telemetry(app_config, log_level) logging.getLogger("tuft").info("Server starting on %s:%s", host, port) diff --git a/src/tuft/config.py b/src/tuft/config.py index c4cb4a7..a5f3360 100644 --- a/src/tuft/config.py +++ b/src/tuft/config.py @@ -2,9 +2,10 @@ from __future__ import annotations -from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, Iterable, List +from typing import Any, Iterable + +from pydantic import BaseModel, Field, model_validator from .persistence import PersistenceConfig @@ -14,12 +15,7 @@ def _default_checkpoint_dir() -> Path | None: return None -def _default_persistence_config() -> PersistenceConfig: - return PersistenceConfig() - - -@dataclass -class TelemetryConfig: +class TelemetryConfig(BaseModel): """Configuration for OpenTelemetry integration. Attributes: @@ -32,26 +28,58 @@ class TelemetryConfig: enabled: bool = False service_name: str = "tuft" otlp_endpoint: str | None = None - resource_attributes: Dict[str, str] = field(default_factory=dict) + resource_attributes: dict[str, str] = Field(default_factory=dict) + + +class ModelConfig(BaseModel): + model_config = {"arbitrary_types_allowed": True} + + model_name: str # name used in APIs + model_path: Path # path to model checkpoint + max_model_len: int # maximum context length supported by the model + tensor_parallel_size: int = 1 # tensor parallel size + + # default sampling parameters for this model + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + logprobs: int = 0 + seed: int = 42 + min_response_tokens: int = 0 + + # default lora setting + max_lora_rank: int = 16 # maximum rank for LoRA adapters + max_loras: int = 1 # maximum number of LoRA adapters that can be applied simultaneously + + # whether to colocate sampling and training on the same device + # only for local testing purposes + colocate: bool = False + sampling_memory_fraction: float = 0.2 # fraction of GPU memory for sampling + + @model_validator(mode="after") + def validate_colocate(self) -> "ModelConfig": + if self.colocate and self.tensor_parallel_size != 1: + raise ValueError("Colocate option is only supported for tensor_parallel_size=1.") + return self -def _default_telemetry_config() -> TelemetryConfig: - return TelemetryConfig() +class AppConfig(BaseModel): + """Runtime configuration for the TuFT server. + This is a Pydantic model that can be serialized/deserialized for persistence. + """ -@dataclass -class AppConfig: - """Runtime configuration for the TuFT server.""" + model_config = {"arbitrary_types_allowed": True} - checkpoint_dir: Path | None = field(default_factory=_default_checkpoint_dir) - supported_models: List[ModelConfig] = field(default_factory=list) + checkpoint_dir: Path | None = Field(default_factory=_default_checkpoint_dir) + supported_models: list[ModelConfig] = Field(default_factory=list) model_owner: str = "local-user" toy_backend_seed: int = 0 # TODO: Temporary implementation for user authorization, # replace with proper auth system later - authorized_users: Dict[str, str] = field(default_factory=dict) - persistence: PersistenceConfig = field(default_factory=_default_persistence_config) - telemetry: TelemetryConfig = field(default_factory=_default_telemetry_config) + authorized_users: dict[str, str] = Field(default_factory=dict) + persistence: PersistenceConfig = Field(default_factory=PersistenceConfig) + telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig) def ensure_directories(self) -> None: if self.checkpoint_dir is not None: @@ -71,53 +99,24 @@ def check_validity(self) -> None: def with_supported_models(self, models: Iterable[ModelConfig]) -> "AppConfig": updated = list(models) if updated: - self.supported_models = updated + self.supported_models = list(updated) return self - -@dataclass -class ModelConfig: - """Configuration for a specific model.""" - - model_name: str # name used in APIs - model_path: Path # path to model checkpoint - max_model_len: int # maximum context length supported by the model - tensor_parallel_size: int = 1 # tensor parallel size - - # default sampling parameters for this model - temperature: float = 1.0 - top_p: float = 1.0 - top_k: int = -1 - logprobs: int = 0 - seed: int = 42 - min_response_tokens: int = 0 - - # default lora setting - max_lora_rank: int = 16 # maximum rank for LoRA adapters - max_loras: int = 1 # maximum number of LoRA adapters that can be applied simultaneously - - # whether to colocate sampling and training on the same device - # only for local testing purposes - colocate: bool = False - sampling_memory_fraction: float = 0.2 # fraction of GPU memory for sampling - - def __post_init__(self) -> None: - if self.colocate and self.tensor_parallel_size != 1: - raise ValueError("Colocate option is only supported for tensor_parallel_size=1.") + def get_config_for_persistence(self) -> dict[str, Any]: + """Get config fields for persistence signature (excludes persistence config itself).""" + return self.model_dump(mode="json", exclude={"persistence"}) def load_yaml_config(config_path: Path) -> AppConfig: """Loads an AppConfig from a YAML file.""" from omegaconf import OmegaConf - schema = OmegaConf.structured(AppConfig) loaded = OmegaConf.load(config_path) try: - config = OmegaConf.merge(schema, loaded) - app_config = OmegaConf.to_object(config) - assert isinstance(app_config, AppConfig), ( - "Loaded config is not of type AppConfig, which should not happen." - ) - return app_config + # Convert OmegaConf to plain dict for Pydantic + config_dict = OmegaConf.to_container(loaded, resolve=True) + if not isinstance(config_dict, dict): + raise ValueError("Config file must contain a dictionary at root level") + return AppConfig.model_validate(config_dict) except Exception as e: raise ValueError(f"Failed to load config from {config_path}: {e}") from e diff --git a/src/tuft/exceptions.py b/src/tuft/exceptions.py index 713ca6b..18981a6 100644 --- a/src/tuft/exceptions.py +++ b/src/tuft/exceptions.py @@ -1,5 +1,7 @@ """Some custom exceptions.""" +from typing import Any + class TuFTException(Exception): """Base exception for TuFT errors.""" @@ -136,3 +138,57 @@ def __init__(self, shapes: list): detail = f"Input tensors must have the same shape. Got shapes: {shapes}" super().__init__(detail) self.shapes = shapes + + +class PersistenceException(TuFTException): + """Base exception for Persistence related errors.""" + + +class ConfigMismatchError(PersistenceException): + """Raised when current config doesn't match the stored config in Redis. + + This error occurs during server startup when persistence is enabled and + the configuration has changed since the last run. This can cause data + corruption when restoring persisted state. + """ + + def __init__( + self, + diff: dict[str, dict[str, Any]], + ): + self.diff = diff + + # Build detailed diff message + diff_parts = [] + for field_name, field_diff in diff.items(): + # Handle list fields (added/removed) + added = field_diff.get("added") + removed = field_diff.get("removed") + # Handle scalar fields (current/stored) + current = field_diff.get("current") + stored = field_diff.get("stored") + + parts = [] + if added is not None: + parts.append(f"added: {added}") + if removed is not None: + parts.append(f"removed: {removed}") + if current is not None or stored is not None: + parts.append(f"current: {current}, stored: {stored}") + + if parts: + diff_parts.append(f"{field_name} ({', '.join(parts)})") + + diff_str = "; ".join(diff_parts) if diff_parts else "unknown difference" + + message = ( + f"Configuration mismatch detected: {diff_str}.\n" + "The current configuration does not match the stored configuration in Redis.\n" + "This can cause data corruption when restoring persisted state.\n\n" + "Options:\n" + " 1. Use a different Redis database (change redis_url in config)\n" + " 2. Use --refresh-persistence to clear existing data and start fresh\n" + " (WARNING: This will delete all persisted sessions, training runs, etc.)\n" + " 3. Restore the original configuration that matches the stored data" + ) + super().__init__(message) diff --git a/src/tuft/persistence/__init__.py b/src/tuft/persistence/__init__.py index 931cf60..445475c 100644 --- a/src/tuft/persistence/__init__.py +++ b/src/tuft/persistence/__init__.py @@ -3,30 +3,38 @@ from __future__ import annotations from .redis_store import ( - DEFAULT_FUTURE_TTL_SECONDS, + ConfigCheckField, PersistenceConfig, PersistenceMode, RedisPipeline, RedisStore, delete_record, + flush_all_data, + get_current_namespace, get_redis_store, is_persistence_enabled, load_record, + save_config_signature, save_record, save_records_atomic, + validate_config_signature, ) __all__ = [ - "DEFAULT_FUTURE_TTL_SECONDS", + "ConfigCheckField", "PersistenceConfig", "PersistenceMode", "RedisPipeline", "RedisStore", "delete_record", + "flush_all_data", + "get_current_namespace", "get_redis_store", "is_persistence_enabled", "load_record", + "save_config_signature", "save_record", "save_records_atomic", + "validate_config_signature", ] diff --git a/src/tuft/persistence/redis_store.py b/src/tuft/persistence/redis_store.py index 654b79e..ac05f85 100644 --- a/src/tuft/persistence/redis_store.py +++ b/src/tuft/persistence/redis_store.py @@ -8,9 +8,14 @@ - Nested records: {namespace}::{type}::{parent_id}::{nested_type}::{nested_id} Persistence Modes: -- disabled: No persistence, all data is in-memory only -- redis_url: Use external Redis server via URL -- file_redis: Use file-backed storage for tests and demos +- DISABLE: No persistence, all data is in-memory only +- REDIS_URL: Use external Redis server via URL +- FILE_REDIS: Use file-backed storage for tests and demos + +Config Validation: +- On startup, the current config signature is compared with the stored signature +- If mismatch is detected, server stops with an error message +- Use --refresh-persistence to override and clear existing data """ from __future__ import annotations @@ -19,12 +24,12 @@ import os import threading import time -from dataclasses import dataclass +from datetime import datetime, timezone from enum import Enum from pathlib import Path from typing import Any, TypeVar -from pydantic import BaseModel +from pydantic import BaseModel, Field logger = logging.getLogger(__name__) @@ -50,7 +55,7 @@ def _get_metrics(): class PersistenceMode(str, Enum): """Persistence mode options.""" - DISABLED = "disabled" # No persistence + DISABLE = "disabled" # No persistence REDIS_URL = "redis_url" # Use external Redis server FILE_REDIS = "file_redis" # Use file-backed storage for tests/demos @@ -59,33 +64,62 @@ class PersistenceMode(str, Enum): DEFAULT_FUTURE_TTL_SECONDS = 24 * 3600 # 1 day for future records (short-lived) -@dataclass -class PersistenceConfig: +class ConfigCheckField: + """Available fields that can be checked for configuration validation. + + Field names correspond directly to AppConfig attribute names. + SUPPORTED_MODELS is always required (mandatory) for restore safety. + """ + + SUPPORTED_MODELS = "supported_models" + CHECKPOINT_DIR = "checkpoint_dir" + MODEL_OWNER = "model_owner" + TOY_BACKEND_SEED = "toy_backend_seed" + AUTHORIZED_USERS = "authorized_users" + TELEMETRY = "telemetry" + + +# Default fields to check (supported_models is mandatory) +DEFAULT_CHECK_FIELDS: list[str] = [ConfigCheckField.SUPPORTED_MODELS] + + +class PersistenceConfig(BaseModel): """Configuration for Redis persistence. Attributes: - mode: Persistence mode - disabled, redis_url, or file_redis - redis_url: Redis server URL (only used when mode=redis_url) - file_path: JSON file path (only used when mode=file_redis) - namespace: Key namespace prefix + mode: Persistence mode - DISABLE, REDIS_URL, or FILE_REDIS + redis_url: Redis server URL (only used when mode=REDIS_URL) + file_path: JSON file path (only used when mode=FILE_REDIS) + namespace: Key namespace prefix for Redis keys. Defaults to "tuft". future_ttl_seconds: TTL for future records in seconds. None means no expiry. + check_fields: List of AppConfig fields to validate on restart. + Defaults to ["SUPPORTED_MODELS"]. SUPPORTED_MODELS is always + checked regardless of this setting for restore safety. + Available fields: SUPPORTED_MODELS, CHECKPOINT_DIR, MODEL_OWNER, + TOY_BACKEND_SEED, AUTHORIZED_USERS, TELEMETRY. """ - mode: PersistenceMode = PersistenceMode.DISABLED + # Allow Path type + model_config = {"arbitrary_types_allowed": True} + + mode: PersistenceMode = PersistenceMode.DISABLE redis_url: str = "redis://localhost:6379/0" file_path: Path | None = None - namespace: str = "tuft" - future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS # Futures expire after 1 day + namespace: str = "tuft" # Default namespace for Redis keys + future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS + check_fields: list[str] = Field(default_factory=lambda: DEFAULT_CHECK_FIELDS.copy()) @property def enabled(self) -> bool: """Check if persistence is enabled.""" - return self.mode != PersistenceMode.DISABLED + return self.mode != PersistenceMode.DISABLE - @classmethod - def disabled(cls, namespace: str = "tuft") -> "PersistenceConfig": - """Create a disabled persistence config.""" - return cls(mode=PersistenceMode.DISABLED, namespace=namespace) + def get_check_fields(self) -> list[str]: + """Get the fields to check, ensuring SUPPORTED_MODELS is always included.""" + fields = list(self.check_fields) + if ConfigCheckField.SUPPORTED_MODELS not in fields: + fields.insert(0, ConfigCheckField.SUPPORTED_MODELS) + return fields @classmethod def from_redis_url( @@ -93,6 +127,7 @@ def from_redis_url( redis_url: str, namespace: str = "tuft", future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS, + check_fields: list[str] | None = None, ) -> "PersistenceConfig": """Create a config using external Redis server.""" return cls( @@ -100,6 +135,7 @@ def from_redis_url( redis_url=redis_url, namespace=namespace, future_ttl_seconds=future_ttl_seconds, + check_fields=check_fields or DEFAULT_CHECK_FIELDS.copy(), ) @classmethod @@ -108,6 +144,7 @@ def from_file_redis( file_path: Path | None = None, namespace: str = "tuft", future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS, + check_fields: list[str] | None = None, ) -> "PersistenceConfig": """Create a config using file-backed storage.""" return cls( @@ -115,6 +152,7 @@ def from_file_redis( file_path=file_path, namespace=namespace, future_ttl_seconds=future_ttl_seconds, + check_fields=check_fields or DEFAULT_CHECK_FIELDS.copy(), ) @@ -123,7 +161,7 @@ class RedisStore: Supports two modes: - External Redis server (via redis-py) - - No persistence (disabled mode) + - No persistence (DISABLE mode) """ _instance: "RedisStore | None" = None @@ -486,3 +524,283 @@ def is_persistence_enabled() -> bool: def get_redis_store() -> RedisStore: """Get the global Redis store instance.""" return RedisStore.get_instance() + + +def _now() -> datetime: + return datetime.now(timezone.utc) + + +class ConfigSignature(BaseModel): + """Stores a complete snapshot of AppConfig for validation on restart. + + Since AppConfig is now a Pydantic model, we directly store its serialized + form (excluding the persistence field which is runtime-only). + """ + + # Serialized AppConfig data (excludes persistence) + config_data: dict[str, Any] = Field(default_factory=dict) + + # Metadata + created_at: datetime = Field(default_factory=_now) + namespace: str = "tuft" + + @classmethod + def from_app_config(cls, config: Any) -> "ConfigSignature": + """Create a signature by serializing the AppConfig.""" + # Use the method on AppConfig to get persistence-safe data + config_data = config.get_config_for_persistence() + namespace = config.persistence.namespace if config.persistence else "tuft" + return cls(config_data=config_data, namespace=namespace) + + def _get_field_value(self, field_name: str) -> Any: + """Get the value of a field by name.""" + return self.config_data.get(field_name) + + def _normalize_for_comparison(self, value: Any) -> Any: + if isinstance(value, list): + normalized_items = [] + for item in value: + if isinstance(item, dict): + normalized_items.append(tuple(sorted(item.items()))) + else: + normalized_items.append(item) + # Sort for order-independent comparison + return sorted(normalized_items, key=lambda x: str(x)) + return value + + def _compare_field(self, other: "ConfigSignature", field_name: str) -> bool: + """Compare a single field between two signatures.""" + current_value = self._get_field_value(field_name) + other_value = other._get_field_value(field_name) + current_normalized = self._normalize_for_comparison(current_value) + other_normalized = self._normalize_for_comparison(other_value) + + return current_normalized == other_normalized + + def _get_field_diff(self, other: "ConfigSignature", field_name: str) -> dict[str, Any] | None: + """Get the difference for a single field. + + Returns: + {"current": value, "stored": value} if different, None otherwise. + """ + current_value = self._get_field_value(field_name) + other_value = other._get_field_value(field_name) + + current_normalized = self._normalize_for_comparison(current_value) + other_normalized = self._normalize_for_comparison(other_value) + + if current_normalized != other_normalized: + return {"current": current_value, "stored": other_value} + return None + + def matches( + self, + other: "ConfigSignature", + check_fields: list[str] | None = None, + ) -> bool: + """Check if this signature matches another signature. + + Args: + other: The other signature to compare against. + check_fields: List of field names to check. If None, uses DEFAULT_CHECK_FIELDS. + SUPPORTED_MODELS is always included (mandatory). + + Returns: + True if all specified fields match, False otherwise. + """ + fields_to_check = self._get_fields_to_check(check_fields) + + for field_name in fields_to_check: + if not self._compare_field(other, field_name): + return False + return True + + def get_diff( + self, + other: "ConfigSignature", + check_fields: list[str] | None = None, + ) -> dict[str, dict[str, Any]]: + """Get the differences between this signature and another. + + Args: + other: The other signature to compare against. + check_fields: List of field names to check. If None, uses DEFAULT_CHECK_FIELDS. + SUPPORTED_MODELS is always included (mandatory). + + Returns: + Dict mapping field names to their differences. + """ + fields_to_check = self._get_fields_to_check(check_fields) + diff: dict[str, dict[str, Any]] = {} + + for field_name in fields_to_check: + field_diff = self._get_field_diff(other, field_name) + if field_diff is not None: + diff[field_name] = field_diff + + return diff + + def _get_fields_to_check(self, check_fields: list[str] | None) -> list[str]: + """Get the list of fields to check, ensuring mandatory fields are included.""" + if check_fields is None: + return DEFAULT_CHECK_FIELDS.copy() + + # Ensure SUPPORTED_MODELS is always included (mandatory) + fields = list(check_fields) + if ConfigCheckField.SUPPORTED_MODELS not in fields: + fields.insert(0, ConfigCheckField.SUPPORTED_MODELS) + return fields + + +CONFIG_SIGNATURE_KEY = "config_signature" + + +def save_config_signature(config: Any) -> bool: + """Save the config signature to Redis. + + Args: + config: The AppConfig to create a signature from. + + Returns: + True if saved successfully, False otherwise. + """ + store = RedisStore.get_instance() + if not store.is_enabled: + return False + + signature = ConfigSignature.from_app_config(config) + key = store.build_key(CONFIG_SIGNATURE_KEY) + + try: + json_str = signature.model_dump_json() + return store.set(key, json_str) + except Exception: + logger.exception("Failed to save config signature to Redis") + return False + + +def load_config_signature() -> ConfigSignature | None: + """Load the config signature from Redis. + + Returns: + The stored ConfigSignature, or None if not found. + """ + store = RedisStore.get_instance() + if not store.is_enabled: + return None + + key = store.build_key(CONFIG_SIGNATURE_KEY) + + try: + json_str = store.get(key) + if json_str is None: + return None + return ConfigSignature.model_validate_json(json_str) + except Exception: + logger.exception("Failed to load config signature from Redis") + return None + + +def has_existing_data() -> bool: + """Check if there is any existing data in the current namespace. + + Returns: + True if any keys exist in the namespace, False otherwise. + """ + store = RedisStore.get_instance() + if not store.is_enabled: + return False + + pattern = f"{store.namespace}::*" + keys = store.keys(pattern) + return len(keys) > 0 + + +def validate_config_signature(config: Any) -> bool: + """Validate that the current config matches the stored config signature. + + This function ONLY reads from Redis, it does NOT write. + The signature should be saved after successful restore using + save_config_signature(). + + The fields to check are read from config.persistence.check_fields. + SUPPORTED_MODELS is always checked regardless of this setting. + + This function handles several cases: + 1. No signature AND no other data in namespace -> fresh start (return True) + 2. No signature BUT other data exists -> corrupted/incompatible state, raise error + 3. Signature exists and matches -> OK (return False, not fresh) + 4. Signature exists but doesn't match -> raise error + + Args: + config: The current AppConfig to validate. + + Returns: + True if this is a fresh start (no existing data), False otherwise. + + Raises: + ConfigMismatchError: If the configs don't match or state is corrupted. + """ + from tuft.exceptions import ConfigMismatchError + + stored = load_config_signature() + + if stored is None: + # Check if there's any other data in the namespace + if has_existing_data(): + # Data exists but no signature -> corrupted/incompatible state + logger.warning( + "Redis namespace has data but no config signature. " + "This indicates a corrupted or incompatible persistence state." + ) + raise ConfigMismatchError( + diff={ + "_state": { + "current": "valid configuration", + "stored": "missing signature (corrupted or legacy data)", + } + } + ) + else: + # No data at all -> fresh start + logger.info("No stored config signature found - fresh start") + return True + + # Get check_fields from persistence config + check_fields = config.persistence.get_check_fields() if config.persistence else None + + current = ConfigSignature.from_app_config(config) + if not current.matches(stored, check_fields=check_fields): + diff = current.get_diff(stored, check_fields=check_fields) + raise ConfigMismatchError(diff) + + logger.debug("Config signature validated successfully") + return False + + +def get_current_namespace() -> str: + """Get the current Redis namespace. + + Returns: + The namespace string, or 'tuft' if not configured. + """ + store = RedisStore.get_instance() + return store.namespace + + +def flush_all_data() -> tuple[int, str]: + """Clear all data from the current Redis namespace. + + This removes all keys with the current namespace prefix. + Use with caution - this is destructive! + + Returns: + A tuple of (number of keys deleted, namespace that was cleared). + """ + store = RedisStore.get_instance() + if not store.is_enabled: + return 0, store.namespace + + pattern = f"{store.namespace}::*" + deleted_count = store.delete_pattern(pattern) + return deleted_count, store.namespace diff --git a/src/tuft/server.py b/src/tuft/server.py index 24adc11..d393fc2 100644 --- a/src/tuft/server.py +++ b/src/tuft/server.py @@ -18,7 +18,7 @@ from .auth import User from .config import AppConfig from .exceptions import TuFTException -from .persistence import get_redis_store +from .persistence import get_redis_store, save_config_signature from .state import ServerState from .telemetry import shutdown_telemetry @@ -76,11 +76,19 @@ def _instrument_fastapi(app: FastAPI) -> None: def create_root_app(config: AppConfig | None = None) -> FastAPI: + resolved_config = config or AppConfig() + @asynccontextmanager async def lifespan(app: FastAPI): try: await app.state.server_state.async_init() logger.info("Server initialized successfully") + + # After successful init/restore, save the current config signature + if resolved_config.persistence.enabled: + save_config_signature(resolved_config) + logger.debug("Config signature saved after successful initialization") + yield finally: logger.info("Server shutting down") @@ -95,7 +103,6 @@ def require_user_dependency(route): route.dependencies = getattr(route, "dependencies", []) + [Depends(_get_user)] return route - resolved_config = config or AppConfig() if resolved_config.persistence.enabled: store = get_redis_store() store.configure(resolved_config.persistence) diff --git a/tests/conftest.py b/tests/conftest.py index 5bbac06..ae75ab0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -92,7 +92,7 @@ def configure_persistence(request): no_persistence = request.config.getoption("--no-persistence", default=False) if no_persistence: - store.configure(PersistenceConfig.disabled(namespace="tuft_test")) + store.configure(PersistenceConfig(namespace="tuft_test")) # mode=disabled by default else: # Persistence enabled - use Redis if available, otherwise FileRedis if _redis_available() and TEST_REDIS_URL is not None: diff --git a/tests/test_persistence.py b/tests/test_persistence.py index ce55aff..547a17a 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -11,10 +11,23 @@ from tinker.types.try_again_response import TryAgainResponse from tuft.auth import User -from tuft.config import AppConfig, ModelConfig -from tuft.exceptions import UnknownModelException +from tuft.config import AppConfig, ModelConfig, TelemetryConfig +from tuft.exceptions import ConfigMismatchError, UnknownModelException from tuft.futures import FutureStore -from tuft.persistence import get_redis_store, is_persistence_enabled +from tuft.persistence import ( + ConfigCheckField, + flush_all_data, + get_redis_store, + is_persistence_enabled, + save_config_signature, + validate_config_signature, +) +from tuft.persistence.redis_store import ( + DEFAULT_CHECK_FIELDS, + ConfigSignature, + has_existing_data, + load_config_signature, +) from tuft.sampling_controller import SamplingController, SamplingSessionRecord from tuft.state import ServerState, SessionManager from tuft.training_controller import TrainingController, TrainingRunRecord @@ -725,3 +738,370 @@ async def op(): assert "loss:sum" in new_forward.metrics await state3.future_store.shutdown() + + +# ============================================================================= +# Config Signature Validation Tests +# ============================================================================= + + +def _create_config_with_models( + checkpoint_dir: Path, + model_names: list[str], + telemetry_enabled: bool = False, + check_fields: list[str] | None = None, +) -> AppConfig: + """Create a test config with specified model names and optional telemetry.""" + from tuft.persistence import PersistenceConfig, PersistenceMode + from tuft.persistence.redis_store import DEFAULT_CHECK_FIELDS + + return AppConfig( + checkpoint_dir=checkpoint_dir, + supported_models=[ + ModelConfig( + model_name=name, + model_path=Path(f"/dummy/{name}"), + max_model_len=2048, + ) + for name in model_names + ], + telemetry=TelemetryConfig(enabled=telemetry_enabled), + persistence=PersistenceConfig( + mode=PersistenceMode.FILE_REDIS, + check_fields=check_fields if check_fields is not None else DEFAULT_CHECK_FIELDS.copy(), + ), + ) + + +@pytest.mark.persistence +class TestConfigSignatureValidation: + """Test configuration signature validation for persistence safety.""" + + @pytest.fixture + def setup(self, tmp_path): + """Setup file-based persistence for testing.""" + _skip_if_no_persistence() + store = get_redis_store() + yield store, tmp_path + + # Cleanup after test + flush_all_data() + + def test_config_signature_creation(self, setup): + """Test that ConfigSignature is created correctly from AppConfig.""" + _, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + config = _create_config_with_models(checkpoint_dir, ["ModelA", "ModelB", "ModelC"]) + + signature = ConfigSignature.from_app_config(config) + + models = signature.config_data.get("supported_models", []) + model_names = sorted([m["model_name"] for m in models]) + assert model_names == ["ModelA", "ModelB", "ModelC"] + assert signature.created_at is not None + + # Verify other fields are also stored + assert "checkpoint_dir" in signature.config_data + assert "model_owner" in signature.config_data + assert "telemetry" in signature.config_data + + def test_save_and_load_config_signature(self, setup): + """Test that config signature can be saved and loaded from Redis.""" + _, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + config = _create_config_with_models(checkpoint_dir, ["ModelA", "ModelB"]) + + # Save signature + result = save_config_signature(config) + assert result is True + + # Load signature + loaded = load_config_signature() + assert loaded is not None + models = loaded.config_data.get("supported_models", []) + model_names = sorted([m["model_name"] for m in models]) + assert model_names == ["ModelA", "ModelB"] + + def test_validate_config_signature_matching(self, setup): + """Test that validation succeeds when configs match.""" + _, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + config = _create_config_with_models(checkpoint_dir, ["ModelA", "ModelB"]) + + save_config_signature(config) + + is_fresh = validate_config_signature(config) + assert is_fresh is False + + def test_validate_config_signature_mismatch_raises(self, setup): + """Test that validation raises ConfigMismatchError on mismatch.""" + _, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + config1 = _create_config_with_models(checkpoint_dir, ["ModelA", "ModelB"]) + config2 = _create_config_with_models(checkpoint_dir, ["ModelC"]) + + save_config_signature(config1) + + with pytest.raises(ConfigMismatchError) as exc_info: + validate_config_signature(config2) + + error = exc_info.value + assert "supported_models" in str(error).lower() or "mismatch" in str(error).lower() + + model_diff = error.diff.get(ConfigCheckField.SUPPORTED_MODELS) + assert model_diff is not None + current_names = [m["model_name"] for m in model_diff["current"]] + stored_names = [m["model_name"] for m in model_diff["stored"]] + assert "ModelC" in current_names + assert "ModelA" in stored_names + assert "ModelB" in stored_names + + def test_flush_all_data_clears_signature(self, setup): + """Test that flush_all_data clears the config signature.""" + _, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + config = _create_config_with_models(checkpoint_dir, ["ModelA"]) + + save_config_signature(config) + assert load_config_signature() is not None + + flush_all_data() + + assert load_config_signature() is None + + def test_config_mismatch_prevents_silent_corruption(self, setup): + """Test the full scenario: config mismatch is detected before corruption. + + This test verifies the fix for the original issue: + 1. Create training run with ModelA + 2. Try to restart with config that only has ModelC + 3. ConfigMismatchError should be raised BEFORE any corruption happens + """ + store, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + # === Phase 1: Create training run with ModelA === + config1 = _create_config_with_models(checkpoint_dir, ["ModelA", "ModelB"]) + config1.ensure_directories() + + # Validate (first run - returns True for fresh start) + is_fresh = validate_config_signature(config1) + assert is_fresh is True + + controller1 = TrainingController(config1) + + # Create a training run + training_run_id = "test-run-validation" + record = TrainingRunRecord( + training_run_id=training_run_id, + base_model="ModelA", + lora_rank=8, + session_id="session-001", + model_owner="user1", + ) + record.backend = controller1.training_backends.get("ModelA") + controller1.training_runs[training_run_id] = record + controller1._save_training_run(training_run_id) + + # Simulate successful init - save signature + save_config_signature(config1) + + del controller1 + + # === Phase 2: Try to restart with different config === + config2 = _create_config_with_models(checkpoint_dir, ["ModelC"]) + config2.ensure_directories() + + # This should raise ConfigMismatchError BEFORE TrainingController is created + with pytest.raises(ConfigMismatchError): + validate_config_signature(config2) + + def test_refresh_persistence_allows_restart_with_new_config(self, setup): + """Test that after flush_all_data, a new config can be used.""" + _, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + config1 = _create_config_with_models(checkpoint_dir, ["ModelA"]) + config2 = _create_config_with_models(checkpoint_dir, ["ModelC"]) + + save_config_signature(config1) + + with pytest.raises(ConfigMismatchError): + validate_config_signature(config2) + + flush_all_data() + + is_fresh = validate_config_signature(config2) + assert is_fresh is True + + save_config_signature(config2) + + loaded = load_config_signature() + assert loaded is not None + models = loaded.config_data.get("supported_models", []) + model_names = [m["model_name"] for m in models] + assert model_names == ["ModelC"] + + def test_matches_with_custom_check_fields(self, setup): + """Test matches() with custom check_fields parameter.""" + _, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + # Same models, different telemetry + config1 = _create_config_with_models(checkpoint_dir, ["ModelA"], telemetry_enabled=False) + config2 = _create_config_with_models(checkpoint_dir, ["ModelA"], telemetry_enabled=True) + + sig1 = ConfigSignature.from_app_config(config1) + sig2 = ConfigSignature.from_app_config(config2) + + # With default check_fields (only model names), they should match + assert sig1.matches(sig2, check_fields=DEFAULT_CHECK_FIELDS) + + # With telemetry check included, they should NOT match + check_with_telemetry = [ + ConfigCheckField.SUPPORTED_MODELS, + ConfigCheckField.TELEMETRY, + ] + assert not sig1.matches(sig2, check_fields=check_with_telemetry) + + def test_get_diff_with_custom_check_fields(self, setup): + """Test get_diff() with custom check_fields parameter.""" + _, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + # Different models AND different telemetry + config1 = _create_config_with_models(checkpoint_dir, ["ModelA"], telemetry_enabled=False) + config2 = _create_config_with_models(checkpoint_dir, ["ModelB"], telemetry_enabled=True) + + sig1 = ConfigSignature.from_app_config(config1) + sig2 = ConfigSignature.from_app_config(config2) + + # With default check_fields, only model diff is returned + diff_default = sig1.get_diff(sig2, check_fields=DEFAULT_CHECK_FIELDS) + assert ConfigCheckField.SUPPORTED_MODELS in diff_default + assert ConfigCheckField.TELEMETRY not in diff_default + + # With telemetry check included, both diffs are returned + check_with_telemetry = [ + ConfigCheckField.SUPPORTED_MODELS, + ConfigCheckField.TELEMETRY, + ] + diff_full = sig1.get_diff(sig2, check_fields=check_with_telemetry) + assert ConfigCheckField.SUPPORTED_MODELS in diff_full + assert ConfigCheckField.TELEMETRY in diff_full + + # Check telemetry diff values (telemetry is a dict with 'enabled' field) + telemetry_diff = diff_full[ConfigCheckField.TELEMETRY] + assert telemetry_diff["current"]["enabled"] is False # current value + assert telemetry_diff["stored"]["enabled"] is True # stored value + + def test_model_names_always_checked(self, setup): + """Test that SUPPORTED_MODELS is always checked even if not in check_fields.""" + _, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + config1 = _create_config_with_models(checkpoint_dir, ["ModelA"]) + config2 = _create_config_with_models(checkpoint_dir, ["ModelB"]) + + sig1 = ConfigSignature.from_app_config(config1) + sig2 = ConfigSignature.from_app_config(config2) + + # Even with only telemetry in check_fields, model names should be checked + # because it's mandatory + check_only_telemetry = [ConfigCheckField.TELEMETRY] + assert not sig1.matches(sig2, check_fields=check_only_telemetry) + + diff = sig1.get_diff(sig2, check_fields=check_only_telemetry) + # Model names should still be in the diff + assert ConfigCheckField.SUPPORTED_MODELS in diff + + def test_validate_with_custom_check_fields(self, setup): + """Test validate_config_signature with custom check_fields from persistence config.""" + _, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + # Ensure clean state + flush_all_data() + + # First run with telemetry disabled - save signature after successful init + config1 = _create_config_with_models(checkpoint_dir, ["ModelA"], telemetry_enabled=False) + save_config_signature(config1) + + # Second run with same model but telemetry enabled (default check_fields) + config2_default = _create_config_with_models( + checkpoint_dir, ["ModelA"], telemetry_enabled=True + ) + + # With default check_fields, should pass (only model names checked) + validate_config_signature(config2_default) + + # With telemetry check in persistence config, should fail + check_with_telemetry = [ + ConfigCheckField.SUPPORTED_MODELS, + ConfigCheckField.TELEMETRY, + ] + config2_with_telemetry_check = _create_config_with_models( + checkpoint_dir, ["ModelA"], telemetry_enabled=True, check_fields=check_with_telemetry + ) + with pytest.raises(ConfigMismatchError): + validate_config_signature(config2_with_telemetry_check) + + def test_has_existing_data_detects_data(self, setup): + """Test that has_existing_data correctly detects existing data.""" + store, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + # Ensure clean state + flush_all_data() + + # Should be empty initially + assert has_existing_data() is False + + # Save some data + config = _create_config_with_models(checkpoint_dir, ["ModelA"]) + save_config_signature(config) + + # Now should have data + assert has_existing_data() is True + + # Flush and check again + flush_all_data() + assert has_existing_data() is False + + def test_validate_detects_corrupted_state(self, setup): + """Test that validation detects when data exists but no signature.""" + store, tmp_path = setup + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + + # Ensure clean state + flush_all_data() + + # Manually add some data without signature (simulating corrupted state) + store.set(store.build_key("session", "test-session"), '{"test": "data"}') + + # Now validation should fail with corrupted state error + config = _create_config_with_models(checkpoint_dir, ["ModelA"]) + with pytest.raises(ConfigMismatchError) as exc_info: + validate_config_signature(config) + + error = exc_info.value + assert "_state" in error.diff + assert "missing signature" in str(error).lower() or "corrupted" in str(error).lower() From 19505b34641cc9fd08c803d1842df22f50642557 Mon Sep 17 00:00:00 2001 From: xuanrui-L Date: Fri, 30 Jan 2026 11:50:01 +0800 Subject: [PATCH 2/5] change launch tips from 'tuft' to 'tuft launch' --- README.md | 6 +++--- config/tuft_config.example.yaml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 9cb9114..8d15c56 100644 --- a/README.md +++ b/README.md @@ -238,7 +238,7 @@ uv pip install "tuft[dev,backend,persistence]" The CLI starts a FastAPI server: ```bash -tuft --port 10610 --config /path/to/tuft_config.yaml +tuft launch --port 10610 --config /path/to/tuft_config.yaml ``` The config file `tuft_config.yaml` specifies server settings including available base models, authentication, persistence, and telemetry. Below is a minimal example. @@ -278,7 +278,7 @@ you can use the pre-built Docker image. -p 10610:10610 \ -v :/data \ ghcr.io/agentscope-ai/tuft:latest \ - tuft --port 10610 --config /data/tuft_config.yaml + tuft launch --port 10610 --config /data/tuft_config.yaml ``` Please replace `` with a directory on your host machine where you want to store model checkpoints and other data. @@ -400,7 +400,7 @@ Available check fields: `SUPPORTED_MODELS`, `CHECKPOINT_DIR`, `MODEL_OWNER`, `TO If a mismatch is detected, use `--refresh-persistence` to clear existing data and start fresh: ```bash -tuft --config config.yaml --refresh-persistence +tuft launch --config config.yaml --refresh-persistence ``` Use `--force-refresh-persistence` to skip the confirmation prompt. diff --git a/config/tuft_config.example.yaml b/config/tuft_config.example.yaml index aa0d4a0..817825d 100644 --- a/config/tuft_config.example.yaml +++ b/config/tuft_config.example.yaml @@ -4,7 +4,7 @@ # Copy this file to your desired location and modify as needed. # # Usage: -# tuft --config /path/to/your/tuft_config.yaml +# tuft launch --config /path/to/your/tuft_config.yaml # ============================================================================= # Checkpoint Directory From b6f480dc885a51fdb05e2eaeafcb45342f6360c0 Mon Sep 17 00:00:00 2001 From: xuanrui-L Date: Fri, 30 Jan 2026 12:42:12 +0800 Subject: [PATCH 3/5] update github workflow timeout to 1800s --- .github/workflows/unittest.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 2bb2c54..aa2a163 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -31,8 +31,8 @@ jobs: - name: Check ray status working-directory: tuft-${{ github.run_id }}/.github/workflows/docker run: | - MAX_RETRIES=20 - RETRY_INTERVAL=5 + MAX_RETRIES=60 + RETRY_INTERVAL=30 for i in $(seq 1 $MAX_RETRIES); do if docker compose exec tuft-node-1 bash -c "source /root/.tuft/venv/bin/activate && ray status"; then break From acba52ce516c15393bffe499d120e60826f5693f Mon Sep 17 00:00:00 2001 From: xuanrui-L Date: Fri, 30 Jan 2026 13:12:00 +0800 Subject: [PATCH 4/5] fix comment --- src/tuft/cli.py | 5 ++--- src/tuft/config.py | 2 +- src/tuft/exceptions.py | 7 ------- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/tuft/cli.py b/src/tuft/cli.py index 3ca4e0f..631e3a0 100644 --- a/src/tuft/cli.py +++ b/src/tuft/cli.py @@ -4,7 +4,6 @@ import logging import os -import sys from pathlib import Path import typer @@ -178,12 +177,12 @@ def _validate_persistence_config( validate_config_signature(config) except ConfigMismatchError as e: typer.secho( - "\n🚫 FATAL ERROR: Configuration Mismatch Detected 🚫", + "\n 🚫 FATAL ERROR: Configuration Mismatch Detected 🚫", fg=typer.colors.RED, bold=True, ) typer.echo(f"\n{e}\n") - sys.exit(1) + raise typer.Exit(1) from e def _init_telemetry(config: AppConfig, log_level: str) -> None: diff --git a/src/tuft/config.py b/src/tuft/config.py index a5f3360..58c570d 100644 --- a/src/tuft/config.py +++ b/src/tuft/config.py @@ -99,7 +99,7 @@ def check_validity(self) -> None: def with_supported_models(self, models: Iterable[ModelConfig]) -> "AppConfig": updated = list(models) if updated: - self.supported_models = list(updated) + self.supported_models = updated return self def get_config_for_persistence(self) -> dict[str, Any]: diff --git a/src/tuft/exceptions.py b/src/tuft/exceptions.py index 18981a6..05ea033 100644 --- a/src/tuft/exceptions.py +++ b/src/tuft/exceptions.py @@ -161,18 +161,11 @@ def __init__( # Build detailed diff message diff_parts = [] for field_name, field_diff in diff.items(): - # Handle list fields (added/removed) - added = field_diff.get("added") - removed = field_diff.get("removed") # Handle scalar fields (current/stored) current = field_diff.get("current") stored = field_diff.get("stored") parts = [] - if added is not None: - parts.append(f"added: {added}") - if removed is not None: - parts.append(f"removed: {removed}") if current is not None or stored is not None: parts.append(f"current: {current}, stored: {stored}") From 45e5f02ab28050c62bdb6f130c3a5ddba0f4480b Mon Sep 17 00:00:00 2001 From: xuanrui-L Date: Mon, 2 Feb 2026 11:02:27 +0800 Subject: [PATCH 5/5] fix comment: - Changed persistence mode names in README and configuration files from `REDIS_URL` to `REDIS` and `FILE_REDIS` to `FILE` for consistency. - Updated CLI commands to replace `--refresh-persistence` with `tuft clear persistence` for clearing existing data. - Adjusted default namespace in persistence configuration to `persistence-tuft-server`. - Enhanced documentation to reflect these changes and improve clarity on usage. --- README.md | 24 +++++---- config/tuft_config.example.yaml | 19 +++---- scripts/install.sh | 4 +- src/tuft/cli.py | 74 +++++++++++++++------------ src/tuft/exceptions.py | 3 +- src/tuft/persistence/redis_store.py | 59 +++++++++++---------- tests/conftest.py | 4 +- tests/test_integration_persistence.py | 2 +- tests/test_persistence.py | 2 +- 9 files changed, 104 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index 8d15c56..190d54d 100644 --- a/README.md +++ b/README.md @@ -332,8 +332,8 @@ TuFT provides three persistence modes: | Mode | Description | Use Case | |------|-------------|----------| | `DISABLE` | No persistence, data in-memory only | Development, testing without state recovery | -| `REDIS_URL` | External Redis server | Production, multi-instance deployments | -| `FILE_REDIS` | File-backed store | Demos, small-scale testing | +| `REDIS` | External Redis server | Production, multi-instance deployments | +| `FILE` | File-backed store | Demos, small-scale testing | ### Configuration @@ -356,9 +356,9 @@ Use an external Redis server for production deployments: ```yaml # tuft_config.yaml persistence: - mode: REDIS_URL + mode: REDIS redis_url: "redis://localhost:6379/0" - namespace: "tuft" # Default: "tuft". + namespace: "persistence-tuft-server" # Default: "persistence-tuft-server". ``` You can start a local Redis instance using Docker: @@ -374,9 +374,9 @@ Use the file-backed store for demos or small-scale testing: ```yaml # tuft_config.yaml persistence: - mode: FILE_REDIS + mode: FILE file_path: "~/.cache/tuft/file_redis.json" - namespace: "tuft" # Default: "tuft" + namespace: "persistence-tuft-server" # Default: "persistence-tuft-server" ``` ### Configuration Validation @@ -387,7 +387,7 @@ You can configure which fields to validate: ```yaml persistence: - mode: REDIS_URL + mode: REDIS redis_url: "redis://localhost:6379/0" check_fields: # Default: ["SUPPORTED_MODELS"] - SUPPORTED_MODELS # Always checked (mandatory) @@ -397,13 +397,17 @@ persistence: Available check fields: `SUPPORTED_MODELS`, `CHECKPOINT_DIR`, `MODEL_OWNER`, `TOY_BACKEND_SEED`, `AUTHORIZED_USERS`, `TELEMETRY`. -If a mismatch is detected, use `--refresh-persistence` to clear existing data and start fresh: +If a mismatch is detected, use `tuft clear persistence` to clear existing data and start fresh: ```bash -tuft launch --config config.yaml --refresh-persistence +tuft clear persistence --config /path/to/tuft_config.yaml ``` -Use `--force-refresh-persistence` to skip the confirmation prompt. +Use `--force` or `-f` to skip the confirmation prompt: + +```bash +tuft clear persistence --config /path/to/tuft_config.yaml --force +``` ## Observability (OpenTelemetry) diff --git a/config/tuft_config.example.yaml b/config/tuft_config.example.yaml index 817825d..f920993 100644 --- a/config/tuft_config.example.yaml +++ b/config/tuft_config.example.yaml @@ -79,28 +79,29 @@ authorized_users: # Persistence Configuration # ============================================================================= # Configure state persistence for recovery after server restart. +# For detailed documentation, see the "Persistence" section in README.md. # # Available modes: # - DISABLE: No persistence (default) -# - REDIS_URL: External Redis server -# - FILE_REDIS: File-backed store +# - REDIS: External Redis server +# - FILE: File-backed store persistence: - mode: DISABLE # Options: DISABLE, REDIS_URL, FILE_REDIS + mode: DISABLE # Options: DISABLE, REDIS, FILE - # For REDIS_URL mode: + # For REDIS mode: # redis_url: "redis://localhost:6379/0" - # For FILE_REDIS mode: + # For FILE mode: # file_path: "~/.cache/tuft/file_redis.json" - # Namespace prefix for Redis keys. (optional, defaults to "tuft".) - # namespace: "tuft" + # Namespace prefix for Redis keys. (optional, defaults to "persistence-tuft-server".) + # namespace: "persistence-tuft-server" # Fields to validate on server restart for config consistency. + # For detailed documentation on available fields and config validation, + # see the "Configuration Validation" section in README.md. # Defaults to ["SUPPORTED_MODELS"]. SUPPORTED_MODELS is always checked. - # Available fields: SUPPORTED_MODELS, CHECKPOINT_DIR, MODEL_OWNER, - # TOY_BACKEND_SEED, AUTHORIZED_USERS, TELEMETRY. # check_fields: # - SUPPORTED_MODELS # - CHECKPOINT_DIR diff --git a/scripts/install.sh b/scripts/install.sh index e6668e9..32cfce6 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -409,9 +409,9 @@ authorized_users: # Optional: Persistence configuration # persistence: -# mode: DISABLE # Options: DISABLE, REDIS_URL, FILE_REDIS +# mode: DISABLE # Options: DISABLE, REDIS, FILE # redis_url: "redis://localhost:6379/0" -# namespace: "tuft" +# namespace: "persistence-tuft-server" CONFIG_EOF fi } diff --git a/src/tuft/cli.py b/src/tuft/cli.py index 631e3a0..15677ae 100644 --- a/src/tuft/cli.py +++ b/src/tuft/cli.py @@ -23,6 +23,8 @@ app = typer.Typer(help="TuFT - Tenant-unified Fine-Tuning Server.", no_args_is_help=True) +clear_app = typer.Typer(help="Clear data commands.", no_args_is_help=True) +app.add_typer(clear_app, name="clear") # Required for Typer to recognize subcommands when using no_args_is_help=True @@ -69,21 +71,6 @@ def _resolve_config_path(config_path: Path | None) -> Path: ) -_REFRESH_PERSISTENCE_OPTION = typer.Option( - False, - "--refresh-persistence", - help=( - "Clear all existing persistence data and start fresh. " - "Use when config has changed and you want to discard old data." - ), -) -_FORCE_REFRESH_PERSISTENCE_OPTION = typer.Option( - False, - "--force-refresh-persistence", - help="Skip confirmation prompts when using --refresh-persistence.", -) - - def _build_config( config_path: Path | None, checkpoint_dir: Path | None, @@ -101,22 +88,52 @@ def _build_config( return config -def _handle_refresh_persistence(force_refresh: bool) -> None: - """Handle the --refresh-persistence flag. +_FORCE_OPTION = typer.Option( + False, + "--force", + "-f", + help="Skip confirmation prompts when clearing persistence data.", +) + + +@clear_app.command(name="persistence") +def clear_persistence( + config_path: Path | None = _CONFIG_OPTION, + force: bool = _FORCE_OPTION, +) -> None: + """Clear all persistence data and start fresh. - Prompts for confirmation unless --force-refresh is provided, - then clears all persistence data in the current namespace. + This command clears all existing persistence data in the configured namespace. + Use this when the configuration has changed and you want to discard old data. """ + # Build config to get persistence settings + try: + resolved_config_path = _resolve_config_path(config_path) + config = load_yaml_config(resolved_config_path) + except typer.BadParameter as e: + typer.secho(f"Error: {e}", fg=typer.colors.RED) + raise typer.Exit(1) from e + + if not config.persistence.enabled: + typer.secho( + "Persistence is disabled in the configuration. Nothing to clear.", + fg=typer.colors.YELLOW, + ) + raise typer.Exit(0) + + # Configure the store + store = get_redis_store() + store.configure(config.persistence) namespace = get_current_namespace() - if not force_refresh: + if not force: typer.secho( "\n🚨🚨🚨 CRITICAL WARNING 🚨🚨🚨\n", fg=typer.colors.RED, bold=True, ) typer.secho( - "--refresh-persistence will PERMANENTLY DELETE ALL persistence data!\n", + "This command will PERMANENTLY DELETE ALL persistence data!\n", fg=typer.colors.RED, bold=True, ) @@ -151,15 +168,12 @@ def _handle_refresh_persistence(force_refresh: bool) -> None: f"āœ… Cleared {deleted_count} keys from namespace '{cleared_namespace}'.", fg=typer.colors.GREEN, ) - typer.echo("Server will start with fresh state.\n") + typer.echo("Persistence data has been cleared. You can now start the server fresh.") -def _validate_persistence_config( - config: AppConfig, refresh_persistence: bool, force_refresh_persistence: bool -) -> None: +def _validate_persistence_config(config: AppConfig) -> None: """Validate that persistence config matches stored config. - If refresh_persistence is True, clears existing data instead of validating. If config mismatch is detected, exits with an error message. """ if not config.persistence.enabled: @@ -169,10 +183,6 @@ def _validate_persistence_config( store = get_redis_store() store.configure(config.persistence) - if refresh_persistence: - _handle_refresh_persistence(force_refresh_persistence) - return - try: validate_config_signature(config) except ConfigMismatchError as e: @@ -207,14 +217,12 @@ def launch( reload: bool = _RELOAD_OPTION, config_path: Path | None = _CONFIG_OPTION, checkpoint_dir: Path | None = _CHECKPOINT_DIR_OPTION, - refresh_persistence: bool = _REFRESH_PERSISTENCE_OPTION, - force_refresh_persistence: bool = _FORCE_REFRESH_PERSISTENCE_OPTION, ) -> None: """Launch the TuFT server.""" app_config = _build_config(config_path, checkpoint_dir) # Validate persistence configuration before starting - _validate_persistence_config(app_config, refresh_persistence, force_refresh_persistence) + _validate_persistence_config(app_config) # Initialize telemetry before starting the server _init_telemetry(app_config, log_level) diff --git a/src/tuft/exceptions.py b/src/tuft/exceptions.py index 05ea033..eb4b495 100644 --- a/src/tuft/exceptions.py +++ b/src/tuft/exceptions.py @@ -180,7 +180,8 @@ def __init__( "This can cause data corruption when restoring persisted state.\n\n" "Options:\n" " 1. Use a different Redis database (change redis_url in config)\n" - " 2. Use --refresh-persistence to clear existing data and start fresh\n" + " 2. Run `tuft clear persistence -c ` to clear existing data\n" + " Use `--force` or `-f` to skip confirmation prompt.\n" " (WARNING: This will delete all persisted sessions, training runs, etc.)\n" " 3. Restore the original configuration that matches the stored data" ) diff --git a/src/tuft/persistence/redis_store.py b/src/tuft/persistence/redis_store.py index ac05f85..0fe8b87 100644 --- a/src/tuft/persistence/redis_store.py +++ b/src/tuft/persistence/redis_store.py @@ -9,13 +9,13 @@ Persistence Modes: - DISABLE: No persistence, all data is in-memory only -- REDIS_URL: Use external Redis server via URL -- FILE_REDIS: Use file-backed storage for tests and demos +- REDIS: Use external Redis server via URL +- FILE: Use file-backed storage for tests and demos Config Validation: - On startup, the current config signature is compared with the stored signature - If mismatch is detected, server stops with an error message -- Use --refresh-persistence to override and clear existing data +- Use `tuft clear persistence` to override and clear existing data """ from __future__ import annotations @@ -55,9 +55,9 @@ def _get_metrics(): class PersistenceMode(str, Enum): """Persistence mode options.""" - DISABLE = "disabled" # No persistence - REDIS_URL = "redis_url" # Use external Redis server - FILE_REDIS = "file_redis" # Use file-backed storage for tests/demos + DISABLE = "DISABLE" # No persistence + REDIS = "REDIS" # Use external Redis server + FILE = "FILE" # Use file-backed storage for tests/demos # Default TTL values in seconds @@ -71,12 +71,12 @@ class ConfigCheckField: SUPPORTED_MODELS is always required (mandatory) for restore safety. """ - SUPPORTED_MODELS = "supported_models" - CHECKPOINT_DIR = "checkpoint_dir" - MODEL_OWNER = "model_owner" - TOY_BACKEND_SEED = "toy_backend_seed" - AUTHORIZED_USERS = "authorized_users" - TELEMETRY = "telemetry" + SUPPORTED_MODELS = "SUPPORTED_MODELS" + CHECKPOINT_DIR = "CHECKPOINT_DIR" + MODEL_OWNER = "MODEL_OWNER" + TOY_BACKEND_SEED = "TOY_BACKEND_SEED" + AUTHORIZED_USERS = "AUTHORIZED_USERS" + TELEMETRY = "TELEMETRY" # Default fields to check (supported_models is mandatory) @@ -87,10 +87,10 @@ class PersistenceConfig(BaseModel): """Configuration for Redis persistence. Attributes: - mode: Persistence mode - DISABLE, REDIS_URL, or FILE_REDIS - redis_url: Redis server URL (only used when mode=REDIS_URL) - file_path: JSON file path (only used when mode=FILE_REDIS) - namespace: Key namespace prefix for Redis keys. Defaults to "tuft". + mode: Persistence mode - DISABLE, REDIS, or FILE + redis_url: Redis server URL (only used when mode=REDIS) + file_path: JSON file path (only used when mode=FILE) + namespace: Key namespace prefix for Redis keys. Defaults to "persistence-tuft-server". future_ttl_seconds: TTL for future records in seconds. None means no expiry. check_fields: List of AppConfig fields to validate on restart. Defaults to ["SUPPORTED_MODELS"]. SUPPORTED_MODELS is always @@ -105,7 +105,7 @@ class PersistenceConfig(BaseModel): mode: PersistenceMode = PersistenceMode.DISABLE redis_url: str = "redis://localhost:6379/0" file_path: Path | None = None - namespace: str = "tuft" # Default namespace for Redis keys + namespace: str = "persistence-tuft-server" # Default namespace for Redis keys future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS check_fields: list[str] = Field(default_factory=lambda: DEFAULT_CHECK_FIELDS.copy()) @@ -125,13 +125,13 @@ def get_check_fields(self) -> list[str]: def from_redis_url( cls, redis_url: str, - namespace: str = "tuft", + namespace: str = "persistence-tuft-server", future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS, check_fields: list[str] | None = None, ) -> "PersistenceConfig": """Create a config using external Redis server.""" return cls( - mode=PersistenceMode.REDIS_URL, + mode=PersistenceMode.REDIS, redis_url=redis_url, namespace=namespace, future_ttl_seconds=future_ttl_seconds, @@ -139,16 +139,16 @@ def from_redis_url( ) @classmethod - def from_file_redis( + def from_file( cls, file_path: Path | None = None, - namespace: str = "tuft", + namespace: str = "persistence-tuft-server", future_ttl_seconds: int | None = DEFAULT_FUTURE_TTL_SECONDS, check_fields: list[str] | None = None, ) -> "PersistenceConfig": """Create a config using file-backed storage.""" return cls( - mode=PersistenceMode.FILE_REDIS, + mode=PersistenceMode.FILE, file_path=file_path, namespace=namespace, future_ttl_seconds=future_ttl_seconds, @@ -204,7 +204,7 @@ def _get_redis(self) -> Any: if self._redis is None or self._pid != current_pid: self._close_connections() - if self._config.mode in (PersistenceMode.REDIS_URL, PersistenceMode.FILE_REDIS): + if self._config.mode in (PersistenceMode.REDIS, PersistenceMode.FILE): logger.info("Redis connection begin") self._redis = self._create_redis_client() @@ -219,7 +219,7 @@ def _create_redis_client(self) -> Any: if self._config is None: return None try: - if self._config.mode == PersistenceMode.FILE_REDIS: + if self._config.mode == PersistenceMode.FILE: from .file_redis import FileRedis file_path = self._config.file_path or ( @@ -239,7 +239,7 @@ def is_enabled(self) -> bool: @property def namespace(self) -> str: - return self._config.namespace if self._config else "tuft" + return self._config.namespace if self._config else "persistence-tuft-server" @property def future_ttl(self) -> int | None: @@ -542,19 +542,22 @@ class ConfigSignature(BaseModel): # Metadata created_at: datetime = Field(default_factory=_now) - namespace: str = "tuft" + namespace: str = "persistence-tuft-server" @classmethod def from_app_config(cls, config: Any) -> "ConfigSignature": """Create a signature by serializing the AppConfig.""" # Use the method on AppConfig to get persistence-safe data config_data = config.get_config_for_persistence() - namespace = config.persistence.namespace if config.persistence else "tuft" + namespace = ( + config.persistence.namespace if config.persistence else "persistence-tuft-server" + ) return cls(config_data=config_data, namespace=namespace) def _get_field_value(self, field_name: str) -> Any: """Get the value of a field by name.""" - return self.config_data.get(field_name) + lowercase_field = field_name.lower() + return self.config_data.get(lowercase_field) def _normalize_for_comparison(self, value: Any) -> Any: if isinstance(value, list): diff --git a/tests/conftest.py b/tests/conftest.py index ae75ab0..fae48d0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -110,7 +110,7 @@ def configure_persistence(request): stacklevel=2, ) store.configure( - PersistenceConfig.from_file_redis( + PersistenceConfig.from_file( file_path=file_path, namespace="tuft_test", ) @@ -161,7 +161,7 @@ def enable_persistence(request): test_name = request.node.name file_path = _get_file_redis_path(f"enable_persistence_{test_name}") store.configure( - PersistenceConfig.from_file_redis( + PersistenceConfig.from_file( file_path=file_path, namespace="tuft_test", ) diff --git a/tests/test_integration_persistence.py b/tests/test_integration_persistence.py index aa12395..616f159 100644 --- a/tests/test_integration_persistence.py +++ b/tests/test_integration_persistence.py @@ -85,7 +85,7 @@ def test_checkpoint_resume_persistence(tmp_path: Path) -> None: config.authorized_users = { "tml-test-key": "default", } - config.persistence = PersistenceConfig.from_file_redis( + config.persistence = PersistenceConfig.from_file( file_path=file_redis_path, namespace="tuft_test", ) diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 547a17a..b38c479 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -767,7 +767,7 @@ def _create_config_with_models( ], telemetry=TelemetryConfig(enabled=telemetry_enabled), persistence=PersistenceConfig( - mode=PersistenceMode.FILE_REDIS, + mode=PersistenceMode.FILE, check_fields=check_fields if check_fields is not None else DEFAULT_CHECK_FIELDS.copy(), ), )