diff --git a/.gitignore b/.gitignore index ccf7a96..72f16ad 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,21 @@ __pycache__/ *.py[cod] *.pyo *.pyd -*.egg-info/ \ No newline at end of file +*.egg-info/ +_DS_Store +*.bak +*.log +*.log.* +*.log.*.* +*.log.*.*.* +*.log.*.*.*.* +*.log.*.*.*.*.* +DS_Store +.DS_Store +.DS_Store.* +.DS_Store.*.* +.DS_Store.*.*.* +.DS_Store.*.*.*.* +.DS_Store.*.*.*.*.* +.DS_Store.*.*.*.*.*.* +.DS_Store.*.*.*.*.*.*.* \ No newline at end of file diff --git a/README.md b/README.md index e42dba9..d48f53b 100644 --- a/README.md +++ b/README.md @@ -12,4 +12,15 @@ pip install deforum ```bash python -m build python -m twine upload dist/* +``` + +## Structure +``` + src/ + ├── cli/ # Interface layer CLI + │ └── main.py + └── deforum/ # Core library + ├── config/ # Configuration management (settings, validation, etc.) + ├── core/ # Core shared utilities (exceptions, logging, etc.) + └── utils/ # Utility functions (file handling, etc.) ``` \ No newline at end of file diff --git a/src/cli/__init__.py b/src/cli/__init__.py new file mode 100644 index 0000000..40c6d25 --- /dev/null +++ b/src/cli/__init__.py @@ -0,0 +1 @@ +"""Package initialization.""" diff --git a/src/cli/main.py b/src/cli/main.py new file mode 100644 index 0000000..bee0bb1 --- /dev/null +++ b/src/cli/main.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +""" +Main CLI for Deforum Flux + +""" + +import argparse +import sys +import os +import json +import time +from pathlib import Path +from typing import Dict, Any, Optional, List + +from deforum.config.settings import Config, get_preset, DeforumConfig +from deforum_flux.bridge import FluxDeforumBridge +from deforum.core.logging_config import setup_logging +from deforum.core.exceptions import DeforumException +from deforum.utils.file_utils import FileUtils +from deforum.utils.validation import InputValidator +from .parameter_adapter import FluxDeforumParameterAdapter + + +class FluxDeforumCLI: + """Main CLI class for Flux-Deforum integration.""" + + def __init__(self): + """Initialize the CLI.""" + self.adapter = FluxDeforumParameterAdapter() + self.validator = InputValidator() + self.file_utils = FileUtils() + + def create_parser(self) -> argparse.ArgumentParser: + """Create argument parser.""" + parser = argparse.ArgumentParser( + description="Flux + Deforum Animation CLI - Production Ready", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic animation + %(prog)s "a serene mountain landscape" --frames 10 + + # Custom motion + %(prog)s "cosmic nebula" --frames 20 --zoom "0:(1.0), 10:(1.3), 20:(1.0)" --angle "0:(0), 20:(15)" + + # Use configuration file + %(prog)s --config config.json + + # Test mode (no Flux required) + %(prog)s --test --frames 5 + """ + ) + + # Basic parameters + parser.add_argument( + "prompt", + nargs="?", + default="a beautiful landscape with gentle motion", + help="Text prompt for generation" + ) + + parser.add_argument( + "--config", + type=str, + help="Configuration file path" + ) + + parser.add_argument( + "--preset", + choices=["fast", "quality", "balanced", "production"], + default="balanced", + help="Configuration preset" + ) + + # Generation parameters + parser.add_argument( + "--frames", + type=int, + default=10, + help="Number of frames to generate" + ) + + parser.add_argument( + "--width", + type=int, + default=1024, + help="Image width" + ) + + parser.add_argument( + "--height", + type=int, + default=1024, + help="Image height" + ) + + parser.add_argument( + "--steps", + type=int, + help="Generation steps (overrides preset)" + ) + + parser.add_argument( + "--guidance", + type=float, + help="Guidance scale (overrides preset)" + ) + + parser.add_argument( + "--seed", + type=int, + help="Random seed" + ) + + # Motion parameters + parser.add_argument( + "--zoom", + default="0:(1.0)", + help="Zoom schedule (e.g., '0:(1.0), 10:(1.2)')" + ) + + parser.add_argument( + "--angle", + default="0:(0)", + help="Rotation angle schedule" + ) + + parser.add_argument( + "--translation-x", + default="0:(0)", + help="X translation schedule" + ) + + parser.add_argument( + "--translation-y", + default="0:(0)", + help="Y translation schedule" + ) + + # Output options + parser.add_argument( + "--output", + default="./outputs", + help="Output directory" + ) + + parser.add_argument( + "--prefix", + default="frame", + help="Output filename prefix" + ) + + parser.add_argument( + "--format", + choices=["png", "jpg", "jpeg"], + default="png", + help="Output image format" + ) + + parser.add_argument( + "--video", + action="store_true", + help="Create video from frames" + ) + + parser.add_argument( + "--fps", + type=int, + default=24, + help="Video FPS" + ) + + # System options + parser.add_argument( + "--device", + choices=["cuda", "cpu", "mps"], + default="cuda", + help="Device to use" + ) + + parser.add_argument( + "--model", + choices=["flux-schnell", "flux-dev"], + help="Flux model to use (overrides preset)" + ) + + parser.add_argument( + "--test", + action="store_true", + help="Test mode (generate without Flux)" + ) + + parser.add_argument( + "--verbose", + action="store_true", + help="Verbose logging" + ) + + parser.add_argument( + "--log-file", + help="Log file path" + ) + + return parser + + def validate_args(self, args: argparse.Namespace) -> None: + """Validate command line arguments.""" + # Validate prompt + if not args.test: + self.validator.validate_prompt(args.prompt) + + # Validate dimensions + self.validator.validate_dimensions(args.width, args.height) + + # Validate generation parameters + steps = args.steps or 20 + guidance = args.guidance or 7.5 + self.validator.validate_generation_params(steps, guidance, args.seed) + + # Validate device + self.validator.validate_device_string(args.device) + + # Validate frame count + if args.frames <= 0 or args.frames > 1000: + raise ValueError(f"Frame count must be between 1 and 1000, got {args.frames}") + + def create_config_from_args(self, args: argparse.Namespace) -> Config: + """Create configuration from command line arguments.""" + # Start with preset + if args.config: + config = Config.from_file(args.config) + else: + config = get_preset(args.preset) + + # Override with command line arguments + overrides = {} + + if args.model: + overrides["model_name"] = args.model + if args.device: + overrides["device"] = args.device + if args.width: + overrides["width"] = args.width + if args.height: + overrides["height"] = args.height + if args.steps: + overrides["steps"] = args.steps + if args.guidance: + overrides["guidance_scale"] = args.guidance + if args.frames: + overrides["max_frames"] = args.frames + if args.output: + overrides["output_dir"] = args.output + if args.log_file: + overrides["log_file"] = args.log_file + if args.verbose: + overrides["log_level"] = "DEBUG" + + return config.update(**overrides) if overrides else config + + def create_deforum_config(self, args: argparse.Namespace) -> DeforumConfig: + """Create Deforum configuration from arguments.""" + return DeforumConfig( + max_frames=args.frames, + zoom=args.zoom, + angle=args.angle, + translation_x=args.translation_x, + translation_y=args.translation_y, + positive_prompts={"0": args.prompt} + ) + + def generate_test_animation(self, args: argparse.Namespace) -> List[str]: + """Generate test animation without Flux.""" + print("🧪 Running in test mode (no Flux required)") + + import numpy as np + from PIL import Image + + output_dir = Path(args.output) + self.file_utils.ensure_directory(output_dir) + + frames = [] + + for i in range(args.frames): + # Create test image with motion + width, height = args.width, args.height + + # Generate colorful test pattern + x = np.linspace(0, 4*np.pi, width) + y = np.linspace(0, 4*np.pi, height) + X, Y = np.meshgrid(x, y) + + # Add time-based animation + t = i / args.frames * 2 * np.pi + pattern = np.sin(X + t) * np.cos(Y + t/2) + np.sin(X/2 + t/3) * np.cos(Y/3) + + # Normalize and convert to RGB + pattern = (pattern + 2) / 4 # Normalize to [0, 1] + rgb = np.stack([ + pattern, + np.roll(pattern, width//4, axis=1), + np.roll(pattern, -width//4, axis=1) + ], axis=2) + + rgb = (rgb * 255).astype(np.uint8) + + # Save frame + image = Image.fromarray(rgb) + filename = f"{args.prefix}_{i:04d}.{args.format}" + filepath = output_dir / filename + image.save(filepath) + frames.append(str(filepath)) + + print(f"Generated test frame {i+1}/{args.frames}") + + return frames + + def run(self, args: List[str] = None) -> int: + """Run the CLI.""" + parser = self.create_parser() + args = parser.parse_args(args) + + try: + # Setup logging + setup_logging( + level="DEBUG" if args.verbose else "INFO", + console_output=True, + log_file=args.log_file, + structured_logging=False + ) + + print("🎬 Flux + Deforum CLI - Production Ready") + print("=" * 50) + + # Validate arguments + self.validate_args(args) + + # Handle test mode + if args.test: + frames = self.generate_test_animation(args) + print(f"\n ++[√]++ Test animation completed!") + print(f"Generated {len(frames)} frames in {args.output}") + + if args.video: + try: + video_path = Path(args.output) / "animation.mp4" + self.file_utils.create_video_from_frames( + args.output, video_path, args.fps + ) + print(f"Video saved: {video_path}") + except Exception as e: + print(f"⚠️ Video creation failed: {e}") + + return 0 + + # Create configuration + config = self.create_config_from_args(args) + deforum_config = self.create_deforum_config(args) + + print(f"Configuration:") + print(f" Model: {config.model_name}") + print(f" Device: {config.device}") + print(f" Resolution: {config.width}x{config.height}") + print(f" Steps: {config.steps}") + print(f" Frames: {config.max_frames}") + + # Initialize bridge + print(f"\n🔧 Initializing Flux-Deforum Bridge...") + bridge = FluxDeforumBridge(config) + + # Create animation configuration + motion_schedule = deforum_config.to_motion_schedule() + + animation_config = { + "prompt": args.prompt, + "max_frames": args.frames, + "width": config.width, + "height": config.height, + "steps": config.steps, + "guidance_scale": config.guidance_scale, + "motion_schedule": motion_schedule, + "seed": args.seed + } + + print(f"\n🎥 Generating animation...") + print(f" Prompt: {args.prompt}") + print(f" Motion: zoom={args.zoom}, angle={args.angle}") + + # Generate animation + start_time = time.time() + frames = bridge.generate_animation(animation_config) + generation_time = time.time() - start_time + + # Save frames + print(f"\n💾 Saving frames...") + saved_files = self.file_utils.save_animation_frames( + frames, args.output, args.prefix, args.format + ) + + print(f"\n ++[√]++ Animation completed!") + print(f" Generated {len(frames)} frames in {generation_time:.2f}s") + print(f" Average time per frame: {generation_time/len(frames):.2f}s") + print(f" Output directory: {args.output}") + + # Create video if requested + if args.video: + try: + video_path = Path(args.output) / "animation.mp4" + self.file_utils.create_video_from_frames( + args.output, video_path, args.fps + ) + print(f" Video saved: {video_path}") + except Exception as e: + print(f"⚠️ Video creation failed: {e}") + + # Save configuration for reference + config_path = Path(args.output) / "config.json" + config_data = { + "config": config.to_dict(), + "animation_config": animation_config, + "generation_stats": bridge.get_stats() + } + self.file_utils.save_config(config_data, config_path) + print(f" Configuration saved: {config_path}") + + # Cleanup + bridge.cleanup() + + return 0 + + except DeforumException as e: + print(f"\n==[X]== Deforum error: {e}") + if hasattr(e, 'details') and e.details: + print(f"Details: {e.details}") + return 1 + + except KeyboardInterrupt: + print(f"\n⚠️ Interrupted by user") + return 130 + + except Exception as e: + print(f"\n==[X]== Unexpected error: {e}") + if args.verbose: + import traceback + traceback.print_exc() + return 1 + + +def main(): + """Main entry point.""" + cli = FluxDeforumCLI() + return cli.run() + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/src/deforum/__init__.py b/src/deforum/__init__.py index 49f34f4..552d273 100644 --- a/src/deforum/__init__.py +++ b/src/deforum/__init__.py @@ -1 +1,12 @@ -__version__ = "0.2.0" \ No newline at end of file +""" +Deforum Core - CLI and Configuration + +This package provides the core CLI interface and configuration +for the Deforum ecosystem. Animation logic has been moved to deforum_flux. +""" + +from .config import Config +from .core import exceptions, logging_config + +__version__ = "0.2.0" +__all__ = ["Config", "exceptions", "logging_config"] diff --git a/src/deforum/config/__init__.py b/src/deforum/config/__init__.py new file mode 100644 index 0000000..e52e564 --- /dev/null +++ b/src/deforum/config/__init__.py @@ -0,0 +1 @@ +from .settings import Config, get_preset diff --git a/src/deforum/config/settings.py b/src/deforum/config/settings.py new file mode 100644 index 0000000..4db6e75 --- /dev/null +++ b/src/deforum/config/settings.py @@ -0,0 +1,351 @@ +""" +Centralized Configuration Management for Deforum Flux + +This module provides a unified configuration system that merges the previous +Config and DeforumConfig classes, eliminating duplication and providing a +single source of truth for all configuration settings. + +""" + +import os +import json +from dataclasses import dataclass, field +from typing import Dict, Any, Optional, List +from pathlib import Path +import torch + +# Unified Deforum configuration - merges Config and DeforumConfig + + +@dataclass +class Config: + """ + Unified configuration class for Deforum Flux backend. + + This class merges the previous Config and DeforumConfig classes to eliminate + duplication and provide a single, comprehensive configuration system. + + Configuration is organized into logical groups: + - Core Settings: Basic system configuration + - Generation Settings: Image/video generation parameters + - Animation Settings: Motion and keyframe parameters + - Performance Settings: Optimization and memory management + - API Settings: Server and network configuration + - Security Settings: Authentication and validation + - Testing Settings: Test-specific options + """ + + # ===== CORE SETTINGS ===== + device: str = "auto" + models_path: str = "models" + output_path: str = "outputs" + cache_path: str = "cache" + model_name: str = "flux-schnell" + + # ===== GENERATION SETTINGS ===== + width: int = 512 + height: int = 512 + steps: int = 20 + guidance_scale: float = 7.5 + seed: Optional[int] = None + prompt: Optional[str] = None + max_prompt_length: int = 256 + + # ===== ANIMATION SETTINGS ===== + # Basic animation parameters + animation_mode: str = "2D" # "2D", "3D", "Video Input" + max_frames: int = 10 + fps: int = 24 + + # Motion schedules (keyframe strings) - from DeforumConfig + zoom: str = "0:(1.0)" + angle: str = "0:(0)" + translation_x: str = "0:(0)" + translation_y: str = "0:(0)" + translation_z: str = "0:(0)" + rotation_3d_x: str = "0:(0)" + rotation_3d_y: str = "0:(0)" + rotation_3d_z: str = "0:(0)" + + # Strength schedules - from DeforumConfig + strength_schedule: str = "0:(0.65)" + noise_schedule: str = "0:(0.02)" + contrast_schedule: str = "0:(1.0)" + + # Classic Deforum motion settings (simplified parameters) - from Config + enable_learned_motion: bool = False # Classic mode only + motion_strength: float = 0.5 + motion_coherence: float = 0.7 + motion_schedule: str = "0:(0.5)" + depth_strength: float = 0.3 + perspective_flip_theta: str = "0:(0)" + perspective_flip_phi: str = "0:(0)" + perspective_flip_gamma: str = "0:(0)" + perspective_flip_fv: str = "0:(53)" + motion_mode: str = "geometric" # "geometric", "learned", "hybrid" + + # 3D settings - from DeforumConfig + midas_weight: float = 0.3 + near_plane: int = 200 + far_plane: int = 10000 + fov: int = 40 + + # Prompts (frame -> prompt mapping) - from DeforumConfig + positive_prompts: Dict[str, str] = field(default_factory=lambda: {"0": "a beautiful landscape"}) + negative_prompts: Dict[str, str] = field(default_factory=dict) + + # ===== PERFORMANCE SETTINGS ===== + batch_size: int = 1 + memory_efficient: bool = True + enable_attention_slicing: bool = True + enable_vae_tiling: bool = False + enable_vae_slicing: bool = False + enable_cpu_offload: bool = False + enable_sequential_cpu_offload: bool = False + offload: bool = False # Model offloading to CPU (alias for enable_cpu_offload) + precision: str = "fp16" # fp16, fp32, bf16 + enable_xformers: bool = True + enable_flash_attention: bool = False + + # ===== SAMPLING SETTINGS ===== + scheduler: str = "euler" + eta: float = 0.0 + clip_skip: int = 1 + + # ===== QUANTIZATION SETTINGS ===== + enable_quantization: bool = False + quantization_type: str = "none" # "none", "fp8", "fp4", "bnb4" + + # ===== LOGGING SETTINGS ===== + log_level: str = "INFO" + enable_tensorboard: bool = False + + # ===== API SETTINGS ===== + api_host: str = "127.0.0.1" + api_port: int = 7860 + enable_cors: bool = True + + # ===== SECURITY SETTINGS ===== + api_key_required: bool = False + api_key: Optional[str] = None + + + def __post_init__(self): + """Post-initialization validation and setup.""" + # Validate device + if self.device == "auto": + import torch + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + # Sync offload settings + if self.offload: + self.enable_cpu_offload = True + + # Create directories + for path_attr in ["models_path", "output_path", "cache_path"]: + path = Path(getattr(self, path_attr)) + path.mkdir(parents=True, exist_ok=True) + + # Environment variable overrides + self.api_host = os.getenv("API_HOST", self.api_host) + self.api_port = int(os.getenv("API_PORT", str(self.api_port))) + self.log_level = os.getenv("LOG_LEVEL", self.log_level) + + # GPU Cloud optimizations + if os.getenv("GPU_CLOUD_MODE"): + self.api_host = "0.0.0.0" + self.enable_cpu_offload = False # Keep models in GPU memory + self.memory_efficient = True + + print(f"Config initialized - device: {self.device}, animation_mode: {self.animation_mode}") + + def to_dict(self) -> Dict[str, Any]: + """Convert config to dictionary.""" + return { + field.name: getattr(self, field.name) + for field in self.__dataclass_fields__.values() + } + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "Config": + """Create config from dictionary.""" + # Filter out any keys that aren't valid DeforumConfig fields + valid_fields = {field.name for field in cls.__dataclass_fields__.values()} + filtered_dict = {k: v for k, v in config_dict.items() if k in valid_fields} + return cls(**filtered_dict) + + @classmethod + def from_file(cls, config_path: str) -> "Config": + """Create config from JSON file.""" + with open(config_path, "r") as f: + config_data = json.load(f) + return cls.from_dict(config_data) + + def update(self, **kwargs) -> "Config": + """Create a new DeforumConfig with updated values.""" + current_values = { + field.name: getattr(self, field.name) + for field in self.__dataclass_fields__.values() + } + current_values.update(kwargs) + return Config(**current_values) + + def get_motion_parameters(self) -> Dict[str, Any]: + """Get motion-related parameters for animation (simplified parameters).""" + return { + "motion_strength": self.motion_strength, + "motion_coherence": self.motion_coherence, + "motion_schedule": self.motion_schedule, + "depth_strength": self.depth_strength, + "perspective_flip_theta": self.perspective_flip_theta, + "perspective_flip_phi": self.perspective_flip_phi, + "perspective_flip_gamma": self.perspective_flip_gamma, + "perspective_flip_fv": self.perspective_flip_fv, + } + + def to_motion_schedule(self) -> Dict[str, Any]: + """Convert keyframe strings to motion schedule dictionary (comprehensive parameters).""" + motion_schedule = {} + + # Parse each motion parameter + motion_params = { + "zoom": self.zoom, + "angle": self.angle, + "translation_x": self.translation_x, + "translation_y": self.translation_y, + "translation_z": self.translation_z, + "rotation_3d_x": self.rotation_3d_x, + "rotation_3d_y": self.rotation_3d_y, + "rotation_3d_z": self.rotation_3d_z + } + + for param_name, keyframe_string in motion_params.items(): + parsed = self._parse_keyframes(keyframe_string) + for frame, value in parsed.items(): + if frame not in motion_schedule: + motion_schedule[frame] = {} + motion_schedule[frame][param_name] = value + + return motion_schedule + + def _parse_keyframes(self, keyframe_string: str) -> Dict[int, float]: + """Parse keyframe string like '0:(1.0), 10:(1.5)' into frame->value dict.""" + result = {} + if not keyframe_string: + return result + + try: + parts = keyframe_string.split(",") + for part in parts: + part = part.strip() + if ":" in part and "(" in part: + frame_part, value_part = part.split(":", 1) + frame = int(frame_part.strip()) + value_str = value_part.strip() + if value_str.startswith("(") and value_str.endswith(")"): + value = float(value_str[1:-1]) + result[frame] = value + except (ValueError, IndexError): + # If parsing fails, return default + result[0] = 1.0 if "zoom" in keyframe_string else 0.0 + + return result + + +# Default configuration instance +DEFAULT_CONFIG = Config() + + +def get_config() -> Config: + """Get the current configuration.""" + return DEFAULT_CONFIG + + +def update_config(updates: Dict[str, Any]) -> None: + """Update the global configuration.""" + global DEFAULT_CONFIG + for key, value in updates.items(): + if hasattr(DEFAULT_CONFIG, key): + setattr(DEFAULT_CONFIG, key, value) + else: + print(f"Warning: Unknown configuration key: {key}") + + +# ===== CONFIGURATION PRESETS ===== +PRESETS = { + "fast": Config( + model_name="flux-schnell", + steps=18, + guidance_scale=3.5, + width=1024, + height=1024, + enable_cpu_offload=True, + enable_attention_slicing=True, + ), + "balanced": Config( + model_name="flux-dev", + steps=20, + guidance_scale=7.5, + width=1024, + height=1024, + enable_cpu_offload=False, + enable_attention_slicing=True, + ), + "quality": Config( + model_name="flux-dev", + steps=28, + guidance_scale=7.5, + width=1024, + height=1024, + enable_cpu_offload=False, + enable_attention_slicing=False, + ), + "production": Config( + model_name="flux-dev", + steps=50, + guidance_scale=7.5, + width=2048, + height=2048, + enable_cpu_offload=False, + enable_attention_slicing=False, + memory_efficient=False, + ), + # TEST-SPECIFIC PRESETS + "test_minimal": Config( + model_name="flux-schnell", + steps=2, + guidance_scale=3.5, + width=768, + height=768, + skip_model_loading=True, + allow_mocks=True, # CI/unit testing only + max_frames=2, + enable_cpu_offload=True, + ), + "test_GPU_Cloud": Config( + model_name="flux-dev", + steps=4, + guidance_scale=7.5, + width=1024, + height=1024, + skip_model_loading=False, + api_host="0.0.0.0", + api_port=7860, + memory_efficient=True, + ) +} + + +def get_preset(preset_name: str) -> Config: + """Get a configuration preset by name.""" + if preset_name not in PRESETS: + available = ", ".join(PRESETS.keys()) + raise ValueError(f"Unknown preset '{preset_name}'. Available presets: {available}") + + # Return a copy to avoid modifying the original preset + preset = PRESETS[preset_name] + return Config( + **{field.name: getattr(preset, field.name) for field in preset.__dataclass_fields__.values()} + ) + + diff --git a/src/deforum/config/validation_rules.py b/src/deforum/config/validation_rules.py new file mode 100644 index 0000000..a54a9a3 --- /dev/null +++ b/src/deforum/config/validation_rules.py @@ -0,0 +1,164 @@ +""" +Centralized Validation Rules for Deforum Flux + +This module provides a single source of truth for all validation constants, +eliminating duplication across the codebase and ensuring consistency. +""" + +from typing import Dict, Tuple, List + + +class ValidationRules: + """Centralized validation constants and rules.""" + + # Image dimensions + DIMENSIONS = { + "min": 64, + "max": 4096, + "divisible_by": 8 + } + + # Generation parameters + STEPS = { + "min": 1, + "max": 200 + } + + GUIDANCE_SCALE = { + "min": 0.0, + "max": 30.0 + } + + BATCH_SIZE = { + "min": 1, + "max": 32, + "performance_max": 16 + } + + # Animation parameters + MAX_FRAMES = { + "min": 1, + "max": 10000 + } + + FPS = { + "min": 1, + "max": 120 + } + + # Motion parameters with their valid ranges + MOTION_RANGES = { + "zoom": (0.1, 10.0), + "angle": (-360.0, 360.0), + "translation_x": (-2000.0, 2000.0), + "translation_y": (-2000.0, 2000.0), + "translation_z": (-2000.0, 2000.0), + "rotation_3d_x": (-360.0, 360.0), + "rotation_3d_y": (-360.0, 360.0), + "rotation_3d_z": (-360.0, 360.0) + } + + # Strength and scheduling parameters + STRENGTH_RANGES = { + "midas_weight": (0.0, 1.0), + "strength_schedule": (0.0, 1.0), + "noise_schedule": (0.0, 1.0), + "contrast_schedule": (0.0, 10.0), + "motion_strength": (0.0, 1.0), + "motion_coherence": (0.0, 1.0), + "depth_strength": (0.0, 1.0) + } + + # 3D rendering parameters + RENDERING_3D = { + "near_plane": {"min": 1, "max": 1000}, + "far_plane": {"min": 100, "max": 50000}, + "fov": {"min": 1, "max": 180} + } + + # Prompt limits + PROMPT = { + "max_length": 2048, + "min_length": 1 + } + + # Device types + VALID_DEVICES = ["cpu", "cuda", "mps"] + + # Model names + VALID_MODELS = ["flux-schnell", "flux-dev"] + + # Animation modes + VALID_ANIMATION_MODES = ["2D", "3D", "Video Input", "Interpolation"] + + # Motion modes + VALID_MOTION_MODES = ["grouped", "independent", "mixed"] + + # Log levels + VALID_LOG_LEVELS = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + + # Seed limits + SEED = { + "min": 0, + "max": 2**32 - 1 + } + + # File extensions + ALLOWED_IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"] + ALLOWED_VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".mkv", ".webm"] + ALLOWED_CONFIG_EXTENSIONS = [".json", ".yaml", ".yml"] + + @classmethod + def get_dimension_range(cls) -> Tuple[int, int]: + """Get dimension min/max as tuple.""" + return cls.DIMENSIONS["min"], cls.DIMENSIONS["max"] + + @classmethod + def get_steps_range(cls) -> Tuple[int, int]: + """Get steps min/max as tuple.""" + return cls.STEPS["min"], cls.STEPS["max"] + + @classmethod + def get_guidance_range(cls) -> Tuple[float, float]: + """Get guidance scale min/max as tuple.""" + return cls.GUIDANCE_SCALE["min"], cls.GUIDANCE_SCALE["max"] + + @classmethod + def get_motion_range(cls, motion_param: str) -> Tuple[float, float]: + """Get motion parameter range.""" + if motion_param not in cls.MOTION_RANGES: + raise ValueError(f"Unknown motion parameter: {motion_param}") + return cls.MOTION_RANGES[motion_param] + + @classmethod + def get_strength_range(cls, strength_param: str) -> Tuple[float, float]: + """Get strength parameter range.""" + if strength_param not in cls.STRENGTH_RANGES: + raise ValueError(f"Unknown strength parameter: {strength_param}") + return cls.STRENGTH_RANGES[strength_param] + + @classmethod + def is_valid_device(cls, device: str) -> bool: + """Check if device string is valid.""" + device_type = device.split(":")[0] # Handle cuda:0, cuda:1, etc. + return device_type in cls.VALID_DEVICES + + @classmethod + def is_valid_model(cls, model_name: str) -> bool: + """Check if model name is valid.""" + return model_name in cls.VALID_MODELS + + @classmethod + def is_valid_animation_mode(cls, mode: str) -> bool: + """Check if animation mode is valid.""" + return mode in cls.VALID_ANIMATION_MODES + + @classmethod + def is_valid_motion_mode(cls, mode: str) -> bool: + """Check if motion mode is valid.""" + return mode in cls.VALID_MOTION_MODES + + @classmethod + def is_valid_log_level(cls, level: str) -> bool: + """Check if log level is valid.""" + return level in cls.VALID_LOG_LEVELS diff --git a/src/deforum/config/validation_utils.py b/src/deforum/config/validation_utils.py new file mode 100644 index 0000000..9e8d678 --- /dev/null +++ b/src/deforum/config/validation_utils.py @@ -0,0 +1,392 @@ +""" +Reusable Validation Utilities for Deforum Flux + +This module provides common validation helper functions that can be used +across different modules, reducing code duplication and ensuring consistency. +""" + +import re +import os +from typing import Any, List, Dict, Optional, Union, Tuple +from pathlib import Path + +from .validation_rules import ValidationRules + + +class ValidationUtils: + """Reusable validation utility functions.""" + + @staticmethod + def is_in_range(value: Union[int, float], min_val: Union[int, float], max_val: Union[int, float]) -> bool: + """ + Check if value is within specified range (inclusive). + + Args: + value: Value to check + min_val: Minimum allowed value + max_val: Maximum allowed value + + Returns: + True if value is in range, False otherwise + """ + return min_val <= value <= max_val + + @staticmethod + def validate_type(value: Any, expected_type: type, param_name: str) -> List[str]: + """ + Validate value type and return errors if invalid. + + Args: + value: Value to validate + expected_type: Expected type + param_name: Parameter name for error messages + + Returns: + List of error messages (empty if valid) + """ + errors = [] + if not isinstance(value, expected_type): + errors.append(f"{param_name} must be {expected_type.__name__}, got {type(value).__name__}") + return errors + + @staticmethod + def validate_range( + value: Union[int, float], + min_val: Union[int, float], + max_val: Union[int, float], + param_name: str, + value_type: type = None + ) -> List[str]: + """ + Validate value is within range and optionally validate type. + + Args: + value: Value to validate + min_val: Minimum allowed value + max_val: Maximum allowed value + param_name: Parameter name for error messages + value_type: Optional type to validate + + Returns: + List of error messages (empty if valid) + """ + errors = [] + + # Type validation if specified + if value_type and not isinstance(value, value_type): + errors.append(f"{param_name} must be {value_type.__name__}, got {type(value).__name__}") + return errors # Return early if type is wrong + + # Range validation + if not ValidationUtils.is_in_range(value, min_val, max_val): + errors.append(f"{param_name} must be between {min_val} and {max_val}, got {value}") + + return errors + + @staticmethod + def validate_positive_integer(value: Any, param_name: str, max_val: Optional[int] = None) -> List[str]: + """ + Validate value is a positive integer. + + Args: + value: Value to validate + param_name: Parameter name for error messages + max_val: Optional maximum value + + Returns: + List of error messages (empty if valid) + """ + errors = [] + + if not isinstance(value, int): + errors.append(f"{param_name} must be an integer, got {type(value).__name__}") + return errors + + if value <= 0: + errors.append(f"{param_name} must be positive, got {value}") + + if max_val is not None and value > max_val: + errors.append(f"{param_name} must be <= {max_val}, got {value}") + + return errors + + @staticmethod + def validate_divisible_by(value: int, divisor: int, param_name: str) -> List[str]: + """ + Validate value is divisible by divisor. + + Args: + value: Value to validate + divisor: Required divisor + param_name: Parameter name for error messages + + Returns: + List of error messages (empty if valid) + """ + errors = [] + if value % divisor != 0: + errors.append(f"{param_name} must be divisible by {divisor}, got {value}") + return errors + + @staticmethod + def validate_string_not_empty(value: Any, param_name: str, max_length: Optional[int] = None) -> List[str]: + """ + Validate string is not empty and optionally check length. + + Args: + value: Value to validate + param_name: Parameter name for error messages + max_length: Optional maximum length + + Returns: + List of error messages (empty if valid) + """ + errors = [] + + if not isinstance(value, str): + errors.append(f"{param_name} must be a string, got {type(value).__name__}") + return errors + + if not value.strip(): + errors.append(f"{param_name} cannot be empty") + + if max_length is not None and len(value) > max_length: + errors.append(f"{param_name} too long: {len(value)} > {max_length}") + + return errors + + @staticmethod + def validate_choice(value: Any, choices: List[Any], param_name: str) -> List[str]: + """ + Validate value is one of allowed choices. + + Args: + value: Value to validate + choices: List of allowed values + param_name: Parameter name for error messages + + Returns: + List of error messages (empty if valid) + """ + errors = [] + if value not in choices: + errors.append(f"Invalid {param_name}: {value}. Must be one of {choices}") + return errors + + @staticmethod + def validate_keyframe_syntax(keyframe_string: str) -> bool: + """ + Validate keyframe syntax (e.g., "0:(1.0), 30:(1.5)"). + + Args: + keyframe_string: Keyframe string to validate + + Returns: + True if valid, False otherwise + """ + if not isinstance(keyframe_string, str): + return False + + try: + # Basic validation - should contain frame:value pairs + if ":" not in keyframe_string or "(" not in keyframe_string: + return False + + # Split by comma and validate each part + parts = keyframe_string.split(",") + for part in parts: + part = part.strip() + if not part: + continue + + if ":" not in part or "(" not in part or ")" not in part: + return False + + frame_part, value_part = part.split(":", 1) + frame_num = int(frame_part.strip()) + value = value_part.strip() + + if not value.startswith("(") or not value.endswith(")"): + return False + + # Try to parse the value + float(value[1:-1]) + + # Frame number should be non-negative + if frame_num < 0: + return False + + return True + + except (ValueError, IndexError): + return False + + @staticmethod + def validate_file_path( + file_path: str, + must_exist: bool = True, + allowed_extensions: Optional[List[str]] = None, + param_name: str = "file_path" + ) -> List[str]: + """ + Validate file path with security checks. + + Args: + file_path: Path to validate + must_exist: Whether file must exist + allowed_extensions: List of allowed file extensions + param_name: Parameter name for error messages + + Returns: + List of error messages (empty if valid) + """ + errors = [] + + if not isinstance(file_path, str): + errors.append(f"{param_name} must be a string, got {type(file_path).__name__}") + return errors + + try: + path = Path(file_path).resolve() + except (OSError, ValueError) as e: + errors.append(f"Invalid {param_name}: {e}") + return errors + + if must_exist and not path.exists(): + errors.append(f"File does not exist: {file_path}") + + if allowed_extensions: + if path.suffix.lower() not in [ext.lower() for ext in allowed_extensions]: + errors.append(f"File extension not allowed. Got {path.suffix}, allowed: {allowed_extensions}") + + return errors + + @staticmethod + def validate_frame_number(frame: Any, param_name: str = "frame") -> List[str]: + """ + Validate frame number is a non-negative integer. + + Args: + frame: Frame value to validate + param_name: Parameter name for error messages + + Returns: + List of error messages (empty if valid) + """ + errors = [] + + try: + frame_num = int(frame) + if frame_num < 0: + errors.append(f"{param_name} number must be non-negative, got {frame_num}") + except (ValueError, TypeError): + errors.append(f"Invalid {param_name} number: {frame}") + + return errors + + @staticmethod + def sanitize_filename(filename: str) -> str: + """ + Sanitize filename for safe file system usage. + + Args: + filename: Original filename + + Returns: + Sanitized filename + """ + # Remove or replace problematic characters + sanitized = re.sub(r'[<>:"/\\|?*]', '_', filename) + + # Remove control characters + sanitized = ''.join(c for c in sanitized if ord(c) >= 32) + + # Limit length + if len(sanitized) > 255: + name, ext = os.path.splitext(sanitized) + sanitized = name[:255-len(ext)] + ext + + # Ensure it's not empty + if not sanitized.strip(): + sanitized = "untitled" + + return sanitized + + @staticmethod + def collect_errors(*error_lists: List[str]) -> List[str]: + """ + Collect and flatten multiple error lists. + + Args: + *error_lists: Variable number of error lists + + Returns: + Flattened list of all errors + """ + all_errors = [] + for error_list in error_lists: + all_errors.extend(error_list) + return all_errors + + +# Convenience functions using ValidationRules +class DomainValidators: + """Domain-specific validators using ValidationRules and ValidationUtils.""" + + @staticmethod + def validate_dimensions(width: int, height: int) -> List[str]: + """Validate image dimensions using centralized rules.""" + min_dim, max_dim = ValidationRules.get_dimension_range() + divisor = ValidationRules.DIMENSIONS["divisible_by"] + + errors = [] + errors.extend(ValidationUtils.validate_range(width, min_dim, max_dim, "width", int)) + errors.extend(ValidationUtils.validate_range(height, min_dim, max_dim, "height", int)) + errors.extend(ValidationUtils.validate_divisible_by(width, divisor, "width")) + errors.extend(ValidationUtils.validate_divisible_by(height, divisor, "height")) + + return errors + + @staticmethod + def validate_generation_params(steps: int, guidance_scale: float, seed: Optional[int] = None) -> List[str]: + """Validate generation parameters using centralized rules.""" + min_steps, max_steps = ValidationRules.get_steps_range() + min_guidance, max_guidance = ValidationRules.get_guidance_range() + + errors = [] + errors.extend(ValidationUtils.validate_range(steps, min_steps, max_steps, "steps", int)) + errors.extend(ValidationUtils.validate_range(guidance_scale, min_guidance, max_guidance, "guidance_scale", (int, float))) + + if seed is not None: + min_seed, max_seed = ValidationRules.SEED["min"], ValidationRules.SEED["max"] + errors.extend(ValidationUtils.validate_range(seed, min_seed, max_seed, "seed", int)) + + return errors + + @staticmethod + def validate_motion_params(motion_params: Dict[str, float]) -> List[str]: + """Validate motion parameters using centralized rules.""" + errors = [] + + for param_name, param_value in motion_params.items(): + if param_name not in ValidationRules.MOTION_RANGES: + errors.append(f"Unknown motion parameter: {param_name}") + continue + + min_val, max_val = ValidationRules.get_motion_range(param_name) + errors.extend(ValidationUtils.validate_range(param_value, min_val, max_val, param_name, (int, float))) + + return errors + + @staticmethod + def validate_animation_settings(max_frames: int, fps: int) -> List[str]: + """Validate animation settings using centralized rules.""" + errors = [] + + min_frames, max_frames_limit = ValidationRules.MAX_FRAMES["min"], ValidationRules.MAX_FRAMES["max"] + min_fps, max_fps = ValidationRules.FPS["min"], ValidationRules.FPS["max"] + + errors.extend(ValidationUtils.validate_range(max_frames, min_frames, max_frames_limit, "max_frames", int)) + errors.extend(ValidationUtils.validate_range(fps, min_fps, max_fps, "fps", int)) + + return errors diff --git a/src/deforum/core/__init__.py b/src/deforum/core/__init__.py new file mode 100644 index 0000000..38790ec --- /dev/null +++ b/src/deforum/core/__init__.py @@ -0,0 +1,6 @@ +"""Core package initialization.""" + +from . import exceptions +from . import logging_config + +__all__ = ["exceptions", "logging_config"] diff --git a/src/deforum/core/exceptions.py b/src/deforum/core/exceptions.py new file mode 100644 index 0000000..62dc9b6 --- /dev/null +++ b/src/deforum/core/exceptions.py @@ -0,0 +1,382 @@ +""" +Exception hierarchy for Deforum Flux + +This module provides a comprehensive exception hierarchy to replace the +scattered error handling identified in the audit. +""" + +from typing import Optional, Dict, Any + + +class DeforumException(Exception): + """Base exception for all Deforum-related errors.""" + + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None, original_error: Optional[Exception] = None): + """ + Initialize Deforum exception. + + Args: + message: Human-readable error message + details: Optional dictionary with additional error details + original_error: Optional original exception that caused this error + """ + super().__init__(message) + self.message = message + self.details = details or {} + self.original_error = original_error + + def __str__(self) -> str: + """Return string representation of the exception.""" + if self.details: + details_str = ", ".join(f"{k}={v}" for k, v in self.details.items()) + return f"{self.message} (Details: {details_str})" + return self.message + + def to_dict(self) -> Dict[str, Any]: + """Convert exception to dictionary for logging/serialization.""" + result = { + "exception_type": self.__class__.__name__, + "message": self.message, + "details": self.details + } + if self.original_error: + result["original_error"] = { + "type": type(self.original_error).__name__, + "message": str(self.original_error) + } + return result + + +class FluxModelError(DeforumException): + """Errors related to Flux model loading, initialization, or inference.""" + + def __init__(self, message: str, model_name: Optional[str] = None, + device: Optional[str] = None, **kwargs): + """ + Initialize Flux model error. + + Args: + message: Error message + model_name: Name of the Flux model that caused the error + device: Device where the error occurred + **kwargs: Additional error details + """ + details = kwargs.copy() + if model_name: + details["model_name"] = model_name + if device: + details["device"] = device + + super().__init__(message, details) + + +class ModelLoadingError(FluxModelError): + """Specific error for model loading failures.""" + + def __init__(self, message: str, model_path: Optional[str] = None, **kwargs): + """ + Initialize model loading error. + + Args: + message: Error message + model_path: Path to the model that failed to load + **kwargs: Additional error details + """ + details = kwargs.copy() + if model_path: + details["model_path"] = model_path + + super().__init__(message, **details) + + +class DeforumConfigError(DeforumException): + """Errors related to configuration validation or processing.""" + + def __init__(self, message: str, config_field: Optional[str] = None, + config_value: Optional[Any] = None, **kwargs): + """ + Initialize configuration error. + + Args: + message: Error message + config_field: Name of the configuration field that caused the error + config_value: Value that caused the error + **kwargs: Additional error details + """ + details = kwargs.copy() + if config_field: + details["config_field"] = config_field + if config_value is not None: + details["config_value"] = str(config_value) + + super().__init__(message, details) + + +class ValidationError(DeforumException): + """Errors related to input validation.""" + + def __init__(self, message: str, validation_errors: Optional[list] = None, + field_name: Optional[str] = None, **kwargs): + """ + Initialize validation error. + + Args: + message: Error message + validation_errors: List of specific validation errors + field_name: Name of the field that failed validation + **kwargs: Additional error details + """ + details = kwargs.copy() + if validation_errors: + details["validation_errors"] = validation_errors + if field_name: + details["field_name"] = field_name + + super().__init__(message, details) + + +class ParameterError(DeforumException): + """Errors related to parameter parsing or processing.""" + + def __init__(self, message: str, parameter_name: Optional[str] = None, + parameter_value: Optional[Any] = None, **kwargs): + """ + Initialize parameter error. + + Args: + message: Error message + parameter_name: Name of the parameter that caused the error + parameter_value: Value that caused the error + **kwargs: Additional error details + """ + details = kwargs.copy() + if parameter_name: + details["parameter_name"] = parameter_name + if parameter_value is not None: + details["parameter_value"] = str(parameter_value) + + super().__init__(message, details) + + +class MotionProcessingError(DeforumException): + """Errors related to motion processing and animation generation.""" + + def __init__(self, message: str, frame_index: Optional[int] = None, + motion_params: Optional[Dict[str, Any]] = None, original_error: Optional[Exception] = None, **kwargs): + """ + Initialize motion processing error. + + Args: + message: Error message + frame_index: Index of the frame where the error occurred + motion_params: Motion parameters that caused the error + original_error: Optional original exception that caused this error + **kwargs: Additional error details + """ + details = kwargs.copy() + if frame_index is not None: + details["frame_index"] = frame_index + if motion_params: + details["motion_params"] = motion_params + + super().__init__(message, details, original_error) + + +class TensorProcessingError(DeforumException): + """Errors related to tensor operations and processing.""" + + def __init__(self, message: str, tensor_shape: Optional[tuple] = None, + expected_shape: Optional[tuple] = None, **kwargs): + """ + Initialize tensor processing error. + + Args: + message: Error message + tensor_shape: Actual tensor shape that caused the error + expected_shape: Expected tensor shape + **kwargs: Additional error details + """ + details = kwargs.copy() + if tensor_shape: + details["tensor_shape"] = tensor_shape + if expected_shape: + details["expected_shape"] = expected_shape + + super().__init__(message, details) + + +class ResourceError(DeforumException): + """Errors related to system resources (memory, disk, etc.).""" + + def __init__(self, message: str, resource_type: Optional[str] = None, + available: Optional[str] = None, required: Optional[str] = None, **kwargs): + """ + Initialize resource error. + + Args: + message: Error message + resource_type: Type of resource (memory, disk, etc.) + available: Available resource amount + required: Required resource amount + **kwargs: Additional error details + """ + details = kwargs.copy() + if resource_type: + details["resource_type"] = resource_type + if available: + details["available"] = available + if required: + details["required"] = required + + super().__init__(message, details) + + +class TimeoutError(DeforumException): + """Errors related to operation timeouts.""" + + def __init__(self, message: str, timeout_seconds: Optional[float] = None, + operation: Optional[str] = None, **kwargs): + """ + Initialize timeout error. + + Args: + message: Error message + timeout_seconds: Timeout duration in seconds + operation: Name of the operation that timed out + **kwargs: Additional error details + """ + details = kwargs.copy() + if timeout_seconds: + details["timeout_seconds"] = timeout_seconds + if operation: + details["operation"] = operation + + super().__init__(message, details) + + +class APIError(DeforumException): + """Errors related to API calls and external services.""" + + def __init__(self, message: str, status_code: Optional[int] = None, + endpoint: Optional[str] = None, **kwargs): + """ + Initialize API error. + + Args: + message: Error message + status_code: HTTP status code + endpoint: API endpoint that caused the error + **kwargs: Additional error details + """ + details = kwargs.copy() + if status_code: + details["status_code"] = status_code + if endpoint: + details["endpoint"] = endpoint + + super().__init__(message, details) + + +# Exception mapping for common error patterns +EXCEPTION_MAPPING = { + "model_loading": ModelLoadingError, + "flux_model": FluxModelError, + "config": DeforumConfigError, + "validation": ValidationError, + "parameter": ParameterError, + "motion": MotionProcessingError, + "tensor": TensorProcessingError, + "resource": ResourceError, + "timeout": TimeoutError, + "api": APIError +} + + +def create_exception(error_type: str, message: str, **kwargs) -> DeforumException: + """ + Create an exception of the appropriate type. + + Args: + error_type: Type of error (key in EXCEPTION_MAPPING) + message: Error message + **kwargs: Additional error details + + Returns: + Appropriate exception instance + """ + exception_class = EXCEPTION_MAPPING.get(error_type, DeforumException) + return exception_class(message, **kwargs) + + +def handle_exception(func): + """ + Decorator to handle exceptions and convert them to appropriate Deforum exceptions. + + Args: + func: Function to decorate + + Returns: + Decorated function + """ + import functools + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except DeforumException: + # Re-raise Deforum exceptions as-is + raise + except FileNotFoundError as e: + raise DeforumConfigError(f"File not found: {e}", file_path=str(e)) + except ValueError as e: + raise ValidationError(f"Invalid value: {e}") + except RuntimeError as e: + if "CUDA" in str(e) or "GPU" in str(e): + raise FluxModelError(f"GPU/CUDA error: {e}") + raise DeforumException(f"Runtime error: {e}") + except MemoryError as e: + raise ResourceError(f"Out of memory: {e}", resource_type="memory") + except KeyError as e: + # Handle specific KeyError issues (like LogRecord problems) + if "module" in str(e) and "LogRecord" in str(type(e)): + raise DeforumException(f"Logging configuration error: {e}", details={"original_exception": type(e)}) + raise ValidationError(f"Missing key: {e}") + except Exception as e: + # More detailed error information + error_details = { + "original_exception": type(e).__name__, + "error_message": str(e), + "function": func.__name__ if hasattr(func, '__name__') else 'unknown' + } + raise DeforumException(f"Unexpected error: \"{e}\"", details=error_details) + + return wrapper + +class SecurityError(DeforumException): + """Errors related to security violations and input validation.""" + + def __init__(self, message: str, security_violation: Optional[str] = None, + input_value: Optional[str] = None, **kwargs): + """ + Initialize security error. + + Args: + message: Error message + security_violation: Type of security violation + input_value: Input that caused the security violation + **kwargs: Additional error details + """ + details = kwargs.copy() + if security_violation: + details["security_violation"] = security_violation + if input_value: + details["input_value"] = str(input_value) + + super().__init__(message, details) + + +# Update the exception mapping +EXCEPTION_MAPPING.update({ + "security": SecurityError +}) diff --git a/src/deforum/core/logging_config.py b/src/deforum/core/logging_config.py new file mode 100644 index 0000000..94c37ad --- /dev/null +++ b/src/deforum/core/logging_config.py @@ -0,0 +1,393 @@ +""" +Logging configuration for Deforum Flux + +This module provides centralized logging configuration with performance monitoring, +structured logging, and multiple output formats. +""" + +import logging +import logging.handlers +import sys +import time +import functools +import json +from pathlib import Path +from typing import Optional, Dict, Any, Union +from datetime import datetime + + +class PerformanceFilter(logging.Filter): + """Filter to add performance metrics to log records.""" + + def filter(self, record): + """Add performance information to log record.""" + record.timestamp = time.time() + record.iso_timestamp = datetime.now().isoformat() + return True + + +class StructuredFormatter(logging.Formatter): + """Formatter that outputs structured JSON logs.""" + + def format(self, record): + """Format log record as structured JSON.""" + log_entry = { + "timestamp": getattr(record, "iso_timestamp", datetime.now().isoformat()), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + "source_module": getattr(record, "filename", "unknown"), # Fixed: Use filename instead of module + "function": record.funcName, + "line": record.lineno + } + + # Add exception information if present + if record.exc_info: + log_entry["exception"] = self.formatException(record.exc_info) + + # Add extra fields (avoid overwriting core fields) + reserved_keys = { + "name", "msg", "args", "levelname", "levelno", "pathname", + "filename", "module", "lineno", "funcName", "created", + "msecs", "relativeCreated", "thread", "threadName", + "processName", "process", "getMessage", "exc_info", + "exc_text", "stack_info", "timestamp", "iso_timestamp", + "level", "logger", "message", "source_module", "function", "line" + } + + for key, value in record.__dict__.items(): + if key not in reserved_keys: + # Use the key directly if it doesn't conflict + log_entry[key] = value + + return json.dumps(log_entry) + + +class ColoredFormatter(logging.Formatter): + """Formatter that adds colors to console output.""" + + COLORS = { + 'DEBUG': '\033[36m', # Cyan + 'INFO': '\033[32m', # Green + 'WARNING': '\033[33m', # Yellow + 'ERROR': '\033[31m', # Red + 'CRITICAL': '\033[35m', # Magenta + 'RESET': '\033[0m' # Reset + } + + def format(self, record): + """Format log record with colors.""" + log_color = self.COLORS.get(record.levelname, self.COLORS['RESET']) + reset_color = self.COLORS['RESET'] + + # Create colored level name + colored_levelname = f"{log_color}{record.levelname}{reset_color}" + + # Format the message + formatted_message = super().format(record) + + # Replace levelname with colored version + formatted_message = formatted_message.replace(record.levelname, colored_levelname) + + return formatted_message + + +def setup_logging( + level: Union[str, int] = logging.INFO, + log_file: Optional[str] = None, + console_output: bool = True, + structured_logging: bool = False, + enable_performance_logging: bool = True, + max_file_size: int = 10 * 1024 * 1024, # 10MB + backup_count: int = 5 +) -> logging.Logger: + """ + Set up comprehensive logging configuration. + + Args: + level: Logging level (string or int) + log_file: Path to log file (optional) + console_output: Whether to output to console + structured_logging: Whether to use structured JSON logging + enable_performance_logging: Whether to enable performance logging + max_file_size: Maximum size of log file before rotation + backup_count: Number of backup log files to keep + + Returns: + Configured root logger + """ + # Convert string level to int if needed + if isinstance(level, str): + level = getattr(logging, level.upper()) + + # Get root logger + root_logger = logging.getLogger() + root_logger.setLevel(level) + + # Clear existing handlers + root_logger.handlers.clear() + + # Add performance filter if enabled + if enable_performance_logging: + perf_filter = PerformanceFilter() + root_logger.addFilter(perf_filter) + + # Console handler + if console_output: + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(level) + + if structured_logging: + console_formatter = StructuredFormatter() + else: + console_formatter = ColoredFormatter( + '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + console_handler.setFormatter(console_formatter) + root_logger.addHandler(console_handler) + + # File handler + if log_file: + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + # Use rotating file handler + file_handler = logging.handlers.RotatingFileHandler( + log_file, + maxBytes=max_file_size, + backupCount=backup_count + ) + file_handler.setLevel(level) + + if structured_logging: + file_formatter = StructuredFormatter() + else: + file_formatter = logging.Formatter( + '%(asctime)s | %(levelname)-8s | %(name)s | %(funcName)s:%(lineno)d | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + file_handler.setFormatter(file_formatter) + root_logger.addHandler(file_handler) + + # Log the setup + setup_logger = logging.getLogger(__name__) + setup_logger.info(f"Logging initialized - Level: {logging.getLevelName(level)}") + if log_file: + setup_logger.info(f"Log file: {log_file}") + if structured_logging: + setup_logger.info("Structured JSON logging enabled") + if enable_performance_logging: + setup_logger.info("Performance logging enabled") + + return root_logger + + +def get_logger(name: str) -> logging.Logger: + """ + Get a logger with the specified name. + + Args: + name: Logger name (usually __name__) + + Returns: + Logger instance + """ + return logging.getLogger(name) + + +def log_performance(func): + """ + Decorator to log function performance metrics. + + Args: + func: Function to decorate + + Returns: + Decorated function + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + logger = get_logger(func.__module__) + start_time = time.time() + + # Log function start + logger.debug(f"Starting {func.__name__}", extra={ + "function": func.__name__, + "source_module": func.__module__, + "args_count": len(args), + "kwargs_count": len(kwargs) + }) + + try: + result = func(*args, **kwargs) + execution_time = time.time() - start_time + + # Log successful completion + logger.info(f"Completed {func.__name__} in {execution_time:.3f}s", extra={ + "function": func.__name__, + "source_module": func.__module__, + "execution_time": execution_time, + "status": "success" + }) + + return result + + except Exception as e: + execution_time = time.time() - start_time + + # Log error + logger.error(f"Error in {func.__name__} after {execution_time:.3f}s: {e}", extra={ + "function": func.__name__, + "source_module": func.__module__, + "execution_time": execution_time, + "status": "error", + "error_type": type(e).__name__, + "error_message": str(e) + }) + + raise + + return wrapper + + +def log_memory_usage(func): + """ + Decorator to log memory usage of functions. + + Args: + func: Function to decorate + + Returns: + Decorated function + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + logger = get_logger(func.__module__) + + try: + import psutil + process = psutil.Process() + + # Get memory before + mem_before = process.memory_info().rss / 1024 / 1024 # MB + + result = func(*args, **kwargs) + + # Get memory after + mem_after = process.memory_info().rss / 1024 / 1024 # MB + mem_delta = mem_after - mem_before + + logger.debug(f"Memory usage for {func.__name__}: {mem_before:.1f}MB -> {mem_after:.1f}MB (Δ{mem_delta:+.1f}MB)", extra={ + "function": func.__name__, + "memory_before_mb": mem_before, + "memory_after_mb": mem_after, + "memory_delta_mb": mem_delta + }) + + return result + + except ImportError: + # psutil not available, just run the function + logger.debug(f"Memory logging unavailable for {func.__name__} (psutil not installed)") + return func(*args, **kwargs) + + return wrapper + + +class LogContext: + """Context manager for structured logging with additional context.""" + + def __init__(self, logger: logging.Logger, operation: str, **context): + """ + Initialize log context. + + Args: + logger: Logger to use + operation: Name of the operation + **context: Additional context to include in logs + """ + self.logger = logger + self.operation = operation + self.context = context + self.start_time = None + + def __enter__(self): + """Enter the context.""" + self.start_time = time.time() + self.logger.info(f"Starting {self.operation}", extra=self.context) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the context.""" + execution_time = time.time() - self.start_time if self.start_time else 0 + + context = self.context.copy() + context["execution_time"] = execution_time + + if exc_type is None: + context["status"] = "success" + self.logger.info(f"Completed {self.operation} in {execution_time:.3f}s", extra=context) + else: + context["status"] = "error" + context["error_type"] = exc_type.__name__ + context["error_message"] = str(exc_val) + self.logger.error(f"Failed {self.operation} after {execution_time:.3f}s: {exc_val}", extra=context) + + def log(self, message: str, level: int = logging.INFO, **extra_context): + """Log a message with the current context.""" + context = self.context.copy() + context.update(extra_context) + self.logger.log(level, message, extra=context) + + +# Pre-configured logger instances +def get_bridge_logger() -> logging.Logger: + """Get logger for bridge operations.""" + return get_logger("deforum.bridge") + + +def get_model_logger() -> logging.Logger: + """Get logger for model operations.""" + return get_logger("deforum.model") + + +def get_motion_logger() -> logging.Logger: + """Get logger for motion processing.""" + return get_logger("deforum.motion") + + +def get_config_logger() -> logging.Logger: + """Get logger for configuration operations.""" + return get_logger("deforum.config") + + +# Example usage patterns +if __name__ == "__main__": + # Example setup + setup_logging( + level="INFO", + log_file="deforum.log", + structured_logging=True, + enable_performance_logging=True + ) + + test_logger = get_logger(__name__) + + # Example usage + with LogContext(test_logger, "test_operation", user_id="test", operation_type="example"): + test_logger.info("This is a test message") + time.sleep(1) # Simulate work + + # Example decorated function + @log_performance + @log_memory_usage + def example_function(): + test_logger.info("Doing some work...") + time.sleep(0.5) + return "result" + + result = example_function() + test_logger.info(f"Got result: {result}") \ No newline at end of file diff --git a/src/deforum/utils/__init__.py b/src/deforum/utils/__init__.py new file mode 100644 index 0000000..40c6d25 --- /dev/null +++ b/src/deforum/utils/__init__.py @@ -0,0 +1 @@ +"""Package initialization.""" diff --git a/src/deforum/utils/device_utils.py b/src/deforum/utils/device_utils.py new file mode 100644 index 0000000..dd3f080 --- /dev/null +++ b/src/deforum/utils/device_utils.py @@ -0,0 +1,282 @@ +""" +Device utilities for consistent device string normalization across the Deforum Flux system. + +This module provides utilities to handle device string inconsistencies between +"cuda" and "cuda:0" formats that can cause tensor device mismatches. +""" + +import torch +from typing import Union, Optional, Dict, Any + + +def normalize_device(device: Union[str, torch.device]) -> str: + """ + Normalize device string to consistent format. + + Converts between "cuda:0" ↔ "cuda" formats to prevent tensor device mismatches. + + Args: + device: Device string or torch.device object + + Returns: + Normalized device string ("cuda" or "cpu") + + Examples: + >>> normalize_device("cuda:0") + "cuda" + >>> normalize_device("cuda") + "cuda" + >>> normalize_device("cpu") + "cpu" + """ + if isinstance(device, torch.device): + device = str(device) + + device_str = str(device).lower().strip() + + # Normalize CUDA devices to "cuda" (without device index) + if device_str.startswith("cuda"): + return "cuda" + elif device_str == "cpu": + return "cpu" + elif device_str == "mps": + return "mps" + else: + # Default to CPU for unknown devices + return "cpu" + + +def get_torch_device(device: Union[str, torch.device], fallback_cpu: bool = True) -> torch.device: + """ + Get torch.device object with proper device index handling. + + Args: + device: Device string or torch.device object + fallback_cpu: Whether to fallback to CPU if CUDA is not available + + Returns: + torch.device object + + Examples: + >>> get_torch_device("cuda") + device(type='cuda', index=0) + >>> get_torch_device("cpu") + device(type='cpu') + >>> get_torch_device("mps") + device(type='mps') + """ + normalized = normalize_device(device) + + if normalized == "cuda": + if torch.cuda.is_available(): + return torch.device("cuda", 0) # Explicitly use device 0 + elif fallback_cpu: + return torch.device("cpu") + else: + raise RuntimeError("CUDA requested but not available") + elif normalized == "mps": + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + return torch.device("mps") + elif fallback_cpu: + return torch.device("cpu") + else: + raise RuntimeError("MPS requested but not available") + else: + return torch.device("cpu") + + +def ensure_tensor_device(tensor: torch.Tensor, target_device: Union[str, torch.device]) -> torch.Tensor: + """ + Ensure tensor is on the target device with proper device normalization. + + Args: + tensor: Input tensor + target_device: Target device + + Returns: + Tensor on target device + """ + target_torch_device = get_torch_device(target_device) + + # Check if tensor is already on the correct device + if tensor.device.type == target_torch_device.type: + if target_torch_device.type == "cpu" or tensor.device.index == target_torch_device.index: + return tensor + + return tensor.to(target_torch_device) + + +def device_matches(device1: Union[str, torch.device], device2: Union[str, torch.device]) -> bool: + """ + Check if two devices are equivalent (handling cuda:0 vs cuda normalization). + + Args: + device1: First device + device2: Second device + + Returns: + True if devices are equivalent + + Examples: + >>> device_matches("cuda", "cuda:0") + True + >>> device_matches("cpu", "cpu") + True + >>> device_matches("cuda", "cpu") + False + """ + return normalize_device(device1) == normalize_device(device2) + + +def get_device(prefer_cuda: bool = True) -> str: + """ + Get the best available device with full MPS, CUDA, and CPU support. + + Args: + prefer_cuda: Whether to prefer CUDA if available + + Returns: + Device string ("mps", "cuda", or "cpu") + + Examples: + >>> get_device() + "cuda" # if CUDA available + >>> get_device() + "mps" # if on M1 Mac with MPS + >>> get_device() + "cpu" # fallback + """ + # Check for MPS (Apple Silicon) support first - priority on M1 Macs + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + # On M1 Macs, prefer MPS over CUDA (better performance and compatibility) + try: + # Test MPS functionality with a small tensor + test_tensor = torch.tensor([1.0], device='mps') + del test_tensor # Clean up + return "mps" + except Exception: + # MPS available but not functional, fall through to other options + pass + + # Then check CUDA + if prefer_cuda and torch.cuda.is_available(): + return "cuda" + + # Fallback to CPU + return "cpu" + + + + +def get_memory_stats(device: Optional[str] = None) -> Dict[str, float]: + """ + Get memory statistics for the specified device. + + Args: + device: Device to get stats for. If None, uses current device. + + Returns: + Dictionary with memory statistics in MB + + Examples: + >>> get_memory_stats("cuda") + {"allocated": 1024.0, "cached": 2048.0, "total": 8192.0} + """ + if device is None: + device = get_device() + + device = normalize_device(device) + + stats = { + "allocated": 0.0, + "cached": 0.0, + "total": 0.0 + } + + if device == "cuda" and torch.cuda.is_available(): + # Convert bytes to MB + stats["allocated"] = torch.cuda.memory_allocated() / (1024 * 1024) + stats["cached"] = torch.cuda.memory_reserved() / (1024 * 1024) + if torch.cuda.device_count() > 0: + stats["total"] = torch.cuda.get_device_properties(0).total_memory / (1024 * 1024) + elif device == "mps" and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + # MPS doesn't have detailed memory stats, return basic info + stats["allocated"] = torch.mps.current_allocated_memory() / (1024 * 1024) if hasattr(torch.mps, 'current_allocated_memory') else 0.0 + stats["cached"] = 0.0 # MPS doesn't expose cached memory + stats["total"] = 0.0 # MPS doesn't expose total memory + + return stats + + +def get_device_info(device: Optional[str] = None) -> Dict[str, Any]: + """ + Get detailed device information. + + Args: + device: Device to get info for. If None, uses current device. + + Returns: + Dictionary with device information + + Examples: + >>> get_device_info("cuda") + {"type": "cuda", "name": "RTX 4090", "memory_gb": 24.0, "compute_capability": (8, 9)} + """ + if device is None: + device = get_device() + + # Don't normalize MPS to CPU - check original device string first + original_device = str(device).lower().strip() + normalized_device = normalize_device(device) + + info = { + "type": normalized_device, + "name": "Unknown", + "memory_gb": 0.0, + "available": False + } + + if normalized_device == "cuda" and torch.cuda.is_available(): + info["available"] = True + info["name"] = torch.cuda.get_device_name(0) + props = torch.cuda.get_device_properties(0) + info["memory_gb"] = props.total_memory / (1024**3) + info["compute_capability"] = (props.major, props.minor) + info["multiprocessor_count"] = props.multi_processor_count + elif original_device == "mps" and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + info["type"] = "mps" + info["available"] = True + info["name"] = "Apple Silicon GPU" + info["memory_gb"] = 0.0 # MPS doesn't expose total memory + elif normalized_device == "cpu": + info["available"] = True + info["name"] = "CPU" + try: + import psutil + info["memory_gb"] = psutil.virtual_memory().total / (1024**3) + except ImportError: + info["memory_gb"] = 0.0 # Fallback if psutil not available + + return info + + +def log_device_info(logger, context: str = "device_info"): + """ + Log current device information for debugging. + + Args: + logger: Logger instance + context: Context string for logging + """ + cuda_available = torch.cuda.is_available() + device_count = torch.cuda.device_count() if cuda_available else 0 + + info = { + "cuda_available": cuda_available, + "device_count": device_count, + "current_device": torch.cuda.current_device() if cuda_available else None, + "memory_allocated": torch.cuda.memory_allocated() if cuda_available else 0, + "memory_cached": torch.cuda.memory_reserved() if cuda_available else 0 + } + + logger.info(f"{context}: {info}") \ No newline at end of file diff --git a/src/deforum/utils/file_utils.py b/src/deforum/utils/file_utils.py new file mode 100644 index 0000000..ede18cf --- /dev/null +++ b/src/deforum/utils/file_utils.py @@ -0,0 +1,475 @@ +""" +File utilities for Deforum Flux + +This module provides utilities for file operations, including saving animations, +managing output directories, and handling configuration files. + +SECURITY ENHANCEMENTS: +- Input validation and sanitization for all user inputs +- Path traversal attack prevention +- Command injection protection +- Secure subprocess execution +""" + +import os +import json +import shutil +import tempfile +import re +from typing import Dict, Any, List, Optional, Union +from pathlib import Path +import numpy as np + +from deforum.core.exceptions import DeforumException, SecurityError +from deforum.core.logging_config import get_logger + + +class SecurityValidator: + """Security validation utilities for input sanitization.""" + + # Allowed characters for file patterns (alphanumeric, underscore, dash, dot, percent, digit specifiers) + SAFE_PATTERN_REGEX = re.compile(r'^[a-zA-Z0-9_\-\.%d]+$') + + # Maximum allowed path depth to prevent excessive directory traversal + MAX_PATH_DEPTH = 20 + + @staticmethod + def validate_file_pattern(pattern: str) -> str: + """ + Validate and sanitize file pattern for ffmpeg. + + Args: + pattern: File pattern string + + Returns: + Sanitized pattern + + Raises: + SecurityError: If pattern contains unsafe characters + """ + if not isinstance(pattern, str): + raise SecurityError(f"Pattern must be string, got {type(pattern)}") + + if not pattern: + raise SecurityError("Pattern cannot be empty") + + if len(pattern) > 255: # Reasonable max filename length + raise SecurityError("Pattern too long (max 255 characters)") + + # Check for suspicious patterns + dangerous_patterns = ['..', '/', '\\', '|', ';', '&', '$', '`', '(', ')', '{', '}', '[', ']', '<', '>'] + for dangerous in dangerous_patterns: + if dangerous in pattern: + raise SecurityError(f"Pattern contains unsafe sequence: {dangerous}") + + # Allow only safe characters + if not SecurityValidator.SAFE_PATTERN_REGEX.match(pattern): + raise SecurityError(f"Pattern contains unsafe characters: {pattern}") + + return pattern + + @staticmethod + def validate_safe_path(path: Union[str, Path], base_path: Optional[Union[str, Path]] = None) -> Path: + """ + Validate path against traversal attacks. + + Args: + path: Path to validate + base_path: Base path to restrict operations to (optional) + + Returns: + Validated Path object + + Raises: + SecurityError: If path is unsafe + """ + if not isinstance(path, (str, Path)): + raise SecurityError(f"Path must be string or Path, got {type(path)}") + + path_obj = Path(path).resolve() + + # Check path depth + parts = path_obj.parts + if len(parts) > SecurityValidator.MAX_PATH_DEPTH: + raise SecurityError(f"Path too deep (max {SecurityValidator.MAX_PATH_DEPTH} levels)") + + # Check for suspicious path components + for part in parts: + if part in ['..', '.', '']: + continue # These are handled by resolve() + if part.startswith('.') and len(part) > 1: + # Allow .gitignore, .env, etc. but be cautious + pass + # Check for suspicious characters in path components + if any(char in part for char in ['|', ';', '&', '$', '`']): + raise SecurityError(f"Path component contains unsafe characters: {part}") + + # If base_path provided, ensure path is within it + if base_path is not None: + base_path_obj = Path(base_path).resolve() + try: + path_obj.relative_to(base_path_obj) + except ValueError: + raise SecurityError(f"Path {path_obj} is outside allowed base path {base_path_obj}") + + return path_obj + + +class FileUtils: + """Utility class for file operations with security enhancements.""" + + def __init__(self): + """Initialize file utilities.""" + self.logger = get_logger(__name__) + + @staticmethod + def ensure_directory(directory: Union[str, Path], base_path: Optional[Union[str, Path]] = None) -> Path: + """ + Ensure directory exists, create if it doesn't. + + SECURITY: Validates path against traversal attacks. + + Args: + directory: Directory path + base_path: Base path to restrict operations to (optional) + + Returns: + Path object of the directory + + Raises: + SecurityError: If path is unsafe + DeforumException: If directory creation fails + """ + try: + # Validate path security + dir_path = SecurityValidator.validate_safe_path(directory, base_path) + + # Create directory securely + dir_path.mkdir(parents=True, exist_ok=True, mode=0o755) # Secure permissions + + logger = get_logger(__name__) + logger.debug(f"Ensured directory exists: {dir_path}") + + return dir_path + + except SecurityError: + raise # Re-raise security errors + except Exception as e: + raise DeforumException(f"Failed to create directory {directory}: {e}") + + @staticmethod + def save_animation_frames( + frames: List[np.ndarray], + output_dir: Union[str, Path], + prefix: str = "frame", + format: str = "png" + ) -> List[Path]: + """ + Save animation frames to files. + + SECURITY: Validates output directory and filename components. + + Args: + frames: List of frame arrays + output_dir: Output directory + prefix: Filename prefix + format: Image format + + Returns: + List of saved file paths + + Raises: + SecurityError: If inputs are unsafe + DeforumException: If saving fails + """ + # Validate inputs + if not isinstance(frames, list) or not frames: + raise DeforumException("Frames must be a non-empty list") + + # Validate prefix for safety + if not re.match(r'^[a-zA-Z0-9_\-]+$', prefix): + raise SecurityError(f"Unsafe filename prefix: {prefix}") + + # Validate format + allowed_formats = {'png', 'jpg', 'jpeg', 'bmp', 'tiff'} + if format.lower() not in allowed_formats: + raise SecurityError(f"Unsupported format: {format}") + + output_path = FileUtils.ensure_directory(output_dir) + saved_files = [] + + try: + from PIL import Image + + for i, frame in enumerate(frames): + # Secure filename generation + filename = f"{prefix}_{i:04d}.{format}" + file_path = output_path / filename + + # Validate final path + SecurityValidator.validate_safe_path(file_path, output_path) + + # Convert numpy array to PIL Image + if frame.dtype != np.uint8: + frame = (np.clip(frame, 0, 1) * 255).astype(np.uint8) + + image = Image.fromarray(frame) + image.save(file_path) + saved_files.append(file_path) + + return saved_files + + except ImportError: + raise DeforumException("PIL (Pillow) is required for saving images. Install with: pip install Pillow") + except SecurityError: + raise # Re-raise security errors + except Exception as e: + raise DeforumException(f"Failed to save animation frames: {e}") + + @staticmethod + def create_video_from_frames( + frame_dir: Union[str, Path], + output_path: Union[str, Path], + fps: int = 24, + pattern: str = "frame_%04d.png" + ) -> Path: + """ + Create video from frame images using ffmpeg. + + SECURITY: Validates all inputs and uses secure subprocess execution. + + Args: + frame_dir: Directory containing frames + output_path: Output video path + fps: Frames per second + pattern: Frame filename pattern + + Returns: + Path to created video + + Raises: + SecurityError: If inputs are unsafe + DeforumException: If video creation fails + """ + # Validate FPS + if not isinstance(fps, int) or fps <= 0 or fps > 120: + raise SecurityError(f"Invalid FPS value: {fps} (must be 1-120)") + + # Validate and sanitize pattern (CRITICAL SECURITY FIX) + pattern = SecurityValidator.validate_file_pattern(pattern) + + # Validate paths + frame_dir = SecurityValidator.validate_safe_path(frame_dir) + output_path = SecurityValidator.validate_safe_path(output_path) + + # Ensure frame directory exists + if not frame_dir.exists() or not frame_dir.is_dir(): + raise DeforumException(f"Frame directory does not exist: {frame_dir}") + + # Ensure output directory exists + FileUtils.ensure_directory(output_path.parent) + + try: + import subprocess + + # Build ffmpeg command with secure parameters + input_pattern = str(frame_dir / pattern) + + # Use explicit parameter list to prevent injection + cmd = [ + "ffmpeg", + "-y", # Overwrite output file + "-framerate", str(fps), # Convert to string securely + "-i", input_pattern, + "-c:v", "libx264", + "-pix_fmt", "yuv420p", + "-crf", "18", # High quality + str(output_path) + ] + + logger = get_logger(__name__) + logger.info(f"Executing ffmpeg command: {' '.join(cmd)}") + + # Execute with security measures + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=300, # 5 minute timeout + cwd=str(frame_dir.parent), # Set working directory + env={"PATH": os.environ.get("PATH", "")}, # Minimal environment + ) + + if result.returncode != 0: + logger.error(f"ffmpeg stderr: {result.stderr}") + raise DeforumException(f"ffmpeg failed with return code {result.returncode}: {result.stderr}") + + logger.info(f"Successfully created video: {output_path}") + return output_path + + except subprocess.TimeoutExpired: + raise DeforumException("ffmpeg command timed out (5 minutes)") + except FileNotFoundError: + raise DeforumException("ffmpeg not found. Please install ffmpeg to create videos.") + except SecurityError: + raise # Re-raise security errors + except Exception as e: + raise DeforumException(f"Failed to create video: {e}") + + @staticmethod + def save_config(config: Dict[str, Any], file_path: Union[str, Path]) -> None: + """ + Save configuration to JSON file. + + SECURITY: Validates file path and config content. + + Args: + config: Configuration dictionary + file_path: Output file path + + Raises: + SecurityError: If inputs are unsafe + DeforumException: If saving fails + """ + # Validate file path + file_path = SecurityValidator.validate_safe_path(file_path) + + # Validate config content + if not isinstance(config, dict): + raise SecurityError("Config must be a dictionary") + + FileUtils.ensure_directory(file_path.parent) + + # Convert any non-serializable objects + serializable_config = FileUtils._make_serializable(config) + + try: + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(serializable_config, f, indent=2, ensure_ascii=False) + + logger = get_logger(__name__) + logger.debug(f"Saved config to: {file_path}") + + except Exception as e: + raise DeforumException(f"Failed to save config: {e}") + + @staticmethod + def load_config(file_path: Union[str, Path]) -> Dict[str, Any]: + """ + Load configuration from JSON file. + + SECURITY: Validates file path and size limits. + + Args: + file_path: Input file path + + Returns: + Configuration dictionary + + Raises: + SecurityError: If file is unsafe + DeforumException: If loading fails + """ + # Validate file path + file_path = SecurityValidator.validate_safe_path(file_path) + + if not file_path.exists(): + raise DeforumException(f"Configuration file not found: {file_path}") + + # Check file size (prevent huge files) + file_size = file_path.stat().st_size + max_size = 10 * 1024 * 1024 # 10MB limit + if file_size > max_size: + raise SecurityError(f"Config file too large: {file_size} bytes (max {max_size})") + + try: + with open(file_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + if not isinstance(config, dict): + raise DeforumException("Config file must contain a JSON object") + + logger = get_logger(__name__) + logger.debug(f"Loaded config from: {file_path}") + + return config + + except json.JSONDecodeError as e: + raise DeforumException(f"Invalid JSON in configuration file: {e}") + except Exception as e: + raise DeforumException(f"Failed to load configuration: {e}") + + @staticmethod + def _make_serializable(obj: Any) -> Any: + """ + Convert object to JSON-serializable format. + + SECURITY: Prevents code injection through object serialization. + + Args: + obj: Object to convert + + Returns: + Serializable object + """ + if isinstance(obj, dict): + return {k: FileUtils._make_serializable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [FileUtils._make_serializable(item) for item in obj] + elif isinstance(obj, (str, int, float, bool)) or obj is None: + return obj + elif hasattr(obj, '__dict__'): + # Only serialize safe attributes (no private/dunder attributes) + safe_dict = {} + for k, v in obj.__dict__.items(): + if not k.startswith('_'): # Skip private attributes + safe_dict[k] = FileUtils._make_serializable(v) + return safe_dict + else: + # Convert to string but sanitize + str_repr = str(obj) + if len(str_repr) > 1000: # Prevent huge strings + str_repr = str_repr[:1000] + "..." + return str_repr + + # Additional secure file operations with remaining methods... + @staticmethod + def backup_file(file_path: Union[str, Path], backup_dir: Optional[Union[str, Path]] = None) -> Path: + """Create backup of a file with security validation.""" + file_path = SecurityValidator.validate_safe_path(file_path) + + if not file_path.exists(): + raise DeforumException(f"File to backup does not exist: {file_path}") + + if backup_dir is None: + backup_dir = file_path.parent / "backups" + else: + backup_dir = SecurityValidator.validate_safe_path(backup_dir) + + FileUtils.ensure_directory(backup_dir) + + # Create unique backup filename with timestamp + import datetime + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + backup_name = f"{file_path.stem}_{timestamp}{file_path.suffix}" + backup_path = backup_dir / backup_name + + shutil.copy2(file_path, backup_path) + return backup_path + + @staticmethod + def find_files(directory: Union[str, Path], pattern: str = "*", recursive: bool = False) -> List[Path]: + """Find files with security validation.""" + directory = SecurityValidator.validate_safe_path(directory) + + if not directory.exists(): + return [] + + if recursive: + return list(directory.rglob(pattern)) + else: + return list(directory.glob(pattern)) + + +# Export classes +__all__ = ["FileUtils", "SecurityValidator"] diff --git a/src/deforum/utils/tensor_utils.py b/src/deforum/utils/tensor_utils.py new file mode 100644 index 0000000..ef78c57 --- /dev/null +++ b/src/deforum/utils/tensor_utils.py @@ -0,0 +1,396 @@ +""" +Tensor processing utilities for Deforum Flux + +This module provides utilities for tensor operations, conversions, and processing +commonly used in the Flux-Deforum pipeline. +""" + +import torch +import torch.nn.functional as F +import numpy as np +from typing import Tuple, Optional, Union, List +from PIL import Image + +from deforum.core.exceptions import TensorProcessingError +from deforum.core.logging_config import get_logger + + +class TensorUtils: + """Utility class for tensor operations and conversions.""" + + def __init__(self): + """Initialize tensor utilities.""" + self.logger = get_logger(__name__) + + @staticmethod + def validate_tensor_shape( + tensor: torch.Tensor, + expected_shape: Optional[Tuple[int, ...]] = None, + expected_dims: Optional[int] = None, + name: str = "tensor" + ) -> None: + """ + Validate tensor shape and dimensions. + + Args: + tensor: Tensor to validate + expected_shape: Expected exact shape (optional) + expected_dims: Expected number of dimensions (optional) + name: Name of tensor for error messages + + Raises: + TensorProcessingError: If validation fails + """ + if not isinstance(tensor, torch.Tensor): + raise TensorProcessingError(f"{name} must be a torch.Tensor, got {type(tensor)}") + + if expected_dims is not None and tensor.ndim != expected_dims: + raise TensorProcessingError( + f"{name} must have {expected_dims} dimensions, got {tensor.ndim}", + tensor_shape=tensor.shape + ) + + if expected_shape is not None: + if tensor.shape != expected_shape: + raise TensorProcessingError( + f"{name} shape mismatch", + tensor_shape=tensor.shape, + expected_shape=expected_shape + ) + + @staticmethod + def tensor_to_numpy(tensor: torch.Tensor, normalize: bool = True) -> np.ndarray: + """ + Convert tensor to numpy array with proper scaling. + + Args: + tensor: Input tensor + normalize: Whether to normalize to [0, 255] range + + Returns: + Numpy array + """ + # Move to CPU and convert to float32 + # Optimized tensor conversion (CRITICAL PERFORMANCE FIX) + # Minimize intermediate allocations by checking current state first + if tensor.is_cuda: + # Only move to CPU if not already there + tensor_cpu = tensor.detach().cpu() + else: + tensor_cpu = tensor.detach() + + # Only convert to float if not already float32/float64 + if tensor_cpu.dtype not in (torch.float32, torch.float64): + tensor_cpu = tensor_cpu.float() + + # Convert to numpy with minimal memory footprint + array = tensor_cpu.numpy() + + # Handle batch dimension + if array.ndim == 4 and array.shape[0] == 1: + array = array[0] + + # Transpose from CHW to HWC if needed + if array.ndim == 3 and array.shape[0] in [1, 3, 4]: + array = np.transpose(array, (1, 2, 0)) + + if normalize: + # Clip and scale to [0, 255] + array = np.clip(array, 0, 1) + array = (array * 255).astype(np.uint8) + + return array + + @staticmethod + def numpy_to_tensor( + array: np.ndarray, + device: str = "cpu", + normalize: bool = True + ) -> torch.Tensor: + """ + Convert numpy array to tensor. + + Args: + array: Input numpy array + device: Target device + normalize: Whether to normalize from [0, 255] to [0, 1] + + Returns: + Tensor + """ + if normalize and array.dtype == np.uint8: + array = array.astype(np.float32) / 255.0 + + # Convert to tensor + tensor = torch.from_numpy(array).to(device) + + # Handle channel dimension + if tensor.ndim == 3: # HWC -> CHW + tensor = tensor.permute(2, 0, 1) + + # Add batch dimension if needed + if tensor.ndim == 3: + tensor = tensor.unsqueeze(0) + + return tensor + + @staticmethod + def pil_to_tensor(image: Image.Image, device: str = "cpu") -> torch.Tensor: + """ + Convert PIL Image to tensor. + + Args: + image: PIL Image + device: Target device + + Returns: + Tensor in format (1, C, H, W) + """ + # Convert to numpy + array = np.array(image) + + # Handle grayscale + if array.ndim == 2: + array = array[:, :, np.newaxis] + + return TensorUtils.numpy_to_tensor(array, device) + + @staticmethod + def tensor_to_pil(tensor: torch.Tensor) -> Image.Image: + """ + Convert tensor to PIL Image. + + Args: + tensor: Input tensor + + Returns: + PIL Image + """ + array = TensorUtils.tensor_to_numpy(tensor, normalize=True) + + if array.ndim == 3 and array.shape[2] == 1: + array = array.squeeze(2) + + return Image.fromarray(array) + + @staticmethod + def resize_tensor( + tensor: torch.Tensor, + size: Tuple[int, int], + mode: str = "bilinear", + align_corners: bool = False + ) -> torch.Tensor: + """ + Resize tensor using interpolation. + + Args: + tensor: Input tensor (B, C, H, W) + size: Target size (height, width) + mode: Interpolation mode + align_corners: Whether to align corners + + Returns: + Resized tensor + """ + TensorUtils.validate_tensor_shape(tensor, expected_dims=4, name="input tensor") + + return F.interpolate( + tensor, + size=size, + mode=mode, + align_corners=align_corners + ) + + @staticmethod + def apply_geometric_transform( + tensor: torch.Tensor, + zoom: float = 1.0, + angle: float = 0.0, + translation_x: float = 0.0, + translation_y: float = 0.0, + mode: str = "bilinear", + padding_mode: str = "reflection" + ) -> torch.Tensor: + """ + Apply geometric transformation to tensor. + + Args: + tensor: Input tensor (B, C, H, W) + zoom: Zoom factor + angle: Rotation angle in degrees + translation_x: X translation in pixels + translation_y: Y translation in pixels + mode: Interpolation mode + padding_mode: Padding mode + + Returns: + Transformed tensor + """ + TensorUtils.validate_tensor_shape(tensor, expected_dims=4, name="input tensor") + + batch_size, channels, height, width = tensor.shape + device = tensor.device + + # Convert angle to radians + angle_rad = torch.tensor(angle * np.pi / 180.0, device=device) + cos_angle = torch.cos(angle_rad) + sin_angle = torch.sin(angle_rad) + + # Create transformation matrix + # [zoom*cos, -zoom*sin, tx] + # [zoom*sin, zoom*cos, ty] + theta = torch.tensor([ + [zoom * cos_angle, -zoom * sin_angle, translation_x / width * 2], + [zoom * sin_angle, zoom * cos_angle, translation_y / height * 2] + ], device=device, dtype=tensor.dtype).unsqueeze(0).repeat(batch_size, 1, 1) + + # Create sampling grid + grid = F.affine_grid(theta, tensor.size(), align_corners=False) + + # Apply transformation + transformed = F.grid_sample( + tensor, grid, + mode=mode, + padding_mode=padding_mode, + align_corners=False + ) + + return transformed + + @staticmethod + def blend_tensors( + tensor1: torch.Tensor, + tensor2: torch.Tensor, + alpha: float + ) -> torch.Tensor: + """ + Blend two tensors with alpha blending. + + Args: + tensor1: First tensor + tensor2: Second tensor + alpha: Blending factor (0.0 = tensor1, 1.0 = tensor2) + + Returns: + Blended tensor + """ + if tensor1.shape != tensor2.shape: + raise TensorProcessingError( + "Tensors must have the same shape for blending", + tensor_shape=tensor1.shape, + expected_shape=tensor2.shape + ) + + return (1 - alpha) * tensor1 + alpha * tensor2 + + @staticmethod + def get_tensor_stats(tensor: torch.Tensor) -> dict: + """ + Get statistical information about a tensor. + + Args: + tensor: Input tensor + + Returns: + Dictionary with statistics + """ + with torch.no_grad(): + stats = { + "shape": tuple(tensor.shape), + "dtype": str(tensor.dtype), + "device": str(tensor.device), + "mean": tensor.mean().item(), + "std": tensor.std().item(), + "min": tensor.min().item(), + "max": tensor.max().item(), + "memory_mb": tensor.numel() * tensor.element_size() / 1024 / 1024 + } + + # Check for problematic values + stats["has_nan"] = torch.isnan(tensor).any().item() + stats["has_inf"] = torch.isinf(tensor).any().item() + + return stats + + @staticmethod + def normalize_tensor( + tensor: torch.Tensor, + method: str = "minmax", + dim: Optional[Union[int, Tuple[int, ...]]] = None + ) -> torch.Tensor: + """ + Normalize tensor values. + + Args: + tensor: Input tensor + method: Normalization method ('minmax', 'zscore', 'unit') + dim: Dimensions to normalize over + + Returns: + Normalized tensor + """ + if method == "minmax": + if dim is None: + min_val = tensor.min() + max_val = tensor.max() + else: + min_val = tensor.min(dim=dim, keepdim=True)[0] + max_val = tensor.max(dim=dim, keepdim=True)[0] + + return (tensor - min_val) / (max_val - min_val + 1e-8) + + elif method == "zscore": + if dim is None: + mean = tensor.mean() + std = tensor.std() + else: + mean = tensor.mean(dim=dim, keepdim=True) + std = tensor.std(dim=dim, keepdim=True) + + return (tensor - mean) / (std + 1e-8) + + elif method == "unit": + if dim is None: + norm = tensor.norm() + else: + norm = tensor.norm(dim=dim, keepdim=True) + + return tensor / (norm + 1e-8) + + else: + raise ValueError(f"Unknown normalization method: {method}") + + @staticmethod + def safe_tensor_operation(func, *tensors, **kwargs): + """ + Safely perform tensor operations with error handling. + + Args: + func: Function to apply + *tensors: Input tensors + **kwargs: Additional arguments + + Returns: + Result of the operation + + Raises: + TensorProcessingError: If operation fails + """ + try: + return func(*tensors, **kwargs) + except RuntimeError as e: + if "CUDA out of memory" in str(e): + raise TensorProcessingError( + "CUDA out of memory during tensor operation", + operation=func.__name__ + ) + else: + raise TensorProcessingError( + f"Tensor operation failed: {e}", + operation=func.__name__ + ) + except Exception as e: + raise TensorProcessingError( + f"Unexpected error in tensor operation: {e}", + operation=func.__name__ + ) \ No newline at end of file diff --git a/src/deforum/utils/validation.py b/src/deforum/utils/validation.py new file mode 100644 index 0000000..0ff86a2 --- /dev/null +++ b/src/deforum/utils/validation.py @@ -0,0 +1,296 @@ +""" +Input validation utilities for Deforum Flux + +This module provides comprehensive input validation using the hybrid validation +approach with centralized rules and reusable utilities from the config module. +""" + +import re +import os +from typing import Any, List, Dict, Optional, Union +from pathlib import Path + +from deforum.core.exceptions import ValidationError +from deforum.core.logging_config import get_logger +from deforum.config.validation_rules import ValidationRules +from deforum.config.validation_utils import ValidationUtils, DomainValidators + + +class InputValidator: + """Utility class for validating various types of inputs using hybrid validation approach.""" + + def __init__(self, max_prompt_length: Optional[int] = None): + """ + Initialize input validator. + + Args: + max_prompt_length: Maximum allowed prompt length (uses ValidationRules default if None) + """ + self.max_prompt_length = max_prompt_length or ValidationRules.PROMPT["max_length"] + self.logger = get_logger(__name__) + + def validate_prompt(self, prompt: str) -> None: + """ + Validate text prompt using centralized rules. + + Args: + prompt: Text prompt to validate + + Raises: + ValidationError: If prompt is invalid + """ + errors = ValidationUtils.validate_string_not_empty(prompt, "prompt", self.max_prompt_length) + + # Check for potentially problematic characters + if re.search(r'[<>{}]', prompt): + self.logger.warning("Prompt contains potentially problematic characters: < > { }") + + if errors: + raise ValidationError("Prompt validation failed", validation_errors=errors) + + def validate_dimensions(self, width: int, height: int) -> None: + """ + Validate image dimensions using domain validators. + + Args: + width: Image width + height: Image height + + Raises: + ValidationError: If dimensions are invalid + """ + errors = DomainValidators.validate_dimensions(width, height) + + if errors: + raise ValidationError("Dimension validation failed", validation_errors=errors) + + def validate_generation_params( + self, + steps: int, + guidance_scale: float, + seed: Optional[int] = None + ) -> None: + """ + Validate generation parameters using domain validators. + + Args: + steps: Number of generation steps + guidance_scale: Guidance scale value + seed: Random seed (optional) + + Raises: + ValidationError: If parameters are invalid + """ + errors = DomainValidators.validate_generation_params(steps, guidance_scale, seed) + + if errors: + raise ValidationError("Generation parameter validation failed", validation_errors=errors) + + def validate_motion_params(self, motion_params: Dict[str, float]) -> None: + """ + Validate motion parameters using domain validators. + + Args: + motion_params: Dictionary of motion parameters + + Raises: + ValidationError: If parameters are invalid + """ + errors = DomainValidators.validate_motion_params(motion_params) + + if errors: + raise ValidationError("Motion parameter validation failed", validation_errors=errors) + + def validate_file_path( + self, + file_path: str, + must_exist: bool = True, + allowed_extensions: Optional[List[str]] = None + ) -> None: + """ + Validate file path using centralized utilities. + + Args: + file_path: Path to validate + must_exist: Whether file must exist + allowed_extensions: List of allowed file extensions + + Raises: + ValidationError: If path is invalid + """ + errors = ValidationUtils.validate_file_path(file_path, must_exist, allowed_extensions) + + if errors: + raise ValidationError("File path validation failed", validation_errors=errors) + + def validate_animation_config(self, config: Dict[str, Any]) -> None: + """ + Validate complete animation configuration using centralized approach. + + Args: + config: Animation configuration dictionary + + Raises: + ValidationError: If configuration is invalid + """ + errors = [] + + # Required fields validation + required_fields = ["prompt", "max_frames"] + for field in required_fields: + if field not in config: + errors.append(f"Missing required field: {field}") + + # Validate prompt + if "prompt" in config: + prompt_errors = ValidationUtils.validate_string_not_empty( + config["prompt"], "prompt", self.max_prompt_length + ) + errors.extend(prompt_errors) + + # Validate dimensions if present + if "width" in config and "height" in config: + dimension_errors = DomainValidators.validate_dimensions( + config["width"], config["height"] + ) + errors.extend(dimension_errors) + + # Validate generation parameters if present + if all(key in config for key in ["steps", "guidance_scale"]): + gen_errors = DomainValidators.validate_generation_params( + config["steps"], + config["guidance_scale"], + config.get("seed") + ) + errors.extend(gen_errors) + + # Validate animation settings + if "max_frames" in config and "fps" in config: + anim_errors = DomainValidators.validate_animation_settings( + config["max_frames"], config["fps"] + ) + errors.extend(anim_errors) + elif "max_frames" in config: + # Validate just max_frames + min_frames, max_frames_limit = ValidationRules.MAX_FRAMES["min"], ValidationRules.MAX_FRAMES["max"] + frame_errors = ValidationUtils.validate_range( + config["max_frames"], min_frames, max_frames_limit, "max_frames", int + ) + errors.extend(frame_errors) + + # Validate motion schedule + if "motion_schedule" in config: + motion_schedule = config["motion_schedule"] + if not isinstance(motion_schedule, dict): + errors.append("motion_schedule must be a dictionary") + else: + for frame, motion_params in motion_schedule.items(): + frame_errors = ValidationUtils.validate_frame_number(frame, "frame") + errors.extend(frame_errors) + + if isinstance(motion_params, dict): + motion_errors = DomainValidators.validate_motion_params(motion_params) + errors.extend(motion_errors) + else: + errors.append(f"Motion parameters for frame {frame} must be a dictionary") + + if errors: + raise ValidationError("Animation configuration validation failed", validation_errors=errors) + + def sanitize_filename(self, filename: str) -> str: + """ + Sanitize filename for safe file system usage using centralized utilities. + + Args: + filename: Original filename + + Returns: + Sanitized filename + """ + return ValidationUtils.sanitize_filename(filename) + + def validate_device_string(self, device: str) -> None: + """ + Validate device string using centralized rules. + + Args: + device: Device string to validate + + Raises: + ValidationError: If device string is invalid + """ + if not ValidationRules.is_valid_device(device): + raise ValidationError(f"Invalid device: {device}. Valid devices: {ValidationRules.VALID_DEVICES}") + + def validate_batch_size(self, batch_size: int, max_batch_size: Optional[int] = None) -> None: + """ + Validate batch size using centralized rules. + + Args: + batch_size: Batch size to validate + max_batch_size: Maximum allowed batch size (uses ValidationRules default if None) + + Raises: + ValidationError: If batch size is invalid + """ + max_batch = max_batch_size or ValidationRules.BATCH_SIZE["max"] + min_batch = ValidationRules.BATCH_SIZE["min"] + + errors = ValidationUtils.validate_range(batch_size, min_batch, max_batch, "batch_size", int) + + if errors: + raise ValidationError("Batch size validation failed", validation_errors=errors) + + def validate_model_name(self, model_name: str) -> None: + """ + Validate model name using centralized rules. + + Args: + model_name: Model name to validate + + Raises: + ValidationError: If model name is invalid + """ + if not ValidationRules.is_valid_model(model_name): + raise ValidationError(f"Invalid model: {model_name}. Valid models: {ValidationRules.VALID_MODELS}") + + def validate_animation_mode(self, animation_mode: str) -> None: + """ + Validate animation mode using centralized rules. + + Args: + animation_mode: Animation mode to validate + + Raises: + ValidationError: If animation mode is invalid + """ + if not ValidationRules.is_valid_animation_mode(animation_mode): + raise ValidationError(f"Invalid animation mode: {animation_mode}. Valid modes: {ValidationRules.VALID_ANIMATION_MODES}") + + def validate_log_level(self, log_level: str) -> None: + """ + Validate log level using centralized rules. + + Args: + log_level: Log level to validate + + Raises: + ValidationError: If log level is invalid + """ + if not ValidationRules.is_valid_log_level(log_level): + raise ValidationError(f"Invalid log level: {log_level}. Valid levels: {ValidationRules.VALID_LOG_LEVELS}") + + def validate_image_file(self, file_path: str, must_exist: bool = True) -> None: + """ + Validate image file path using centralized rules. + + Args: + file_path: Path to image file + must_exist: Whether file must exist + + Raises: + ValidationError: If file is invalid + """ + self.validate_file_path(file_path, must_exist, ValidationRules.ALLOWED_IMAGE_EXTENSIONS) + +