diff --git a/.gitignore b/.gitignore
index 10b9e7ecde..ebb18dec3c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,6 +3,9 @@
outputs/
assets/
+# logs
+**/logs/*
+
# setup
data_juicer.egg-info/
py_data_juicer.egg-info/
@@ -16,6 +19,7 @@ wandb/
__pycache__
.vscode/
.ipynb_checkpoints/
+performance_test_results*.json
# label studio related
label_studio_data/
@@ -31,3 +35,6 @@ tests/ops/data/*dup*
tests/tools/tmp_*/
tests/ops/deduplicator/chinese_dedup/
tests/ops/deduplicator/english_dedup/
+
+# perf bench data
+perf_bench_data
diff --git a/configs/config_all.yaml b/configs/config_all.yaml
index fb00e16800..2ff7e1a49c 100644
--- a/configs/config_all.yaml
+++ b/configs/config_all.yaml
@@ -68,6 +68,11 @@ eoc_special_token: '<|__dj__eoc|>' # the special token
executor_type: default # type of executor, support "default" or "ray" for now.
ray_address: auto # the address of the Ray cluster.
+# Core optimizer configuration
+enable_optimizer: false # enable/disable core optimizer
+optimizer_strategies: ['op_reorder'] # list of optimization strategies to apply
+ # available strategies: op_reorder, filter_fusion, mapper_fusion
+
# only for data analysis
percentiles: [0.25, 0.5, 0.75] # percentiles to analyze the dataset distribution
export_original_dataset: false # whether to export the original dataset with stats. If you only need the stats of the dataset, setting it to false could speed up the exporting.
diff --git a/configs/config_min.yaml b/configs/config_min.yaml
index ad85f7a538..5afd876bc4 100644
--- a/configs/config_min.yaml
+++ b/configs/config_min.yaml
@@ -11,3 +11,8 @@ executor_type: default # type of executor,
ray_address: auto # the address of the Ray cluster.
suffixes: null
add_suffix: false
+
+# Core optimizer configuration
+enable_optimizer: false # enable/disable core optimizer
+optimizer_strategies: ['op_reorder'] # list of optimization strategies to apply
+ # available strategies: op_reorder, filter_fusion, mapper_fusion
diff --git a/configs/demo/fused_operations_demo.yaml b/configs/demo/fused_operations_demo.yaml
new file mode 100644
index 0000000000..77a7a02f58
--- /dev/null
+++ b/configs/demo/fused_operations_demo.yaml
@@ -0,0 +1,92 @@
+# Fused Operations Demo Configuration
+# This config demonstrates how to use fused operations for optimal performance
+
+project_name: 'fused_operations_demo'
+dataset_path: 'path/to/your/dataset.jsonl' # Replace with your dataset path
+export_path: 'output/fused_processed_dataset.jsonl'
+export_shard_size: 0
+export_in_parallel: false
+np: 4
+text_keys: 'text'
+suffixes: []
+turbo: false
+skip_op_error: true
+use_cache: true
+ds_cache_dir: null
+open_monitor: true
+use_checkpoint: false
+temp_dir: null
+open_tracer: false
+op_list_to_trace: []
+trace_num: 10
+
+# Enable fused operations for optimal performance
+op_fusion: true
+fusion_strategy: 'probe' # Use probe strategy for optimal ordering
+cache_compress: null
+keep_stats_in_res_ds: false
+keep_hashes_in_res_ds: false
+adaptive_batch_size: false
+
+# For multimodal data processing
+image_key: 'images'
+image_special_token: '<__dj__image>'
+audio_key: 'audios'
+audio_special_token: '<__dj__audio>'
+video_key: 'videos'
+video_special_token: '<__dj__video>'
+eoc_special_token: '<|__dj__eoc|>'
+
+# Executor configuration
+executor_type: default
+ray_address: auto
+
+# Process pipeline with operations that will be automatically fused
+process:
+ # Phase 1: Text cleaning mappers (these run first)
+ - clean_html_mapper: {} # Remove HTML tags
+ - clean_links_mapper: {} # Remove URLs
+ - clean_email_mapper: {} # Remove email addresses
+ - clean_copyright_mapper: {} # Remove copyright notices
+
+ # Phase 2: Text quality filters (these will be fused automatically)
+ # Basic text characteristics
+ - text_length_filter: # Filter by text length
+ min_len: 50
+ max_len: 2000
+ - words_num_filter: # Filter by word count
+ min_num: 10
+ max_num: 500
+ - character_repetition_filter: # Filter repetitive characters
+ repetition_ratio: 0.8
+ - word_repetition_filter: # Filter repetitive words
+ min_ratio: 0.0
+ max_ratio: 0.5
+ - special_characters_filter: # Filter special character ratio
+ min_ratio: 0.0
+ max_ratio: 0.3
+ - alphanumeric_filter: # Filter alphanumeric ratio
+ min_ratio: 0.3
+ - average_line_length_filter: # Filter by average line length
+ min_len: 10
+ max_len: 100
+ - maximum_line_length_filter: # Filter by maximum line length
+ min_len: 10
+ max_len: 200
+
+ # Phase 3: Content quality filters (these will also be fused)
+ - perplexity_filter: # Filter by language model perplexity
+ max_ppl: 1500
+ - stopwords_filter: # Filter by stopword ratio
+ min_ratio: 0.1
+ - flagged_words_filter: # Filter by flagged word ratio
+ max_ratio: 0.05
+ - language_id_score_filter: # Filter by language confidence
+ lang: 'en'
+ min_score: 0.5
+ max_score: 1.0
+
+ # Phase 4: Text transformation mappers (these run after filtering)
+ - expand_macro_mapper: {} # Expand LaTeX macros
+ - chinese_convert_mapper: # Convert Chinese text
+ mode: 's2t' # Simplified to Traditional
diff --git a/configs/optimization/op_reorder_showcase.yaml b/configs/optimization/op_reorder_showcase.yaml
new file mode 100644
index 0000000000..393e813a8f
--- /dev/null
+++ b/configs/optimization/op_reorder_showcase.yaml
@@ -0,0 +1,53 @@
+# Configuration to showcase operation reordering optimization
+# This config has a suboptimal order that should be reordered by the optimizer
+# GOAL: Show dramatic performance difference by putting expensive operations first
+
+project_name: 'op-reorder-showcase'
+dataset_path: 'perf_bench_data/text/wiki-10k.jsonl'
+export_path: 'outputs/op_reorder_showcase/res.jsonl'
+np: 4
+use_cache: false
+
+process:
+ # VERY EXPENSIVE OPERATIONS (should be moved after filtering)
+ # These are resource-intensive operations that waste computation on filtered data
+ - text_chunk_mapper:
+ chunk_size: 500 # Smaller chunks = more processing
+ text_key: 'text'
+ mem_required: '2GB'
+
+ - text_entity_dependency_filter:
+ min_score: 0.9 # Very strict filtering
+ text_key: 'text'
+ mem_required: '3GB'
+
+ - text_pair_similarity_filter:
+ min_score: 0.8
+ text_key: 'text'
+ mem_required: '2GB'
+
+ # LIGHT FILTERS (should be moved to front)
+ # These are fast filters that should run early to reduce data volume
+ - text_length_filter:
+ min_len: 50 # Less restrictive to keep more data
+ max_len: 5000
+ text_key: 'text'
+
+ - text_action_filter:
+ action_types: ['question', 'command', 'statement'] # Keep all types
+ text_key: 'text'
+
+ # DEPENDENCY CHAIN (must stay in order)
+ # language_id must come before perplexity
+ - language_id_score_filter:
+ lang: 'en'
+ min_score: 0.5 # Much less strict to keep more data
+ text_key: 'text'
+
+ - perplexity_filter:
+ lang: 'en'
+ min_score: 0.1 # Much less strict to keep more data
+ text_key: 'text'
+
+ # ADDITIONAL EXPENSIVE OPERATIONS
+ # text_pair_similarity_filter moved up to replace text_embd_similarity_filter
diff --git a/data_juicer/benchmark/__init__.py b/data_juicer/benchmark/__init__.py
new file mode 100644
index 0000000000..53bf1226d9
--- /dev/null
+++ b/data_juicer/benchmark/__init__.py
@@ -0,0 +1,32 @@
+"""
+Data-Juicer Performance Benchmark Framework
+
+A comprehensive framework for A/B testing optimization strategies
+across different workloads, modalities, and operation complexities.
+"""
+
+from .core.benchmark_runner import BenchmarkConfig, BenchmarkRunner
+from .core.metrics_collector import MetricsCollector
+from .core.report_generator import ReportGenerator
+from .core.result_analyzer import ResultAnalyzer
+from .strategies.ab_test import ABTestConfig, StrategyABTest
+from .strategies.strategy_library import STRATEGY_LIBRARY, OptimizationStrategy
+from .utils.config_manager import ConfigManager
+from .workloads.workload_suite import WORKLOAD_SUITE, WorkloadDefinition, WorkloadSuite
+
+__version__ = "1.0.0"
+__all__ = [
+ "BenchmarkRunner",
+ "BenchmarkConfig",
+ "MetricsCollector",
+ "ResultAnalyzer",
+ "ReportGenerator",
+ "OptimizationStrategy",
+ "STRATEGY_LIBRARY",
+ "StrategyABTest",
+ "ABTestConfig",
+ "WorkloadSuite",
+ "WorkloadDefinition",
+ "WORKLOAD_SUITE",
+ "ConfigManager",
+]
diff --git a/data_juicer/benchmark/core/__init__.py b/data_juicer/benchmark/core/__init__.py
new file mode 100644
index 0000000000..860c40e2da
--- /dev/null
+++ b/data_juicer/benchmark/core/__init__.py
@@ -0,0 +1,15 @@
+"""Core benchmark framework components."""
+
+from .benchmark_runner import BenchmarkRunner
+from .metrics_collector import BenchmarkMetrics, MetricsCollector
+from .report_generator import ReportGenerator
+from .result_analyzer import ComparisonResult, ResultAnalyzer
+
+__all__ = [
+ "BenchmarkRunner",
+ "MetricsCollector",
+ "BenchmarkMetrics",
+ "ResultAnalyzer",
+ "ComparisonResult",
+ "ReportGenerator",
+]
diff --git a/data_juicer/benchmark/core/benchmark_runner.py b/data_juicer/benchmark/core/benchmark_runner.py
new file mode 100644
index 0000000000..0282bbea6d
--- /dev/null
+++ b/data_juicer/benchmark/core/benchmark_runner.py
@@ -0,0 +1,480 @@
+#!/usr/bin/env python3
+"""
+Main benchmark runner for executing performance tests.
+"""
+
+import hashlib
+import os
+import shutil
+import subprocess
+import time
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from loguru import logger
+
+from ..utils.config_manager import ConfigManager
+from .metrics_collector import BenchmarkMetrics, MetricsCollector
+
+
+@dataclass
+class BenchmarkConfig:
+ """Configuration for a benchmark run."""
+
+ dataset_path: str
+ config_path: str
+ output_dir: str
+ iterations: int = 3
+ warmup_runs: int = 1
+ timeout_seconds: int = 3600
+ strategy_name: str = "baseline"
+ strategy_config: Dict[str, Any] = None
+ sample_ratio: float = 1.0
+ sample_method: str = "random"
+
+
+class BenchmarkRunner:
+ """Main benchmark execution engine."""
+
+ def __init__(self, config: BenchmarkConfig):
+ self.config = config
+ self.metrics_collector = MetricsCollector()
+ self.config_manager = ConfigManager()
+
+ # Ensure output directory exists
+ Path(self.config.output_dir).mkdir(parents=True, exist_ok=True)
+
+ def run_benchmark(self) -> BenchmarkMetrics:
+ """Run a single benchmark iteration."""
+ logger.info(f"Starting benchmark: {self.config.strategy_name}")
+
+ # Start metrics collection
+ self.metrics_collector.start_monitoring()
+
+ try:
+ # Prepare configuration
+ config_file = self._prepare_config()
+
+ # Run the benchmark
+ start_time = time.time()
+ result = self._execute_benchmark(config_file)
+ end_time = time.time()
+
+ # Stop metrics collection
+ metrics = self.metrics_collector.stop_monitoring()
+
+ # Enhance metrics with benchmark-specific data
+ if metrics:
+ metrics.total_wall_time = end_time - start_time
+ metrics.strategy_name = self.config.strategy_name
+ metrics.config_hash = self._get_config_hash()
+
+ # Add benchmark-specific metrics
+ if result:
+ metrics.samples_processed = result.get("samples_processed", 0)
+ metrics.samples_retained = result.get("samples_retained", 0)
+ metrics.samples_per_second = (
+ metrics.samples_processed / metrics.total_wall_time if metrics.total_wall_time > 0 else 0
+ )
+ metrics.data_retention_rate = (
+ metrics.samples_retained / metrics.samples_processed if metrics.samples_processed > 0 else 0
+ )
+
+ return metrics
+
+ except Exception as e:
+ logger.error(f"Benchmark failed: {e}")
+ # Still try to get partial metrics
+ metrics = self.metrics_collector.stop_monitoring()
+ if metrics:
+ metrics.strategy_name = self.config.strategy_name
+ metrics.config_hash = self._get_config_hash()
+ return metrics
+
+ def run_benchmark_suite(self) -> List[BenchmarkMetrics]:
+ """Run multiple iterations of the benchmark."""
+ logger.info(f"Running benchmark suite: {self.config.iterations} iterations")
+
+ all_metrics = []
+
+ # Warmup runs (not counted in results)
+ for i in range(self.config.warmup_runs):
+ logger.info(f"Warmup run {i+1}/{self.config.warmup_runs}")
+ self.run_benchmark()
+
+ # Actual benchmark runs
+ for i in range(self.config.iterations):
+ logger.info(f"Benchmark iteration {i+1}/{self.config.iterations}")
+ metrics = self.run_benchmark()
+ if metrics:
+ all_metrics.append(metrics)
+
+ return all_metrics
+
+ def _prepare_config(self) -> str:
+ """Prepare configuration file for the benchmark."""
+ # Load base configuration
+ base_config = self.config_manager.load_config(self.config.config_path)
+
+ # Apply strategy-specific modifications
+ if self.config.strategy_config:
+ # Check if this is a core optimizer strategy
+ if self.config.strategy_config.get("_benchmark_optimizer_enabled"):
+ # For core optimizer strategies, we use environment variables instead of config keys
+ # This avoids config validation issues
+ logger.info("๐ง Core optimizer strategy detected - will use environment variables")
+ else:
+ # For regular config strategies, apply them directly
+ base_config = self.config_manager.apply_strategy_config(base_config, self.config.strategy_config)
+
+ # Apply sampling if needed
+ if self.config.sample_ratio < 1.0:
+ base_config = self._apply_sampling_config(base_config)
+
+ # Save modified configuration
+ config_output_path = os.path.join(self.config.output_dir, f"config_{self.config.strategy_name}.yaml")
+ self.config_manager.save_config(base_config, config_output_path)
+
+ return config_output_path
+
+ def _apply_sampling_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
+ """Apply sampling configuration to the config."""
+ # If sampling is enabled, we'll need to create a sampled dataset
+ if self.config.sample_ratio < 1.0:
+ sampled_dataset_path = self._create_sampled_dataset()
+ config["dataset_path"] = sampled_dataset_path
+
+ return config
+
+ def _create_sampled_dataset(self) -> str:
+ """Create a sampled version of the dataset."""
+ import random
+
+ # Read the original dataset
+ with open(self.config.dataset_path, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+
+ total_samples = len(lines)
+ sample_size = int(total_samples * self.config.sample_ratio)
+
+ # Sample the data based on method
+ if self.config.sample_method == "random":
+ sampled_lines = random.sample(lines, sample_size)
+ elif self.config.sample_method == "first":
+ sampled_lines = lines[:sample_size]
+ elif self.config.sample_method == "last":
+ sampled_lines = lines[-sample_size:]
+ else:
+ raise ValueError(f"Unknown sampling method: {self.config.sample_method}")
+
+ # Create sampled dataset file
+ sampled_path = os.path.join(
+ self.config.output_dir, f"sampled_dataset_{self.config.sample_ratio}_{self.config.sample_method}.jsonl"
+ )
+
+ with open(sampled_path, "w", encoding="utf-8") as f:
+ f.writelines(sampled_lines)
+
+ logger.info(f"Created sampled dataset: {sampled_path} ({sample_size}/{total_samples} samples)")
+ return sampled_path
+
+ def _execute_benchmark(self, config_file: str) -> Optional[Dict[str, Any]]:
+ """Execute the actual benchmark using data-juicer."""
+ try:
+ # Clean up output directory before running benchmark
+ self._cleanup_output_directory()
+
+ # Check if core optimizer should be enabled
+ if self.config.strategy_config and self.config.strategy_config.get("_benchmark_optimizer_enabled"):
+ # Use custom executor that applies core optimizer
+ return self._execute_with_core_optimizer(config_file)
+ else:
+ # Use standard subprocess execution
+ return self._execute_standard_benchmark(config_file)
+
+ except subprocess.TimeoutExpired:
+ logger.error(f"Benchmark timed out after {self.config.timeout_seconds} seconds")
+ return None
+ except Exception as e:
+ logger.error(f"Error executing benchmark: {e}")
+ return None
+
+ def _execute_standard_benchmark(self, config_file: str) -> Optional[Dict[str, Any]]:
+ """Execute benchmark using standard subprocess."""
+ # Build command
+ cmd = [
+ "python",
+ "-m",
+ "data_juicer.tools.process_data",
+ "--config",
+ config_file,
+ "--export_path",
+ os.path.join(self.config.output_dir, "output.jsonl"),
+ ]
+
+ # Only add dataset_path if it's provided (config might have it instead)
+ if self.config.dataset_path:
+ cmd.extend(["--dataset_path", self.config.dataset_path])
+
+ logger.debug(f"Executing command: {' '.join(cmd)}")
+
+ # Run the benchmark
+ result = subprocess.run(
+ cmd, capture_output=True, text=True, timeout=self.config.timeout_seconds, cwd=os.getcwd()
+ )
+
+ if result.returncode != 0:
+ logger.error(f"Benchmark execution failed: {result.stderr}")
+ return None
+
+ # Log the subprocess output for debugging
+ logger.info("=== Subprocess STDOUT ===")
+ logger.info(result.stdout)
+ logger.info("=== Subprocess STDERR ===")
+ logger.info(result.stderr)
+ logger.info("=== End Subprocess Output ===")
+
+ # Parse output for metrics (data-juicer logs to stderr, not stdout)
+ return self._parse_benchmark_output(result.stdout)
+
+ def _execute_with_core_optimizer(self, config_file: str) -> Optional[Dict[str, Any]]:
+ """Execute benchmark with core optimizer applied."""
+ try:
+ # Get the enabled optimizer strategies
+ enabled_strategies = self.config.strategy_config.get("_benchmark_optimizer_strategies", [])
+ logger.info(f"๐ง Applying core optimizer with strategies: {enabled_strategies}")
+
+ # Use environment variables to pass optimizer information
+ env = os.environ.copy()
+ env["DJ_ENABLE_CORE_OPTIMIZER"] = "true"
+ env["DJ_OPTIMIZER_STRATEGIES"] = ",".join(enabled_strategies)
+
+ # Build command with the original config
+ cmd = [
+ "python",
+ "-m",
+ "data_juicer.tools.process_data",
+ "--config",
+ config_file,
+ "--export_path",
+ os.path.join(self.config.output_dir, "output.jsonl"),
+ ]
+
+ # Only add dataset_path if it's provided (config might have it instead)
+ if self.config.dataset_path:
+ cmd.extend(["--dataset_path", self.config.dataset_path])
+
+ logger.debug(f"Executing command with core optimizer: {' '.join(cmd)}")
+
+ # Run the benchmark with environment variables
+ result = subprocess.run(
+ cmd, capture_output=True, text=True, timeout=self.config.timeout_seconds, cwd=os.getcwd(), env=env
+ )
+
+ # No cleanup needed since we're using environment variables
+
+ if result.returncode != 0:
+ logger.error(f"Benchmark execution failed: {result.stderr}")
+ return None
+
+ # Log the subprocess output for debugging
+ logger.info("=== Subprocess STDOUT ===")
+ logger.info(result.stdout)
+ logger.info("=== Subprocess STDERR ===")
+ logger.info(result.stderr)
+ logger.info("=== End Subprocess Output ===")
+
+ # Parse output for metrics (data-juicer logs to stderr, not stdout)
+ return self._parse_benchmark_output(result.stdout)
+
+ except Exception as e:
+ logger.error(f"Error executing benchmark with core optimizer: {e}")
+ return None
+
+ def _cleanup_output_directory(self):
+ """Clean up the output directory before running benchmark to prevent multiple outputs."""
+ try:
+
+ output_path = os.path.join(self.config.output_dir, "output.jsonl")
+
+ # Remove existing output directory if it exists
+ if os.path.exists(output_path):
+ logger.info(f"Cleaning up existing output directory: {output_path}")
+ if os.path.isdir(output_path):
+ shutil.rmtree(output_path)
+ else:
+ os.remove(output_path)
+ logger.info("Output directory cleaned up successfully")
+ else:
+ logger.info("No existing output directory to clean up")
+
+ except Exception as e:
+ logger.warning(f"Failed to clean up output directory: {e}")
+ # Don't fail the benchmark if cleanup fails
+
+ def _get_config_hash(self) -> str:
+ """Generate hash of current configuration."""
+ config_str = f"{self.config.dataset_path}_{self.config.config_path}_{self.config.strategy_name}"
+ if self.config.strategy_config:
+ config_str += str(sorted(self.config.strategy_config.items()))
+
+ return hashlib.md5(config_str.encode()).hexdigest()[:8]
+
+ def _parse_benchmark_output(self, output: str) -> Dict[str, Any]:
+ """Parse benchmark output to extract metrics."""
+
+ # Try file-based metrics first (more reliable)
+ file_metrics = self._get_file_based_metrics()
+ if file_metrics:
+ logger.info(f"=== File-Based Metrics ===")
+ logger.info(f"Initial samples: {file_metrics.get('samples_processed', 'N/A')}")
+ logger.info(f"Final samples: {file_metrics.get('samples_retained', 'N/A')}")
+ logger.info(f"Retention rate: {file_metrics.get('retention_rate', 'N/A')}")
+ logger.info(f"=== End File-Based Metrics ===")
+ return file_metrics
+
+ # Fallback to text parsing (less reliable)
+ logger.info("File-based metrics not available, falling back to text parsing...")
+ return self._parse_text_output(output)
+
+ def _get_file_based_metrics(self) -> Optional[Dict[str, Any]]:
+ """Get metrics by counting actual files (more reliable than text parsing)."""
+ try:
+ # Get initial dataset size
+ initial_samples = self._count_input_records()
+
+ # Get final dataset size from output files
+ final_samples = self._count_output_records()
+
+ if initial_samples is None or final_samples is None:
+ logger.warning("Could not determine initial or final sample counts from files")
+ return None
+
+ retention_rate = final_samples / initial_samples if initial_samples > 0 else 0
+
+ return {
+ "samples_processed": initial_samples,
+ "samples_retained": final_samples,
+ "retention_rate": retention_rate,
+ }
+
+ except Exception as e:
+ logger.error(f"Error getting file-based metrics: {e}")
+ return None
+
+ def _count_input_records(self) -> Optional[int]:
+ """Count records in the input dataset."""
+ try:
+ import subprocess
+
+ import yaml
+
+ # Get dataset path - either from config or from config file
+ dataset_path = self.config.dataset_path
+ if dataset_path is None:
+ # Try to extract dataset_path from config file
+ try:
+ with open(self.config.config_path, "r") as f:
+ config_data = yaml.safe_load(f)
+ dataset_path = config_data.get("dataset_path")
+ except Exception as e:
+ logger.warning(f"Could not read dataset_path from config file: {e}")
+ return None
+
+ if dataset_path is None:
+ logger.warning("No dataset path available for counting input records")
+ return None
+
+ result = subprocess.run(["wc", "-l", dataset_path], capture_output=True, text=True, timeout=30)
+ if result.returncode == 0:
+ return int(result.stdout.split()[0])
+ except Exception as e:
+ logger.error(f"Error counting input records: {e}")
+ return None
+
+ def _count_output_records(self) -> Optional[int]:
+ """Count records in the output files."""
+ try:
+ import os
+ import subprocess
+
+ output_dir = os.path.join(self.config.output_dir, "output.jsonl")
+ if not os.path.exists(output_dir):
+ logger.warning(f"Output directory not found: {output_dir}")
+ return None
+
+ # Count all JSON files in the output directory
+ result = subprocess.run(
+ ["find", output_dir, "-name", "*.json", "-exec", "wc", "-l", "{}", "+"],
+ capture_output=True,
+ text=True,
+ timeout=60,
+ )
+
+ if result.returncode == 0:
+ # Parse the output to get total count
+ lines = result.stdout.strip().split("\n")
+
+ # Check if we have a "total" line (from using + syntax)
+ for line in lines:
+ if "total" in line.lower():
+ parts = line.split()
+ if parts and parts[0].isdigit():
+ return int(parts[0])
+
+ # Fallback: sum individual counts (for \; syntax)
+ total = 0
+ for line in lines:
+ if line.strip() and "total" not in line.lower():
+ parts = line.split()
+ if parts and parts[0].isdigit():
+ total += int(parts[0])
+ return total
+
+ except Exception as e:
+ logger.error(f"Error counting output records: {e}")
+ return None
+
+ def _parse_text_output(self, output: str) -> Dict[str, Any]:
+ """Fallback text parsing method (less reliable)."""
+ metrics = {}
+ lines = output.split("\n")
+ initial_samples = None
+ final_samples = None
+
+ for line in lines:
+ # Look for initial sample count patterns
+ if "samples left after filtering empty text" in line:
+ try:
+ parts = line.split()
+ for part in parts:
+ if part.isdigit():
+ initial_samples = int(part)
+ break
+ except Exception as e:
+ logger.error(f"Error parsing initial samples: {e}")
+
+ # Look for final sample count patterns
+ if "Left" in line and "samples" in line and "Done in" in line:
+ try:
+ parts = line.split()
+ for i, part in enumerate(parts):
+ if part == "Left" and i + 1 < len(parts):
+ if parts[i + 1].isdigit():
+ final_samples = int(parts[i + 1])
+ break
+ except Exception as e:
+ logger.error(f"Error parsing final samples: {e}")
+
+ if initial_samples is not None:
+ metrics["samples_processed"] = initial_samples
+ if final_samples is not None:
+ metrics["samples_retained"] = final_samples
+
+ logger.info(f"=== Text Parsing Results ===")
+ logger.info(f"Initial samples found: {initial_samples}")
+ logger.info(f"Final samples found: {final_samples}")
+ logger.info(f"Parsed metrics: {metrics}")
+ logger.info(f"=== End Text Parsing ===")
+ return metrics
diff --git a/data_juicer/benchmark/core/metrics_collector.py b/data_juicer/benchmark/core/metrics_collector.py
new file mode 100644
index 0000000000..e6c505d129
--- /dev/null
+++ b/data_juicer/benchmark/core/metrics_collector.py
@@ -0,0 +1,181 @@
+#!/usr/bin/env python3
+"""
+Metrics collection system for performance benchmarking.
+"""
+
+import os
+import threading
+import time
+from dataclasses import dataclass, field
+from typing import Any, Dict
+
+import psutil
+from loguru import logger
+
+
+@dataclass
+class BenchmarkMetrics:
+ """Comprehensive performance metrics for a benchmark run."""
+
+ # Timing metrics
+ total_wall_time: float
+ processing_time: float
+ io_time: float
+ overhead_time: float
+
+ # Throughput metrics
+ samples_per_second: float
+ bytes_per_second: float
+ operations_per_second: float
+
+ # Resource metrics
+ peak_memory_mb: float
+ average_cpu_percent: float
+ peak_cpu_percent: float
+
+ # Quality metrics
+ samples_processed: int
+ samples_retained: int
+ data_retention_rate: float
+
+ # Configuration
+ config_hash: str
+ strategy_name: str
+
+ # Additional metadata
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def speedup_vs_baseline(self, baseline: "BenchmarkMetrics") -> float:
+ """Calculate speedup vs baseline."""
+ if baseline.total_wall_time == 0:
+ return 0.0
+ return baseline.total_wall_time / self.total_wall_time
+
+ def memory_efficiency(self, baseline: "BenchmarkMetrics") -> float:
+ """Calculate memory efficiency vs baseline."""
+ if self.peak_memory_mb == 0:
+ return 0.0
+ return baseline.peak_memory_mb / self.peak_memory_mb
+
+ def throughput_improvement(self, baseline: "BenchmarkMetrics") -> float:
+ """Calculate throughput improvement vs baseline."""
+ if baseline.samples_per_second == 0:
+ return 0.0
+ return self.samples_per_second / baseline.samples_per_second
+
+
+class MetricsCollector:
+ """Collects comprehensive performance metrics during benchmark runs."""
+
+ def __init__(self):
+ self.monitoring = False
+ self.monitor_thread = None
+ self.metrics_data = []
+ self.start_time = None
+ self.end_time = None
+
+ def start_monitoring(self):
+ """Start monitoring system resources."""
+ self.monitoring = True
+ self.metrics_data = []
+ self.start_time = time.time()
+
+ # Start monitoring thread
+ self.monitor_thread = threading.Thread(target=self._monitor_resources)
+ self.monitor_thread.daemon = True
+ self.monitor_thread.start()
+
+ logger.debug("Started metrics monitoring")
+
+ def stop_monitoring(self):
+ """Stop monitoring and return collected metrics."""
+ if not self.monitoring:
+ return None
+
+ self.monitoring = False
+ self.end_time = time.time()
+
+ if self.monitor_thread:
+ self.monitor_thread.join(timeout=1.0)
+
+ logger.debug("Stopped metrics monitoring")
+ return self._calculate_metrics()
+
+ def _monitor_resources(self):
+ """Monitor system resources in background thread."""
+ process = psutil.Process()
+
+ while self.monitoring:
+ try:
+ # Get current metrics
+ cpu_percent = process.cpu_percent()
+ memory_info = process.memory_info()
+ memory_mb = memory_info.rss / 1024 / 1024 # Convert to MB
+
+ self.metrics_data.append({"timestamp": time.time(), "cpu_percent": cpu_percent, "memory_mb": memory_mb})
+
+ time.sleep(0.1) # Sample every 100ms
+
+ except Exception as e:
+ logger.warning(f"Error monitoring resources: {e}")
+ break
+
+ def _calculate_metrics(self) -> BenchmarkMetrics:
+ """Calculate final metrics from collected data."""
+ if not self.metrics_data:
+ return BenchmarkMetrics(
+ total_wall_time=0,
+ processing_time=0,
+ io_time=0,
+ overhead_time=0,
+ samples_per_second=0,
+ bytes_per_second=0,
+ operations_per_second=0,
+ peak_memory_mb=0,
+ average_cpu_percent=0,
+ peak_cpu_percent=0,
+ samples_processed=0,
+ samples_retained=0,
+ data_retention_rate=0,
+ config_hash="",
+ strategy_name="",
+ )
+
+ # Calculate timing metrics
+ total_wall_time = self.end_time - self.start_time if self.end_time and self.start_time else 0
+
+ # Calculate resource metrics
+ memory_values = [d["memory_mb"] for d in self.metrics_data]
+ cpu_values = [d["cpu_percent"] for d in self.metrics_data]
+
+ peak_memory = max(memory_values) if memory_values else 0
+ avg_cpu = sum(cpu_values) / len(cpu_values) if cpu_values else 0
+ peak_cpu = max(cpu_values) if cpu_values else 0
+
+ return BenchmarkMetrics(
+ total_wall_time=total_wall_time,
+ processing_time=total_wall_time, # Simplified for now
+ io_time=0, # Would need more detailed tracking
+ overhead_time=0, # Would need more detailed tracking
+ samples_per_second=0, # Will be set by benchmark runner
+ bytes_per_second=0, # Will be set by benchmark runner
+ operations_per_second=0, # Will be set by benchmark runner
+ peak_memory_mb=peak_memory,
+ average_cpu_percent=avg_cpu,
+ peak_cpu_percent=peak_cpu,
+ samples_processed=0, # Will be set by benchmark runner
+ samples_retained=0, # Will be set by benchmark runner
+ data_retention_rate=0, # Will be set by benchmark runner
+ config_hash="", # Will be set by benchmark runner
+ strategy_name="", # Will be set by benchmark runner
+ )
+
+ def get_system_info(self) -> Dict[str, Any]:
+ """Get system information for context."""
+ return {
+ "cpu_count": psutil.cpu_count(),
+ "memory_total_gb": psutil.virtual_memory().total / (1024**3),
+ "disk_free_gb": psutil.disk_usage("/").free / (1024**3),
+ "python_version": os.sys.version,
+ "platform": os.name,
+ }
diff --git a/data_juicer/benchmark/core/report_generator.py b/data_juicer/benchmark/core/report_generator.py
new file mode 100644
index 0000000000..4874e3f528
--- /dev/null
+++ b/data_juicer/benchmark/core/report_generator.py
@@ -0,0 +1,336 @@
+#!/usr/bin/env python3
+"""
+Report generation for benchmark results.
+"""
+
+import json
+from datetime import datetime
+from pathlib import Path
+from typing import Dict, List
+
+import numpy as np
+from loguru import logger
+
+from .metrics_collector import BenchmarkMetrics
+from .result_analyzer import ComparisonResult
+
+
+class ReportGenerator:
+ """Generates comprehensive reports from benchmark results."""
+
+ def __init__(self, output_dir: str):
+ self.output_dir = Path(output_dir)
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ def _convert_numpy_types(self, obj):
+ """Convert numpy types to Python native types for JSON serialization."""
+ if isinstance(obj, np.bool_):
+ return bool(obj)
+ elif isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, dict):
+ return {key: self._convert_numpy_types(value) for key, value in obj.items()}
+ elif isinstance(obj, list):
+ return [self._convert_numpy_types(item) for item in obj]
+ else:
+ return obj
+
+ def generate_ab_test_report(
+ self,
+ results: Dict[str, List[BenchmarkMetrics]],
+ comparisons: Dict[str, ComparisonResult],
+ test_name: str = "A/B Test",
+ ) -> str:
+ """Generate a comprehensive A/B test report."""
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ report_file = self.output_dir / f"ab_test_report_{timestamp}.html"
+
+ # Generate HTML report
+ html_content = self._generate_html_report(results, comparisons, test_name)
+
+ with open(report_file, "w") as f:
+ f.write(html_content)
+
+ # Also generate JSON data
+ json_file = self.output_dir / f"ab_test_data_{timestamp}.json"
+ self._save_json_data(results, comparisons, json_file)
+
+ logger.info(f"Generated A/B test report: {report_file}")
+ return str(report_file)
+
+ def generate_workload_report(
+ self, workload_results: Dict[str, Dict[str, List[BenchmarkMetrics]]], test_name: str = "Workload Benchmark"
+ ) -> str:
+ """Generate a report for workload testing across different scenarios."""
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ report_file = self.output_dir / f"workload_report_{timestamp}.html"
+
+ html_content = self._generate_workload_html_report(workload_results, test_name)
+
+ with open(report_file, "w") as f:
+ f.write(html_content)
+
+ logger.info(f"Generated workload report: {report_file}")
+ return str(report_file)
+
+ def _generate_html_report(
+ self, results: Dict[str, List[BenchmarkMetrics]], comparisons: Dict[str, ComparisonResult], test_name: str
+ ) -> str:
+ """Generate HTML report content."""
+
+ html = f"""
+
+
+
+ {test_name} - Benchmark Report
+
+
+
+
+
+
+
Executive Summary
+ {self._generate_summary_html(comparisons)}
+
+
+
+ {self._generate_metrics_cards(results)}
+
+
+ Detailed Comparisons
+ {self._generate_comparisons_html(comparisons)}
+
+ Raw Data
+ {self._generate_raw_data_table(results)}
+
+
+
+ """
+
+ return html
+
+ def _generate_summary_html(self, comparisons: Dict[str, ComparisonResult]) -> str:
+ """Generate summary section HTML."""
+ if not comparisons:
+ return "No comparisons available.
"
+
+ summary_html = ""
+ for strategy_name, comparison in comparisons.items():
+ status_class = (
+ "improvement" if comparison.is_improvement() else "regression" if comparison.is_regression() else ""
+ )
+ summary_html += f"""
+ -
+ {strategy_name}: {comparison.summary}
+
+ """
+ summary_html += "
"
+
+ return summary_html
+
+ def _generate_metrics_cards(self, results: Dict[str, List[BenchmarkMetrics]]) -> str:
+ """Generate metrics cards HTML."""
+ cards_html = ""
+
+ for strategy_name, metrics_list in results.items():
+ if not metrics_list:
+ continue
+
+ # Calculate aggregate metrics
+ avg_time = sum(m.total_wall_time for m in metrics_list) / len(metrics_list)
+ avg_throughput = sum(m.samples_per_second for m in metrics_list) / len(metrics_list)
+ max_memory = max(m.peak_memory_mb for m in metrics_list)
+ avg_retention = sum(m.data_retention_rate for m in metrics_list) / len(metrics_list)
+
+ cards_html += f"""
+
+
{avg_time:.2f}s
+
Avg Time - {strategy_name}
+
+
+
{avg_throughput:.1f}
+
Samples/sec - {strategy_name}
+
+
+
{max_memory:.0f}MB
+
Peak Memory - {strategy_name}
+
+
+
{avg_retention:.1%}
+
Retention Rate - {strategy_name}
+
+ """
+
+ return cards_html
+
+ def _generate_comparisons_html(self, comparisons: Dict[str, ComparisonResult]) -> str:
+ """Generate comparisons section HTML."""
+ if not comparisons:
+ return "No comparisons available.
"
+
+ comparisons_html = ""
+ for strategy_name, comparison in comparisons.items():
+ status_class = (
+ "improvement" if comparison.is_improvement() else "regression" if comparison.is_regression() else ""
+ )
+
+ comparisons_html += f"""
+
+
{strategy_name} vs {comparison.baseline_name}
+
Speedup: {comparison.speedup:.2f}x
+
Throughput Improvement: {comparison.throughput_improvement:.2f}x
+
Memory Efficiency: {comparison.memory_efficiency:.2f}x
+
Statistical Significance: {comparison.is_significant} (p={comparison.p_value:.4f})
+
Summary: {comparison.summary}
+
+ """
+
+ return comparisons_html
+
+ def _generate_raw_data_table(self, results: Dict[str, List[BenchmarkMetrics]]) -> str:
+ """Generate raw data table HTML."""
+ table_html = "| Strategy | Run | Time (s) | Throughput | Memory (MB) | Retention |
"
+
+ for strategy_name, metrics_list in results.items():
+ for i, metrics in enumerate(metrics_list):
+ table_html += f"""
+
+ | {strategy_name} |
+ {i+1} |
+ {metrics.total_wall_time:.2f} |
+ {metrics.samples_per_second:.1f} |
+ {metrics.peak_memory_mb:.0f} |
+ {metrics.data_retention_rate:.1%} |
+
+ """
+
+ table_html += "
"
+ return table_html
+
+ def _generate_workload_html_report(
+ self, workload_results: Dict[str, Dict[str, List[BenchmarkMetrics]]], test_name: str
+ ) -> str:
+ """Generate HTML report for workload testing."""
+ # Similar structure but organized by workload
+ html = f"""
+
+
+
+ {test_name} - Workload Report
+
+
+
+
+"""
+
+ for workload_name, workload_data in workload_results.items():
+ html += f"""
+
+
+
{workload_name}
+
+ {self._generate_metrics_cards(workload_data)}
+
+ """
+
+ html += ""
+ return html
+
+ def _save_json_data(
+ self, results: Dict[str, List[BenchmarkMetrics]], comparisons: Dict[str, ComparisonResult], json_file: Path
+ ):
+ """Save raw data as JSON."""
+ data = {"timestamp": datetime.now().isoformat(), "results": {}, "comparisons": {}}
+
+ # Convert results to serializable format
+ for strategy_name, metrics_list in results.items():
+ data["results"][strategy_name] = [
+ {
+ "total_wall_time": m.total_wall_time,
+ "samples_per_second": m.samples_per_second,
+ "peak_memory_mb": m.peak_memory_mb,
+ "average_cpu_percent": m.average_cpu_percent,
+ "samples_processed": m.samples_processed,
+ "samples_retained": m.samples_retained,
+ "data_retention_rate": m.data_retention_rate,
+ "strategy_name": m.strategy_name,
+ "config_hash": m.config_hash,
+ }
+ for m in metrics_list
+ ]
+
+ # Convert comparisons to serializable format
+ for strategy_name, comparison in comparisons.items():
+ data["comparisons"][strategy_name] = {
+ "baseline_name": comparison.baseline_name,
+ "test_name": comparison.test_name,
+ "speedup": comparison.speedup,
+ "throughput_improvement": comparison.throughput_improvement,
+ "memory_efficiency": comparison.memory_efficiency,
+ "is_significant": comparison.is_significant,
+ "confidence_level": comparison.confidence_level,
+ "p_value": comparison.p_value,
+ "summary": comparison.summary,
+ # Note: baseline_metrics and test_metrics are excluded as they contain non-serializable BenchmarkMetrics objects
+ }
+
+ try:
+ # Convert numpy types to Python native types
+ data = self._convert_numpy_types(data)
+ with open(json_file, "w") as f:
+ json.dump(data, f, indent=2)
+ logger.info(f"Successfully saved JSON data to {json_file}")
+ except Exception as e:
+ logger.error(f"Failed to save JSON data to {json_file}: {e}")
+ # Try to save a minimal version without problematic fields
+ try:
+ minimal_data = {
+ "timestamp": data["timestamp"],
+ "results": data["results"],
+ "comparisons": {
+ name: {
+ "baseline_name": comp.baseline_name,
+ "test_name": comp.test_name,
+ "speedup": comp.speedup,
+ "summary": comp.summary,
+ }
+ for name, comp in comparisons.items()
+ },
+ }
+ minimal_data = self._convert_numpy_types(minimal_data)
+ with open(json_file, "w") as f:
+ json.dump(minimal_data, f, indent=2)
+ logger.warning(f"Saved minimal JSON data to {json_file}")
+ except Exception as e2:
+ logger.error(f"Failed to save even minimal JSON data: {e2}")
+ raise
diff --git a/data_juicer/benchmark/core/result_analyzer.py b/data_juicer/benchmark/core/result_analyzer.py
new file mode 100644
index 0000000000..d9b0d47ea6
--- /dev/null
+++ b/data_juicer/benchmark/core/result_analyzer.py
@@ -0,0 +1,224 @@
+#!/usr/bin/env python3
+"""
+Result analysis and comparison tools for benchmark results.
+"""
+
+import statistics
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple
+
+from loguru import logger
+
+from .metrics_collector import BenchmarkMetrics
+
+
+@dataclass
+class ComparisonResult:
+ """Results of comparing two benchmark configurations."""
+
+ # Configuration info
+ baseline_name: str
+ test_name: str
+
+ # Performance comparison
+ speedup: float # test_time / baseline_time
+ throughput_improvement: float # test_throughput / baseline_throughput
+ memory_efficiency: float # baseline_memory / test_memory
+
+ # Statistical significance
+ is_significant: bool
+ confidence_level: float
+ p_value: float
+
+ # Raw metrics
+ baseline_metrics: BenchmarkMetrics
+ test_metrics: BenchmarkMetrics
+
+ # Summary
+ summary: str
+
+ def is_improvement(self, threshold: float = 1.05) -> bool:
+ """Check if the test shows significant improvement."""
+ return self.speedup > threshold and self.is_significant
+
+ def is_regression(self, threshold: float = 0.95) -> bool:
+ """Check if the test shows significant regression."""
+ return self.speedup < threshold and self.is_significant
+
+
+class ResultAnalyzer:
+ """Analyzes and compares benchmark results."""
+
+ def __init__(self, confidence_level: float = 0.95):
+ self.confidence_level = confidence_level
+
+ def analyze_single_run(self, metrics: BenchmarkMetrics) -> Dict[str, Any]:
+ """Analyze a single benchmark run."""
+ return {
+ "total_time": metrics.total_wall_time,
+ "throughput": metrics.samples_per_second,
+ "memory_peak": metrics.peak_memory_mb,
+ "cpu_avg": metrics.average_cpu_percent,
+ "retention_rate": metrics.data_retention_rate,
+ "efficiency_score": self._calculate_efficiency_score(metrics),
+ }
+
+ def compare_runs(
+ self,
+ baseline_metrics: List[BenchmarkMetrics],
+ test_metrics: List[BenchmarkMetrics],
+ baseline_name: str = "baseline",
+ test_name: str = "test",
+ ) -> ComparisonResult:
+ """Compare two sets of benchmark runs."""
+
+ if not baseline_metrics or not test_metrics:
+ raise ValueError("Both baseline and test metrics must be provided")
+
+ # Calculate aggregate metrics
+ baseline_agg = self._aggregate_metrics(baseline_metrics)
+ test_agg = self._aggregate_metrics(test_metrics)
+
+ # Calculate comparisons
+ speedup = baseline_agg.total_wall_time / test_agg.total_wall_time if test_agg.total_wall_time > 0 else 0
+ throughput_improvement = (
+ test_agg.samples_per_second / baseline_agg.samples_per_second if baseline_agg.samples_per_second > 0 else 0
+ )
+ memory_efficiency = baseline_agg.peak_memory_mb / test_agg.peak_memory_mb if test_agg.peak_memory_mb > 0 else 0
+
+ # Statistical significance test
+ is_significant, p_value = self._test_significance(
+ [m.total_wall_time for m in baseline_metrics], [m.total_wall_time for m in test_metrics]
+ )
+
+ # Generate summary
+ summary = self._generate_summary(
+ speedup, throughput_improvement, memory_efficiency, is_significant, baseline_name, test_name
+ )
+
+ return ComparisonResult(
+ baseline_name=baseline_name,
+ test_name=test_name,
+ speedup=speedup,
+ throughput_improvement=throughput_improvement,
+ memory_efficiency=memory_efficiency,
+ is_significant=is_significant,
+ confidence_level=self.confidence_level,
+ p_value=p_value,
+ baseline_metrics=baseline_agg,
+ test_metrics=test_agg,
+ summary=summary,
+ )
+
+ def analyze_ab_test(self, results: Dict[str, List[BenchmarkMetrics]]) -> Dict[str, ComparisonResult]:
+ """Analyze results from an A/B test with multiple strategies."""
+ if len(results) < 2:
+ raise ValueError("A/B test requires at least 2 strategies")
+
+ # Use first strategy as baseline
+ baseline_name = list(results.keys())[0]
+ baseline_metrics = results[baseline_name]
+
+ comparisons = {}
+
+ for strategy_name, strategy_metrics in results.items():
+ if strategy_name == baseline_name:
+ continue
+
+ comparison = self.compare_runs(baseline_metrics, strategy_metrics, baseline_name, strategy_name)
+ comparisons[strategy_name] = comparison
+
+ return comparisons
+
+ def _aggregate_metrics(self, metrics_list: List[BenchmarkMetrics]) -> BenchmarkMetrics:
+ """Aggregate multiple metrics into a single representative metric."""
+ if not metrics_list:
+ raise ValueError("Cannot aggregate empty metrics list")
+
+ # Calculate means for most metrics
+ total_wall_time = statistics.mean([m.total_wall_time for m in metrics_list])
+ samples_per_second = statistics.mean([m.samples_per_second for m in metrics_list])
+ peak_memory_mb = max([m.peak_memory_mb for m in metrics_list]) # Use max for peak
+ average_cpu_percent = statistics.mean([m.average_cpu_percent for m in metrics_list])
+ peak_cpu_percent = max([m.peak_cpu_percent for m in metrics_list])
+ samples_processed = sum([m.samples_processed for m in metrics_list]) // len(metrics_list)
+ samples_retained = sum([m.samples_retained for m in metrics_list]) // len(metrics_list)
+ data_retention_rate = statistics.mean([m.data_retention_rate for m in metrics_list])
+
+ # Use the first metric as template and update values
+ aggregated = metrics_list[0]
+ aggregated.total_wall_time = total_wall_time
+ aggregated.samples_per_second = samples_per_second
+ aggregated.peak_memory_mb = peak_memory_mb
+ aggregated.average_cpu_percent = average_cpu_percent
+ aggregated.peak_cpu_percent = peak_cpu_percent
+ aggregated.samples_processed = samples_processed
+ aggregated.samples_retained = samples_retained
+ aggregated.data_retention_rate = data_retention_rate
+
+ return aggregated
+
+ def _test_significance(self, baseline_times: List[float], test_times: List[float]) -> Tuple[bool, float]:
+ """Test statistical significance between two sets of timing data."""
+ try:
+ # Simple t-test for now - could be enhanced with more sophisticated tests
+ from scipy import stats
+
+ # Perform Welch's t-test (unequal variances)
+ statistic, p_value = stats.ttest_ind(test_times, baseline_times, equal_var=False)
+
+ is_significant = p_value < (1 - self.confidence_level)
+ return is_significant, p_value
+
+ except ImportError:
+ # Fallback to simple comparison if scipy not available
+ baseline_mean = statistics.mean(baseline_times)
+ test_mean = statistics.mean(test_times)
+
+ # Simple threshold-based significance
+ difference = abs(test_mean - baseline_mean)
+ threshold = baseline_mean * 0.1 # 10% threshold
+
+ is_significant = difference > threshold
+ p_value = 0.5 if is_significant else 1.0
+
+ return is_significant, p_value
+ except Exception as e:
+ logger.warning(f"Error in significance test: {e}")
+ return False, 1.0
+
+ def _calculate_efficiency_score(self, metrics: BenchmarkMetrics) -> float:
+ """Calculate an overall efficiency score."""
+ # Weighted combination of throughput and memory efficiency
+ throughput_score = min(metrics.samples_per_second / 1000, 1.0) # Normalize to 0-1
+ memory_score = max(0, 1.0 - metrics.peak_memory_mb / 10000) # Penalize high memory
+
+ return throughput_score * 0.7 + memory_score * 0.3
+
+ def _generate_summary(
+ self,
+ speedup: float,
+ throughput_improvement: float,
+ memory_efficiency: float,
+ is_significant: bool,
+ baseline_name: str,
+ test_name: str,
+ ) -> str:
+ """Generate a human-readable summary of the comparison."""
+ if speedup > 1.05:
+ direction = "faster"
+ improvement = f"{speedup:.2f}x"
+ elif speedup < 0.95:
+ direction = "slower"
+ improvement = f"{1/speedup:.2f}x"
+ else:
+ direction = "similar"
+ improvement = "~1x"
+
+ significance = "statistically significant" if is_significant else "not statistically significant"
+
+ return (
+ f"{test_name} is {improvement} {direction} than {baseline_name} "
+ f"({significance}). Throughput: {throughput_improvement:.2f}x, "
+ f"Memory efficiency: {memory_efficiency:.2f}x"
+ )
diff --git a/data_juicer/benchmark/example_usage.py b/data_juicer/benchmark/example_usage.py
new file mode 100644
index 0000000000..bddd20fb9b
--- /dev/null
+++ b/data_juicer/benchmark/example_usage.py
@@ -0,0 +1,164 @@
+#!/usr/bin/env python3
+"""
+Example usage of the benchmark framework.
+"""
+
+import os
+
+from data_juicer.benchmark import (
+ STRATEGY_LIBRARY,
+ WORKLOAD_SUITE,
+ ABTestConfig,
+ BenchmarkConfig,
+ BenchmarkRunner,
+ StrategyABTest,
+)
+
+
+def example_single_benchmark():
+ """Example: Run a single benchmark."""
+ print("=== Single Benchmark Example ===")
+
+ # Create benchmark configuration
+ config = BenchmarkConfig(
+ dataset_path="demos/data/text_data.jsonl",
+ config_path="configs/demo/process.yaml",
+ output_dir="benchmark_results/single",
+ iterations=3,
+ strategy_name="op_fusion_greedy",
+ strategy_config={"op_fusion": True, "fusion_strategy": "greedy"},
+ )
+
+ # Run benchmark
+ runner = BenchmarkRunner(config)
+ results = runner.run_benchmark_suite()
+
+ # Print results
+ for i, metrics in enumerate(results):
+ print(f"Run {i+1}: {metrics.total_wall_time:.2f}s, " f"{metrics.samples_per_second:.1f} samples/sec")
+
+
+def example_ab_test():
+ """Example: Run A/B test between strategies."""
+ print("\n=== A/B Test Example ===")
+
+ # Get workload
+ workload = WORKLOAD_SUITE.get_workload("text_simple")
+ if not workload:
+ print("Workload not found!")
+ return
+
+ # Create strategy configurations
+ baseline = STRATEGY_LIBRARY.create_strategy_config("baseline")
+ test_strategy = STRATEGY_LIBRARY.create_strategy_config("op_fusion_greedy")
+
+ # Create A/B test configuration
+ ab_config = ABTestConfig(
+ name="fusion_vs_baseline",
+ baseline_strategy=baseline,
+ test_strategies=[test_strategy],
+ workload=workload,
+ iterations=3,
+ output_dir="benchmark_results/ab_test",
+ )
+
+ # Run A/B test
+ ab_test = StrategyABTest(ab_config)
+ results = ab_test.run_ab_test()
+
+ # Print results
+ for strategy_name, comparison in results.items():
+ print(f"{strategy_name}: {comparison.speedup:.2f}x speedup")
+ print(f" Summary: {comparison.summary}")
+
+
+def example_workload_suite():
+ """Example: Run tests across multiple workloads."""
+ print("\n=== Workload Suite Example ===")
+
+ # Get multiple workloads
+ workloads = [WORKLOAD_SUITE.get_workload("text_simple"), WORKLOAD_SUITE.get_workload("image_simple")]
+ workloads = [w for w in workloads if w] # Filter out None values
+
+ if not workloads:
+ print("No workloads found!")
+ return
+
+ # Create strategies
+ strategies = [
+ STRATEGY_LIBRARY.create_strategy_config("baseline"),
+ STRATEGY_LIBRARY.create_strategy_config("op_fusion_greedy"),
+ STRATEGY_LIBRARY.create_strategy_config("adaptive_batch_size"),
+ ]
+
+ # Create A/B test
+ ab_config = ABTestConfig(
+ name="workload_suite_test",
+ baseline_strategy=strategies[0],
+ test_strategies=strategies[1:],
+ workload=workloads[0], # Will be overridden
+ iterations=2,
+ output_dir="benchmark_results/workload_suite",
+ )
+
+ ab_test = StrategyABTest(ab_config)
+ results = ab_test.run_workload_suite_ab_test(strategies[1:], workloads)
+
+ # Print results
+ for workload_name, workload_results in results.items():
+ print(f"\n{workload_name}:")
+ for strategy_name, comparison in workload_results.items():
+ print(f" {strategy_name}: {comparison.speedup:.2f}x speedup")
+
+
+def example_strategy_comparison():
+ """Example: Compare multiple strategies."""
+ print("\n=== Strategy Comparison Example ===")
+
+ # Get workload
+ workload = WORKLOAD_SUITE.get_workload("text_simple")
+ if not workload:
+ print("Workload not found!")
+ return
+
+ # Create strategy comparison
+ strategy_names = ["baseline", "op_fusion_greedy", "adaptive_batch_size"]
+ ab_test = StrategyABTest.create_strategy_comparison(strategy_names, workload)
+
+ # Run comparison
+ results = ab_test.run_ab_test()
+
+ # Print results
+ print("Strategy Comparison Results:")
+ for strategy_name, comparison in results.items():
+ print(f" {strategy_name}: {comparison.speedup:.2f}x speedup")
+ print(f" {comparison.summary}")
+
+
+def main():
+ """Run all examples."""
+ print("Data-Juicer Benchmark Framework Examples")
+ print("=" * 50)
+
+ # Ensure output directories exist
+ os.makedirs("benchmark_results", exist_ok=True)
+
+ try:
+ # Run examples
+ example_single_benchmark()
+ example_ab_test()
+ example_workload_suite()
+ example_strategy_comparison()
+
+ print("\n=== All Examples Completed ===")
+ print("Check the 'benchmark_results' directory for detailed reports.")
+
+ except Exception as e:
+ print(f"Error running examples: {e}")
+ import traceback
+
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/data_juicer/benchmark/single_benchmark_example.py b/data_juicer/benchmark/single_benchmark_example.py
new file mode 100644
index 0000000000..c85213df04
--- /dev/null
+++ b/data_juicer/benchmark/single_benchmark_example.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python3
+"""
+Example: Single benchmark with custom config and dataset.
+"""
+
+from data_juicer.benchmark import BenchmarkConfig, BenchmarkRunner
+
+
+def run_single_benchmark():
+ """Run a single benchmark with your own config and dataset."""
+
+ # Create benchmark configuration
+ config = BenchmarkConfig(
+ dataset_path="your_dataset.jsonl", # Your dataset path
+ config_path="your_config.yaml", # Your config path
+ output_dir="benchmark_results/", # Output directory
+ iterations=3, # Number of runs
+ warmup_runs=1, # Warmup runs (not counted)
+ timeout_seconds=3600, # Timeout
+ strategy_name="op_fusion_greedy", # Strategy to test
+ strategy_config={"op_fusion": True, "fusion_strategy": "greedy"}, # Strategy-specific config
+ )
+
+ # Run benchmark
+ runner = BenchmarkRunner(config)
+ results = runner.run_benchmark_suite()
+
+ # Analyze results
+ print("=== Single Benchmark Results ===")
+ for i, metrics in enumerate(results):
+ print(f"\nRun {i+1}:")
+ print(f" Time: {metrics.total_wall_time:.2f}s")
+ print(f" Throughput: {metrics.samples_per_second:.1f} samples/sec")
+ print(f" Memory: {metrics.peak_memory_mb:.0f} MB")
+ print(f" CPU: {metrics.average_cpu_percent:.1f}%")
+ print(f" Retention: {metrics.data_retention_rate:.1%}")
+
+ # Calculate statistics
+ times = [m.total_wall_time for m in results]
+ throughputs = [m.samples_per_second for m in results]
+
+ print(f"\n=== Summary Statistics ===")
+ print(f"Average time: {sum(times)/len(times):.2f}s")
+ print(f"Average throughput: {sum(throughputs)/len(throughputs):.1f} samples/sec")
+ print(f"Time std dev: {((sum((t - sum(times)/len(times))**2 for t in times) / len(times))**0.5):.2f}s")
+
+ return results
+
+
+def run_multiple_strategies_same_dataset():
+ """Run multiple strategies on the same dataset for comparison."""
+
+ strategies = [
+ ("baseline", {}),
+ ("op_fusion_greedy", {"op_fusion": True, "fusion_strategy": "greedy"}),
+ ("adaptive_batch_size", {"adaptive_batch_size": True}),
+ ("memory_efficient", {"memory_efficient": True, "streaming": True}),
+ ]
+
+ all_results = {}
+
+ for strategy_name, strategy_config in strategies:
+ print(f"\n=== Running {strategy_name} ===")
+
+ config = BenchmarkConfig(
+ dataset_path="your_dataset.jsonl",
+ config_path="your_config.yaml",
+ output_dir=f"benchmark_results/{strategy_name}/",
+ iterations=3,
+ strategy_name=strategy_name,
+ strategy_config=strategy_config,
+ )
+
+ runner = BenchmarkRunner(config)
+ results = runner.run_benchmark_suite()
+ all_results[strategy_name] = results
+
+ # Print quick summary
+ avg_time = sum(m.total_wall_time for m in results) / len(results)
+ avg_throughput = sum(m.samples_per_second for m in results) / len(results)
+ print(f" Average time: {avg_time:.2f}s")
+ print(f" Average throughput: {avg_throughput:.1f} samples/sec")
+
+ # Compare results
+ print(f"\n=== Strategy Comparison ===")
+ baseline_time = sum(m.total_wall_time for m in all_results["baseline"]) / len(all_results["baseline"])
+
+ for strategy_name, results in all_results.items():
+ if strategy_name == "baseline":
+ continue
+
+ avg_time = sum(m.total_wall_time for m in results) / len(results)
+ speedup = baseline_time / avg_time
+ print(f"{strategy_name}: {speedup:.2f}x speedup")
+
+ return all_results
+
+
+if __name__ == "__main__":
+ print("Single Benchmark Examples")
+ print("=" * 50)
+
+ # Example 1: Single strategy
+ print("\n1. Single Strategy Benchmark:")
+ # results = run_single_benchmark()
+
+ # Example 2: Multiple strategies on same dataset
+ print("\n2. Multiple Strategies Comparison:")
+ # all_results = run_multiple_strategies_same_dataset()
+
+ print("\nNote: Update the dataset and config paths to your actual files!")
diff --git a/data_juicer/benchmark/strategies/__init__.py b/data_juicer/benchmark/strategies/__init__.py
new file mode 100644
index 0000000000..15bfb6b986
--- /dev/null
+++ b/data_juicer/benchmark/strategies/__init__.py
@@ -0,0 +1,14 @@
+"""Strategy definitions and A/B testing framework."""
+
+from .ab_test import ABTestConfig, StrategyABTest
+from .config_strategies import BaselineStrategy, CoreOptimizerStrategy
+from .strategy_library import STRATEGY_LIBRARY, OptimizationStrategy
+
+__all__ = [
+ "OptimizationStrategy",
+ "STRATEGY_LIBRARY",
+ "StrategyABTest",
+ "ABTestConfig",
+ "BaselineStrategy",
+ "CoreOptimizerStrategy",
+]
diff --git a/data_juicer/benchmark/strategies/ab_test.py b/data_juicer/benchmark/strategies/ab_test.py
new file mode 100644
index 0000000000..b0ae7487fa
--- /dev/null
+++ b/data_juicer/benchmark/strategies/ab_test.py
@@ -0,0 +1,202 @@
+#!/usr/bin/env python3
+"""
+A/B testing framework for optimization strategies.
+"""
+
+import os
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List
+
+from loguru import logger
+
+from ..core.benchmark_runner import BenchmarkConfig, BenchmarkRunner
+from ..core.metrics_collector import BenchmarkMetrics
+from ..core.report_generator import ReportGenerator
+from ..core.result_analyzer import ComparisonResult, ResultAnalyzer
+from ..workloads.workload_suite import WorkloadDefinition
+from .strategy_library import STRATEGY_LIBRARY, StrategyConfig
+
+
+@dataclass
+class ABTestConfig:
+ """Configuration for an A/B test."""
+
+ name: str
+ baseline_strategy: StrategyConfig
+ test_strategies: List[StrategyConfig]
+ workload: WorkloadDefinition
+ iterations: int = 3
+ warmup_runs: int = 1
+ output_dir: str = "benchmark_results"
+ timeout_seconds: int = 3600
+
+
+class StrategyABTest:
+ """A/B testing framework for optimization strategies."""
+
+ def __init__(self, config: ABTestConfig):
+ self.config = config
+ self.analyzer = ResultAnalyzer()
+ self.report_generator = ReportGenerator(config.output_dir)
+
+ # Ensure output directory exists
+ Path(config.output_dir).mkdir(parents=True, exist_ok=True)
+
+ def run_ab_test(self) -> Dict[str, ComparisonResult]:
+ """Run the complete A/B test."""
+ logger.info(f"Starting A/B test: {self.config.name}")
+
+ # Run baseline
+ logger.info("Running baseline strategy...")
+ baseline_results = self._run_strategy(self.config.baseline_strategy)
+
+ # Run test strategies
+ all_results = {self.config.baseline_strategy.name: baseline_results}
+ comparisons = {}
+
+ for test_strategy in self.config.test_strategies:
+ logger.info(f"Running test strategy: {test_strategy.name}")
+ test_results = self._run_strategy(test_strategy)
+ all_results[test_strategy.name] = test_results
+
+ # Compare with baseline
+ comparison = self.analyzer.compare_runs(
+ baseline_results, test_results, self.config.baseline_strategy.name, test_strategy.name
+ )
+ comparisons[test_strategy.name] = comparison
+
+ # Generate report
+ report_path = self.report_generator.generate_ab_test_report(all_results, comparisons, self.config.name)
+
+ logger.info(f"A/B test completed. Report: {report_path}")
+ return comparisons
+
+ def run_workload_suite_ab_test(
+ self, strategies: List[StrategyConfig], workloads: List[WorkloadDefinition]
+ ) -> Dict[str, Dict[str, ComparisonResult]]:
+ """Run A/B test across multiple workloads."""
+ logger.info(f"Running workload suite A/B test with {len(workloads)} workloads")
+
+ all_results = {}
+
+ for workload in workloads:
+ logger.info(f"Testing workload: {workload.name}")
+
+ # Create workload-specific A/B test
+ workload_config = ABTestConfig(
+ name=f"{self.config.name}_{workload.name}",
+ baseline_strategy=self.config.baseline_strategy,
+ test_strategies=self.config.test_strategies,
+ workload=workload,
+ iterations=self.config.iterations,
+ warmup_runs=self.config.warmup_runs,
+ output_dir=os.path.join(self.config.output_dir, workload.name),
+ timeout_seconds=self.config.timeout_seconds,
+ )
+
+ # Run A/B test for this workload
+ workload_ab_test = StrategyABTest(workload_config)
+ workload_results = workload_ab_test.run_ab_test()
+
+ all_results[workload.name] = workload_results
+
+ # Generate comprehensive report
+ report_path = self.report_generator.generate_workload_report(all_results, f"{self.config.name}_workload_suite")
+
+ logger.info(f"Workload suite A/B test completed. Report: {report_path}")
+ return all_results
+
+ def _run_strategy(self, strategy: StrategyConfig) -> List[BenchmarkMetrics]:
+ """Run a single strategy and return metrics."""
+
+ # Create benchmark configuration
+ benchmark_config = BenchmarkConfig(
+ dataset_path=self.config.workload.dataset_path,
+ config_path=self.config.workload.config_path,
+ output_dir=os.path.join(self.config.output_dir, strategy.name),
+ iterations=self.config.iterations,
+ warmup_runs=self.config.warmup_runs,
+ timeout_seconds=self.config.timeout_seconds,
+ strategy_name=strategy.name,
+ strategy_config=self._strategy_to_config_dict(strategy),
+ )
+
+ # Run benchmark
+ runner = BenchmarkRunner(benchmark_config)
+ return runner.run_benchmark_suite()
+
+ def _strategy_to_config_dict(self, strategy: StrategyConfig) -> Dict[str, Any]:
+ """Convert strategy to configuration dictionary."""
+ config_dict = {}
+
+ if strategy.name == "op_fusion_greedy":
+ config_dict["op_fusion"] = True
+ config_dict["fusion_strategy"] = "greedy"
+ elif strategy.name == "op_fusion_probe":
+ config_dict["op_fusion"] = True
+ config_dict["fusion_strategy"] = "probe"
+ elif strategy.name == "adaptive_batch_size":
+ config_dict["adaptive_batch_size"] = True
+ elif strategy.name == "large_batch_size":
+ config_dict["batch_size"] = 1000 # Large batch size
+ elif strategy.name == "memory_efficient":
+ config_dict["memory_efficient"] = True
+ elif strategy.name == "streaming_processing":
+ config_dict["streaming"] = True
+ elif strategy.name == "max_parallelism":
+ config_dict["num_processes"] = -1 # Use all available cores
+ elif strategy.name == "ray_optimized":
+ config_dict["executor"] = "ray"
+ config_dict["ray_config"] = {"num_cpus": -1}
+ elif strategy.name == "aggressive_caching":
+ config_dict["cache_intermediate"] = True
+ elif strategy.name == "fast_algorithms":
+ config_dict["use_fast_algorithms"] = True
+ elif strategy.name == "vectorized_ops":
+ config_dict["vectorized_operations"] = True
+
+ # Add strategy-specific parameters
+ config_dict.update(strategy.parameters)
+
+ return config_dict
+
+ def create_strategy_comparison(self, strategy_names: List[str], workload: WorkloadDefinition) -> "StrategyABTest":
+ """Create an A/B test comparing multiple strategies."""
+
+ if len(strategy_names) < 2:
+ raise ValueError("Need at least 2 strategies for comparison")
+
+ # Use first strategy as baseline
+ baseline = STRATEGY_LIBRARY.create_strategy_config(strategy_names[0])
+ test_strategies = [STRATEGY_LIBRARY.create_strategy_config(name) for name in strategy_names[1:]]
+
+ config = ABTestConfig(
+ name=f"comparison_{'_vs_'.join(strategy_names)}",
+ baseline_strategy=baseline,
+ test_strategies=test_strategies,
+ workload=workload,
+ )
+
+ return StrategyABTest(config)
+
+ def get_recommended_strategies(self, workload: WorkloadDefinition) -> List[str]:
+ """Get recommended strategies for a specific workload."""
+ recommendations = []
+
+ # Modality-specific recommendations
+ if workload.modality == "text":
+ recommendations.extend(["op_fusion_greedy", "adaptive_batch_size", "vectorized_ops"])
+ elif workload.modality == "image":
+ recommendations.extend(["op_fusion_probe", "memory_efficient", "ray_optimized"])
+ elif workload.modality == "video":
+ recommendations.extend(["streaming_processing", "ray_optimized", "aggressive_caching"])
+ elif workload.modality == "audio":
+ recommendations.extend(["op_fusion_greedy", "adaptive_batch_size"])
+
+ # Complexity-specific recommendations
+ if workload.complexity == "complex":
+ recommendations.extend(["memory_efficient", "ray_optimized"])
+
+ # Remove duplicates while preserving order
+ return list(dict.fromkeys(recommendations))
diff --git a/data_juicer/benchmark/strategies/config_strategies.py b/data_juicer/benchmark/strategies/config_strategies.py
new file mode 100644
index 0000000000..d041dc8bb4
--- /dev/null
+++ b/data_juicer/benchmark/strategies/config_strategies.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+"""
+Concrete strategy implementations for the benchmark framework.
+"""
+
+from typing import Any, Dict, List
+
+from .strategy_library import OptimizationStrategy, StrategyType
+
+
+class CoreOptimizerStrategy(OptimizationStrategy):
+ """Strategy that configures the core optimizer to enable/disable specific strategies."""
+
+ def __init__(self, name: str, description: str, enabled_strategies: List[str]):
+ super().__init__(name, description)
+ self.strategy_type = StrategyType.FUSION # Core optimizer is primarily fusion-based
+ self.enabled_strategies = enabled_strategies
+
+ def apply_to_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
+ """Apply core optimizer configuration."""
+ # Enable core optimizer and specify which strategies to use
+ config = config.copy()
+ config["_benchmark_optimizer_enabled"] = True
+ config["_benchmark_optimizer_strategies"] = self.enabled_strategies
+ return config
+
+ def get_expected_impact(self) -> Dict[str, str]:
+ """Get expected impact description."""
+ return {
+ "performance": "Improved performance through core optimizer strategies",
+ "memory": "Optimized memory usage through operation fusion",
+ "complexity": "Moderate configuration complexity",
+ }
+
+
+class BaselineStrategy(OptimizationStrategy):
+ """Baseline strategy with no optimizations."""
+
+ def __init__(self, name: str = "baseline"):
+ super().__init__(name, "Baseline configuration with no optimizations")
+ self.strategy_type = StrategyType.ALGORITHM
+
+ def apply_to_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
+ """Apply baseline configuration (no changes)."""
+ # Return config unchanged for baseline
+ return config.copy()
+
+ def get_expected_impact(self) -> Dict[str, str]:
+ """Get expected impact description."""
+ return {
+ "performance": "Baseline performance",
+ "memory": "Standard memory usage",
+ "complexity": "Minimal configuration complexity",
+ }
diff --git a/data_juicer/benchmark/strategies/strategy_library.py b/data_juicer/benchmark/strategies/strategy_library.py
new file mode 100644
index 0000000000..54b0cb17d6
--- /dev/null
+++ b/data_juicer/benchmark/strategies/strategy_library.py
@@ -0,0 +1,173 @@
+#!/usr/bin/env python3
+"""
+Library of optimization strategies for A/B testing.
+"""
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict, List, Optional
+
+
+class StrategyType(Enum):
+ """Types of optimization strategies."""
+
+ FUSION = "fusion"
+ BATCHING = "batching"
+ MEMORY = "memory"
+ PARALLEL = "parallel"
+ CACHING = "caching"
+ ALGORITHM = "algorithm"
+
+
+@dataclass
+class StrategyConfig:
+ """Configuration for an optimization strategy."""
+
+ name: str
+ enabled: bool
+ parameters: Dict[str, Any]
+ description: str = ""
+ strategy_type: StrategyType = StrategyType.ALGORITHM
+
+
+class OptimizationStrategy(ABC):
+ """Base class for optimization strategies."""
+
+ def __init__(self, name: str, description: str = ""):
+ self.name = name
+ self.description = description
+ self.enabled = False
+ self.parameters = {}
+
+ @abstractmethod
+ def apply_to_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
+ """Apply this strategy to a configuration."""
+ pass
+
+ @abstractmethod
+ def get_expected_impact(self) -> Dict[str, str]:
+ """Get expected impact description."""
+ pass
+
+ def validate_parameters(self, parameters: Dict[str, Any]) -> List[str]:
+ """Validate strategy parameters."""
+ return []
+
+ def get_parameter_schema(self) -> Dict[str, Any]:
+ """Get parameter schema for this strategy."""
+ return {}
+
+
+class StrategyLibrary:
+ """Library of available optimization strategies."""
+
+ def __init__(self):
+ self.strategies = {}
+ self._initialize_strategies()
+
+ def _initialize_strategies(self):
+ """Initialize all available strategies."""
+ from .config_strategies import BaselineStrategy, CoreOptimizerStrategy
+
+ # Baseline Strategy (no optimizations)
+ self.strategies["baseline"] = BaselineStrategy()
+
+ # Core Optimizer Strategies - these configure the actual optimizer
+ self.strategies["mapper_fusion"] = CoreOptimizerStrategy(
+ "mapper_fusion", "Enable mapper operation fusion", ["mapper_fusion"]
+ )
+ self.strategies["filter_fusion"] = CoreOptimizerStrategy(
+ "filter_fusion", "Enable filter operation fusion", ["filter_fusion"]
+ )
+ self.strategies["full_optimization"] = CoreOptimizerStrategy(
+ "full_optimization", "Enable all core optimizations", ["mapper_fusion", "filter_fusion"]
+ )
+
+ # Additional configuration-based strategies (not core optimizer)
+ class SimpleStrategy(OptimizationStrategy):
+ def __init__(self, name, description, strategy_type, apply_func, impact_dict):
+ super().__init__(name, description)
+ self.strategy_type = strategy_type
+ self._apply_func = apply_func
+ self._impact_dict = impact_dict
+
+ def apply_to_config(self, config):
+ return self._apply_func(config)
+
+ def get_expected_impact(self):
+ return self._impact_dict
+
+ # Configuration-only strategies (not core optimizer)
+ self.strategies["large_batch_size"] = SimpleStrategy(
+ "large_batch_size",
+ "Use large batch sizes for better throughput",
+ StrategyType.BATCHING,
+ lambda config: {**config, "batch_size": 1000},
+ {
+ "performance": "Improved throughput with large batches",
+ "memory": "Higher memory usage",
+ "complexity": "Minimal configuration complexity",
+ },
+ )
+
+ self.strategies["streaming_processing"] = SimpleStrategy(
+ "streaming_processing",
+ "Enable streaming processing to reduce memory usage",
+ StrategyType.MEMORY,
+ lambda config: {**config, "streaming": True},
+ {
+ "performance": "May reduce throughput for memory savings",
+ "memory": "Significantly reduced memory usage",
+ "complexity": "Minimal configuration complexity",
+ },
+ )
+
+ self.strategies["ray_optimized"] = SimpleStrategy(
+ "ray_optimized",
+ "Optimize for Ray distributed processing",
+ StrategyType.PARALLEL,
+ lambda config: {**config, "executor": "ray", "ray_config": {"num_cpus": -1}},
+ {
+ "performance": "Improved throughput through distributed processing",
+ "memory": "Distributed memory usage across nodes",
+ "complexity": "Moderate configuration complexity",
+ },
+ )
+
+ # Operation Reordering Strategy (from core optimizer)
+ self.strategies["op_reorder"] = CoreOptimizerStrategy(
+ "op_reorder", "Enable operation reordering for optimal execution order", ["op_reorder"]
+ )
+
+ def get_strategy(self, name: str) -> Optional[OptimizationStrategy]:
+ """Get a strategy by name."""
+ return self.strategies.get(name)
+
+ def get_strategies_by_type(self, strategy_type: StrategyType) -> List[OptimizationStrategy]:
+ """Get all strategies of a specific type."""
+ return [s for s in self.strategies.values() if s.strategy_type == strategy_type]
+
+ def get_all_strategies(self) -> List[OptimizationStrategy]:
+ """Get all available strategies."""
+ return list(self.strategies.values())
+
+ def create_strategy_config(
+ self, name: str, enabled: bool = True, parameters: Dict[str, Any] = None
+ ) -> StrategyConfig:
+ """Create a strategy configuration."""
+ strategy = self.get_strategy(name)
+ if not strategy:
+ raise ValueError(f"Unknown strategy: {name}")
+
+ return StrategyConfig(
+ name=name,
+ enabled=enabled,
+ parameters=parameters or {},
+ description=strategy.description,
+ strategy_type=strategy.strategy_type,
+ )
+
+
+# Global strategy library instance
+STRATEGY_LIBRARY = StrategyLibrary()
diff --git a/data_juicer/benchmark/utils/__init__.py b/data_juicer/benchmark/utils/__init__.py
new file mode 100644
index 0000000000..552514e0e7
--- /dev/null
+++ b/data_juicer/benchmark/utils/__init__.py
@@ -0,0 +1,6 @@
+"""Utility components for the benchmark framework."""
+
+from .benchmark_cli import BenchmarkCLI
+from .config_manager import ConfigManager
+
+__all__ = ["ConfigManager", "BenchmarkCLI"]
diff --git a/data_juicer/benchmark/utils/benchmark_cli.py b/data_juicer/benchmark/utils/benchmark_cli.py
new file mode 100644
index 0000000000..5c2146e8ef
--- /dev/null
+++ b/data_juicer/benchmark/utils/benchmark_cli.py
@@ -0,0 +1,450 @@
+#!/usr/bin/env python3
+"""
+Command-line interface for the benchmark framework.
+"""
+
+import argparse
+import sys
+from typing import List, Optional
+
+from loguru import logger
+
+# Import these inside functions to avoid circular imports
+# from ..core.benchmark_runner import BenchmarkRunner, BenchmarkConfig
+# from ..strategies.ab_test import StrategyABTest, ABTestConfig
+# from ..strategies.strategy_library import STRATEGY_LIBRARY
+# from ..workloads.workload_suite import WORKLOAD_SUITE
+
+
+class BenchmarkCLI:
+ """Command-line interface for benchmark framework."""
+
+ def __init__(self):
+ self.parser = self._create_parser()
+
+ def _create_parser(self) -> argparse.ArgumentParser:
+ """Create argument parser."""
+ parser = argparse.ArgumentParser(
+ description="Data-Juicer Performance Benchmark Framework",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Run A/B test with specific strategies
+ python -m data_juicer.benchmark.utils.benchmark_cli ab-test \\
+ --baseline baseline \\
+ --target-strategies mapper_fusion,filter_fusion \\
+ --workload text_simple \\
+ --output-dir outputs/benchmark/
+
+ # Run workload suite test
+ python -m data_juicer.benchmark.utils.benchmark_cli workload-suite \\
+ --workloads text_simple,image_simple \\
+ --baseline baseline \\
+ --target-strategies mapper_fusion,full_optimization \\
+ --output-dir outputs/benchmark/
+
+ # Run single benchmark with custom dataset and config
+ python -m data_juicer.benchmark.utils.benchmark_cli single \\
+ --dataset /path/to/your/dataset.jsonl \\
+ --config /path/to/your/config.yaml \\
+ --strategy baseline \\
+ --output-dir outputs/benchmark/
+
+ # Run benchmark with production text dataset and simple config
+ python -m data_juicer.benchmark.utils.benchmark_cli single \\
+ --modality text \\
+ --config-type simple \\
+ --strategy baseline \\
+ --output-dir outputs/benchmark/
+
+ # Run benchmark with production text dataset and production config
+ python -m data_juicer.benchmark.utils.benchmark_cli single \\
+ --modality text \\
+ --config-type production \\
+ --strategy mapper_fusion \\
+ --output-dir outputs/benchmark/
+
+ # Run benchmark with 10% sampling
+ python -m data_juicer.benchmark.utils.benchmark_cli single \\
+ --modality text \\
+ --config-type production \\
+ --strategy baseline \\
+ --sample-ratio 0.1 \\
+ --sample-method random \\
+ --output-dir outputs/benchmark/
+ """,
+ )
+
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
+
+ # A/B test command
+ ab_parser = subparsers.add_parser("ab-test", help="Run A/B test between strategies")
+ ab_parser.add_argument("--baseline", required=True, help="Baseline strategy name")
+ ab_parser.add_argument(
+ "--target-strategies",
+ required=True,
+ help="Comma-separated list of target strategies to test against baseline",
+ )
+ ab_parser.add_argument("--workload", required=True, help="Workload to use for testing")
+ ab_parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per strategy")
+ ab_parser.add_argument("--output-dir", default="outputs/benchmark", help="Output directory for results")
+ ab_parser.add_argument("--timeout", type=int, default=3600, help="Timeout in seconds")
+
+ # Workload suite command
+ suite_parser = subparsers.add_parser("workload-suite", help="Run tests across multiple workloads")
+ suite_parser.add_argument("--workloads", required=True, help="Comma-separated list of workloads to test")
+ suite_parser.add_argument("--baseline", required=True, help="Baseline strategy name")
+ suite_parser.add_argument(
+ "--target-strategies",
+ required=True,
+ help="Comma-separated list of target strategies to test against baseline",
+ )
+ suite_parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per strategy")
+ suite_parser.add_argument("--output-dir", default="outputs/benchmark", help="Output directory for results")
+ suite_parser.add_argument("--timeout", type=int, default=3600, help="Timeout in seconds")
+
+ # Single benchmark command
+ single_parser = subparsers.add_parser("single", help="Run single benchmark")
+
+ # Dataset options
+ single_parser.add_argument("--dataset", help="Path to custom dataset")
+ single_parser.add_argument(
+ "--modality",
+ choices=["text", "image", "video", "audio"],
+ help="Use production dataset for specified modality",
+ )
+
+ # Config options
+ single_parser.add_argument("--config", help="Path to custom configuration file")
+ single_parser.add_argument(
+ "--config-type",
+ choices=["simple", "production"],
+ default="simple",
+ help="Use simple or production config for the modality",
+ )
+
+ # Sampling and other options
+ single_parser.add_argument("--strategy", required=True, help="Strategy to test")
+ single_parser.add_argument("--iterations", type=int, default=3, help="Number of iterations")
+ single_parser.add_argument("--output-dir", default="outputs/benchmark", help="Output directory for results")
+ single_parser.add_argument("--timeout", type=int, default=3600, help="Timeout in seconds")
+ single_parser.add_argument(
+ "--sample-ratio",
+ type=float,
+ default=1.0,
+ help="Sample ratio (0.1 = 10 percent of dataset, 1.0 = full dataset)",
+ )
+ single_parser.add_argument(
+ "--sample-method", choices=["random", "first", "last"], default="random", help="Sampling method"
+ )
+
+ # A/B test optimization command
+ ab_opt_parser = subparsers.add_parser(
+ "ab-optimization", help="Run A/B test comparing baseline vs optimized strategies"
+ )
+
+ # Dataset options for A/B test
+ ab_opt_parser.add_argument("--dataset", help="Path to custom dataset")
+ ab_opt_parser.add_argument(
+ "--modality",
+ choices=["text", "image", "video", "audio"],
+ help="Use production dataset for specified modality",
+ )
+
+ # Config options for A/B test
+ ab_opt_parser.add_argument("--config", help="Path to custom configuration file")
+ ab_opt_parser.add_argument(
+ "--config-type",
+ choices=["simple", "production"],
+ default="simple",
+ help="Use simple or production config for the modality",
+ )
+
+ # Optimization strategy options
+ ab_opt_parser.add_argument(
+ "--optimizations",
+ nargs="+",
+ choices=["mapper_fusion", "filter_fusion", "full_optimization"],
+ default=["mapper_fusion"],
+ help="Optimization strategies to test (default: mapper_fusion)",
+ )
+ ab_opt_parser.add_argument(
+ "--baseline-name",
+ default="baseline",
+ help="Name for baseline strategy (default: baseline)",
+ )
+ ab_opt_parser.add_argument(
+ "--optimized-name",
+ default="optimized",
+ help="Name for optimized strategy (default: optimized)",
+ )
+
+ # Sampling and other options
+ ab_opt_parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per strategy")
+ ab_opt_parser.add_argument("--output-dir", default="outputs/benchmark", help="Output directory for results")
+ ab_opt_parser.add_argument("--timeout", type=int, default=3600, help="Timeout in seconds")
+ ab_opt_parser.add_argument(
+ "--sample-ratio",
+ type=float,
+ default=1.0,
+ help="Sample ratio (0.1 = 10 percent of dataset, 1.0 = full dataset)",
+ )
+ ab_opt_parser.add_argument(
+ "--sample-method", choices=["random", "first", "last"], default="random", help="Sampling method"
+ )
+
+ # List command
+ list_parser = subparsers.add_parser("list", help="List available options")
+ list_parser.add_argument("--workloads", action="store_true", help="List available workloads")
+ list_parser.add_argument("--strategies", action="store_true", help="List available strategies")
+
+ return parser
+
+ def run(self, args: Optional[List[str]] = None) -> int:
+ """Run the CLI with given arguments."""
+ if args is None:
+ args = sys.argv[1:]
+
+ parsed_args = self.parser.parse_args(args)
+
+ if not parsed_args.command:
+ self.parser.print_help()
+ return 1
+
+ try:
+ if parsed_args.command == "ab-test":
+ return self._run_ab_test(parsed_args)
+ elif parsed_args.command == "workload-suite":
+ return self._run_workload_suite(parsed_args)
+ elif parsed_args.command == "single":
+ return self._run_single_benchmark(parsed_args)
+ elif parsed_args.command == "ab-optimization":
+ return self._run_ab_optimization(parsed_args)
+ elif parsed_args.command == "list":
+ return self._list_options(parsed_args)
+ else:
+ logger.error(f"Unknown command: {parsed_args.command}")
+ return 1
+
+ except Exception as e:
+ logger.error(f"Error running command: {e}")
+ return 1
+
+ def _run_ab_test(self, args) -> int:
+ """Run A/B test."""
+ # Import here to avoid circular imports
+ from data_juicer.benchmark.strategies.ab_test import (
+ ABTestConfig,
+ StrategyABTest,
+ )
+ from data_juicer.benchmark.strategies.strategy_library import STRATEGY_LIBRARY
+ from data_juicer.benchmark.workloads.workload_suite import WORKLOAD_SUITE
+
+ # Parse baseline and target strategies
+ baseline_name = args.baseline.strip()
+ target_strategy_names = [s.strip() for s in args.target_strategies.split(",")]
+
+ # Get workload
+ workload = WORKLOAD_SUITE.get_workload(args.workload)
+ if not workload:
+ logger.error(f"Unknown workload: {args.workload}")
+ return 1
+
+ # Create strategy configs
+ baseline = STRATEGY_LIBRARY.create_strategy_config(baseline_name)
+ test_strategies = [STRATEGY_LIBRARY.create_strategy_config(name) for name in target_strategy_names]
+
+ # Create A/B test config
+ ab_config = ABTestConfig(
+ name=f"ab_test_{args.workload}",
+ baseline_strategy=baseline,
+ test_strategies=test_strategies,
+ workload=workload,
+ iterations=args.iterations,
+ output_dir=args.output_dir,
+ timeout_seconds=args.timeout,
+ )
+
+ # Run A/B test
+ ab_test = StrategyABTest(ab_config)
+ results = ab_test.run_ab_test()
+
+ # Print summary
+ print("\n=== A/B Test Results ===")
+ for strategy_name, comparison in results.items():
+ print(f"\n{strategy_name} vs {comparison.baseline_name}:")
+ print(f" Speedup: {comparison.speedup:.2f}x")
+ print(f" Throughput: {comparison.throughput_improvement:.2f}x")
+ print(f" Memory: {comparison.memory_efficiency:.2f}x")
+ print(f" Significant: {comparison.is_significant}")
+ print(f" Summary: {comparison.summary}")
+
+ return 0
+
+ def _run_workload_suite(self, args) -> int:
+ """Run workload suite test."""
+ # Import here to avoid circular imports
+ from data_juicer.benchmark.strategies.ab_test import (
+ ABTestConfig,
+ StrategyABTest,
+ )
+ from data_juicer.benchmark.strategies.strategy_library import STRATEGY_LIBRARY
+ from data_juicer.benchmark.workloads.workload_suite import WORKLOAD_SUITE
+
+ # Parse workloads, baseline and target strategies
+ workload_names = [w.strip() for w in args.workloads.split(",")]
+ baseline_name = args.baseline.strip()
+ target_strategy_names = [s.strip() for s in args.target_strategies.split(",")]
+
+ # Get workloads
+ workloads = []
+ for name in workload_names:
+ workload = WORKLOAD_SUITE.get_workload(name)
+ if not workload:
+ logger.error(f"Unknown workload: {name}")
+ return 1
+ workloads.append(workload)
+
+ # Create strategy configs
+ baseline = STRATEGY_LIBRARY.create_strategy_config(baseline_name)
+ test_strategies = [STRATEGY_LIBRARY.create_strategy_config(name) for name in target_strategy_names]
+
+ # Run workload suite test
+ ab_test = StrategyABTest(
+ ABTestConfig(
+ name="workload_suite_test",
+ baseline_strategy=baseline,
+ test_strategies=test_strategies,
+ workload=workloads[0], # Will be overridden
+ iterations=args.iterations,
+ output_dir=args.output_dir,
+ timeout_seconds=args.timeout,
+ )
+ )
+
+ results = ab_test.run_workload_suite_ab_test(test_strategies, workloads)
+
+ # Print summary
+ print("\n=== Workload Suite Results ===")
+ for workload_name, workload_results in results.items():
+ print(f"\n{workload_name}:")
+ for strategy_name, comparison in workload_results.items():
+ print(f" {strategy_name}: {comparison.speedup:.2f}x speedup")
+
+ return 0
+
+ def _run_single_benchmark(self, args) -> int:
+ """Run single benchmark."""
+ # Import here to avoid circular imports
+ from ..core.benchmark_runner import BenchmarkConfig, BenchmarkRunner
+ from ..strategies.strategy_library import STRATEGY_LIBRARY
+
+ # Determine dataset and config paths
+ dataset_path, config_path = self._resolve_dataset_and_config(args)
+
+ # Get the actual strategy and apply it to get the configuration changes
+ strategy_obj = STRATEGY_LIBRARY.get_strategy(args.strategy)
+ if strategy_obj:
+ # Apply the strategy to get the actual config changes
+ strategy_config = strategy_obj.apply_to_config({})
+ else:
+ # Fallback to basic config
+ strategy_config = {}
+
+ # Create benchmark config
+ benchmark_config = BenchmarkConfig(
+ dataset_path=dataset_path,
+ config_path=config_path,
+ output_dir=args.output_dir,
+ iterations=args.iterations,
+ timeout_seconds=args.timeout,
+ strategy_name=args.strategy,
+ strategy_config=strategy_config,
+ sample_ratio=args.sample_ratio,
+ sample_method=args.sample_method,
+ )
+
+ # Run benchmark
+ runner = BenchmarkRunner(benchmark_config)
+ results = runner.run_benchmark_suite()
+
+ # Print results
+ print("\n=== Benchmark Results ===")
+ for i, metrics in enumerate(results):
+ print(f"\nRun {i+1}:")
+ print(f" Time: {metrics.total_wall_time:.2f}s")
+ print(f" Throughput: {metrics.samples_per_second:.1f} samples/sec")
+ print(f" Memory: {metrics.peak_memory_mb:.0f} MB")
+ print(f" Retention: {metrics.data_retention_rate:.1%}")
+
+ return 0
+
+ def _resolve_dataset_and_config(self, args):
+ """Resolve dataset and config paths based on arguments."""
+ # Validate arguments
+ if not args.config and not args.modality:
+ raise ValueError("Either --config or --modality must be specified")
+
+ # Determine dataset path
+ dataset_path = None
+ if args.dataset:
+ dataset_path = args.dataset
+ elif args.modality:
+ # Use production dataset for the specified modality
+ dataset_path = f"perf_bench_data/{args.modality}/"
+ if args.modality == "text":
+ dataset_path += "wiki-10k.jsonl"
+ elif args.modality == "image":
+ dataset_path += "10k.jsonl"
+ elif args.modality == "video":
+ dataset_path += "msr_vtt_train.jsonl"
+ elif args.modality == "audio":
+ dataset_path += "audio-10k.jsonl"
+ # If neither --dataset nor --modality is specified, dataset_path will be None
+ # and the config file should contain the dataset_path field
+
+ # Determine config path
+ if args.config:
+ config_path = args.config
+ else: # args.modality is specified
+ # Use production or simple config for the specified modality
+ if args.config_type == "production":
+ config_path = f"tests/benchmark_performance/configs/{args.modality}.yaml"
+ else: # simple
+ config_path = "configs/demo/process.yaml"
+
+ return dataset_path, config_path
+
+ def _list_options(self, args) -> int:
+ """List available options."""
+ # Import here to avoid circular imports
+ from ..strategies.strategy_library import STRATEGY_LIBRARY
+ from ..workloads.workload_suite import WORKLOAD_SUITE
+
+ if args.workloads:
+ print("\n=== Available Workloads ===")
+ for workload in WORKLOAD_SUITE.get_all_workloads():
+ print(f" {workload.name}: {workload.description}")
+ print(f" Modality: {workload.modality}, Complexity: {workload.complexity}")
+ print(f" Expected samples: {workload.expected_samples}")
+ print(f" Duration: {workload.estimated_duration_minutes} min")
+ print()
+
+ if args.strategies:
+ print("\n=== Available Strategies ===")
+ for strategy in STRATEGY_LIBRARY.get_all_strategies():
+ print(f" {strategy.name}: {strategy.description}")
+ print(f" Type: {strategy.strategy_type.value}")
+ print()
+
+ return 0
+
+
+def main():
+ """Main entry point for CLI."""
+ cli = BenchmarkCLI()
+ return cli.run()
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/data_juicer/benchmark/utils/config_manager.py b/data_juicer/benchmark/utils/config_manager.py
new file mode 100644
index 0000000000..41c52bfcaf
--- /dev/null
+++ b/data_juicer/benchmark/utils/config_manager.py
@@ -0,0 +1,159 @@
+#!/usr/bin/env python3
+"""
+Configuration management for benchmark framework.
+"""
+
+import hashlib
+import json
+from pathlib import Path
+from typing import Any, Dict, List
+
+import yaml
+from loguru import logger
+
+
+class ConfigManager:
+ """Manages configuration files for benchmark testing."""
+
+ def __init__(self):
+ self.config_cache = {}
+
+ def load_config(self, config_path: str) -> Dict[str, Any]:
+ """Load configuration from file."""
+ config_path = Path(config_path)
+
+ # Check cache first
+ if str(config_path) in self.config_cache:
+ return self.config_cache[str(config_path)].copy()
+
+ try:
+ with open(config_path, "r") as f:
+ if config_path.suffix.lower() in [".yaml", ".yml"]:
+ config = yaml.safe_load(f)
+ elif config_path.suffix.lower() == ".json":
+ config = json.load(f)
+ else:
+ raise ValueError(f"Unsupported config format: {config_path.suffix}")
+
+ # Cache the config
+ self.config_cache[str(config_path)] = config.copy()
+
+ logger.debug(f"Loaded config from {config_path}")
+ return config
+
+ except Exception as e:
+ logger.error(f"Failed to load config from {config_path}: {e}")
+ raise
+
+ def save_config(self, config: Dict[str, Any], output_path: str):
+ """Save configuration to file."""
+ output_path = Path(output_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ try:
+ with open(output_path, "w") as f:
+ if output_path.suffix.lower() in [".yaml", ".yml"]:
+ yaml.dump(config, f, default_flow_style=False, indent=2)
+ elif output_path.suffix.lower() == ".json":
+ json.dump(config, f, indent=2)
+ else:
+ raise ValueError(f"Unsupported output format: {output_path.suffix}")
+
+ logger.debug(f"Saved config to {output_path}")
+
+ except Exception as e:
+ logger.error(f"Failed to save config to {output_path}: {e}")
+ raise
+
+ def apply_strategy_config(self, base_config: Dict[str, Any], strategy_config: Dict[str, Any]) -> Dict[str, Any]:
+ """Apply strategy-specific configuration to base config."""
+
+ # Deep copy to avoid modifying original
+ result_config = self._deep_copy_dict(base_config)
+
+ # Apply strategy modifications
+ for key, value in strategy_config.items():
+ self._set_nested_value(result_config, key, value)
+
+ logger.debug(f"Applied strategy config: {strategy_config}")
+ return result_config
+
+ def merge_configs(self, *configs: Dict[str, Any]) -> Dict[str, Any]:
+ """Merge multiple configurations."""
+ result = {}
+
+ for config in configs:
+ result = self._deep_merge_dicts(result, config)
+
+ return result
+
+ def validate_config(self, config: Dict[str, Any], schema: Dict[str, Any] = None) -> List[str]:
+ """Validate configuration against schema."""
+ issues = []
+
+ if schema is None:
+ # Basic validation
+ required_fields = ["dataset", "process"]
+ for field in required_fields:
+ if field not in config:
+ issues.append(f"Missing required field: {field}")
+ else:
+ # Schema-based validation
+ issues.extend(self._validate_against_schema(config, schema))
+
+ return issues
+
+ def get_config_hash(self, config: Dict[str, Any]) -> str:
+ """Generate hash for configuration."""
+ config_str = json.dumps(config, sort_keys=True)
+ return hashlib.md5(config_str.encode()).hexdigest()[:8]
+
+ def _deep_copy_dict(self, d: Dict[str, Any]) -> Dict[str, Any]:
+ """Deep copy a dictionary."""
+ return json.loads(json.dumps(d))
+
+ def _deep_merge_dicts(self, dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
+ """Deep merge two dictionaries."""
+ result = dict1.copy()
+
+ for key, value in dict2.items():
+ if key in result and isinstance(result[key], dict) and isinstance(value, dict):
+ result[key] = self._deep_merge_dicts(result[key], value)
+ else:
+ result[key] = value
+
+ return result
+
+ def _set_nested_value(self, config: Dict[str, Any], key: str, value: Any):
+ """Set a nested value in configuration."""
+ keys = key.split(".")
+ current = config
+
+ for k in keys[:-1]:
+ if k not in current:
+ current[k] = {}
+ current = current[k]
+
+ current[keys[-1]] = value
+
+ def _validate_against_schema(self, config: Dict[str, Any], schema: Dict[str, Any]) -> List[str]:
+ """Validate configuration against schema."""
+ issues = []
+
+ # This is a simplified schema validation
+ # In practice, you might want to use a more robust schema validation library
+
+ for field, field_schema in schema.items():
+ if field_schema.get("required", False) and field not in config:
+ issues.append(f"Missing required field: {field}")
+
+ if field in config:
+ field_type = field_schema.get("type", "string")
+ if field_type == "boolean" and not isinstance(config[field], bool):
+ issues.append(f"Field {field} must be boolean")
+ elif field_type == "integer" and not isinstance(config[field], int):
+ issues.append(f"Field {field} must be integer")
+ elif field_type == "string" and not isinstance(config[field], str):
+ issues.append(f"Field {field} must be string")
+
+ return issues
diff --git a/data_juicer/benchmark/workloads/__init__.py b/data_juicer/benchmark/workloads/__init__.py
new file mode 100644
index 0000000000..7e4aa0780f
--- /dev/null
+++ b/data_juicer/benchmark/workloads/__init__.py
@@ -0,0 +1,5 @@
+"""Workload definitions for comprehensive benchmarking."""
+
+from .workload_suite import WORKLOAD_SUITE, WorkloadDefinition, WorkloadSuite
+
+__all__ = ["WorkloadSuite", "WorkloadDefinition", "WORKLOAD_SUITE"]
diff --git a/data_juicer/benchmark/workloads/workload_suite.py b/data_juicer/benchmark/workloads/workload_suite.py
new file mode 100644
index 0000000000..1e9c3f4a76
--- /dev/null
+++ b/data_juicer/benchmark/workloads/workload_suite.py
@@ -0,0 +1,229 @@
+#!/usr/bin/env python3
+"""
+Comprehensive workload suite for benchmarking different scenarios.
+"""
+
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from loguru import logger
+
+
+@dataclass
+class WorkloadDefinition:
+ """Definition of a benchmark workload."""
+
+ name: str
+ description: str
+ dataset_path: str
+ config_path: str
+ expected_samples: int
+ modality: str # text, image, video, audio, multimodal
+ complexity: str # simple, medium, complex
+ estimated_duration_minutes: int
+ resource_requirements: Dict[str, Any]
+
+ def __post_init__(self):
+ """Validate workload definition."""
+ if not Path(self.dataset_path).exists():
+ logger.warning(f"Dataset path does not exist: {self.dataset_path}")
+ if not Path(self.config_path).exists():
+ logger.warning(f"Config path does not exist: {self.config_path}")
+
+
+class WorkloadSuite:
+ """Comprehensive suite of benchmark workloads."""
+
+ def __init__(self):
+ self.workloads = {}
+ self._initialize_workloads()
+
+ def _initialize_workloads(self):
+ """Initialize all available workloads using production datasets and configs."""
+
+ # Text workloads - Production Wikipedia dataset
+ self.workloads["text_simple"] = WorkloadDefinition(
+ name="text_simple",
+ description="Simple text processing with basic filters",
+ dataset_path="perf_bench_data/text/wiki-10k.jsonl",
+ config_path="configs/demo/process.yaml",
+ expected_samples=10000,
+ modality="text",
+ complexity="simple",
+ estimated_duration_minutes=5,
+ resource_requirements={"memory_gb": 2, "cpu_cores": 2},
+ )
+
+ self.workloads["text_production"] = WorkloadDefinition(
+ name="text_production",
+ description="Production text processing with ML operations",
+ dataset_path="perf_bench_data/text/wiki-10k.jsonl",
+ config_path="tests/benchmark_performance/configs/text.yaml",
+ expected_samples=10000,
+ modality="text",
+ complexity="complex",
+ estimated_duration_minutes=40,
+ resource_requirements={"memory_gb": 8, "cpu_cores": 12, "gpu": True},
+ )
+
+ # Image workloads - Production image dataset
+ self.workloads["image_simple"] = WorkloadDefinition(
+ name="image_simple",
+ description="Simple image processing",
+ dataset_path="perf_bench_data/image/10k.jsonl",
+ config_path="configs/demo/process.yaml",
+ expected_samples=10000,
+ modality="image",
+ complexity="simple",
+ estimated_duration_minutes=10,
+ resource_requirements={"memory_gb": 4, "cpu_cores": 4, "gpu": True},
+ )
+
+ self.workloads["image_production"] = WorkloadDefinition(
+ name="image_production",
+ description="Production image processing with ML models",
+ dataset_path="perf_bench_data/image/10k.jsonl",
+ config_path="tests/benchmark_performance/configs/image.yaml",
+ expected_samples=10000,
+ modality="image",
+ complexity="complex",
+ estimated_duration_minutes=30,
+ resource_requirements={"memory_gb": 16, "cpu_cores": 12, "gpu": True},
+ )
+
+ # Video workloads - Production MSR-VTT dataset
+ self.workloads["video_simple"] = WorkloadDefinition(
+ name="video_simple",
+ description="Simple video processing",
+ dataset_path="perf_bench_data/video/msr_vtt_train.jsonl",
+ config_path="configs/demo/process.yaml",
+ expected_samples=1000,
+ modality="video",
+ complexity="simple",
+ estimated_duration_minutes=20,
+ resource_requirements={"memory_gb": 8, "cpu_cores": 8, "gpu": True},
+ )
+
+ self.workloads["video_production"] = WorkloadDefinition(
+ name="video_production",
+ description="Production video processing with frame analysis",
+ dataset_path="perf_bench_data/video/msr_vtt_train.jsonl",
+ config_path="tests/benchmark_performance/configs/video.yaml",
+ expected_samples=1000,
+ modality="video",
+ complexity="complex",
+ estimated_duration_minutes=60,
+ resource_requirements={"memory_gb": 32, "cpu_cores": 16, "gpu": True},
+ )
+
+ # Audio workloads - Production audio dataset
+ self.workloads["audio_simple"] = WorkloadDefinition(
+ name="audio_simple",
+ description="Simple audio processing",
+ dataset_path="perf_bench_data/audio/audio-10k.jsonl",
+ config_path="configs/demo/process.yaml",
+ expected_samples=10000,
+ modality="audio",
+ complexity="simple",
+ estimated_duration_minutes=15,
+ resource_requirements={"memory_gb": 4, "cpu_cores": 4},
+ )
+
+ self.workloads["audio_production"] = WorkloadDefinition(
+ name="audio_production",
+ description="Production audio processing with quality filters",
+ dataset_path="perf_bench_data/audio/audio-10k.jsonl",
+ config_path="tests/benchmark_performance/configs/audio.yaml",
+ expected_samples=10000,
+ modality="audio",
+ complexity="complex",
+ estimated_duration_minutes=25,
+ resource_requirements={"memory_gb": 8, "cpu_cores": 8},
+ )
+
+ # C4 dataset stress tests - Local vs Ray execution
+ self.workloads["stress_test_text_c4_local"] = WorkloadDefinition(
+ name="stress_test_text_c4_local",
+ description="C4 dataset stress test with local execution (16 processes)",
+ dataset_path="perf_bench_data/text/c4-train.00000-of-01024.jsonl",
+ config_path="tests/benchmark_performance/configs/text-c4-local.yaml",
+ expected_samples=100000, # C4 dataset is much larger
+ modality="text",
+ complexity="complex",
+ estimated_duration_minutes=120, # Longer due to C4 dataset size
+ resource_requirements={"memory_gb": 64, "cpu_cores": 16, "gpu": False},
+ )
+
+ self.workloads["stress_test_text_c4_ray"] = WorkloadDefinition(
+ name="stress_test_text_c4_ray",
+ description="C4 dataset stress test with Ray distributed execution",
+ dataset_path="perf_bench_data/text/c4-train.00000-of-01024.jsonl",
+ config_path="tests/benchmark_performance/configs/text-c4-ray.yaml",
+ expected_samples=100000, # C4 dataset is much larger
+ modality="text",
+ complexity="complex",
+ estimated_duration_minutes=90, # Ray should be faster than local
+ resource_requirements={"memory_gb": 64, "cpu_cores": 32, "gpu": False},
+ )
+
+ # Operation Reordering Showcase
+ self.workloads["op_reorder_showcase"] = WorkloadDefinition(
+ name="op_reorder_showcase",
+ description="Showcase workload designed to demonstrate operation reordering benefits",
+ dataset_path="perf_bench_data/text/c4-train.00000-of-01024.jsonl",
+ config_path="configs/optimization/op_reorder_showcase.yaml",
+ expected_samples=10000,
+ modality="text",
+ complexity="complex",
+ estimated_duration_minutes=15, # Moderate complexity
+ resource_requirements={"memory_gb": 8, "cpu_cores": 8, "gpu": False},
+ )
+
+ def get_workload(self, name: str) -> Optional[WorkloadDefinition]:
+ """Get a specific workload by name."""
+ return self.workloads.get(name)
+
+ def get_workloads_by_modality(self, modality: str) -> List[WorkloadDefinition]:
+ """Get all workloads for a specific modality."""
+ return [w for w in self.workloads.values() if w.modality == modality]
+
+ def get_workloads_by_complexity(self, complexity: str) -> List[WorkloadDefinition]:
+ """Get all workloads for a specific complexity level."""
+ return [w for w in self.workloads.values() if w.complexity == complexity]
+
+ def get_all_workloads(self) -> List[WorkloadDefinition]:
+ """Get all available workloads."""
+ return list(self.workloads.values())
+
+ def get_workload_names(self) -> List[str]:
+ """Get names of all available workloads."""
+ return list(self.workloads.keys())
+
+ def validate_workloads(self) -> Dict[str, List[str]]:
+ """Validate all workloads and return any issues."""
+ issues = {}
+
+ for name, workload in self.workloads.items():
+ workload_issues = []
+
+ if not Path(workload.dataset_path).exists():
+ workload_issues.append(f"Dataset not found: {workload.dataset_path}")
+
+ if not Path(workload.config_path).exists():
+ workload_issues.append(f"Config not found: {workload.config_path}")
+
+ if workload.expected_samples <= 0:
+ workload_issues.append("Expected samples must be positive")
+
+ if workload.estimated_duration_minutes <= 0:
+ workload_issues.append("Estimated duration must be positive")
+
+ if workload_issues:
+ issues[name] = workload_issues
+
+ return issues
+
+
+# Global workload suite instance
+WORKLOAD_SUITE = WorkloadSuite()
diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py
index e368784288..d7e49aba0c 100644
--- a/data_juicer/config/config.py
+++ b/data_juicer/config/config.py
@@ -471,9 +471,10 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None, l
"--op_fusion",
type=bool,
default=False,
- help="Whether to fuse operators that share the same intermediate " # noqa: E251
- "variables automatically. Op fusion might reduce the memory "
- "requirements slightly but speed up the whole process.",
+ help="Whether to fuse operators that share the same intermediate "
+ "variables automatically. Op fusion increases memory usage "
+ "but significantly speeds up the whole process by avoiding "
+ "re-computation of shared intermediate variables.",
)
parser.add_argument(
"--fusion_strategy",
@@ -520,6 +521,18 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None, l
default=False,
help="Whether to save all stats to only one file. Only used in " "Analysis.",
)
+ parser.add_argument(
+ "--enable_optimizer",
+ type=bool,
+ default=False,
+ help="Enable/disable the core pipeline optimizer.",
+ )
+ parser.add_argument(
+ "--optimizer_strategies",
+ type=List[str],
+ default=["op_reorder"],
+ help="List of optimization strategies to apply when optimizer is enabled.",
+ )
parser.add_argument("--ray_address", type=str, default="auto", help="The address of the Ray cluster.")
parser.add_argument(
"--custom-operator-paths", nargs="+", help="Paths to custom operator scripts or directories."
@@ -736,6 +749,23 @@ def init_setup_from_cfg(cfg: Namespace, load_configs_only=False):
if cfg.get("auto", False):
cfg.process = load_ops_with_stats_meta()
+ # Handle optimization configuration
+ cfg.enable_optimizer = cfg.get("enable_optimizer", False)
+ cfg.optimizer_strategies = cfg.get("optimizer_strategies", ["op_reorder"])
+
+ # Ensure optimizer_strategies is a list
+ if isinstance(cfg.optimizer_strategies, str):
+ cfg.optimizer_strategies = cfg.optimizer_strategies.split(",")
+
+ if cfg.enable_optimizer:
+ from data_juicer.core.optimizer.strategy import StrategyRegistry
+
+ available_strategies = StrategyRegistry.get_available_strategies()
+ logger.info(f"๐ง Pipeline optimizer enabled with strategies: {cfg.optimizer_strategies}")
+ logger.debug(f"๐ง Available optimization strategies: {available_strategies}")
+ else:
+ logger.debug("๐ง Pipeline optimizer disabled")
+
# Apply text_key modification during initializing configs
# users can freely specify text_key for different ops using `text_key`
# otherwise, set arg text_key of each op to text_keys
diff --git a/data_juicer/core/__init__.py b/data_juicer/core/__init__.py
index 7261b3419c..88bb078bd2 100644
--- a/data_juicer/core/__init__.py
+++ b/data_juicer/core/__init__.py
@@ -4,6 +4,11 @@
from .executor import DefaultExecutor, ExecutorBase, ExecutorFactory
from .exporter import Exporter
from .monitor import Monitor
+from .optimization_manager import (
+ OptimizationManager,
+ apply_optimizations,
+ get_optimization_manager,
+)
from .ray_exporter import RayExporter
from .tracer import Tracer
@@ -17,5 +22,8 @@
"Exporter",
"RayExporter",
"Monitor",
+ "OptimizationManager",
+ "apply_optimizations",
+ "get_optimization_manager",
"Tracer",
]
diff --git a/data_juicer/core/executor/default_executor.py b/data_juicer/core/executor/default_executor.py
index 2a5cdc1a78..8d5af56240 100644
--- a/data_juicer/core/executor/default_executor.py
+++ b/data_juicer/core/executor/default_executor.py
@@ -14,7 +14,6 @@
from data_juicer.core.exporter import Exporter
from data_juicer.core.tracer import Tracer
from data_juicer.ops import load_ops
-from data_juicer.ops.op_fusion import fuse_operators
from data_juicer.ops.selector import (
FrequencySpecifiedFieldSelector,
TopkSpecifiedFieldSelector,
@@ -121,15 +120,10 @@ def run(
logger.info("Preparing process operators...")
ops = load_ops(self.cfg.process)
- # OP fusion
- if self.cfg.op_fusion:
- probe_res = None
- if self.cfg.fusion_strategy == "probe":
- logger.info("Probe the OP speed for OP reordering...")
- probe_res, _ = self.adapter.probe_small_batch(dataset, ops)
+ # Apply core optimizer if enabled
+ from data_juicer.core.optimization_manager import apply_optimizations
- logger.info(f"Start OP fusion and reordering with strategy " f"[{self.cfg.fusion_strategy}]...")
- ops = fuse_operators(ops, probe_res)
+ ops = apply_optimizations(ops, self.cfg)
# adaptive batch size
if self.cfg.adaptive_batch_size:
diff --git a/data_juicer/core/executor/ray_executor.py b/data_juicer/core/executor/ray_executor.py
index 32ef0570a2..24e96fddd0 100644
--- a/data_juicer/core/executor/ray_executor.py
+++ b/data_juicer/core/executor/ray_executor.py
@@ -11,7 +11,6 @@
from data_juicer.core.executor import ExecutorBase
from data_juicer.core.ray_exporter import RayExporter
from data_juicer.ops import load_ops
-from data_juicer.ops.op_fusion import fuse_operators
from data_juicer.utils.lazy_loader import LazyLoader
ray = LazyLoader("ray")
@@ -97,9 +96,10 @@ def run(self, load_data_np: Optional[PositiveInt] = None, skip_export: bool = Fa
logger.info("Preparing process operators...")
ops = load_ops(self.cfg.process)
- if self.cfg.op_fusion:
- logger.info(f"Start OP fusion and reordering with strategy " f"[{self.cfg.fusion_strategy}]...")
- ops = fuse_operators(ops)
+ # Apply core optimizer if enabled
+ from data_juicer.core.optimization_manager import apply_optimizations
+
+ ops = apply_optimizations(ops, self.cfg)
with TempDirManager(self.tmp_dir):
# 3. data process
diff --git a/data_juicer/core/optimization_manager.py b/data_juicer/core/optimization_manager.py
new file mode 100644
index 0000000000..273da42528
--- /dev/null
+++ b/data_juicer/core/optimization_manager.py
@@ -0,0 +1,349 @@
+#!/usr/bin/env python3
+"""
+Optimization Manager for Data-Juicer Pipeline Optimization.
+
+This module provides a centralized way to apply optimization strategies
+to data processing pipelines across different executors.
+"""
+
+from typing import Any, List
+
+from loguru import logger
+
+from data_juicer.core.optimizer.optimizer import PipelineOptimizer
+from data_juicer.core.optimizer.strategy import StrategyRegistry
+from data_juicer.core.pipeline_ast import OpNode, OpType, PipelineAST
+
+
+class OptimizationManager:
+ """
+ Centralized manager for applying optimization strategies to data processing pipelines.
+
+ This class provides a clean interface for executors to apply optimization
+ without duplicating logic across different executor implementations.
+ """
+
+ def __init__(self, cfg=None):
+ """Initialize the optimization manager."""
+ self.cfg = cfg
+ self._check_optimization_enabled()
+
+ def _check_optimization_enabled(self):
+ """Check if optimization is enabled via config."""
+ # Check config for optimization settings
+ if self.cfg and hasattr(self.cfg, "enable_optimizer"):
+ self.optimization_enabled = self.cfg.enable_optimizer
+ else:
+ self.optimization_enabled = False
+
+ if self.optimization_enabled:
+ # Get strategies from config
+ if self.cfg and hasattr(self.cfg, "optimizer_strategies"):
+ self.optimization_strategies = self.cfg.optimizer_strategies
+ else:
+ self.optimization_strategies = ["op_reorder"] # Default strategy
+
+ # Ensure strategies is a list
+ if isinstance(self.optimization_strategies, str):
+ self.optimization_strategies = self.optimization_strategies.split(",")
+
+ logger.info(f"๐ง Core optimizer enabled with strategies: {self.optimization_strategies}")
+ else:
+ self.optimization_strategies = []
+
+ def apply_optimizations(self, ops: List[Any]) -> List[Any]:
+ """
+ Apply optimization strategies to a list of operations.
+
+ Args:
+ ops: List of operations to optimize
+
+ Returns:
+ Optimized list of operations
+ """
+ if not self.optimization_enabled:
+ return ops
+
+ try:
+ logger.info("๐ง Applying core optimizer to operations...")
+
+ # Create AST from operations
+ ast = self._create_ast_from_ops(ops)
+
+ # Print original operation order
+ original_order = [getattr(op, "_name", getattr(op, "name", str(op))) for op in ops]
+ logger.info("๐ Original operation order:")
+ for i, op_name in enumerate(original_order, 1):
+ logger.info(f" {i}. {op_name}")
+
+ # Print original AST structure
+ logger.info("๐ Original AST structure:")
+ self._print_ast_structure(ast, "BEFORE")
+
+ # Apply core optimizer with properly initialized strategies
+ strategy_objects = self._initialize_strategies()
+ optimizer = PipelineOptimizer(strategy_objects)
+ optimized_ast = optimizer.optimize(ast)
+
+ # Print optimized AST structure
+ logger.info("๐ Optimized AST structure:")
+ self._print_ast_structure(optimized_ast, "AFTER")
+
+ # Extract optimized operations from the AST
+ optimized_ops = self._extract_ops_from_ast(optimized_ast, ops)
+
+ # Print final optimized operation order
+ optimized_order = [getattr(op, "_name", getattr(op, "name", str(op))) for op in optimized_ops]
+ logger.info("๐ Final optimized operation order:")
+ for i, op_name in enumerate(optimized_order, 1):
+ logger.info(f" {i}. {op_name}")
+
+ # Show the difference
+ if original_order != optimized_order:
+ logger.info("๐ Operation order has been optimized!")
+ logger.info("๐ Changes:")
+ for i, (orig, opt) in enumerate(zip(original_order, optimized_order)):
+ if orig != opt:
+ logger.info(f" Position {i+1}: {orig} โ {opt}")
+ else:
+ logger.info("โน๏ธ No changes to operation order")
+
+ logger.info("โ
Core optimizer applied successfully")
+ return optimized_ops
+
+ except Exception as e:
+ logger.error(f"โ Failed to apply core optimizer: {e}")
+ logger.warning("โ ๏ธ Continuing with original operation order")
+ return ops
+
+ def _create_ast_from_ops(self, ops: List[Any]) -> PipelineAST:
+ """Create a PipelineAST from operations."""
+ ast = PipelineAST()
+ ast.root = OpNode(name="root", op_type=OpType.ROOT, config={})
+
+ current_node = ast.root
+ for op in ops:
+ # Determine operation type and name
+ if hasattr(op, "_name"):
+ op_name = op._name
+ elif hasattr(op, "name"):
+ op_name = op.name
+ else:
+ op_name = str(op)
+
+ # Determine operation type based on name
+ if "filter" in op_name.lower():
+ op_type = OpType.FILTER
+ elif "mapper" in op_name.lower():
+ op_type = OpType.MAPPER
+ else:
+ op_type = OpType.MAPPER # Default to mapper
+
+ # Get operation config
+ op_config = {}
+ if hasattr(op, "config"):
+ op_config = op.config
+ elif hasattr(op, "__dict__"):
+ op_config = {k: v for k, v in op.__dict__.items() if not k.startswith("_")}
+
+ # Create operation node
+ op_node = OpNode(name=op_name, op_type=op_type, config=op_config)
+
+ # Add to AST
+ current_node.children = [op_node]
+ op_node.parent = current_node
+ current_node = op_node
+
+ return ast
+
+ def _extract_ops_from_ast(self, ast: PipelineAST, original_ops: List[Any]) -> List[Any]:
+ """Extract optimized operations from the AST."""
+ try:
+ logger.info(f"๐ Extracting operations from AST with {len(original_ops)} original operations")
+
+ # Get the optimized operation order from the AST
+ optimized_order = self._get_operation_order_from_ast(ast)
+
+ # Create a mapping from operation names to original operation objects
+ op_map = {}
+ for op in original_ops:
+ if hasattr(op, "_name"):
+ op_name = op._name
+ elif hasattr(op, "name"):
+ op_name = op.name
+ else:
+ op_name = str(op)
+ op_map[op_name] = op
+
+ logger.info(f"๐ Created operation map with {len(op_map)} operations")
+
+ # Reorder operations according to the optimized AST
+ optimized_ops = []
+ for op_name in optimized_order:
+ if op_name in op_map:
+ optimized_ops.append(op_map[op_name])
+ else:
+ logger.warning(f"โ ๏ธ Could not find operation '{op_name}' in original operations")
+
+ # Add any operations that weren't in the AST (shouldn't happen, but safety check)
+ for op in original_ops:
+ op_name = getattr(op, "_name", getattr(op, "name", str(op)))
+ if op_name not in optimized_order:
+ logger.warning(f"โ ๏ธ Operation '{op_name}' not found in optimized order, adding at end")
+ optimized_ops.append(op)
+
+ logger.info(f"๐ Reordered {len(optimized_ops)} operations based on optimization")
+ return optimized_ops
+
+ except Exception as e:
+ logger.error(f"โ Failed to extract optimized operations: {e}")
+ import traceback
+
+ logger.error(f"โ Traceback: {traceback.format_exc()}")
+ logger.warning("โ ๏ธ Returning original operations")
+ return original_ops
+
+ def _get_operation_order_from_ast(self, ast: PipelineAST) -> List[str]:
+ """Get the operation order from the AST using depth-first traversal."""
+ order = []
+
+ logger.info(f"๐ Starting AST traversal from root: {ast.root}")
+
+ # Use depth-first traversal to get all operations
+ self._traverse_ast_dfs(ast.root, order)
+
+ logger.info(f"๐ Extracted operation order from AST: {order}")
+ return order
+
+ def _traverse_ast_dfs(self, node: OpNode, order: List[str]):
+ """Depth-first traversal of AST nodes."""
+ if not node:
+ return
+
+ # Skip root node but process its children
+ if node.name != "root":
+ order.append(node.name)
+ logger.info(f"๐ Added to order: {node.name}")
+
+ # Recursively traverse all children
+ if node.children:
+ for i, child in enumerate(node.children):
+ logger.info(f"๐ Processing child {i+1}/{len(node.children)}: {child.name}")
+ self._traverse_ast_dfs(child, order)
+
+ def is_optimization_enabled(self) -> bool:
+ """Check if optimization is enabled."""
+ return self.optimization_enabled
+
+ def get_enabled_strategies(self) -> List[str]:
+ """Get list of enabled optimization strategies."""
+ return self.optimization_strategies
+
+ def _initialize_strategies(self) -> List[Any]:
+ """Initialize strategy objects from strategy names using the registry."""
+ strategy_objects = []
+
+ for strategy_name in self.optimization_strategies:
+ strategy_name = strategy_name.strip() # Remove any whitespace
+
+ # Use the registry to create strategy instances
+ strategy_obj = StrategyRegistry.create_strategy(strategy_name)
+
+ if strategy_obj is not None:
+ strategy_objects.append(strategy_obj)
+ logger.info(f"๐ง Initialized strategy: {strategy_name}")
+ else:
+ logger.warning(f"โ ๏ธ Failed to initialize strategy: {strategy_name}")
+
+ if not strategy_objects:
+ logger.warning("โ ๏ธ No valid strategies initialized, using default op_reorder strategy")
+ default_strategy = StrategyRegistry.create_strategy("op_reorder")
+ if default_strategy is not None:
+ strategy_objects = [default_strategy]
+ else:
+ logger.error("โ Failed to create default strategy")
+ strategy_objects = []
+
+ logger.info(f"๐ง Initialized {len(strategy_objects)} optimization strategies")
+ return strategy_objects
+
+ def _print_ast_structure(self, ast: PipelineAST, phase: str):
+ """Print the AST structure for debugging purposes."""
+ logger.info(f"๐ {phase} optimization - AST structure:")
+
+ if not ast or not ast.root:
+ logger.info(" Empty AST")
+ return
+
+ # Print the AST tree structure
+ self._print_ast_node(ast.root, 0, phase)
+
+ def _print_ast_node(self, node: OpNode, depth: int, phase: str):
+ """Recursively print AST node structure."""
+ indent = " " * depth
+
+ if node.name == "root":
+ logger.info(f"{indent}๐ณ ROOT")
+ else:
+ # Get operation type emoji
+ type_emoji = "๐ง" if node.op_type == OpType.MAPPER else "๐" if node.op_type == OpType.FILTER else "โ๏ธ"
+ logger.info(f"{indent}{type_emoji} {node.op_type.name}: {node.name}")
+
+ # Print key config parameters if available
+ if node.config:
+ important_configs = {}
+ for key, value in node.config.items():
+ if key in [
+ "text_key",
+ "image_key",
+ "audio_key",
+ "video_key",
+ "threshold",
+ "min_length",
+ "max_length",
+ ]:
+ important_configs[key] = value
+
+ if important_configs:
+ config_str = ", ".join([f"{k}={v}" for k, v in important_configs.items()])
+ logger.info(f"{indent} ๐ Config: {config_str}")
+
+ # Print children
+ if node.children:
+ for child in node.children:
+ self._print_ast_node(child, depth + 1, phase)
+
+
+# Global optimization manager instance
+_optimization_manager = None
+
+
+def get_optimization_manager(cfg=None) -> OptimizationManager:
+ """
+ Get the global optimization manager instance.
+
+ Args:
+ cfg: Configuration object (optional)
+
+ Returns:
+ OptimizationManager instance
+ """
+ global _optimization_manager
+ if _optimization_manager is None:
+ _optimization_manager = OptimizationManager(cfg)
+ return _optimization_manager
+
+
+def apply_optimizations(ops: List[Any], cfg=None) -> List[Any]:
+ """
+ Convenience function to apply optimizations to operations.
+
+ Args:
+ ops: List of operations to optimize
+ cfg: Configuration object (optional)
+
+ Returns:
+ Optimized list of operations
+ """
+ manager = get_optimization_manager(cfg)
+ return manager.apply_optimizations(ops)
diff --git a/data_juicer/core/optimizer/filter_fusion_strategy.py b/data_juicer/core/optimizer/filter_fusion_strategy.py
new file mode 100644
index 0000000000..4492d22247
--- /dev/null
+++ b/data_juicer/core/optimizer/filter_fusion_strategy.py
@@ -0,0 +1,642 @@
+from typing import Any, Dict, List, Optional
+
+from loguru import logger
+
+from data_juicer.core.optimizer.strategy import OptimizationStrategy, register_strategy
+from data_juicer.core.pipeline_ast import OpNode, OpType, PipelineAST
+from data_juicer.utils.constant import InterVars, StatsKeys
+from data_juicer.utils.registry import Registry
+
+# Type of intermediate vars
+INTER_LINES = Registry(InterVars.lines)
+INTER_WORDS = Registry(InterVars.words)
+LOADED_IMAGES = Registry(InterVars.loaded_images)
+LOADED_AUDIOS = Registry(InterVars.loaded_audios)
+LOADED_VIDEOS = Registry(InterVars.loaded_videos)
+INTER_SAMPLED_FRAMES = Registry(InterVars.sampled_frames)
+
+ALL_INTER_VARS = [INTER_LINES, INTER_WORDS, LOADED_AUDIOS, LOADED_IMAGES, LOADED_VIDEOS, INTER_SAMPLED_FRAMES]
+
+
+@register_strategy("filter_fusion")
+class FilterFusionStrategy(OptimizationStrategy):
+ """Strategy for fusing filter operations in the pipeline."""
+
+ def __init__(
+ self, probe_results: Optional[Dict[str, Any]] = None, analyzer_insights: Optional[Dict[str, Any]] = None
+ ):
+ """Initialize the filter fusion strategy.
+
+ Args:
+ probe_results: Optional dictionary containing operation speeds
+ analyzer_insights: Optional dictionary containing dataset analysis insights
+ """
+ super().__init__(name="filter_fusion")
+ self.probe_results = probe_results or {}
+ self.analyzer_insights = analyzer_insights or {}
+
+ def optimize(self, ast: PipelineAST) -> PipelineAST:
+ """Apply filter fusion to the pipeline AST.
+
+ Args:
+ ast: The pipeline AST to optimize
+
+ Returns:
+ Optimized pipeline AST
+ """
+ if not ast.root:
+ return ast
+
+ # Create a new AST
+ new_ast = PipelineAST()
+ new_ast.root = OpNode(name="root", op_type=OpType.ROOT, config={})
+
+ # Get all unique operation chains
+ op_chains = self._get_unique_op_chains(ast.root)
+
+ # Process each chain
+ current = new_ast.root
+ for chain in op_chains:
+ # Group filter operations with analyzer insights
+ filter_groups = self._group_filters_with_insights(chain)
+
+ for group in filter_groups:
+ if len(group) > 1:
+ # Create fused operation with clean naming
+ fused_name = "fused_filter"
+ detailed_ops = [n.name for n in group]
+ logger.info(f"Fusing filter operations into {fused_name}: {detailed_ops}")
+
+ # Create operation configs
+ op_configs = []
+ for op in group:
+ op_config = {op.name: op.config or {}}
+ op_configs.append(op_config)
+
+ # Create fused node
+ fused_node = OpNode(
+ name=fused_name,
+ op_type=OpType.FILTER,
+ config={
+ "general_fused_op": {
+ "fused_op_list": op_configs,
+ "detailed_ops": detailed_ops, # For display purposes
+ }
+ },
+ )
+ current.add_child(fused_node)
+ current = fused_node
+ else:
+ # Keep single operations as is
+ new_node = OpNode(name=group[0].name, op_type=group[0].op_type, config=group[0].config or {})
+ current.add_child(new_node)
+ current = new_node
+
+ return new_ast
+
+ def _get_unique_op_chains(self, node: OpNode) -> List[List[OpNode]]:
+ """Get unique chains of operations from the tree.
+
+ Args:
+ node: Root node of the tree
+
+ Returns:
+ List of unique operation chains
+ """
+ chains = []
+ seen_chains = set()
+
+ def traverse(current: OpNode, chain: List[OpNode]):
+ if not current.children:
+ # End of chain, check if we've seen this sequence before
+ chain_key = tuple(n.name for n in chain)
+ if chain_key not in seen_chains:
+ chains.append(chain.copy())
+ seen_chains.add(chain_key)
+ return
+
+ for child in current.children:
+ chain.append(child)
+ traverse(child, chain)
+ chain.pop()
+
+ traverse(node, [])
+ return chains
+
+ def _group_filters_with_insights(self, chain: List[OpNode]) -> List[List[OpNode]]:
+ """Group filter operations using analyzer insights for better decisions.
+
+ Args:
+ chain: List of operations in the pipeline
+
+ Returns:
+ List of filter operation groups
+ """
+ groups = []
+ current_group = []
+
+ for node in chain:
+ if not PipelineAST.is_filter_op(node):
+ # If we encounter a non-filter, finalize current group
+ if current_group:
+ groups.append(current_group)
+ current_group = []
+ # Add the non-filter node as a separate group
+ groups.append([node])
+ else:
+ # This is a filter node
+ if not current_group:
+ # Start a new group
+ current_group = [node]
+ else:
+ # Check if current filter can be fused with the group using insights
+ if self._can_fuse_with_group_insights(node, current_group):
+ current_group.append(node)
+ else:
+ # Finalize current group and start a new one
+ groups.append(current_group)
+ current_group = [node]
+
+ # Don't forget the last group
+ if current_group:
+ groups.append(current_group)
+
+ return groups
+
+ def _can_fuse_with_group_insights(self, node: OpNode, group: List[OpNode]) -> bool:
+ """Check if a filter can be fused with a group using analyzer insights.
+
+ Args:
+ node: Operation to check
+ group: Group of operations
+
+ Returns:
+ True if the operation can be fused with the group
+ """
+ # Basic dependency check
+ for op in group:
+ if self._has_dependency(node, op) or self._has_dependency(op, node):
+ return False
+
+ # Use smart complex filter fusion for better performance
+ if not self._smart_complex_filter_fusion(node, group):
+ return False
+
+ # Use analyzer insights for advanced decisions
+ if self.analyzer_insights:
+ return self._analyzer_based_fusion_decision(node, group)
+
+ return True
+
+ def _analyzer_based_fusion_decision(self, node: OpNode, group: List[OpNode]) -> bool:
+ """Make fusion decisions based on analyzer insights.
+
+ Args:
+ node: Operation to check
+ group: Group of operations
+
+ Returns:
+ True if fusion is recommended based on data characteristics
+ """
+ # Get dataset characteristics from analyzer
+ dataset_size = self.analyzer_insights.get("dataset_size", 0)
+ text_length_stats = self.analyzer_insights.get("text_length", {})
+ content_ratios = self.analyzer_insights.get("content_ratios", {})
+
+ # Decision 1: Large datasets benefit more from fusion
+ if dataset_size > 100000:
+ logger.debug(f"Large dataset ({dataset_size:,} samples) - favoring fusion")
+ return True
+
+ # Decision 2: High variance in text length suggests complex processing
+ if text_length_stats:
+ mean_length = text_length_stats.get("mean", 0)
+ std_length = text_length_stats.get("std", 0)
+ if mean_length > 0 and std_length / mean_length > 1.5:
+ logger.debug("High text length variance - favoring fusion for complex data")
+ return True
+
+ # Decision 3: Mixed content types suggest complex processing
+ multimodal_indicators = ["image_ratio", "audio_ratio", "video_ratio"]
+ multimodal_count = sum(1 for indicator in multimodal_indicators if content_ratios.get(indicator, 0) > 0.1)
+
+ if multimodal_count > 1:
+ logger.debug(f"Multimodal content detected ({multimodal_count} types) - favoring fusion")
+ return True
+
+ # Decision 4: Check if operations are computationally similar
+ return self._check_computational_similarity(node, group)
+
+ def _check_computational_similarity(self, node: OpNode, group: List[OpNode]) -> bool:
+ """Check if operations have similar computational characteristics.
+
+ Args:
+ node: Operation to check
+ group: Group of operations
+
+ Returns:
+ True if operations are computationally similar
+ """
+ node_complexity = self._get_operation_complexity(node.name)
+ group_complexities = [self._get_operation_complexity(op.name) for op in group]
+
+ # Prefer grouping operations of similar complexity
+ if node_complexity in group_complexities:
+ return True
+
+ # Allow mixing simple and medium operations
+ if node_complexity == "simple" and all(c in ["simple", "medium"] for c in group_complexities):
+ return True
+ if node_complexity == "medium" and all(c in ["simple", "medium"] for c in group_complexities):
+ return True
+
+ return False
+
+ def _get_operation_complexity(self, op_name: str) -> str:
+ """Get the computational complexity of an operation using dynamic analysis.
+
+ Args:
+ op_name: Name of the operation
+
+ Returns:
+ Complexity level: 'simple', 'medium', or 'complex'
+ """
+ # First, try to get complexity from operation metadata if available
+ complexity = self._get_complexity_from_metadata(op_name)
+ if complexity:
+ return complexity
+
+ # Fallback to pattern-based analysis
+ return self._analyze_complexity_by_pattern(op_name)
+
+ def _get_complexity_from_metadata(self, op_name: str) -> Optional[str]:
+ """Get complexity from operation metadata or runtime analysis."""
+ # Try to load the actual operation to analyze its characteristics
+ try:
+ from data_juicer.ops import load_ops
+
+ # Create a minimal config to load the operation
+ op_config = {op_name: {}}
+ loaded_ops = load_ops([op_config])
+
+ if loaded_ops:
+ op = loaded_ops[0]
+ return self._analyze_operation_complexity(op)
+ except Exception:
+ pass
+
+ return None
+
+ def _analyze_operation_complexity(self, op) -> str:
+ """Analyze operation complexity based on its actual characteristics."""
+ complexity_indicators = {"simple": 0, "medium": 0, "complex": 0}
+
+ # Check for model dependencies (indicates complexity)
+ if hasattr(op, "config") and op.config:
+ config = op.config
+ if any(key in config for key in ["model_key", "sp_model_key", "kl_model_key"]):
+ complexity_indicators["complex"] += 2
+ if "lang" in config:
+ complexity_indicators["medium"] += 1
+
+ # Check for external dependencies
+ if hasattr(op, "_name"):
+ op_name = op._name.lower()
+
+ # Language model dependencies
+ if any(keyword in op_name for keyword in ["perplexity", "language", "spacy", "nlp"]):
+ complexity_indicators["complex"] += 2
+
+ # Statistical analysis dependencies
+ if any(keyword in op_name for keyword in ["repetition", "ratio", "statistics"]):
+ complexity_indicators["medium"] += 1
+
+ # Simple text processing
+ if any(keyword in op_name for keyword in ["length", "words", "characters"]):
+ complexity_indicators["simple"] += 1
+
+ # Check method complexity
+ if hasattr(op, "compute_stats_batched"):
+ # Analyze the method signature and docstring for complexity hints
+ method = op.compute_stats_batched
+ if hasattr(method, "__doc__") and method.__doc__:
+ doc = method.__doc__.lower()
+ if any(keyword in doc for keyword in ["model", "spacy", "nlp", "language"]):
+ complexity_indicators["complex"] += 1
+ elif any(keyword in doc for keyword in ["statistics", "ratio", "analysis"]):
+ complexity_indicators["medium"] += 1
+
+ # Determine final complexity
+ max_complexity = max(complexity_indicators.items(), key=lambda x: x[1])
+ if max_complexity[1] == 0:
+ return "medium" # Default
+ return max_complexity[0]
+
+ def _analyze_complexity_by_pattern(self, op_name: str) -> str:
+ """Analyze complexity based on operation name patterns."""
+ op_name_lower = op_name.lower()
+
+ # Simple operations (basic text processing)
+ simple_patterns = [
+ "text_length",
+ "words_num",
+ "character_repetition",
+ "average_line_length",
+ "maximum_line_length",
+ ]
+
+ # Medium complexity operations (statistical analysis)
+ medium_patterns = ["word_repetition", "special_characters", "alphanumeric", "stopwords", "flagged_words"]
+
+ # Complex operations (language models, NLP)
+ complex_patterns = [
+ "perplexity",
+ "language_id",
+ "text_entity",
+ "text_action",
+ "spacy",
+ "nlp",
+ "dependency",
+ "pos_tag",
+ ]
+
+ # Check patterns
+ for pattern in simple_patterns:
+ if pattern in op_name_lower:
+ return "simple"
+
+ for pattern in complex_patterns:
+ if pattern in op_name_lower:
+ return "complex"
+
+ for pattern in medium_patterns:
+ if pattern in op_name_lower:
+ return "medium"
+
+ # Default to medium if no patterns match
+ return "medium"
+
+ def _get_adaptive_complexity(self, op_name: str, performance_data: Optional[Dict] = None) -> str:
+ """Get adaptive complexity based on performance data if available."""
+ if performance_data and op_name in performance_data:
+ # Use performance data to adjust complexity
+ avg_time = performance_data[op_name].get("avg_time", 0)
+
+ if avg_time < 0.001: # Very fast
+ return "simple"
+ elif avg_time < 0.01: # Fast
+ return "medium"
+ else: # Slow
+ return "complex"
+
+ # Fall back to static analysis
+ return self._get_operation_complexity(op_name)
+
+ def _has_dependency(self, op1: OpNode, op2: OpNode) -> bool:
+ """Check if op1 depends on op2.
+
+ Args:
+ op1: First operation
+ op2: Second operation
+
+ Returns:
+ True if op1 depends on op2
+ """
+ # Get operation configurations
+ config1 = op1.config or {}
+ config2 = op2.config or {}
+
+ # 1. Check intermediate variables
+ op1_vars = set(config1.get("inter_vars", []))
+ op2_vars = set(config2.get("inter_vars", []))
+ if op1_vars & op2_vars:
+ return True
+
+ # 2. Check stats dependencies
+ if self._check_stats_dependencies(op1, op2):
+ return True
+
+ # 3. Check model dependencies
+ if self._check_model_dependencies(op1, op2):
+ return True
+
+ # 4. Check data field dependencies
+ if self._check_field_dependencies(op1, op2):
+ return True
+
+ # 5. Check operation-specific dependencies
+ if self._check_operation_specific_dependencies(op1, op2):
+ return True
+
+ return False
+
+ def _check_stats_dependencies(self, op1: OpNode, op2: OpNode) -> bool:
+ """Check if operations depend on the same stats."""
+ # Get stats keys that each operation produces/consumes
+ op1_stats = self._get_stats_keys(op1)
+ op2_stats = self._get_stats_keys(op2)
+
+ # If they share any stats keys, they have a dependency
+ return bool(op1_stats & op2_stats)
+
+ def _get_stats_keys(self, op: OpNode) -> set:
+ """Get stats keys that an operation produces or consumes."""
+ # Map operation names to their stats keys
+ stats_mapping = {
+ "words_num_filter": {StatsKeys.num_words},
+ "text_length_filter": {StatsKeys.text_len},
+ "character_repetition_filter": {StatsKeys.char_rep_ratio},
+ "word_repetition_filter": {StatsKeys.word_rep_ratio},
+ "average_line_length_filter": {StatsKeys.avg_line_length},
+ "maximum_line_length_filter": {StatsKeys.max_line_length},
+ "alphanumeric_filter": {StatsKeys.alnum_ratio, StatsKeys.alpha_token_ratio},
+ "special_characters_filter": {StatsKeys.special_char_ratio},
+ "perplexity_filter": {StatsKeys.perplexity},
+ "stopwords_filter": {StatsKeys.stopwords_ratio},
+ "flagged_words_filter": {StatsKeys.flagged_words_ratio},
+ "text_entity_dependency_filter": {StatsKeys.num_dependency_edges},
+ "general_field_filter": {StatsKeys.general_field_filter_condition},
+ }
+
+ return stats_mapping.get(op.name, set())
+
+ def _check_model_dependencies(self, op1: OpNode, op2: OpNode) -> bool:
+ """Check if operations use the same models."""
+ config1 = op1.config or {}
+ config2 = op2.config or {}
+
+ # Get model keys
+ op1_models = set()
+ op2_models = set()
+
+ # Check for model keys in config
+ for key in ["model_key", "sp_model_key", "kl_model_key"]:
+ if key in config1:
+ op1_models.add(config1[key])
+ if key in config2:
+ op2_models.add(config2[key])
+
+ # If they share any models, they have a dependency
+ return bool(op1_models & op2_models)
+
+ def _check_field_dependencies(self, op1: OpNode, op2: OpNode) -> bool:
+ """Check if operations process the same data fields."""
+ config1 = op1.config or {}
+ config2 = op2.config or {}
+
+ # Get field keys
+ op1_fields = set()
+ op2_fields = set()
+
+ # Check for field keys in config
+ for key in ["text_key", "image_key", "audio_key", "video_key"]:
+ if key in config1:
+ op1_fields.add(config1[key])
+ if key in config2:
+ op2_fields.add(config2[key])
+
+ # If they share any fields, they might have a dependency
+ # (This is a conservative check - some operations can share fields safely)
+ shared_fields = op1_fields & op2_fields
+
+ # Only consider it a dependency if both operations are text processors
+ # and they share text_key (indicating they process the same text)
+ if shared_fields and "text_key" in shared_fields:
+ return True
+
+ return False
+
+ def _check_operation_specific_dependencies(self, op1: OpNode, op2: OpNode) -> bool:
+ """Check operation-specific dependencies."""
+ # Some operations have specific dependencies that can't be generalized
+
+ # Example: Operations that modify the same data structure
+ # This is a placeholder for future operation-specific checks
+ return False
+
+ def _smart_complex_filter_fusion(self, node: OpNode, group: List[OpNode]) -> bool:
+ """Smart fusion decision for complex filters that may cause slowdown.
+
+ Args:
+ node: Operation to check
+ group: Group of operations
+
+ Returns:
+ True if fusion is recommended, False if it would cause slowdown
+ """
+ # Get complexity of current operation and group
+ node_complexity = self._get_operation_complexity(node.name)
+ group_complexities = [self._get_operation_complexity(op.name) for op in group]
+
+ # Rule 1: Never fuse complex operations together (causes slowdown)
+ if node_complexity == "complex" and any(c == "complex" for c in group_complexities):
+ logger.debug(f"Rejecting fusion: complex operation {node.name} with complex group")
+ return False
+
+ # Rule 2: Limit group size for complex operations (max 2 complex filters per group)
+ complex_count_in_group = sum(1 for c in group_complexities if c == "complex")
+ if node_complexity == "complex" and complex_count_in_group >= 2:
+ logger.debug(
+ f"Rejecting fusion: complex operation {node.name} would exceed max 2 complex filters per group"
+ )
+ return False
+
+ # Rule 3: Check for model conflicts
+ if self._has_model_conflicts(node, group):
+ logger.debug(f"Rejecting fusion: model conflicts detected for {node.name}")
+ return False
+
+ # Rule 4: Check memory requirements
+ if self._would_exceed_memory_limit(node, group):
+ logger.debug(f"Rejecting fusion: would exceed memory limit for {node.name}")
+ return False
+
+ # Rule 5: Check for sequential dependencies
+ if self._has_sequential_dependencies(node, group):
+ logger.debug(f"Rejecting fusion: sequential dependencies for {node.name}")
+ return False
+
+ return True
+
+ def _has_model_conflicts(self, node: OpNode, group: List[OpNode]) -> bool:
+ """Check if operations have conflicting model requirements."""
+ # Get model requirements for current operation
+ node_models = self._get_model_requirements(node)
+
+ # Check against group models
+ for op in group:
+ group_models = self._get_model_requirements(op)
+ # If both operations require different models of the same type
+ for model_type in node_models:
+ if model_type in group_models and node_models[model_type] != group_models[model_type]:
+ return True
+ return False
+
+ def _get_model_requirements(self, node: OpNode) -> Dict[str, str]:
+ """Get model requirements for an operation."""
+ models = {}
+ config = node.config or {}
+
+ # Check for common model keys
+ for key in ["model_key", "sp_model_key", "kl_model_key"]:
+ if key in config:
+ models[key] = config[key]
+
+ # Check for language-specific models
+ if "lang" in config:
+ models["lang"] = config["lang"]
+
+ return models
+
+ def _would_exceed_memory_limit(self, node: OpNode, group: List[OpNode]) -> bool:
+ """Check if fusion would exceed memory limits."""
+ # Estimate memory usage for current operation
+ node_memory = self._estimate_operation_memory(node)
+
+ # Estimate memory for group
+ group_memory = sum(self._estimate_operation_memory(op) for op in group)
+
+ # Total estimated memory
+ total_memory = node_memory + group_memory
+
+ # Conservative memory limit (2GB)
+ memory_limit = 2 * 1024 * 1024 * 1024 # 2GB in bytes
+
+ return total_memory > memory_limit
+
+ def _estimate_operation_memory(self, node: OpNode) -> int:
+ """Estimate memory usage for an operation in bytes."""
+ complexity = self._get_operation_complexity(node.name)
+
+ # Rough memory estimates based on complexity
+ if complexity == "simple":
+ return 50 * 1024 * 1024 # 50MB
+ elif complexity == "medium":
+ return 200 * 1024 * 1024 # 200MB
+ else: # complex
+ return 500 * 1024 * 1024 # 500MB
+
+ def _has_sequential_dependencies(self, node: OpNode, group: List[OpNode]) -> bool:
+ """Check if operations must be executed sequentially."""
+ # Check for data flow dependencies
+ for op in group:
+ if self._has_dependency(node, op) or self._has_dependency(op, node):
+ return True
+
+ # Check for operation-specific sequential requirements
+ node_name = node.name.lower()
+ group_names = [op.name.lower() for op in group]
+
+ # Some operations must be sequential
+ sequential_patterns = [
+ ("perplexity", "language_id"), # Language detection before perplexity
+ ("text_entity", "text_action"), # Entity detection before action analysis
+ ]
+
+ for pattern1, pattern2 in sequential_patterns:
+ if pattern1 in node_name and any(pattern2 in name for name in group_names):
+ return True
+ if pattern2 in node_name and any(pattern1 in name for name in group_names):
+ return True
+
+ return False
diff --git a/data_juicer/core/optimizer/fused_op.py b/data_juicer/core/optimizer/fused_op.py
new file mode 100644
index 0000000000..0e1d1f5c87
--- /dev/null
+++ b/data_juicer/core/optimizer/fused_op.py
@@ -0,0 +1,714 @@
+import concurrent.futures
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+from loguru import logger
+
+from data_juicer.ops import load_ops
+from data_juicer.ops.base_op import OPERATORS, Filter, Mapper
+from data_juicer.utils.constant import Fields
+
+
+@OPERATORS.register_module("fused_filter")
+class FusedFilter(Filter):
+ """A fused operator for filters that can execute multiple filters in one pass."""
+
+ _batched_op = True
+
+ def __init__(
+ self, name: str, fused_filters: List[Filter], analyzer_insights: Optional[Dict[str, Any]] = None, **kwargs
+ ):
+ """Initialize the fused filter.
+
+ Args:
+ name: Name of the fused filter
+ fused_filters: List of filters to fuse
+ analyzer_insights: Optional dataset analysis insights for optimization
+ **kwargs: Extra config arguments (e.g., accelerator, batch_size, etc.)
+ """
+ super().__init__()
+ self._name = name
+ self.fused_filters = fused_filters
+ self.analyzer_insights = analyzer_insights or {}
+
+ # Store extra config arguments as attributes
+ self.accelerator = kwargs.get("accelerator", "cpu")
+ self.batch_size = kwargs.get("batch_size", None)
+ self.cpu_required = kwargs.get("cpu_required", 1) # Default to 1 CPU
+ self.mem_required = kwargs.get("mem_required", 1) # Default to 1 GB
+ self.num_proc = kwargs.get("num_proc", None)
+ self.skip_op_error = kwargs.get("skip_op_error", False)
+ self.turbo = kwargs.get("turbo", False)
+ self.text_key = kwargs.get("text_key", None)
+ self.image_key = kwargs.get("image_key", None)
+ self.audio_key = kwargs.get("audio_key", None)
+ self.video_key = kwargs.get("video_key", None)
+ self.history_key = kwargs.get("history_key", None)
+ self.query_key = kwargs.get("query_key", None)
+ self.response_key = kwargs.get("response_key", None)
+ self.execution_strategy = kwargs.get("execution_strategy", None)
+ self.has_dependencies = kwargs.get("has_dependencies", None)
+
+ # Add recursion prevention flag
+ self._in_performance_test = False
+
+ # Set accelerator based on available methods (if not set by kwargs)
+ if self.accelerator is None:
+ if any(hasattr(op, "accelerator") and op.accelerator == "cuda" for op in self.fused_filters):
+ self.accelerator = "cuda"
+ else:
+ self.accelerator = "cpu"
+
+ # Update num_proc with the minimum of all fused filters if not set by kwargs
+ if self.num_proc is None:
+ self.num_proc = min([op.runtime_np() for op in self.fused_filters])
+
+ # Store original operation configs (create simple config if not available)
+ self._op_cfg = {}
+ for op in self.fused_filters:
+ op_name = getattr(op, "_name", None)
+ op_config = getattr(op, "config", None)
+ if op_name is not None and op_config:
+ self._op_cfg[op_name] = op_config
+ elif op_name is not None:
+ # Create a simple config for filters without explicit config
+ self._op_cfg[op_name] = {"inter_vars": [], "dependencies": []}
+
+ # Analyze dependencies and determine execution strategy
+ self._analyze_dependencies()
+ self._determine_execution_strategy()
+
+ # Analyze filter dependencies
+ self._analyze_dependencies()
+
+ # Pre-allocate result arrays
+ self._result_cache = {}
+
+ # Log the chosen strategy
+ logger.info(f"FusedFilter '{name}' using {self.execution_strategy} execution strategy")
+ if self.has_dependencies:
+ logger.info(" Reason: Filters have dependencies")
+ else:
+ simple_count = sum(
+ 1
+ for op in self.fused_filters
+ if getattr(op, "_name", None)
+ in {
+ "text_length_filter",
+ "words_num_filter",
+ "character_repetition_filter",
+ "word_repetition_filter",
+ "special_characters_filter",
+ "alphanumeric_filter",
+ "average_line_length_filter",
+ "maximum_line_length_filter",
+ }
+ )
+ complex_count = len(self.fused_filters) - simple_count
+ logger.info(f" Reason: {simple_count} simple filters, {complex_count} complex filters")
+
+ # Log analyzer-based insights if available
+ if self.analyzer_insights:
+ self._log_analyzer_insights()
+
+ def _log_analyzer_insights(self):
+ """Log insights from analyzer that influenced strategy decisions."""
+ dataset_size = self.analyzer_insights.get("dataset_size", 0)
+ text_length_stats = self.analyzer_insights.get("text_length", {})
+ content_ratios = self.analyzer_insights.get("content_ratios", {})
+
+ logger.info(" Analyzer Insights:")
+ if dataset_size > 0:
+ logger.info(f" Dataset size: {dataset_size:,} samples")
+
+ if text_length_stats:
+ mean_len = text_length_stats.get("mean", 0)
+ std_len = text_length_stats.get("std", 0)
+ if mean_len > 0:
+ cv = std_len / mean_len
+ logger.info(f" Text length CV: {cv:.2f} (mean: {mean_len:.1f}, std: {std_len:.1f})")
+
+ multimodal_count = sum(
+ 1 for indicator in ["image_ratio", "audio_ratio", "video_ratio"] if content_ratios.get(indicator, 0) > 0.1
+ )
+ if multimodal_count > 0:
+ logger.info(f" Multimodal content: {multimodal_count} types detected")
+
+ def _analyze_dependencies(self):
+ """Analyze dependencies between filters to optimize execution order."""
+ # Create dependency graph
+ self.dependency_graph = {}
+ self.independent_groups = []
+ self.has_dependencies = False
+
+ for i, op1 in enumerate(self.fused_filters):
+ self.dependency_graph[op1] = set()
+ for j, op2 in enumerate(self.fused_filters):
+ if i != j:
+ # Check if op2 depends on op1's output
+ if self._has_dependency(op1, op2):
+ self.dependency_graph[op1].add(op2)
+ self.has_dependencies = True
+
+ # Find independent groups
+ visited = set()
+ for op in self.fused_filters:
+ if op not in visited:
+ group = self._get_independent_group(op, visited)
+ if group:
+ self.independent_groups.append(group)
+
+ # Determine execution strategy
+ self.execution_strategy = self._determine_execution_strategy()
+
+ def _has_dependency(self, op1: Filter, op2: Filter) -> bool:
+ """Check if op2 depends on op1's output."""
+ # Get intermediate variables used by each operation from stored configs
+ op1_vars = set(self._op_cfg.get(getattr(op1, "_name", ""), {}).get("inter_vars", []))
+ op2_vars = set(self._op_cfg.get(getattr(op2, "_name", ""), {}).get("inter_vars", []))
+
+ # Check if op2 uses any variables produced by op1
+ return bool(op1_vars & op2_vars)
+
+ def _get_independent_group(self, start_op: Filter, visited: set) -> List[Filter]:
+ """Get a group of independent operations starting from start_op."""
+ group = []
+ to_visit = {start_op}
+
+ while to_visit:
+ op = to_visit.pop()
+ if op not in visited:
+ visited.add(op)
+ group.append(op)
+ # Add independent operations to visit
+ for other_op in self.fused_filters:
+ if (
+ other_op not in visited
+ and other_op not in self.dependency_graph[op]
+ and op not in self.dependency_graph[other_op]
+ ):
+ to_visit.add(other_op)
+
+ return group
+
+ def _determine_execution_strategy(self):
+ """Determine the best execution strategy based on filter characteristics and analyzer insights."""
+ if self.has_dependencies:
+ return "sequential"
+
+ # Use analyzer insights for better decisions
+ if self.analyzer_insights:
+ return self._analyzer_based_strategy_selection()
+
+ # Fallback to original logic
+ return self._fallback_strategy_selection()
+
+ def _analyzer_based_strategy_selection(self) -> str:
+ """Select execution strategy based on analyzer insights."""
+ dataset_size = self.analyzer_insights.get("dataset_size", 0)
+ text_length_stats = self.analyzer_insights.get("text_length", {})
+ content_ratios = self.analyzer_insights.get("content_ratios", {})
+
+ # Factor 1: Dataset size
+ if dataset_size > 500000: # Large datasets benefit from parallel
+ logger.debug("Large dataset detected - favoring parallel execution")
+ return "parallel"
+
+ # Factor 2: Text complexity
+ if text_length_stats:
+ mean_length = text_length_stats.get("mean", 0)
+ std_length = text_length_stats.get("std", 0)
+ if mean_length > 0 and std_length / mean_length > 2.0:
+ logger.debug("High text length variance - using sequential for complex data")
+ return "sequential"
+
+ # Factor 3: Multimodal content
+ multimodal_indicators = ["image_ratio", "audio_ratio", "video_ratio"]
+ multimodal_count = sum(1 for indicator in multimodal_indicators if content_ratios.get(indicator, 0) > 0.1)
+
+ if multimodal_count > 1:
+ logger.debug(f"Multimodal content ({multimodal_count} types) - using sequential")
+ return "sequential"
+
+ # Factor 4: Filter complexity distribution
+ return self._complexity_based_strategy()
+
+ def _complexity_based_strategy(self) -> str:
+ """Select strategy based on filter complexity distribution."""
+ simple_filters = 0
+ complex_filters = 0
+
+ for op in self.fused_filters:
+ simple_filter_names = {
+ "text_length_filter",
+ "words_num_filter",
+ "character_repetition_filter",
+ "word_repetition_filter",
+ "special_characters_filter",
+ "alphanumeric_filter",
+ "average_line_length_filter",
+ "maximum_line_length_filter",
+ }
+
+ if getattr(op, "_name", "") in simple_filter_names:
+ simple_filters += 1
+ else:
+ complex_filters += 1
+
+ # Use parallel if mostly simple filters, sequential if complex filters
+ if complex_filters > simple_filters:
+ return "sequential"
+ else:
+ return "parallel"
+
+ def _fallback_strategy_selection(self) -> str:
+ """Fallback strategy selection using original logic."""
+ # Check if filters are simple enough for parallel execution
+ simple_filters = 0
+ complex_filters = 0
+
+ for op in self.fused_filters:
+ # Simple filters: text_length, words_num, character_repetition
+ # Complex filters: perplexity, stopwords, flagged_words
+ simple_filter_names = {
+ "text_length_filter",
+ "words_num_filter",
+ "character_repetition_filter",
+ "word_repetition_filter",
+ "special_characters_filter",
+ "alphanumeric_filter",
+ "average_line_length_filter",
+ "maximum_line_length_filter",
+ }
+
+ if getattr(op, "_name", "") in simple_filter_names:
+ simple_filters += 1
+ else:
+ complex_filters += 1
+
+ # Use parallel if mostly simple filters, sequential if complex filters
+ if complex_filters > simple_filters:
+ return "sequential"
+ else:
+ return "parallel"
+
+ def _should_skip_fusion(self, sample_size: int = 1000) -> bool:
+ """Determine if fusion should be skipped based on performance analysis and analyzer insights.
+
+ Args:
+ sample_size: Number of samples being processed
+
+ Returns:
+ True if fusion should be skipped, False if fusion is beneficial
+ """
+ # Prevent recursion during performance testing
+ if self._in_performance_test:
+ return False
+
+ # Use analyzer insights for better decisions
+ if self.analyzer_insights:
+ return self._analyzer_based_fusion_decision(sample_size)
+
+ # Fallback to original logic
+ return self._fallback_fusion_decision(sample_size)
+
+ def _analyzer_based_fusion_decision(self, sample_size: int) -> bool:
+ """Make fusion decisions based on analyzer insights."""
+ dataset_size = self.analyzer_insights.get("dataset_size", 0)
+ text_length_stats = self.analyzer_insights.get("text_length", {})
+
+ # Decision 1: Always use fusion for large datasets
+ if dataset_size > 100000:
+ logger.debug(f"Large dataset ({dataset_size:,} samples) - always use fusion")
+ return False
+
+ # Decision 2: Skip fusion for very small datasets with simple filters
+ if sample_size < 100 and len(self.fused_filters) <= 2:
+ simple_count = sum(
+ 1
+ for op in self.fused_filters
+ if getattr(op, "_name", "")
+ in {"text_length_filter", "words_num_filter", "character_repetition_filter"}
+ )
+ if simple_count == len(self.fused_filters):
+ logger.debug("Small dataset with simple filters - skipping fusion")
+ return True
+
+ # Decision 3: Use fusion for complex data characteristics
+ if text_length_stats:
+ mean_length = text_length_stats.get("mean", 0)
+ std_length = text_length_stats.get("std", 0)
+ if mean_length > 0 and std_length / mean_length > 1.5:
+ logger.debug("Complex text characteristics - using fusion")
+ return False
+
+ # Decision 4: Run performance test for edge cases
+ return self._quick_performance_test(min(100, sample_size))
+
+ def _fallback_fusion_decision(self, sample_size: int) -> bool:
+ """Fallback fusion decision using original logic."""
+ # Skip performance test for very large datasets (fusion is always beneficial)
+ if sample_size > 10000:
+ return False
+
+ # Skip performance test for complex filters (always use fusion)
+ complex_filter_names = {
+ "perplexity_filter",
+ "stopwords_filter",
+ "flagged_words_filter",
+ "language_id_score_filter",
+ "word_repetition_filter",
+ }
+ complex_count = sum(1 for op in self.fused_filters if getattr(op, "_name", "") in complex_filter_names)
+
+ # Always use fusion for complex filters
+ if complex_count > 0:
+ return False
+
+ # Skip fusion for single filters
+ if len(self.fused_filters) == 1:
+ return True
+
+ # For simple filters, run a quick performance test
+ try:
+ return self._quick_performance_test(min(100, sample_size))
+ except Exception as e:
+ logger.warning(f"Performance test failed: {e}. Defaulting to fusion.")
+ return False
+
+ def _quick_performance_test(self, sample_size: int) -> bool:
+ """Run a quick performance test to determine if fusion is beneficial.
+
+ Args:
+ sample_size: Number of samples to test (small sample for speed)
+
+ Returns:
+ True if fusion should be skipped, False if fusion is beneficial
+ """
+ import random
+ import string
+ import time
+
+ from data_juicer.utils.constant import Fields
+
+ # Set recursion prevention flag
+ self._in_performance_test = True
+
+ try:
+ # Create minimal test data
+ test_data = {
+ "text": ["".join(random.choices(string.ascii_letters + " ", k=50)) for _ in range(sample_size)],
+ Fields.stats: [{} for _ in range(sample_size)],
+ }
+
+ # Measure individual execution time
+ individual_start = time.time()
+ for op in self.fused_filters:
+ op.compute_stats_batched(test_data.copy())
+ op.process_batched(test_data.copy())
+ individual_time = time.time() - individual_start
+
+ # Measure fused execution time
+ fused_start = time.time()
+ self.compute_stats_batched(test_data.copy())
+ self.process_batched(test_data.copy())
+ fused_time = time.time() - fused_start
+
+ # Calculate overhead ratio
+ overhead_ratio = fused_time / individual_time if individual_time > 0 else float("inf")
+
+ # Decision logic for simple filters only
+ if individual_time < 0.001:
+ # Very fast filters - overhead not worth it
+ should_skip = True
+ elif overhead_ratio > 3.0 and individual_time < 0.01:
+ # Simple filters with high overhead - skip fusion
+ should_skip = True
+ else:
+ # Default to fusion for most cases
+ should_skip = False
+
+ logger.info(
+ f"Performance test: individual={individual_time:.3f}s, "
+ f"fused={fused_time:.3f}s, ratio={overhead_ratio:.2f}x, "
+ f"skip={should_skip}"
+ )
+
+ return should_skip
+
+ finally:
+ # Always clear the recursion prevention flag
+ self._in_performance_test = False
+
+ def compute_stats_batched(self, samples, rank=None):
+ """Compute statistics for all fused filters using the best strategy."""
+ import av
+
+ # Check if we should skip fusion based on performance analysis
+ if self._should_skip_fusion(len(samples[Fields.stats])):
+ from loguru import logger
+
+ logger.debug(f"Skipping fusion for {self._name} - executing filters individually")
+
+ # Execute filters individually (no fusion)
+ for op in self.fused_filters:
+ if op.accelerator == "cuda":
+ samples = op.compute_stats_batched(samples, rank=rank)
+ else:
+ samples = op.compute_stats_batched(samples)
+ return samples
+
+ # Initialize context for intermediate variables
+ num_samples = len(samples[Fields.stats])
+ samples[Fields.context] = [{} for _ in range(num_samples)]
+
+ if self.execution_strategy == "parallel":
+ # Parallel execution for independent filters
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_proc) as executor:
+ futures = []
+ for group in self.independent_groups:
+ for op in group:
+ if op.accelerator == "cuda":
+ futures.append(executor.submit(op.compute_stats_batched, samples, rank=rank, context=True))
+ else:
+ futures.append(executor.submit(op.compute_stats_batched, samples, context=True))
+
+ # Wait for all operations to complete
+ concurrent.futures.wait(futures)
+ else:
+ # Sequential execution for dependent or complex filters
+ for op in self.fused_filters:
+ if op.accelerator == "cuda":
+ samples = op.compute_stats_batched(samples, rank=rank)
+ else:
+ samples = op.compute_stats_batched(samples)
+
+ # Clean up contexts
+ for ctx in samples[Fields.context]:
+ for context_key in ctx:
+ if isinstance(ctx[context_key], av.container.InputContainer):
+ ctx[context_key].streams.video[0].close()
+ ctx[context_key].close()
+
+ # Remove context from samples
+ _ = samples.pop(Fields.context)
+ return samples
+
+ def process_batched(self, samples):
+ """Process samples through all fused filters using the best strategy."""
+ # Check if we should skip fusion based on performance analysis
+ if self._should_skip_fusion(len(samples[Fields.stats])):
+ from loguru import logger
+
+ logger.debug(f"Skipping fusion for {self._name} - processing filters individually")
+
+ # Process filters individually (no fusion)
+ result = None
+ for op in self.fused_filters:
+ filter_result = list(op.process_batched(samples))
+
+ if result is None:
+ result = filter_result
+ else:
+ # Combine with logical AND (sample must pass all filters)
+ result = [r1 and r2 for r1, r2 in zip(result, filter_result)]
+
+ return result
+
+ if self.execution_strategy == "parallel":
+ # Parallel execution - all filters see original data
+ return self._process_batched_parallel(samples)
+ else:
+ # Sequential execution - each filter sees previous filter's output
+ return self._process_batched_sequential(samples)
+
+ def _process_batched_parallel(self, samples):
+ """Process filters in parallel (all see original data)."""
+ # Initialize result array
+ res = None
+
+ # Process independent groups in parallel
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_proc) as executor:
+ futures = []
+ for group in self.independent_groups:
+ group_futures = []
+ for op in group:
+ future = executor.submit(op.process_batched, samples)
+ group_futures.append(future)
+ futures.append(group_futures)
+
+ # Process results in dependency order
+ for group_futures in futures:
+ group_results = []
+ for future in group_futures:
+ this_res = np.array(list(future.result()))
+ group_results.append(this_res)
+
+ # Combine results within group
+ group_res = group_results[0]
+ for this_res in group_results[1:]:
+ group_res = np.logical_and(group_res, this_res)
+
+ # Combine with overall results
+ if res is not None:
+ res = np.logical_and(res, group_res)
+ else:
+ res = group_res
+
+ return res
+
+ def _process_batched_sequential(self, samples):
+ """Process filters sequentially (each sees previous output)."""
+ # Process filters sequentially to match individual execution behavior
+ result = None
+
+ for op in self.fused_filters:
+ filter_result = list(op.process_batched(samples))
+
+ if result is None:
+ result = filter_result
+ else:
+ # Combine with logical AND (sample must pass all filters)
+ result = [r1 and r2 for r1, r2 in zip(result, filter_result)]
+
+ return result
+
+ def run(self, dataset, *, exporter=None, tracer=None, reduce=True):
+ """Run the fused filter on a dataset.
+
+ Args:
+ dataset: Dataset to process
+ exporter: Optional exporter for results
+ tracer: Optional tracer for monitoring
+ reduce: Whether to apply filtering (True) or just compute stats (False)
+
+ Returns:
+ Processed dataset
+ """
+ # Prepare the dataset
+ from data_juicer.core.data import NestedDataset
+
+ if not isinstance(dataset, NestedDataset):
+ dataset = NestedDataset(dataset)
+
+ # Initialize each filter
+ for op in self.fused_filters:
+ dataset = Filter.run(op, dataset)
+
+ # Compute stats for all filters
+ new_dataset = dataset.map(
+ self.compute_stats,
+ num_proc=self.runtime_np(),
+ with_rank=self.use_cuda(),
+ batch_size=self.batch_size,
+ desc=self._name + "_compute_stats",
+ )
+
+ # Export stats if requested
+ if exporter and self.stats_export_path is not None:
+ exporter.export_compute_stats(new_dataset, self.stats_export_path)
+
+ # Apply filtering if reduce=True
+ if reduce:
+ new_dataset = new_dataset.filter(
+ self.process, num_proc=self.runtime_np(), batch_size=self.batch_size, desc=self._name + "_process"
+ )
+ if tracer:
+ tracer.trace_filter(self._name, dataset, new_dataset)
+
+ # Free models to save memory
+ from data_juicer.utils.model_utils import free_models
+
+ free_models()
+
+ return new_dataset
+
+
+@OPERATORS.register_module("fused_mapper")
+class FusedMapper(Mapper):
+ """A fused operator for mappers that can execute multiple mappers in one pass."""
+
+ _batched_op = True
+
+ def __init__(self, name: str, fused_mappers: List[str], batch_size: int = 32):
+ """Initialize the fused mapper.
+
+ Args:
+ name: Name of the fused mapper
+ fused_mappers: List of mapper names to be fused
+ batch_size: Batch size for processing
+ """
+ self._name = name
+ super().__init__()
+ self.batch_size = batch_size
+
+ # Load the mapper operations
+ self.fused_mappers = []
+ for mapper_name in fused_mappers:
+ # Skip if this is a fused_mapper to avoid recursive instantiation
+ if mapper_name == "fused_mapper":
+ logger.warning("Skipping recursive fused_mapper in FusedMapper initialization")
+ continue
+
+ mapper_config = {mapper_name: {}}
+ mapper = load_ops([mapper_config])[0]
+ self.fused_mappers.append(mapper)
+
+ # Set accelerator to 'cuda' if any of the fused mappers use CUDA
+ accelerator_methods = set([op.accelerator for op in self.fused_mappers])
+ if "cuda" in accelerator_methods:
+ self.accelerator = "cuda"
+
+ # Update num_proc with the minimum of all fused mappers
+ self.num_proc = min([op.runtime_np() for op in self.fused_mappers])
+
+ # Store original operation configs
+ self._op_cfg = {name: [op._op_cfg for op in self.fused_mappers]}
+
+ def process_batched(self, samples, rank=None):
+ """Process samples through all fused mappers.
+
+ Args:
+ samples: Batch of samples to process
+ rank: Rank for distributed processing
+
+ Returns:
+ Processed samples
+ """
+ # Process mappers sequentially
+ for op in self.fused_mappers:
+ process_args = {"rank": rank} if op.accelerator == "cuda" else {}
+ samples = op.process_batched(samples, **process_args)
+ return samples
+
+ def run(self, dataset, *, exporter=None, tracer=None):
+ """Run the fused mapper on a dataset.
+
+ Args:
+ dataset: Dataset to process
+ exporter: Optional exporter for results
+ tracer: Optional tracer for monitoring
+
+ Returns:
+ Processed dataset
+ """
+ # Prepare the dataset
+ from data_juicer.core.data import NestedDataset
+
+ if not isinstance(dataset, NestedDataset):
+ dataset = NestedDataset(dataset)
+
+ # Initialize each mapper
+ for op in self.fused_mappers:
+ dataset = Mapper.run(op, dataset)
+
+ # Process the dataset
+ new_dataset = dataset.map(
+ self.process_batched,
+ num_proc=self.num_proc,
+ with_rank=self.use_cuda(),
+ batch_size=self.batch_size,
+ desc=self._name + "_process",
+ )
+
+ return new_dataset
diff --git a/data_juicer/core/optimizer/mapper_fusion_strategy.py b/data_juicer/core/optimizer/mapper_fusion_strategy.py
new file mode 100644
index 0000000000..684ce12f9b
--- /dev/null
+++ b/data_juicer/core/optimizer/mapper_fusion_strategy.py
@@ -0,0 +1,261 @@
+from typing import List
+
+from loguru import logger
+
+from data_juicer.core.optimizer.strategy import OptimizationStrategy, register_strategy
+from data_juicer.core.pipeline_ast import OpNode, OpType, PipelineAST
+
+
+@register_strategy("mapper_fusion")
+class MapperFusionStrategy(OptimizationStrategy):
+ """Strategy for fusing mapper operations in the pipeline."""
+
+ def __init__(self):
+ """Initialize the mapper fusion strategy."""
+ super().__init__(name="mapper_fusion")
+
+ def optimize(self, ast: PipelineAST) -> PipelineAST:
+ """Apply mapper fusion to the pipeline AST.
+
+ Args:
+ ast: The pipeline AST to optimize
+
+ Returns:
+ Optimized pipeline AST
+ """
+ if not ast.root:
+ return ast
+
+ # Create a new AST
+ new_ast = PipelineAST()
+ new_ast.root = OpNode(name="root", op_type=OpType.ROOT, config={})
+
+ # Get all unique operation chains
+ op_chains = self._get_unique_op_chains(ast.root)
+
+ # Process each chain
+ current = new_ast.root
+ for chain in op_chains:
+ # Group mapper operations
+ mapper_groups = self._group_mappers(chain)
+
+ for group in mapper_groups:
+ if len(group) > 1:
+ # Create fused operation with clean naming
+ fused_name = "fused_mapper"
+ detailed_ops = [n.name for n in group]
+ logger.info(f"Fusing mapper operations into {fused_name}: {detailed_ops}")
+
+ # Create fused node using FusedMapper
+ fused_node = OpNode(
+ name=fused_name,
+ op_type=OpType.MAPPER,
+ config={
+ "fused_mapper": {
+ "name": fused_name,
+ "fused_mappers": detailed_ops,
+ "detailed_ops": detailed_ops, # For display purposes
+ }
+ },
+ )
+ current.add_child(fused_node)
+ current = fused_node
+ else:
+ # Keep single operations as is
+ new_node = OpNode(name=group[0].name, op_type=group[0].op_type, config=group[0].config or {})
+ current.add_child(new_node)
+ current = new_node
+
+ return new_ast
+
+ def _get_unique_op_chains(self, node: OpNode) -> List[List[OpNode]]:
+ """Get unique chains of operations from the tree.
+
+ Args:
+ node: Root node of the tree
+
+ Returns:
+ List of unique operation chains
+ """
+ chains = []
+ seen_chains = set()
+
+ def traverse(current: OpNode, chain: List[OpNode]):
+ if not current.children:
+ # End of chain, check if we've seen this sequence before
+ chain_key = tuple(n.name for n in chain)
+ if chain_key not in seen_chains:
+ chains.append(chain.copy())
+ seen_chains.add(chain_key)
+ return
+
+ for child in current.children:
+ chain.append(child)
+ traverse(child, chain)
+ chain.pop()
+
+ traverse(node, [])
+ return chains
+
+ def _group_mappers(self, chain: List[OpNode]) -> List[List[OpNode]]:
+ """Group mapper operations that can be fused together.
+
+ Args:
+ chain: List of operations in the pipeline
+
+ Returns:
+ List of mapper operation groups
+ """
+ groups = []
+ current_group = []
+
+ logger.info(f"Grouping mappers from chain: {[n.name for n in chain]}")
+
+ for node in chain:
+ if not PipelineAST.is_mapper_op(node):
+ # If we encounter a non-mapper, finalize current group
+ if current_group:
+ logger.info(f"Finalizing mapper group: {[n.name for n in current_group]}")
+ groups.append(current_group)
+ current_group = []
+ # Add the non-mapper node as a separate group
+ groups.append([node])
+ else:
+ # This is a mapper node
+ if not current_group:
+ # Start a new group
+ current_group = [node]
+ logger.info(f"Starting new mapper group with: {node.name}")
+ else:
+ # Check if current mapper can be fused with the group
+ if self._can_fuse_with_group(node, current_group):
+ current_group.append(node)
+ logger.info(f"Added {node.name} to current group: {[n.name for n in current_group]}")
+ else:
+ # Finalize current group and start a new one
+ logger.info(f"Finalizing mapper group due to dependency: {[n.name for n in current_group]}")
+ groups.append(current_group)
+ current_group = [node]
+ logger.info(f"Starting new mapper group with: {node.name}")
+
+ # Don't forget the last group
+ if current_group:
+ logger.info(f"Finalizing final mapper group: {[n.name for n in current_group]}")
+ groups.append(current_group)
+
+ logger.info(f"Final mapper groups: {[[n.name for n in group] for group in groups]}")
+ return groups
+
+ def _can_fuse_with_group(self, node: OpNode, group: List[OpNode]) -> bool:
+ """Check if a mapper can be fused with a group.
+
+ Args:
+ node: Operation to check
+ group: Group of operations
+
+ Returns:
+ True if the operation can be fused with the group
+ """
+ # Check dependencies
+ for op in group:
+ if self._has_dependency(node, op) or self._has_dependency(op, node):
+ logger.info(f"Cannot fuse {node.name} with group {[n.name for n in group]} due to dependency")
+ return False
+
+ logger.info(f"Can fuse {node.name} with group {[n.name for n in group]}")
+ return True
+
+ def _has_dependency(self, op1: OpNode, op2: OpNode) -> bool:
+ """Check if op1 depends on op2.
+
+ Args:
+ op1: First operation
+ op2: Second operation
+
+ Returns:
+ True if op1 depends on op2
+ """
+ # 1. Check intermediate variables (for mappers that produce/consume inter_vars)
+ op1_vars = set(op1.config.get("inter_vars", []))
+ op2_vars = set(op2.config.get("inter_vars", []))
+ if op1_vars & op2_vars:
+ logger.info(f"Dependency found via inter_vars: {op1.name} <-> {op2.name}")
+ return True
+
+ # 2. Check field dependencies (mappers that modify the same fields)
+ if self._check_field_dependencies(op1, op2):
+ logger.info(f"Dependency found via field dependencies: {op1.name} <-> {op2.name}")
+ return True
+
+ # 3. Check operation-specific dependencies
+ if self._check_operation_specific_dependencies(op1, op2):
+ logger.info(f"Dependency found via operation-specific dependencies: {op1.name} <-> {op2.name}")
+ return True
+
+ return False
+
+ def _check_field_dependencies(self, op1: OpNode, op2: OpNode) -> bool:
+ """Check if operations modify the same fields."""
+ # For mappers, we allow fusion even if they modify the same fields
+ # since they can be executed sequentially
+ # Only prevent fusion for specific logical dependencies
+ return False
+
+ def _get_modified_fields(self, op: OpNode) -> set:
+ """Get the fields that an operation modifies."""
+ # This is a simplified mapping - in practice, you'd want to analyze the actual operation logic
+ field_mapping = {
+ "clean_email_mapper": {"text"},
+ "clean_links_mapper": {"text"},
+ "fix_unicode_mapper": {"text"},
+ "punctuation_normalization_mapper": {"text"},
+ "whitespace_normalization_mapper": {"text"},
+ "text_lowercase_mapper": {"text"},
+ "text_uppercase_mapper": {"text"},
+ "remove_words_mapper": {"text"},
+ "remove_characters_mapper": {"text"},
+ "replace_words_mapper": {"text"},
+ "replace_characters_mapper": {"text"},
+ "split_text_mapper": {"text"},
+ "join_text_mapper": {"text"},
+ "text_length_mapper": {"text"},
+ "text_quality_mapper": {"text"},
+ }
+
+ return field_mapping.get(op.name, set())
+
+ def _check_operation_specific_dependencies(self, op1: OpNode, op2: OpNode) -> bool:
+ """Check operation-specific dependencies."""
+ # Define specific dependencies that prevent fusion
+
+ # Unicode fixing should come before punctuation normalization
+ if op1.name == "punctuation_normalization_mapper" and op2.name == "fix_unicode_mapper":
+ return True
+
+ if op1.name == "fix_unicode_mapper" and op2.name == "punctuation_normalization_mapper":
+ return True
+
+ # Email/links cleaning should come before punctuation normalization
+ if op1.name == "punctuation_normalization_mapper" and op2.name in ["clean_email_mapper", "clean_links_mapper"]:
+ return True
+
+ if op1.name in ["clean_email_mapper", "clean_links_mapper"] and op2.name == "punctuation_normalization_mapper":
+ return True
+
+ # Whitespace normalization should come after most other text operations
+ if op1.name == "whitespace_normalization_mapper" and op2.name in [
+ "clean_email_mapper",
+ "clean_links_mapper",
+ "fix_unicode_mapper",
+ "punctuation_normalization_mapper",
+ ]:
+ return True
+
+ if (
+ op1.name
+ in ["clean_email_mapper", "clean_links_mapper", "fix_unicode_mapper", "punctuation_normalization_mapper"]
+ and op2.name == "whitespace_normalization_mapper"
+ ):
+ return True
+
+ return False
diff --git a/data_juicer/core/optimizer/op_reorder_strategy.py b/data_juicer/core/optimizer/op_reorder_strategy.py
new file mode 100644
index 0000000000..a05890cbe0
--- /dev/null
+++ b/data_juicer/core/optimizer/op_reorder_strategy.py
@@ -0,0 +1,302 @@
+#!/usr/bin/env python3
+"""
+Operation Reordering Strategy for the core optimizer.
+
+This strategy analyzes dependencies between operations and reorders them
+for optimal performance, prioritizing filters over mappers when possible.
+"""
+
+from collections import defaultdict, deque
+from typing import Any, Dict, List, Set
+
+from loguru import logger
+
+from ..pipeline_ast import OpNode, OpType, PipelineAST
+from .strategy import OptimizationStrategy, register_strategy
+
+
+@register_strategy("op_reorder")
+class OpReorderStrategy(OptimizationStrategy):
+ """
+ Strategy that reorders operations based on dependency analysis and performance optimization.
+
+ Key features:
+ 1. Analyzes dependencies between operations
+ 2. Performs topological sorting
+ 3. Prioritizes filters over mappers when possible
+ 4. Optimizes for early filtering to reduce data volume
+ """
+
+ def __init__(self):
+ """Initialize the operation reordering strategy."""
+ super().__init__(name="op_reorder")
+
+ def optimize(self, ast: PipelineAST) -> PipelineAST:
+ """
+ Apply operation reordering to the pipeline AST.
+
+ Args:
+ ast: The pipeline AST to optimize
+
+ Returns:
+ Optimized pipeline AST with reordered operations
+ """
+ if not ast.root or not ast.root.children:
+ return ast
+
+ logger.info("๐ Applying operation reordering optimization...")
+
+ # Get all operations from the AST
+ operations = self._extract_operations(ast.root)
+
+ # Log original order
+ logger.info("๐ Original operation order:")
+ for i, op in enumerate(operations, 1):
+ op_type = "๐ง MAPPER" if op.op_type == OpType.MAPPER else "๐ FILTER"
+ logger.info(f" {i}. {op_type}: {op.name}")
+
+ # Analyze dependencies
+ dependencies = self._analyze_dependencies(operations)
+ logger.info(f"๐ Found {len(dependencies)} dependencies between operations")
+
+ # Log specific dependencies
+ if dependencies:
+ logger.info("๐ Operation dependencies:")
+ for op_name, deps in dependencies.items():
+ if deps:
+ logger.info(f" {op_name} depends on: {', '.join(deps)}")
+ else:
+ logger.info("๐ No dependencies found - operations can be freely reordered")
+
+ # Perform topological sort
+ optimal_order = self._topological_sort(operations, dependencies)
+
+ # Create a mapping from operation names to operation nodes
+ op_map = {op.name: op for op in operations}
+
+ # Log optimized order
+ logger.info("โก Optimized operation order:")
+ for i, op_name in enumerate(optimal_order, 1):
+ op_node = op_map.get(op_name)
+ if op_node:
+ op_type = "๐ง MAPPER" if op_node.op_type == OpType.MAPPER else "๐ FILTER"
+ logger.info(f" {i}. {op_type}: {op_name}")
+
+ # Reorder operations in the AST
+ new_ast = self._reorder_ast(ast, optimal_order)
+
+ logger.info(f"โ
Reordered {len(operations)} operations for optimal performance")
+ return new_ast
+
+ def _extract_operations(self, root: OpNode) -> List[OpNode]:
+ """Extract all operations from the AST."""
+ operations = []
+
+ def collect_ops(node: OpNode):
+ if node.op_type in [OpType.MAPPER, OpType.FILTER]:
+ operations.append(node)
+ for child in node.children:
+ collect_ops(child)
+
+ collect_ops(root)
+ return operations
+
+ def _analyze_dependencies(self, operations: List[OpNode]) -> Dict[str, Set[str]]:
+ """
+ Analyze dependencies between operations.
+
+ Args:
+ operations: List of operation nodes
+
+ Returns:
+ Dictionary mapping operation names to their dependencies
+ """
+ dependency_graph = defaultdict(set)
+
+ for i, op1 in enumerate(operations):
+ op1_name = op1.name
+ op1_vars = set(op1.config.get("inter_vars", []))
+ op1_fields = set(op1.config.get("fields", []))
+
+ for j, op2 in enumerate(operations):
+ if i == j:
+ continue
+
+ op2_name = op2.name
+ op2_vars = set(op2.config.get("inter_vars", []))
+ op2_fields = set(op2.config.get("fields", []))
+
+ # Check for variable dependencies
+ if op1_vars & op2_vars:
+ dependency_graph[op2_name].add(op1_name)
+
+ # Check for field dependencies
+ if op1_fields & op2_fields:
+ dependency_graph[op2_name].add(op1_name)
+
+ # Check for operation-specific dependencies
+ if self._has_operation_dependency(op1, op2):
+ dependency_graph[op2_name].add(op1_name)
+
+ return dict(dependency_graph)
+
+ def _has_operation_dependency(self, op1: OpNode, op2: OpNode) -> bool:
+ """Check if op2 depends on op1 based on operation types."""
+ op1_name = op1.name.lower()
+ op2_name = op2.name.lower()
+
+ # Language detection before perplexity
+ if "language_id" in op1_name and "perplexity" in op2_name:
+ return True
+ if "perplexity" in op1_name and "language_id" in op2_name:
+ return True
+
+ # Entity detection before action analysis
+ if "text_entity" in op1_name and "text_action" in op2_name:
+ return True
+ if "text_action" in op1_name and "text_entity" in op2_name:
+ return True
+
+ # Image preprocessing before analysis
+ if "image_resize" in op1_name and "image_quality" in op2_name:
+ return True
+ if "image_quality" in op1_name and "image_resize" in op2_name:
+ return True
+
+ return False
+
+ def _topological_sort(self, operations: List[OpNode], dependencies: Dict[str, Set[str]]) -> List[str]:
+ """
+ Perform topological sorting of operations.
+
+ Args:
+ operations: List of operation nodes
+ dependencies: Dependency graph
+
+ Returns:
+ List of operation names in optimal order
+ """
+ # Build in-degree count
+ in_degree = defaultdict(int)
+ all_ops = {op.name: op for op in operations}
+
+ for op_name in all_ops:
+ in_degree[op_name] = 0
+
+ for op_name, deps in dependencies.items():
+ for dep in deps:
+ in_degree[op_name] += 1
+
+ # Initialize queue with operations that have no dependencies
+ queue = deque([op for op in all_ops if in_degree[op] == 0])
+ result = []
+
+ while queue:
+ # Sort queue to prioritize filters over mappers
+ queue = deque(sorted(queue, key=lambda op: self._get_operation_priority(all_ops[op])))
+
+ current = queue.popleft()
+ result.append(current)
+
+ # Update in-degrees of dependent operations
+ for op_name, deps in dependencies.items():
+ if current in deps:
+ in_degree[op_name] -= 1
+ if in_degree[op_name] == 0:
+ queue.append(op_name)
+
+ return result
+
+ def _get_operation_priority(self, operation: OpNode) -> int:
+ """Get priority for operation ordering (lower = higher priority)."""
+ op_name = operation.name.lower()
+ op_type = operation.op_type
+
+ # Filters have highest priority (lowest number)
+ if op_type == OpType.FILTER:
+ return 1
+
+ # Text filters before other filters
+ if "text" in op_name and op_type == OpType.FILTER:
+ return 2
+
+ # Image filters
+ if "image" in op_name and op_type == OpType.FILTER:
+ return 3
+
+ # Audio filters
+ if "audio" in op_name and op_type == OpType.FILTER:
+ return 4
+
+ # Video filters
+ if "video" in op_name and op_type == OpType.FILTER:
+ return 5
+
+ # Mappers have lower priority
+ if op_type == OpType.MAPPER:
+ return 10
+
+ # Other operations
+ return 20
+
+ def _reorder_ast(self, ast: PipelineAST, optimal_order: List[str]) -> PipelineAST:
+ """
+ Reorder operations in the AST according to optimal order.
+
+ Args:
+ ast: Original pipeline AST
+ optimal_order: List of operation names in optimal order
+
+ Returns:
+ New AST with reordered operations
+ """
+ # Create new AST
+ new_ast = PipelineAST()
+ new_ast.root = OpNode(name="root", op_type=OpType.ROOT, config={})
+
+ # Get all operations in optimal order
+ operations = self._extract_operations(ast.root)
+ op_dict = {op.name: op for op in operations}
+
+ # Build new chain in optimal order
+ current = new_ast.root
+ for op_name in optimal_order:
+ if op_name in op_dict:
+ # Create a copy of the operation
+ op = op_dict[op_name]
+ new_op = OpNode(name=op.name, op_type=op.op_type, config=op.config.copy())
+ current.children = [new_op]
+ new_op.parent = current
+ current = new_op
+
+ return new_ast
+
+ def get_reorder_benefits(self, operations: List[OpNode]) -> Dict[str, Any]:
+ """
+ Analyze the potential benefits of reordering operations.
+
+ Args:
+ operations: List of operation nodes
+
+ Returns:
+ Dictionary with reordering benefits analysis
+ """
+ # Count filters vs mappers
+ filter_count = sum(1 for op in operations if op.op_type == OpType.FILTER)
+ mapper_count = sum(1 for op in operations if op.op_type == OpType.MAPPER)
+
+ # Analyze potential early filtering
+ early_filter_ops = []
+ for i, op in enumerate(operations):
+ if op.op_type == OpType.FILTER and i > 0:
+ early_filter_ops.append(op.name)
+
+ return {
+ "total_operations": len(operations),
+ "filter_count": filter_count,
+ "mapper_count": mapper_count,
+ "early_filter_opportunities": len(early_filter_ops),
+ "early_filter_ops": early_filter_ops,
+ "potential_speedup": "High" if len(early_filter_ops) > 0 else "Medium",
+ "memory_reduction": "High" if filter_count > mapper_count else "Medium",
+ }
diff --git a/data_juicer/core/optimizer/optimizer.py b/data_juicer/core/optimizer/optimizer.py
new file mode 100644
index 0000000000..9f08182587
--- /dev/null
+++ b/data_juicer/core/optimizer/optimizer.py
@@ -0,0 +1,160 @@
+from typing import Any, Dict, List, Optional
+
+from loguru import logger
+
+from data_juicer.core.optimizer.filter_fusion_strategy import FilterFusionStrategy
+from data_juicer.core.optimizer.mapper_fusion_strategy import MapperFusionStrategy
+from data_juicer.core.optimizer.op_reorder_strategy import OpReorderStrategy
+from data_juicer.core.optimizer.strategy import OptimizationStrategy
+from data_juicer.core.pipeline_ast import PipelineAST
+
+
+class PipelineOptimizer:
+ """Main optimizer class that manages multiple optimization strategies."""
+
+ def __init__(
+ self,
+ strategies: Optional[List[OptimizationStrategy]] = None,
+ analyzer_insights: Optional[Dict[str, Any]] = None,
+ ):
+ """Initialize the optimizer with a list of strategies.
+
+ Args:
+ strategies: List of optimization strategies to apply. If None,
+ default strategies will be used.
+ analyzer_insights: Optional dataset analysis insights for optimization
+ """
+ self.analyzer_insights = analyzer_insights or {}
+
+ if strategies is None:
+ # Create strategies with analyzer insights
+ self.strategies = [
+ OpReorderStrategy(), # Apply reordering first
+ MapperFusionStrategy(),
+ FilterFusionStrategy(analyzer_insights=self.analyzer_insights),
+ ]
+ else:
+ self.strategies = strategies
+
+ def add_strategy(self, strategy: OptimizationStrategy) -> None:
+ """Add a new optimization strategy.
+
+ Args:
+ strategy: The optimization strategy to add.
+ """
+ self.strategies.append(strategy)
+
+ def remove_strategy(self, strategy_name: str) -> None:
+ """Remove an optimization strategy by name.
+
+ Args:
+ strategy_name: Name of the strategy to remove.
+ """
+ self.strategies = [s for s in self.strategies if s.name != strategy_name]
+
+ def optimize(self, ast: PipelineAST) -> PipelineAST:
+ """Apply all optimization strategies to the pipeline AST.
+
+ Args:
+ ast: The pipeline AST to optimize
+
+ Returns:
+ Optimized pipeline AST
+ """
+ logger.info(f"Starting pipeline optimization with {len(self.strategies)} strategies")
+
+ if self.analyzer_insights:
+ logger.info("Using analyzer insights for optimization:")
+ dataset_size = self.analyzer_insights.get("dataset_size", 0)
+ if dataset_size > 0:
+ logger.info(f" Dataset size: {dataset_size:,} samples")
+
+ text_stats = self.analyzer_insights.get("text_length", {})
+ if text_stats:
+ mean_len = text_stats.get("mean", 0)
+ std_len = text_stats.get("std", 0)
+ if mean_len > 0:
+ cv = std_len / mean_len
+ logger.info(f" Text length CV: {cv:.2f} (mean: {mean_len:.1f}, std: {std_len:.1f})")
+
+ optimized_ast = ast
+ for strategy in self.strategies:
+ logger.info(f"Applying {strategy.name} strategy...")
+ optimized_ast = strategy.optimize(optimized_ast)
+
+ logger.info("Pipeline optimization completed")
+ return optimized_ast
+
+ def set_analyzer_insights(self, insights: Dict[str, Any]) -> None:
+ """Set analyzer insights for optimization strategies.
+
+ Args:
+ insights: Dictionary containing dataset analysis insights
+ """
+ self.analyzer_insights = insights
+
+ # Update existing strategies that support analyzer insights
+ for strategy in self.strategies:
+ if hasattr(strategy, "analyzer_insights"):
+ strategy.analyzer_insights = insights
+ logger.info(f"Updated {strategy.name} with analyzer insights")
+
+ def get_optimization_summary(self) -> Dict[str, Any]:
+ """Get a summary of the optimization configuration.
+
+ Returns:
+ Dictionary containing optimization summary
+ """
+ summary = {
+ "strategies": [s.name for s in self.strategies],
+ "analyzer_insights_available": bool(self.analyzer_insights),
+ "insights_keys": list(self.analyzer_insights.keys()) if self.analyzer_insights else [],
+ }
+
+ if self.analyzer_insights:
+ dataset_size = self.analyzer_insights.get("dataset_size", 0)
+ summary["dataset_size"] = dataset_size
+
+ # Add data complexity indicators
+ text_stats = self.analyzer_insights.get("text_length", {})
+ if text_stats:
+ mean_len = text_stats.get("mean", 0)
+ std_len = text_stats.get("std", 0)
+ if mean_len > 0:
+ summary["text_complexity"] = std_len / mean_len
+
+ content_ratios = self.analyzer_insights.get("content_ratios", {})
+ multimodal_count = sum(
+ 1
+ for indicator in ["image_ratio", "audio_ratio", "video_ratio"]
+ if content_ratios.get(indicator, 0) > 0.1
+ )
+ summary["multimodal_types"] = multimodal_count
+
+ return summary
+
+ def get_strategy(self, strategy_name: str) -> Optional[OptimizationStrategy]:
+ """Get a strategy by name.
+
+ Args:
+ strategy_name: Name of the strategy to get.
+
+ Returns:
+ The strategy if found, None otherwise.
+ """
+ for strategy in self.strategies:
+ if strategy.name == strategy_name:
+ return strategy
+ return None
+
+ def get_strategy_names(self) -> List[str]:
+ """Get names of all registered strategies.
+
+ Returns:
+ List of strategy names.
+ """
+ return [strategy.name for strategy in self.strategies]
+
+ def clear_strategies(self) -> None:
+ """Remove all optimization strategies."""
+ self.strategies = []
diff --git a/data_juicer/core/optimizer/strategy.py b/data_juicer/core/optimizer/strategy.py
new file mode 100644
index 0000000000..648dccf62c
--- /dev/null
+++ b/data_juicer/core/optimizer/strategy.py
@@ -0,0 +1,138 @@
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional, Type
+
+from loguru import logger
+
+from data_juicer.core.pipeline_ast import OpNode, PipelineAST
+
+
+class OptimizationStrategy(ABC):
+ """Base class for pipeline optimization strategies."""
+
+ def __init__(self, name: str):
+ """Initialize the optimization strategy.
+
+ Args:
+ name: Name of the optimization strategy
+ """
+ self.name = name
+
+ @abstractmethod
+ def optimize(self, ast: PipelineAST) -> PipelineAST:
+ """Apply the optimization strategy to the pipeline AST.
+
+ Args:
+ ast: The pipeline AST to optimize
+
+ Returns:
+ Optimized pipeline AST
+ """
+ pass
+
+ def _get_operation_chain(self, node: OpNode) -> List[OpNode]:
+ """Get the linear chain of operations from a node.
+
+ Args:
+ node: The node to start from
+
+ Returns:
+ List of operations in the chain
+ """
+ chain = []
+ current = node
+ while current.children:
+ current = current.children[0]
+ chain.append(current)
+ return chain
+
+ def _rebuild_chain(self, root: OpNode, chain: List[OpNode]) -> None:
+ """Rebuild the operation chain from a list of nodes.
+
+ Args:
+ root: The root node
+ chain: List of operations to chain
+ """
+ current = root
+ for node in chain:
+ current.children = [node]
+ node.parent = current
+ current = node
+
+
+class StrategyRegistry:
+ """Registry for optimization strategies."""
+
+ _strategies: Dict[str, Type[OptimizationStrategy]] = {}
+
+ @classmethod
+ def register(cls, name: str, strategy_class: Type[OptimizationStrategy]) -> None:
+ """Register a strategy class with a name.
+
+ Args:
+ name: Name to register the strategy under
+ strategy_class: The strategy class to register
+ """
+ cls._strategies[name] = strategy_class
+ logger.debug(f"๐ง Registered optimization strategy: {name} -> {strategy_class.__name__}")
+
+ @classmethod
+ def get_strategy_class(cls, name: str) -> Optional[Type[OptimizationStrategy]]:
+ """Get a strategy class by name.
+
+ Args:
+ name: Name of the strategy
+
+ Returns:
+ The strategy class if found, None otherwise
+ """
+ return cls._strategies.get(name)
+
+ @classmethod
+ def get_available_strategies(cls) -> List[str]:
+ """Get list of available strategy names.
+
+ Returns:
+ List of registered strategy names
+ """
+ return list(cls._strategies.keys())
+
+ @classmethod
+ def create_strategy(cls, name: str, **kwargs) -> Optional[OptimizationStrategy]:
+ """Create a strategy instance by name.
+
+ Args:
+ name: Name of the strategy
+ **kwargs: Additional arguments to pass to strategy constructor
+
+ Returns:
+ Strategy instance if found and created successfully, None otherwise
+ """
+ strategy_class = cls.get_strategy_class(name)
+ if strategy_class is None:
+ logger.warning(f"โ ๏ธ Unknown strategy '{name}'. Available strategies: {cls.get_available_strategies()}")
+ return None
+
+ try:
+ return strategy_class(**kwargs)
+ except Exception as e:
+ logger.error(f"โ Failed to create strategy '{name}': {e}")
+ return None
+
+
+def register_strategy(name: str):
+ """Decorator to register a strategy class.
+
+ Args:
+ name: Name to register the strategy under
+
+ Example:
+ @register_strategy('op_reorder')
+ class OpReorderStrategy(OptimizationStrategy):
+ pass
+ """
+
+ def decorator(strategy_class: Type[OptimizationStrategy]) -> Type[OptimizationStrategy]:
+ StrategyRegistry.register(name, strategy_class)
+ return strategy_class
+
+ return decorator
diff --git a/data_juicer/core/pipeline_ast.py b/data_juicer/core/pipeline_ast.py
new file mode 100644
index 0000000000..11708ec42a
--- /dev/null
+++ b/data_juicer/core/pipeline_ast.py
@@ -0,0 +1,230 @@
+# standard library imports
+import argparse
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict, List, Optional
+
+# third party imports
+import yaml
+
+
+class OpType(Enum):
+ """Types of operations in the pipeline."""
+
+ ROOT = "root"
+ MAPPER = "mapper"
+ FILTER = "filter"
+ DEDUPLICATOR = "deduplicator"
+ SELECTOR = "selector"
+ GROUPER = "grouper"
+ AGGREGATOR = "aggregator"
+
+
+@dataclass
+class OpNode:
+ """Node in the pipeline AST representing an operation."""
+
+ name: str
+ op_type: OpType
+ config: Dict[str, Any]
+ children: List["OpNode"] = None
+ parent: Optional["OpNode"] = None
+
+ def __post_init__(self):
+ if self.children is None:
+ self.children = []
+
+ def add_child(self, child: "OpNode"):
+ """Add a child node to this operation."""
+ child.parent = self
+ self.children.append(child)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert the node to a dictionary representation."""
+ return {
+ "name": self.name,
+ "type": self.op_type.value,
+ "config": self.config,
+ "children": [child.to_dict() for child in self.children],
+ }
+
+
+class PipelineAST:
+ """Abstract Syntax Tree for a Data-Juicer pipeline."""
+
+ def __init__(self):
+ self.root = None
+ self._op_type_map = {
+ "mapper": OpType.MAPPER,
+ "filter": OpType.FILTER,
+ "deduplicator": OpType.DEDUPLICATOR,
+ "selector": OpType.SELECTOR,
+ "grouper": OpType.GROUPER,
+ "aggregator": OpType.AGGREGATOR,
+ }
+
+ # Operation dependencies and optimization rules
+ self._op_dependencies = {
+ OpType.FILTER: {OpType.MAPPER}, # Filters can depend on mappers
+ OpType.DEDUPLICATOR: {OpType.MAPPER, OpType.FILTER}, # Deduplicators can depend on mappers and filters
+ OpType.SELECTOR: {
+ OpType.MAPPER,
+ OpType.FILTER,
+ OpType.DEDUPLICATOR,
+ }, # Selectors can depend on all previous ops
+ OpType.GROUPER: {
+ OpType.MAPPER,
+ OpType.FILTER,
+ OpType.DEDUPLICATOR,
+ OpType.SELECTOR,
+ }, # Groupers can depend on all previous ops
+ OpType.AGGREGATOR: {OpType.GROUPER}, # Aggregators can only depend on groupers
+ }
+
+ def _get_op_type(self, op_name: str) -> OpType:
+ """Determine the operation type from its name."""
+ for suffix, op_type in self._op_type_map.items():
+ if op_name.endswith(f"_{suffix}"):
+ return op_type
+ return OpType.MAPPER # Default to mapper if type cannot be determined
+
+ def build_from_config(self, config: Dict[str, Any]) -> None:
+ """Build the AST from a configuration dictionary."""
+ if "process" not in config:
+ raise ValueError("Configuration must contain a 'process' field")
+
+ process_list = config["process"]
+ if not process_list:
+ return
+
+ # Create root node
+ self.root = OpNode(name="root", op_type=OpType.ROOT, config={}) # Root is a special type
+
+ # Build tree following the order in process_list
+ current_node = self.root
+ for op_config in process_list:
+ op_name, op_args = list(op_config.items())[0]
+ op_type = self._get_op_type(op_name)
+
+ new_node = OpNode(name=op_name, op_type=op_type, config=op_args)
+ current_node.add_child(new_node)
+ current_node = new_node
+
+ def build_from_yaml(self, yaml_path: str) -> None:
+ """Build the AST from a YAML configuration file."""
+ with open(yaml_path, "r") as f:
+ config = yaml.safe_load(f)
+ self.build_from_config(config)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert the AST to a dictionary representation."""
+ if not self.root:
+ return {}
+ return self.root.to_dict()
+
+ def visualize(self) -> str:
+ """Generate a string representation of the AST for visualization."""
+ if not self.root:
+ return "Empty pipeline"
+
+ def _visualize_node(node: OpNode, level: int = 0, is_last: bool = True) -> str:
+ indent = " " * level
+ prefix = "โโโ " if is_last else "โโโ "
+
+ # Check if this is a fused operation and get detailed ops
+ detailed_ops = None
+ if node.name == "fused_mapper" and "fused_mapper" in node.config:
+ detailed_ops = node.config["fused_mapper"].get("detailed_ops", [])
+ elif node.name == "fused_filter" and "general_fused_op" in node.config:
+ detailed_ops = node.config["general_fused_op"].get("detailed_ops", [])
+
+ # Format the node name with detailed operations if available
+ if detailed_ops:
+ ops_str = ", ".join(detailed_ops)
+ result = f"{indent}{prefix}{node.name} ({node.op_type.value}) [{ops_str}]\n"
+ else:
+ result = f"{indent}{prefix}{node.name} ({node.op_type.value})\n"
+
+ for i, child in enumerate(node.children):
+ is_last_child = i == len(node.children) - 1
+ result += _visualize_node(child, level + 1, is_last_child)
+ return result
+
+ return "Pipeline:\n" + _visualize_node(self.root, 0, True)
+
+ @staticmethod
+ def is_mapper_op(node_or_type) -> bool:
+ """Check if node or op_type is a mapper operation using value comparison."""
+ if hasattr(node_or_type, "op_type"):
+ return getattr(node_or_type, "op_type").value == "mapper"
+ return node_or_type.value == "mapper"
+
+ @staticmethod
+ def is_filter_op(node_or_type) -> bool:
+ """Check if node or op_type is a filter operation using value comparison."""
+ if hasattr(node_or_type, "op_type"):
+ return getattr(node_or_type, "op_type").value == "filter"
+ return node_or_type.value == "filter"
+
+ @staticmethod
+ def op_type_equals(a, b) -> bool:
+ """Compare OpType values safely to handle module import issues."""
+ return getattr(a, "value", a) == getattr(b, "value", b)
+
+
+if __name__ == "__main__":
+ import os
+
+ # Set up argument parser
+ parser = argparse.ArgumentParser(description="Build and visualize pipeline AST from config file")
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="configs/data_juicer_recipes/pile-philpaper-refine.yaml",
+ help="Path to the pipeline configuration file (YAML)",
+ )
+ parser.add_argument(
+ "--probe-results", type=str, help="Path to probe results file (YAML) containing operation speeds"
+ )
+ parser.add_argument("--optimize", action="store_true", help="Apply optimization strategies to the pipeline")
+
+ args = parser.parse_args()
+
+ # Get absolute path to config file
+ config_path = os.path.abspath(args.config)
+ print(f"Using config file: {config_path}")
+
+ # Load and process config
+ config = yaml.safe_load(open(config_path, "r"))
+
+ # Build initial AST
+ ast = PipelineAST()
+ ast.build_from_config(config)
+ print("\nOriginal Pipeline:")
+ print(ast.visualize())
+
+ # Apply optimization if requested
+ if args.optimize:
+ from data_juicer.core.optimizer.filter_fusion_strategy import (
+ FilterFusionStrategy,
+ )
+ from data_juicer.core.optimizer.mapper_fusion_strategy import (
+ MapperFusionStrategy,
+ )
+ from data_juicer.core.optimizer.optimizer import PipelineOptimizer
+
+ # Load probe results if provided
+ probe_results = None
+ if args.probe_results:
+ probe_path = os.path.abspath(args.probe_results)
+ print(f"\nUsing probe results from: {probe_path}")
+ probe_results = yaml.safe_load(open(probe_path, "r"))
+
+ # Create optimizer with filter fusion strategy
+ optimizer = PipelineOptimizer([FilterFusionStrategy(probe_results=probe_results), MapperFusionStrategy()])
+
+ # Apply optimization
+ optimized_ast = optimizer.optimize(ast)
+
+ print("\nOptimized Pipeline:")
+ print(optimized_ast.visualize())
diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py
index f350f21741..1f1f87e569 100644
--- a/data_juicer/ops/base_op.py
+++ b/data_juicer/ops/base_op.py
@@ -497,14 +497,14 @@ def compute_stats_batched(self, samples, *args, **kwargs):
def process_batched(self, samples):
return map(lambda stat: self.process_single({Fields.stats: stat}), samples[Fields.stats])
- def compute_stats_single(self, sample, context=False):
+ def compute_stats_single(self, sample, *args, **kwargs):
"""
Compute stats for the sample which is used as a metric to decide
whether to filter this sample.
:param sample: input sample.
- :param context: whether to store context information of intermediate
- vars in the sample temporarily.
+ :param args: additional positional arguments
+ :param kwargs: additional keyword arguments (e.g., context=False, rank=None)
:return: sample with computed stats
"""
raise NotImplementedError
diff --git a/data_juicer/ops/filter/character_repetition_filter.py b/data_juicer/ops/filter/character_repetition_filter.py
index 80c2dc0c78..bc77c4caa6 100644
--- a/data_juicer/ops/filter/character_repetition_filter.py
+++ b/data_juicer/ops/filter/character_repetition_filter.py
@@ -43,7 +43,27 @@ def __init__(self, rep_len: PositiveInt = 10, min_ratio: float = 0.0, max_ratio:
self.min_ratio = min_ratio
self.max_ratio = max_ratio
- def compute_stats_batched(self, samples):
+ def _compute_char_rep_ratio(self, text):
+ """Compute character repetition ratio for a given text."""
+ char_ngrams = [text[i : i + self.n] for i in range(len(text) - self.n + 1)]
+ freq_char_ngrams = {}
+ for char_ngram in char_ngrams:
+ freq_char_ngrams[char_ngram] = freq_char_ngrams.get(char_ngram, 0) + 1
+
+ if len(freq_char_ngrams) == 0:
+ return 0.0
+
+ freq_char_ngrams = sorted(list(freq_char_ngrams.values()), reverse=True)
+ num_no_rep_char_ngrams = len([el for el in freq_char_ngrams if el == 1])
+ num_rep_char_ngrams = min(
+ int(np.sqrt(len(freq_char_ngrams))),
+ len(freq_char_ngrams) - num_no_rep_char_ngrams,
+ )
+ return (
+ (sum(freq_char_ngrams[:num_rep_char_ngrams]) / sum(freq_char_ngrams)) if sum(freq_char_ngrams) != 0 else 0.0
+ )
+
+ def compute_stats_batched(self, samples, *args, **kwargs):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
diff --git a/data_juicer/ops/filter/language_id_score_filter.py b/data_juicer/ops/filter/language_id_score_filter.py
index b41d705b5b..3a6ab262fb 100644
--- a/data_juicer/ops/filter/language_id_score_filter.py
+++ b/data_juicer/ops/filter/language_id_score_filter.py
@@ -42,7 +42,7 @@ def __init__(self, lang: Union[str, List[str]] = "", min_score: float = 0.8, *ar
self.min_score = min_score
self.model_key = prepare_model(model_type="fasttext")
- def compute_stats_single(self, sample):
+ def compute_stats_single(self, sample, *args, **kwargs):
# check if it's computed already
if StatsKeys.lang in sample[Fields.stats] and StatsKeys.lang_score in sample[Fields.stats]:
return sample
diff --git a/data_juicer/ops/filter/text_length_filter.py b/data_juicer/ops/filter/text_length_filter.py
index 5e7c218da8..0e78786d1b 100644
--- a/data_juicer/ops/filter/text_length_filter.py
+++ b/data_juicer/ops/filter/text_length_filter.py
@@ -34,7 +34,7 @@ def __init__(self, min_len: int = 10, max_len: int = sys.maxsize, *args, **kwarg
self.min_len = min_len
self.max_len = max_len
- def compute_stats_batched(self, samples):
+ def compute_stats_batched(self, samples, *args, **kwargs):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
for i, stat in enumerate(samples_stats):
diff --git a/data_juicer/ops/filter/words_num_filter.py b/data_juicer/ops/filter/words_num_filter.py
index b6b7fe7f36..f8aa1cfe81 100644
--- a/data_juicer/ops/filter/words_num_filter.py
+++ b/data_juicer/ops/filter/words_num_filter.py
@@ -56,7 +56,7 @@ def __init__(
if tokenization:
self.model_key = prepare_model(model_type="sentencepiece", lang=lang)
- def compute_stats_batched(self, samples, context=False):
+ def compute_stats_batched(self, samples, *args, **kwargs):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
words_key = f"{InterVars.words}-{self.model_key}"
@@ -65,6 +65,7 @@ def compute_stats_batched(self, samples, context=False):
# check if it's computed already
if StatsKeys.num_words in stat:
continue
+ context = kwargs.get("context", False)
if context and words_key in samples[Fields.context][idx]:
words = samples[Fields.context][idx][words_key]
else:
diff --git a/data_juicer/ops/op_fusion.py b/data_juicer/ops/op_fusion.py
index a4a8967f3f..49fe59036d 100644
--- a/data_juicer/ops/op_fusion.py
+++ b/data_juicer/ops/op_fusion.py
@@ -263,6 +263,27 @@ def process_batched(self, samples, rank=None):
_ = tmp_samples.pop(Fields.context)
return tmp_samples
+ def process_batched_for_validation(self, samples, rank=None):
+ """Process samples and return boolean masks for validation purposes."""
+ # Initialize mask as all True
+ mask = [True] * len(samples.get("text", []))
+
+ for op in self.fused_ops:
+ process_args = {"rank": rank} if op.accelerator == "cuda" else {}
+ if isinstance(op, Mapper):
+ samples = op.process_batched(samples, **process_args)
+ elif isinstance(op, Filter):
+ samples = op.compute_stats_batched(samples, **process_args)
+ indicators = list(op.process_batched(samples))
+ # Apply AND logic with the current mask
+ mask = [m and i for m, i in zip(mask, indicators)]
+ else:
+ raise NotImplementedError(
+ f"FusedOP does not support OP {op._name} of type "
+ f"{type(op)} and only supports Mapper and Filter now."
+ )
+ return mask
+
def run(self, dataset, *, exporter=None, tracer=None):
# prepare the dataset
from data_juicer.core.data import NestedDataset
diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py
index 9f38ac3ba0..c7e65146be 100644
--- a/data_juicer/utils/model_utils.py
+++ b/data_juicer/utils/model_utils.py
@@ -117,6 +117,9 @@ def check_model(model_name, force=False):
except: # noqa: E722
backup_model_link = get_backup_model_link(model_name)
if backup_model_link is not None:
+ # Ensure backup_model_link is a string, not bytes
+ if isinstance(backup_model_link, bytes):
+ backup_model_link = backup_model_link.decode("utf-8")
backup_model_link = os.path.join(backup_model_link, model_name)
try:
wget.download(backup_model_link, cached_model_path)
@@ -494,7 +497,13 @@ def prepare_kenlm_model(lang, name_pattern="{}.arpa.bin", **model_params):
model_name = name_pattern.format(lang)
- logger.info("Loading kenlm language model...")
+ # Add stack trace to see where this is being called from
+ import traceback
+
+ stack_trace = traceback.format_stack()[-3:] # Last 3 frames
+ logger.info(f"Loading kenlm language model... (lang={lang}, model_name={model_name})")
+ logger.debug(f"Call stack:\n{''.join(stack_trace)}")
+
try:
kenlm_model = kenlm.Model(check_model(model_name), **model_params)
except: # noqa: E722
@@ -638,7 +647,13 @@ def prepare_sentencepiece_model(model_path, **model_params):
:param model_path: input model path
:return: model instance
"""
- logger.info("Loading sentencepiece model...")
+ # Add stack trace to see where this is being called from
+ import traceback
+
+ stack_trace = traceback.format_stack()[-3:] # Last 3 frames
+ logger.info(f"Loading sentencepiece model... (model_path={model_path})")
+ logger.debug(f"Call stack:\n{''.join(stack_trace)}")
+
sentencepiece_model = sentencepiece.SentencePieceProcessor()
try:
sentencepiece_model.load(check_model(model_path))
diff --git a/pyproject.toml b/pyproject.toml
index e2b5439ad1..302daa4917 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -202,8 +202,10 @@ extend-ignore = [
"E203", # whitespace before ':' (black handles this)
"E501", # line too long (black handles this)
"BLK100", # black would make changes (black handles this)
+ "F541",
]
+
[tool.black]
line-length = 120
target-version = ['py310']
diff --git a/tests/benchmark_performance/configs/text-c4-local.yaml b/tests/benchmark_performance/configs/text-c4-local.yaml
new file mode 100644
index 0000000000..6db0c31ef6
--- /dev/null
+++ b/tests/benchmark_performance/configs/text-c4-local.yaml
@@ -0,0 +1,58 @@
+project_name: 'performance-benchmark-text-c4-local'
+dataset_path: 'perf_bench_data/text/c4-train.00000-of-01024.jsonl'
+export_path: 'outputs/performance_benchmark_text_c4_local/res.jsonl'
+np: 16
+use_cache: false
+
+# Process pipeline with real DataJuicer operations
+process:
+ # Text cleaning operations
+ - clean_links_mapper:
+ text_key: "text"
+ min_links: 0
+ max_links: 10
+
+ - clean_email_mapper:
+ text_key: "text"
+ min_emails: 0
+ max_emails: 5
+
+ - whitespace_normalization_mapper:
+ text_key: "text"
+
+ - fix_unicode_mapper:
+ text_key: "text"
+
+ # Text filtering operations
+ - text_length_filter:
+ text_key: "text"
+ min_len: 5
+ max_len: 10000
+
+ - alphanumeric_filter:
+ text_key: "text"
+ min_ratio: 0.1
+
+ # Quality filtering
+ - character_repetition_filter:
+ text_key: "text"
+ min_ratio: 0.0
+ max_ratio: 0.5
+
+ - word_repetition_filter:
+ text_key: "text"
+ min_ratio: 0.0
+ max_ratio: 0.5
+
+ # Deduplication
+ - document_deduplicator:
+ text_key: "text"
+ lowercase: false
+ ignore_non_character: false
+
+
+
+# Export configuration
+export_in_parallel: true
+keep_stats_in_res_ds: true
+keep_hashes_in_res_ds: true
diff --git a/tests/benchmark_performance/configs/text-c4-ray.yaml b/tests/benchmark_performance/configs/text-c4-ray.yaml
new file mode 100644
index 0000000000..ec4c09aabb
--- /dev/null
+++ b/tests/benchmark_performance/configs/text-c4-ray.yaml
@@ -0,0 +1,63 @@
+project_name: 'performance-benchmark-text-c4-ray'
+dataset_path: 'perf_bench_data/text/c4-train.00000-of-01024.jsonl'
+export_path: 'outputs/performance_benchmark_text_c4_ray/res.jsonl'
+use_cache: false
+
+executor_type: ray
+ray_address: "auto"
+
+# Process pipeline with real DataJuicer operations
+process:
+ # Text cleaning operations
+ - clean_links_mapper:
+ text_key: "text"
+ min_links: 0
+ max_links: 10
+
+ - clean_email_mapper:
+ text_key: "text"
+ min_emails: 0
+ max_emails: 5
+
+ - whitespace_normalization_mapper:
+ text_key: "text"
+
+ - fix_unicode_mapper:
+ text_key: "text"
+
+ # Text filtering operations
+ - text_length_filter:
+ text_key: "text"
+ min_len: 5
+ max_len: 10000
+
+ - alphanumeric_filter:
+ text_key: "text"
+ min_ratio: 0.1
+
+ # Quality filtering
+ - character_repetition_filter:
+ text_key: "text"
+ min_ratio: 0.0
+ max_ratio: 0.5
+
+ - word_repetition_filter:
+ text_key: "text"
+ min_ratio: 0.0
+ max_ratio: 0.5
+
+ # Deduplication
+ - ray_bts_minhash_deduplicator:
+ text_key: "text"
+ tokenization: "space"
+ window_size: 5
+ lowercase: false
+ jaccard_threshold: 0.7
+ num_permutations: 256
+
+
+
+# Export configuration
+export_in_parallel: true
+keep_stats_in_res_ds: true
+keep_hashes_in_res_ds: true
diff --git a/tests/core/optimizer/test_op_fusion_strategy.py b/tests/core/optimizer/test_op_fusion_strategy.py
new file mode 100644
index 0000000000..1cfae7b67c
--- /dev/null
+++ b/tests/core/optimizer/test_op_fusion_strategy.py
@@ -0,0 +1,120 @@
+import unittest
+from unittest.mock import Mock, patch
+
+from data_juicer.core.optimizer.filter_fusion_strategy import FilterFusionStrategy
+from data_juicer.core.pipeline_ast import PipelineAST, OpNode, OpType
+
+class TestFilterFusionStrategy(unittest.TestCase):
+ def setUp(self):
+ self.strategy = FilterFusionStrategy()
+ self.ast = PipelineAST()
+
+ # Sample probe results
+ self.probe_results = {
+ 'language_id_score_filter': {'speed': 0.5},
+ 'clean_copyright_mapper': {'speed': 0.3},
+ 'alphanumeric_filter': {'speed': 0.4}
+ }
+
+ # Create a sample pipeline configuration
+ self.config = {
+ 'process': [
+ {
+ 'name': 'clean_copyright_mapper',
+ 'type': 'mapper',
+ 'config': {'key': 'value1'}
+ },
+ {
+ 'name': 'language_id_score_filter',
+ 'type': 'filter',
+ 'config': {'key': 'value2'}
+ },
+ {
+ 'name': 'alphanumeric_filter',
+ 'type': 'filter',
+ 'config': {'key': 'value3'}
+ }
+ ]
+ }
+
+ def test_optimize_single_filter(self):
+ """Test optimization with a single filter."""
+ # Build AST with single filter
+ config = {
+ 'process': [
+ {
+ 'name': 'language_id_score_filter',
+ 'type': 'filter',
+ 'config': {'key': 'value'}
+ }
+ ]
+ }
+ self.ast.build_from_config(config)
+
+ # Apply optimization
+ optimized_ast = self.strategy.optimize(self.ast)
+
+ # Verify the filter remains unchanged
+ chain = self.strategy._get_operation_chain(optimized_ast.root)
+ self.assertEqual(len(chain), 1)
+ self.assertEqual(chain[0].name, 'language_id_score_filter')
+
+ def test_optimize_multiple_filters(self):
+ """Test optimization with multiple filters."""
+ # Build AST with multiple filters
+ self.ast.build_from_config(self.config)
+
+ # Apply optimization
+ optimized_ast = self.strategy.optimize(self.ast)
+
+ # Verify filters are fused
+ chain = self.strategy._get_operation_chain(optimized_ast.root)
+ self.assertEqual(len(chain), 2) # mapper + fused filters
+
+ # Check that the fused node contains both filters
+ fused_node = chain[1]
+ self.assertTrue(fused_node.name.startswith('fused_'))
+ self.assertEqual(len(fused_node.original_ops), 2)
+
+ def test_optimize_with_probe_results(self):
+ """Test optimization with probe results for speed-based sorting."""
+ strategy = FilterFusionStrategy(probe_results=self.probe_results)
+ self.ast.build_from_config(self.config)
+
+ # Apply optimization
+ optimized_ast = strategy.optimize(self.ast)
+
+ # Verify filters are fused and sorted by speed
+ chain = strategy._get_operation_chain(optimized_ast.root)
+ fused_node = chain[1]
+
+ # Check that filters are sorted by speed
+ original_ops = fused_node.original_ops
+ self.assertEqual(original_ops[0].name, 'language_id_score_filter') # speed: 0.5
+ self.assertEqual(original_ops[1].name, 'alphanumeric_filter') # speed: 0.4
+
+ def test_optimize_empty_pipeline(self):
+ """Test optimization with an empty pipeline."""
+ optimized_ast = self.strategy.optimize(self.ast)
+ self.assertIsNone(optimized_ast.root)
+
+ def test_create_fused_filter_node(self):
+ """Test creation of a fused filter node."""
+ # Create sample filter nodes
+ filter1 = OpNode('filter1', OpType.FILTER, {'key1': 'value1'})
+ filter2 = OpNode('filter2', OpType.FILTER, {'key2': 'value2'})
+
+ # Create fused node
+ fused_node = self.strategy._create_fused_filter_node(
+ 'fused_filters',
+ [filter1, filter2]
+ )
+
+ # Verify fused node properties
+ self.assertEqual(fused_node.name, 'fused_filters')
+ self.assertEqual(fused_node.op_type, OpType.FILTER)
+ self.assertEqual(fused_node.config, {'key1': 'value1', 'key2': 'value2'})
+ self.assertEqual(fused_node.original_ops, [filter1, filter2])
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/core/optimizer/test_optimizer.py b/tests/core/optimizer/test_optimizer.py
new file mode 100644
index 0000000000..03b69652bd
--- /dev/null
+++ b/tests/core/optimizer/test_optimizer.py
@@ -0,0 +1,123 @@
+import unittest
+from typing import List, Dict, Any
+from unittest.mock import Mock, patch
+
+from data_juicer.core.pipeline_ast import PipelineAST, OpNode, OpType
+from data_juicer.core.optimizer.optimizer import PipelineOptimizer
+from data_juicer.core.optimizer.strategy import OptimizationStrategy
+from data_juicer.core.optimizer.filter_fusion_strategy import FilterFusionStrategy
+
+class MockStrategy(OptimizationStrategy):
+ """Mock strategy for testing."""
+
+ def __init__(self, name: str = "mock_strategy"):
+ super().__init__(name)
+ self.optimize_called = False
+
+ def optimize(self, ast: PipelineAST) -> PipelineAST:
+ self.optimize_called = True
+ return ast
+
+class TestPipelineOptimizer(unittest.TestCase):
+ def setUp(self):
+ self.ast = PipelineAST()
+ self.config = {
+ 'process': [
+ {
+ 'name': 'clean_copyright_mapper',
+ 'type': 'mapper',
+ 'config': {'key': 'value1'}
+ },
+ {
+ 'name': 'language_id_score_filter',
+ 'type': 'filter',
+ 'config': {'key': 'value2'}
+ },
+ {
+ 'name': 'alphanumeric_filter',
+ 'type': 'filter',
+ 'config': {'key': 'value3'}
+ }
+ ]
+ }
+ self.ast.build_from_config(self.config)
+
+ def test_init_default_strategies(self):
+ """Test initialization with default strategies."""
+ optimizer = PipelineOptimizer()
+ self.assertEqual(len(optimizer.strategies), 1)
+ self.assertIsInstance(optimizer.strategies[0], FilterFusionStrategy)
+
+ def test_init_custom_strategies(self):
+ """Test initialization with custom strategies."""
+ strategies = [MockStrategy("strategy1"), MockStrategy("strategy2")]
+ optimizer = PipelineOptimizer(strategies)
+ self.assertEqual(len(optimizer.strategies), 2)
+ self.assertEqual(optimizer.strategies, strategies)
+
+ def test_add_strategy(self):
+ """Test adding a new strategy."""
+ optimizer = PipelineOptimizer()
+ strategy = MockStrategy()
+ optimizer.add_strategy(strategy)
+ self.assertEqual(len(optimizer.strategies), 2)
+ self.assertIn(strategy, optimizer.strategies)
+
+ def test_remove_strategy(self):
+ """Test removing a strategy by name."""
+ optimizer = PipelineOptimizer()
+ strategy = MockStrategy("test_strategy")
+ optimizer.add_strategy(strategy)
+ optimizer.remove_strategy("test_strategy")
+ self.assertEqual(len(optimizer.strategies), 1)
+ self.assertNotIn(strategy, optimizer.strategies)
+
+ def test_optimize_empty_pipeline(self):
+ """Test optimization of an empty pipeline."""
+ optimizer = PipelineOptimizer()
+ empty_ast = PipelineAST()
+ optimized_ast = optimizer.optimize(empty_ast)
+ self.assertIsNone(optimized_ast.root)
+
+ def test_optimize_with_multiple_strategies(self):
+ """Test optimization with multiple strategies."""
+ strategy1 = MockStrategy("strategy1")
+ strategy2 = MockStrategy("strategy2")
+ optimizer = PipelineOptimizer([strategy1, strategy2])
+
+ optimized_ast = optimizer.optimize(self.ast)
+
+ self.assertTrue(strategy1.optimize_called)
+ self.assertTrue(strategy2.optimize_called)
+ self.assertIsNotNone(optimized_ast.root)
+
+ def test_get_strategy(self):
+ """Test getting a strategy by name."""
+ strategy = MockStrategy("test_strategy")
+ optimizer = PipelineOptimizer([strategy])
+
+ found_strategy = optimizer.get_strategy("test_strategy")
+ self.assertEqual(found_strategy, strategy)
+
+ not_found = optimizer.get_strategy("nonexistent")
+ self.assertIsNone(not_found)
+
+ def test_get_strategy_names(self):
+ """Test getting names of all strategies."""
+ strategies = [
+ MockStrategy("strategy1"),
+ MockStrategy("strategy2")
+ ]
+ optimizer = PipelineOptimizer(strategies)
+
+ names = optimizer.get_strategy_names()
+ self.assertEqual(names, ["strategy1", "strategy2"])
+
+ def test_clear_strategies(self):
+ """Test clearing all strategies."""
+ optimizer = PipelineOptimizer()
+ optimizer.clear_strategies()
+ self.assertEqual(len(optimizer.strategies), 0)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/core/test_pipeline_ast.py b/tests/core/test_pipeline_ast.py
new file mode 100644
index 0000000000..2d4f0ed5dd
--- /dev/null
+++ b/tests/core/test_pipeline_ast.py
@@ -0,0 +1,213 @@
+import unittest
+from data_juicer.core.pipeline_ast import PipelineAST, OpType
+
+class TestPipelineAST(unittest.TestCase):
+ def setUp(self):
+ self.ast = PipelineAST()
+
+ def test_build_from_config(self):
+ config = {
+ 'process': [
+ {'language_id_score_filter': {'lang': 'zh', 'min_score': 0.8}},
+ {'clean_copyright_mapper': {}},
+ {'alphanumeric_filter': {'min_ratio': 0.25, 'max_ratio': 1.0}}
+ ]
+ }
+
+ self.ast.build_from_config(config)
+
+ # Test root node
+ self.assertIsNotNone(self.ast.root)
+ self.assertEqual(self.ast.root.name, "root")
+ self.assertEqual(self.ast.root.op_type, OpType.MAPPER)
+
+ # Test operation chain
+ chain = self.ast.get_operation_chain()
+ self.assertEqual(len(chain), 3)
+
+ # Test first operation
+ self.assertEqual(chain[0].name, "language_id_score_filter")
+ self.assertEqual(chain[0].op_type, OpType.FILTER)
+ self.assertEqual(chain[0].config, {'lang': 'zh', 'min_score': 0.8})
+
+ # Test second operation
+ self.assertEqual(chain[1].name, "clean_copyright_mapper")
+ self.assertEqual(chain[1].op_type, OpType.MAPPER)
+ self.assertEqual(chain[1].config, {})
+
+ # Test third operation
+ self.assertEqual(chain[2].name, "alphanumeric_filter")
+ self.assertEqual(chain[2].op_type, OpType.FILTER)
+ self.assertEqual(chain[2].config, {'min_ratio': 0.25, 'max_ratio': 1.0})
+
+ def test_visualize(self):
+ config = {
+ 'process': [
+ {'language_id_score_filter': {'lang': 'zh', 'min_score': 0.8}},
+ {'clean_copyright_mapper': {}}
+ ]
+ }
+
+ self.ast.build_from_config(config)
+ visualization = self.ast.visualize()
+
+ expected = """Pipeline:
+โโโ root (mapper)
+ โโโ clean_copyright_mapper (mapper)
+ โโโ language_id_score_filter (filter)
+"""
+ self.assertEqual(visualization, expected)
+
+ def test_to_dict(self):
+ config = {
+ 'process': [
+ {'language_id_score_filter': {'lang': 'zh', 'min_score': 0.8}},
+ {'clean_copyright_mapper': {}}
+ ]
+ }
+
+ self.ast.build_from_config(config)
+ ast_dict = self.ast.to_dict()
+
+ expected = {
+ 'name': 'root',
+ 'type': 'mapper',
+ 'config': {},
+ 'children': [{
+ 'name': 'language_id_score_filter',
+ 'type': 'filter',
+ 'config': {'lang': 'zh', 'min_score': 0.8},
+ 'children': [{
+ 'name': 'clean_copyright_mapper',
+ 'type': 'mapper',
+ 'config': {},
+ 'children': []
+ }]
+ }]
+ }
+ self.assertEqual(ast_dict, expected)
+
+ def test_empty_config(self):
+ config = {'process': []}
+ self.ast.build_from_config(config)
+ self.assertIsNone(self.ast.root)
+ self.assertEqual(self.ast.visualize(), "Empty pipeline")
+
+ def test_invalid_config(self):
+ config = {}
+ with self.assertRaises(ValueError):
+ self.ast.build_from_config(config)
+
+ def test_validate_dependencies(self):
+ config = {
+ 'process': [
+ {'clean_copyright_mapper': {}},
+ {'language_id_score_filter': {'lang': 'zh', 'min_score': 0.8}},
+ {'document_deduplicator': {}},
+ {'text_length_filter': {'min_len': 10, 'max_len': 1000}}
+ ]
+ }
+
+ self.ast.build_from_config(config)
+ invalid_deps = self.ast._validate_dependencies()
+ self.assertEqual(len(invalid_deps), 0) # This pipeline should be valid
+
+ # Test invalid pipeline
+ invalid_config = {
+ 'process': [
+ {'document_deduplicator': {}}, # Deduplicator before mapper
+ {'clean_copyright_mapper': {}}
+ ]
+ }
+
+ self.ast.build_from_config(invalid_config)
+ invalid_deps = self.ast._validate_dependencies()
+ self.assertGreater(len(invalid_deps), 0) # Should have invalid dependencies
+
+ def test_optimize_operation_order(self):
+ config = {
+ 'process': [
+ {'text_length_filter': {'min_len': 10, 'max_len': 1000}},
+ {'clean_copyright_mapper': {}},
+ {'language_id_score_filter': {'lang': 'zh', 'min_score': 0.8}},
+ {'document_deduplicator': {}}
+ ]
+ }
+
+ self.ast.build_from_config(config)
+ self.ast._optimize_operation_order()
+
+ chain = self.ast.get_operation_chain()
+ self.assertEqual(len(chain), 4)
+
+ # Check order: Mapper -> Filter -> Deduplicator
+ self.assertEqual(chain[0].name, "clean_copyright_mapper")
+ self.assertEqual(chain[0].op_type, OpType.MAPPER)
+
+ self.assertEqual(chain[1].name, "text_length_filter")
+ self.assertEqual(chain[1].op_type, OpType.FILTER)
+
+ self.assertEqual(chain[2].name, "language_id_score_filter")
+ self.assertEqual(chain[2].op_type, OpType.FILTER)
+
+ self.assertEqual(chain[3].name, "document_deduplicator")
+ self.assertEqual(chain[3].op_type, OpType.DEDUPLICATOR)
+
+ def test_merge_compatible_operations(self):
+ config = {
+ 'process': [
+ {'text_length_filter': {'min_len': 10, 'max_len': 1000}},
+ {'language_id_score_filter': {'lang': 'zh', 'min_score': 0.8}},
+ {'clean_copyright_mapper': {}},
+ {'text_clean_mapper': {}}
+ ]
+ }
+
+ self.ast.build_from_config(config)
+ self.ast._merge_compatible_operations()
+
+ chain = self.ast.get_operation_chain()
+ self.assertEqual(len(chain), 3) # Two filters should be merged
+
+ # Check merged operations
+ self.assertEqual(chain[0].name, "merged_text_length_filter_language_id_score_filter")
+ self.assertEqual(chain[0].op_type, OpType.FILTER)
+ self.assertEqual(chain[0].config, {
+ 'min_len': 10,
+ 'max_len': 1000,
+ 'lang': 'zh',
+ 'min_score': 0.8
+ })
+
+ self.assertEqual(chain[1].name, "merged_clean_copyright_mapper_text_clean_mapper")
+ self.assertEqual(chain[1].op_type, OpType.MAPPER)
+
+ def test_full_optimization(self):
+ config = {
+ 'process': [
+ {'text_length_filter': {'min_len': 10, 'max_len': 1000}},
+ {'language_id_score_filter': {'lang': 'zh', 'min_score': 0.8}},
+ {'clean_copyright_mapper': {}},
+ {'text_clean_mapper': {}},
+ {'document_deduplicator': {}}
+ ]
+ }
+
+ self.ast.build_from_config(config)
+ self.ast.optimize()
+
+ chain = self.ast.get_operation_chain()
+ self.assertEqual(len(chain), 3) # Two filters and two mappers should be merged
+
+ # Check final order and merged operations
+ self.assertEqual(chain[0].name, "merged_clean_copyright_mapper_text_clean_mapper")
+ self.assertEqual(chain[0].op_type, OpType.MAPPER)
+
+ self.assertEqual(chain[1].name, "merged_text_length_filter_language_id_score_filter")
+ self.assertEqual(chain[1].op_type, OpType.FILTER)
+
+ self.assertEqual(chain[2].name, "document_deduplicator")
+ self.assertEqual(chain[2].op_type, OpType.DEDUPLICATOR)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tools/clear_cache.py b/tools/clear_cache.py
new file mode 100644
index 0000000000..bf5df9de65
--- /dev/null
+++ b/tools/clear_cache.py
@@ -0,0 +1,182 @@
+#!/usr/bin/env python3
+"""
+Cache clearing script for Data-Juicer.
+Clears all types of caches to ensure fresh model loading.
+"""
+
+import gc
+import os
+import shutil
+
+
+def clear_data_juicer_cache():
+ """Clear all Data-Juicer related caches."""
+ print("๐งน Clearing Data-Juicer caches...")
+
+ # Clear model cache from memory
+ try:
+ from data_juicer.utils.model_utils import free_models
+
+ free_models(clear_model_zoo=True)
+ print("โ
Cleared model cache from memory")
+ except Exception as e:
+ print(f"โ ๏ธ Could not clear model cache from memory: {e}")
+
+ # Clear downloaded model files
+ try:
+ from data_juicer.utils.cache_utils import DATA_JUICER_MODELS_CACHE
+
+ if os.path.exists(DATA_JUICER_MODELS_CACHE):
+ shutil.rmtree(DATA_JUICER_MODELS_CACHE)
+ print(f"โ
Cleared downloaded models: {DATA_JUICER_MODELS_CACHE}")
+ else:
+ print("โน๏ธ No downloaded models cache found")
+ except Exception as e:
+ print(f"โ ๏ธ Could not clear downloaded models: {e}")
+
+ # Clear assets cache
+ try:
+ from data_juicer.utils.cache_utils import DATA_JUICER_ASSETS_CACHE
+
+ if os.path.exists(DATA_JUICER_ASSETS_CACHE):
+ shutil.rmtree(DATA_JUICER_ASSETS_CACHE)
+ print(f"โ
Cleared assets cache: {DATA_JUICER_ASSETS_CACHE}")
+ else:
+ print("โน๏ธ No assets cache found")
+ except Exception as e:
+ print(f"โ ๏ธ Could not clear assets cache: {e}")
+
+
+def clear_huggingface_cache():
+ """Clear HuggingFace cache."""
+ print("๐ค Clearing HuggingFace cache...")
+
+ try:
+ from transformers import TRANSFORMERS_CACHE
+
+ if os.path.exists(TRANSFORMERS_CACHE):
+ shutil.rmtree(TRANSFORMERS_CACHE)
+ print(f"โ
Cleared HuggingFace cache: {TRANSFORMERS_CACHE}")
+ else:
+ print("โน๏ธ No HuggingFace cache found")
+ except Exception as e:
+ print(f"โ ๏ธ Could not clear HuggingFace cache: {e}")
+
+
+def clear_nltk_cache():
+ """Clear NLTK cache."""
+ print("๐ Clearing NLTK cache...")
+
+ try:
+ from data_juicer.utils.nltk_utils import clean_nltk_cache
+
+ clean_nltk_cache(complete_reset=True)
+ print("โ
Cleared NLTK cache")
+ except Exception as e:
+ print(f"โ ๏ธ Could not clear NLTK cache: {e}")
+
+
+def clear_python_cache():
+ """Clear Python cache files."""
+ print("๐ Clearing Python cache...")
+
+ # Clear __pycache__ directories
+ cache_dirs = []
+ for root, dirs, files in os.walk("."):
+ for dir_name in dirs:
+ if dir_name == "__pycache__":
+ cache_path = os.path.join(root, dir_name)
+ cache_dirs.append(cache_path)
+ try:
+ shutil.rmtree(cache_path)
+ except Exception as e:
+ print(f"โ ๏ธ Could not clear {cache_path}: {e}")
+
+ if cache_dirs:
+ print(f"โ
Cleared {len(cache_dirs)} Python cache directories")
+ else:
+ print("โน๏ธ No Python cache directories found")
+
+
+def clear_system_cache():
+ """Clear system-level caches."""
+ print("๐ป Clearing system caches...")
+
+ # Clear macOS system cache (if on macOS)
+ if os.uname().sysname == "Darwin":
+ try:
+ # Clear various macOS caches
+ cache_paths = [
+ os.path.expanduser("~/Library/Caches"),
+ "/System/Library/Caches",
+ ]
+
+ for cache_path in cache_paths:
+ if os.path.exists(cache_path):
+ # Only clear specific subdirectories to avoid system issues
+ for item in os.listdir(cache_path):
+ item_path = os.path.join(cache_path, item)
+ if os.path.isdir(item_path) and "python" in item.lower():
+ try:
+ shutil.rmtree(item_path)
+ print(f"โ
Cleared system cache: {item_path}")
+ except Exception:
+ pass # Skip if we can't clear it
+ except Exception as e:
+ print(f"โ ๏ธ Could not clear system cache: {e}")
+
+
+def force_garbage_collection():
+ """Force garbage collection to free memory."""
+ print("๐๏ธ Running garbage collection...")
+
+ # Force garbage collection
+ gc.collect()
+
+ # Clear any remaining references
+ import sys
+
+ for module_name in list(sys.modules.keys()):
+ if module_name.startswith("data_juicer") or "transformers" in module_name:
+ try:
+ del sys.modules[module_name]
+ except Exception:
+ pass
+
+ # Force another garbage collection
+ gc.collect()
+ print("โ
Garbage collection completed")
+
+
+def main():
+ """Main function to clear all caches."""
+ print("๐ Starting comprehensive cache clearing...")
+ print("=" * 50)
+
+ # Clear all types of caches
+ clear_data_juicer_cache()
+ print()
+
+ clear_huggingface_cache()
+ print()
+
+ clear_nltk_cache()
+ print()
+
+ clear_python_cache()
+ print()
+
+ clear_system_cache()
+ print()
+
+ force_garbage_collection()
+ print()
+
+ print("=" * 50)
+ print("โ
Cache clearing completed!")
+ print("\n๐ก Next time you run the benchmark, models will be loaded fresh from disk.")
+ print(" This should eliminate the caching speed difference between runs.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/optimizer_perf_test/README_pipeline_perf_test.md b/tools/optimizer_perf_test/README_pipeline_perf_test.md
new file mode 100644
index 0000000000..c7274ddd03
--- /dev/null
+++ b/tools/optimizer_perf_test/README_pipeline_perf_test.md
@@ -0,0 +1,312 @@
+# Pipeline Performance Test
+
+This directory contains a comprehensive performance testing framework for comparing individual pipeline execution vs optimized pipeline execution in Data-Juicer.
+
+## Overview
+
+The pipeline performance test compares two execution modes:
+1. **Individual Pipeline**: Operations are executed one by one without optimization
+2. **Optimized Pipeline**: Operations are fused and optimized using the pipeline optimizer
+
+## Files
+
+- `perf_test_pipeline_comparison.py`: Main performance test script
+- `run_pipeline_perf_test.sh`: Convenient shell script wrapper
+- `README_pipeline_perf_test.md`: This documentation file
+
+## Features
+
+- **Separate Process Execution**: Each mode runs in its own process for fair comparison
+- **Recipe Support**: Works with any Data-Juicer recipe YAML file
+- **Dataset Support**: Supports various dataset formats
+- **Comprehensive Metrics**: Collects execution time, memory usage, throughput
+- **Result Validation**: Ensures both modes produce equivalent results
+- **Detailed Reporting**: Generates markdown reports with performance analysis
+- **Two Execution Modes**:
+ - Separate process execution (default)
+ - Comprehensive benchmark framework
+
+## Usage
+
+### Using the Shell Script (Recommended)
+
+```bash
+# Basic usage
+./tests/benchmark_performance/run_pipeline_perf_test.sh \
+ -r configs/demo/analyzer.yaml \
+ -d demos/data/demo-dataset.jsonl
+
+# With verbose logging
+./tests/benchmark_performance/run_pipeline_perf_test.sh \
+ -r configs/demo/analyzer.yaml \
+ -d demos/data/demo-dataset.jsonl \
+ -v
+
+# Using comprehensive benchmark framework
+./tests/benchmark_performance/run_pipeline_perf_test.sh \
+ -r configs/demo/analyzer.yaml \
+ -d demos/data/demo-dataset.jsonl \
+ -b
+
+# Custom output directory
+./tests/benchmark_performance/run_pipeline_perf_test.sh \
+ -r configs/demo/analyzer.yaml \
+ -d demos/data/demo-dataset.jsonl \
+ -o ./my_results
+```
+
+### Using Python Script Directly
+
+```bash
+python tests/benchmark_performance/perf_test_pipeline_comparison.py \
+ --recipe-path configs/demo/analyzer.yaml \
+ --dataset-path demos/data/demo-dataset.jsonl \
+ --output-dir ./outputs/pipeline_perf_test \
+ --verbose
+```
+
+## Command Line Options
+
+### Shell Script Options
+
+- `-r, --recipe-path PATH`: Path to the recipe YAML file (required)
+- `-d, --dataset-path PATH`: Path to the dataset file (required)
+- `-o, --output-dir PATH`: Output directory (default: ./outputs/pipeline_perf_test)
+- `-b, --benchmark-framework`: Use comprehensive benchmark framework
+- `-v, --verbose`: Enable verbose logging
+- `-h, --help`: Show help message
+
+### Python Script Options
+
+- `--recipe-path`: Path to the recipe YAML file (required)
+- `--dataset-path`: Path to the dataset file (required)
+- `--output-dir`: Output directory for results and reports
+- `--use-benchmark-framework`: Use the comprehensive benchmark framework
+- `--verbose`: Enable verbose logging
+
+## Output
+
+The test generates several output files:
+
+### Results JSON (`results.json`)
+Contains detailed performance metrics and comparison data:
+```json
+{
+ "individual": {
+ "wall_time": 10.5,
+ "output_samples": 1000,
+ "success": true,
+ "error": null
+ },
+ "optimized": {
+ "wall_time": 8.2,
+ "output_samples": 1000,
+ "success": true,
+ "error": null
+ },
+ "comparison": {
+ "individual_time": 10.5,
+ "optimized_time": 8.2,
+ "speedup": 1.28,
+ "improvement_percent": 21.9,
+ "faster_mode": "optimized"
+ },
+ "validation": {
+ "samples_match": true,
+ "individual_samples": 1000,
+ "optimized_samples": 1000,
+ "sample_difference": 0,
+ "validation_passed": true
+ },
+ "metadata": {
+ "recipe_path": "configs/demo/analyzer.yaml",
+ "dataset_path": "demos/data/demo-dataset.jsonl",
+ "test_timestamp": "2024-01-15 10:30:00",
+ "use_benchmark_framework": false
+ }
+}
+```
+
+### Performance Report (`performance_report.md`)
+A comprehensive markdown report with:
+- Executive summary
+- Detailed results for each mode
+- Performance comparison
+- Validation results
+- Recommendations
+
+### Log File (`perf_test.log`)
+Detailed execution logs for debugging and analysis.
+
+## Execution Modes
+
+### 1. Separate Process Execution (Default)
+
+This mode runs each pipeline in a completely separate process to ensure:
+- No interference between executions
+- Fair resource allocation
+- Clean memory state for each run
+
+**Pros:**
+- Completely isolated execution
+- Fair comparison
+- No memory leaks between runs
+
+**Cons:**
+- Higher overhead due to process creation
+- Slower startup time
+
+### 2. Comprehensive Benchmark Framework
+
+This mode uses the existing performance benchmark framework which provides:
+- More detailed metrics
+- Memory usage tracking
+- Throughput analysis
+- Resource utilization
+
+**Pros:**
+- More comprehensive metrics
+- Better integration with existing tools
+- Detailed resource analysis
+
+**Cons:**
+- Runs in the same process
+- May have interference between runs
+
+## Example Recipes
+
+### Simple Analyzer Recipe
+```yaml
+project_name: 'demo-analyzer'
+dataset_path: 'demos/data/demo-dataset.jsonl'
+export_path: 'outputs/demo-analyzer/res.jsonl'
+np: 1
+use_cache: false
+
+process:
+ - whitespace_normalization_mapper:
+ - token_num_filter:
+ hf_tokenizer: 'EleutherAI/pythia-6.9b-deduped'
+ min_num: 0
+ - document_deduplicator:
+ lowercase: false
+ ignore_non_character: false
+```
+
+### Complex Pipeline Recipe
+```yaml
+project_name: 'complex-pipeline'
+dataset_path: 'data/large-dataset.jsonl'
+export_path: 'outputs/complex-pipeline/res.jsonl'
+np: 4
+use_cache: true
+
+process:
+ - whitespace_normalization_mapper:
+ - token_num_filter:
+ hf_tokenizer: 'EleutherAI/pythia-6.9b-deduped'
+ min_num: 10
+ max_num: 1000
+ - document_deduplicator:
+ lowercase: true
+ ignore_non_character: true
+ - language_id_score_filter:
+ lang: 'en'
+ min_score: 0.8
+ - text_length_filter:
+ min_len: 50
+ max_len: 2000
+ - topk_specified_field_selector:
+ field_key: '__dj__stats__.num_token'
+ topk: 10000
+```
+
+## Interpreting Results
+
+### Performance Metrics
+
+- **Execution Time**: Wall clock time for each mode
+- **Speedup**: Ratio of individual time to optimized time
+- **Improvement**: Percentage improvement from individual to optimized
+- **Throughput**: Samples processed per second
+
+### Validation
+
+- **Samples Match**: Whether both modes produce the same number of output samples
+- **Sample Difference**: Absolute difference in output sample counts
+- **Validation Passed**: Overall validation status
+
+### Recommendations
+
+Based on the results, the test provides recommendations:
+
+- **Use Optimized Pipeline**: When optimized mode is faster and produces correct results
+- **Consider Individual Pipeline**: When individual mode is faster (may still be beneficial for larger datasets)
+- **Both Modes Similar**: When performance is similar between modes
+- **Investigation Required**: When results don't match between modes
+
+## Troubleshooting
+
+### Common Issues
+
+1. **Import Errors**: Make sure you're running from the project root directory
+2. **File Not Found**: Verify recipe and dataset paths are correct
+3. **Permission Errors**: Ensure the shell script is executable (`chmod +x`)
+4. **Timeout Errors**: Large datasets may need longer timeout values
+
+### Debug Mode
+
+Use the `--verbose` flag for detailed logging:
+```bash
+./tests/benchmark_performance/run_pipeline_perf_test.sh \
+ -r configs/demo/analyzer.yaml \
+ -d demos/data/demo-dataset.jsonl \
+ -v
+```
+
+### Manual Testing
+
+For debugging, you can run the Python script directly:
+```bash
+cd /path/to/data-juicer
+python tests/benchmark_performance/perf_test_pipeline_comparison.py \
+ --recipe-path configs/demo/analyzer.yaml \
+ --dataset-path demos/data/demo-dataset.jsonl \
+ --verbose
+```
+
+## Integration with CI/CD
+
+The test can be integrated into CI/CD pipelines:
+
+```yaml
+# Example GitHub Actions workflow
+- name: Run Pipeline Performance Test
+ run: |
+ ./tests/benchmark_performance/run_pipeline_perf_test.sh \
+ -r configs/demo/analyzer.yaml \
+ -d demos/data/demo-dataset.jsonl \
+ -o ./test-results
+
+- name: Upload Results
+ uses: actions/upload-artifact@v2
+ with:
+ name: performance-test-results
+ path: ./test-results/
+```
+
+## Contributing
+
+When adding new features to the performance test:
+
+1. Update this README with new options and examples
+2. Add appropriate error handling
+3. Include new metrics in the results JSON
+4. Update the markdown report template
+5. Add tests for new functionality
+
+## Related Documentation
+
+- [Data-Juicer Operators](https://github.com/modelscope/data-juicer/blob/main/docs/Operators.md)
+- [Performance Benchmark Framework](../core/optimizer/performance_benchmark.py)
+- [Pipeline Optimizer](../core/optimizer/optimizer.py)
\ No newline at end of file
diff --git a/tools/optimizer_perf_test/perf_test_pipeline_comparison.py b/tools/optimizer_perf_test/perf_test_pipeline_comparison.py
new file mode 100644
index 0000000000..3c9aa2a856
--- /dev/null
+++ b/tools/optimizer_perf_test/perf_test_pipeline_comparison.py
@@ -0,0 +1,527 @@
+#!/usr/bin/env python3
+"""
+Performance Test: Individual Pipeline vs Optimized Mode Comparison
+
+This script runs performance benchmarks comparing individual pipeline execution
+vs optimized pipeline execution using separate processes to ensure fair comparison.
+
+Features:
+- Separate process execution for isolation
+- Support for recipe path and dataset path
+- Comprehensive metrics collection
+- Result validation and comparison
+- Detailed reporting
+"""
+
+import argparse
+import json
+import multiprocessing as mp
+import os
+import sys
+import time
+from argparse import Namespace
+from pathlib import Path
+from typing import Any, Dict, List
+
+import yaml
+from loguru import logger
+
+from data_juicer.config import init_configs
+from data_juicer.core import DefaultExecutor
+from data_juicer.core.data.dataset_builder import DatasetBuilder
+from data_juicer.core.optimizer.filter_fusion_strategy import FilterFusionStrategy
+from data_juicer.core.optimizer.mapper_fusion_strategy import MapperFusionStrategy
+from data_juicer.core.optimizer.optimizer import PipelineOptimizer
+from data_juicer.core.optimizer.performance_benchmark import PerformanceBenchmark
+from data_juicer.core.pipeline_ast import PipelineAST
+from data_juicer.ops import load_ops
+
+# Add the project root to the path
+project_root = Path(__file__).parent.parent.parent
+sys.path.insert(0, str(project_root))
+
+
+class PipelinePerformanceTester:
+ """Performance tester for comparing individual vs optimized pipeline execution."""
+
+ def __init__(self, output_dir: str = "./outputs/pipeline_perf_test"):
+ self.output_dir = Path(output_dir)
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Setup logging
+ log_file = self.output_dir / "perf_test.log"
+ logger.add(log_file, rotation="10 MB", level="INFO")
+
+ self.results = {"individual": {}, "optimized": {}, "comparison": {}, "metadata": {}}
+
+ def load_dataset(self, dataset_path: str) -> Any:
+ """Load dataset from path using DatasetBuilder."""
+ logger.info(f"Loading dataset from: {dataset_path}")
+ if not os.path.exists(dataset_path):
+ raise FileNotFoundError(f"Dataset not found: {dataset_path}")
+ # Build a minimal config Namespace for DatasetBuilder
+ cfg = Namespace(dataset_path=dataset_path)
+ builder = DatasetBuilder(cfg)
+ dataset = builder.load_dataset()
+ dataset_length = len(dataset.to_list()) if dataset is not None else 0
+ logger.info(f"Loaded dataset with {dataset_length} samples")
+ return dataset
+
+ def create_temp_config(self, recipe_path: str, dataset_path: str, mode: str) -> str:
+ """Create a temporary config file for execution."""
+ # Load the original recipe
+ with open(recipe_path, "r") as f:
+ recipe_config = yaml.safe_load(f)
+
+ # Create temp config
+ temp_config = {
+ "project_name": f"perf-test-{mode}",
+ "dataset_path": dataset_path,
+ "export_path": str(self.output_dir / f"result_{mode}.jsonl"),
+ "np": 1, # Single process for fair comparison
+ "use_cache": False,
+ "op_fusion": mode == "optimized", # Enable fusion only for optimized mode
+ "process": recipe_config.get("process", []),
+ }
+
+ # Write temp config
+ temp_config_path = self.output_dir / f"temp_config_{mode}.yaml"
+ with open(temp_config_path, "w") as f:
+ yaml.dump(temp_config, f, default_flow_style=False)
+
+ return str(temp_config_path)
+
+ def run_individual_pipeline(self, recipe_path: str, dataset_path: str) -> Dict[str, Any]:
+ """Run individual pipeline execution in separate process."""
+ logger.info("Running individual pipeline execution...")
+
+ temp_config_path = self.create_temp_config(recipe_path, dataset_path, "individual")
+
+ # Run in separate process
+ start_time = time.time()
+ result = self._run_in_process(temp_config_path, "individual")
+ end_time = time.time()
+
+ result["wall_time"] = end_time - start_time
+ result["config_path"] = temp_config_path
+
+ return result
+
+ def run_optimized_pipeline(self, recipe_path: str, dataset_path: str) -> Dict[str, Any]:
+ """Run optimized pipeline execution in separate process."""
+ logger.info("Running optimized pipeline execution...")
+
+ temp_config_path = self.create_temp_config(recipe_path, dataset_path, "optimized")
+
+ # Run in separate process
+ start_time = time.time()
+ result = self._run_in_process(temp_config_path, "optimized")
+ end_time = time.time()
+
+ result["wall_time"] = end_time - start_time
+ result["config_path"] = temp_config_path
+
+ return result
+
+ def _run_in_process(self, config_path: str, mode: str) -> Dict[str, Any]:
+ """Run pipeline execution in a separate process."""
+ # Create process and run
+ result_queue = mp.Queue()
+ process = mp.Process(target=_worker_process, args=(config_path, mode, result_queue))
+
+ process.start()
+ process.join(timeout=3600) # 1 hour timeout
+
+ if process.is_alive():
+ process.terminate()
+ process.join()
+ return {"execution_time": 0, "output_samples": 0, "success": False, "error": "Process timeout"}
+
+ if not result_queue.empty():
+ return result_queue.get()
+ else:
+ return {"execution_time": 0, "output_samples": 0, "success": False, "error": "No result from process"}
+
+
+def _worker_process(config_path: str, mode: str, result_queue: mp.Queue):
+ """Worker function for running pipeline execution in separate process."""
+ try:
+ # Add the project root to the path
+ project_root = Path(__file__).parent.parent.parent
+ sys.path.insert(0, str(project_root))
+
+ # Initialize config
+ args = ["--config", config_path]
+ cfg = init_configs(args=args)
+
+ # Create executor
+ executor = DefaultExecutor(cfg)
+
+ # Run and collect metrics
+ start_time = time.time()
+ dataset = executor.run()
+ end_time = time.time()
+
+ # Collect results
+ dataset_length = len(dataset) if dataset is not None else 0
+ result = {
+ "execution_time": end_time - start_time,
+ "output_samples": dataset_length,
+ "success": True,
+ "error": None,
+ }
+
+ result_queue.put(result)
+
+ except Exception as e:
+ result = {"execution_time": 0, "output_samples": 0, "success": False, "error": str(e)}
+ result_queue.put(result)
+
+ def run_benchmark_comparison(self, recipe_path: str, dataset_path: str) -> Dict[str, Any]:
+ """Run comprehensive benchmark comparison using the performance benchmark framework."""
+ logger.info("Running comprehensive benchmark comparison...")
+
+ # Load dataset
+ dataset = self.load_dataset(dataset_path)
+
+ # Create benchmark instance
+ benchmark = PerformanceBenchmark()
+
+ # Load recipe and create AST
+ ast = PipelineAST()
+ ast.build_from_yaml(recipe_path)
+
+ # Get analyzer insights
+ analyzer_insights = benchmark.get_analyzer_insights(dataset)
+
+ # Create optimizer
+ optimizer = PipelineOptimizer(
+ [FilterFusionStrategy(), MapperFusionStrategy()], analyzer_insights=analyzer_insights
+ )
+
+ # Optimize pipeline
+ optimized_ast = optimizer.optimize(ast)
+
+ # Convert to operations
+ original_operations = benchmark._convert_ast_to_operations(ast)
+ optimized_operations = benchmark._convert_ast_to_operations(optimized_ast)
+
+ # Load operations
+ loaded_original_ops = self._load_operations(original_operations)
+ loaded_optimized_ops = self._load_operations(optimized_operations)
+
+ # Run benchmark
+ results = benchmark.run_mixed_operations_benchmark_with_original_ops(
+ loaded_original_ops, loaded_optimized_ops, dataset, "recipe"
+ )
+
+ return results
+
+ def _load_operations(self, operation_configs: List[Dict]) -> List:
+ """Load operations from configs."""
+ loaded_ops = []
+
+ for op_config in operation_configs:
+ op_name = list(op_config.keys())[0]
+ op_args = op_config[op_name]
+
+ if op_name == "fused_filter":
+ # Handle fused filter
+ fused_op_list = op_args.get("fused_op_list", [])
+ individual_filters = []
+
+ for filter_config in fused_op_list:
+ filter_name = list(filter_config.keys())[0]
+ filter_args = filter_config[filter_name]
+ loaded_filters = load_ops([{filter_name: filter_args}])
+ if loaded_filters:
+ individual_filters.append(loaded_filters[0])
+
+ if individual_filters:
+ from data_juicer.core.optimizer.fused_op import FusedFilter
+
+ fused_filter = FusedFilter(name="fused_filter", fused_filters=individual_filters)
+ fused_filter.execution_strategy = "sequential"
+ loaded_ops.append(fused_filter)
+
+ elif op_name == "fused_mapper":
+ # Handle fused mapper
+ from data_juicer.core.optimizer.fused_op import FusedMapper
+
+ name = op_args.get("name", "fused_mapper")
+ fused_mappers = op_args.get("fused_mappers", [])
+ fused_mapper = FusedMapper(name=name, fused_mappers=fused_mappers)
+ loaded_ops.append(fused_mapper)
+
+ else:
+ # Load regular operation
+ loaded_ops_list = load_ops([op_config])
+ if loaded_ops_list:
+ loaded_ops.append(loaded_ops_list[0])
+
+ return loaded_ops
+
+ def validate_results(self, individual_result: Dict, optimized_result: Dict) -> Dict[str, Any]:
+ """Validate that both executions produced similar results."""
+ logger.info("Validating results...")
+
+ validation = {
+ "samples_match": False,
+ "individual_samples": individual_result.get("output_samples", 0),
+ "optimized_samples": optimized_result.get("output_samples", 0),
+ "sample_difference": 0,
+ "validation_passed": False,
+ }
+
+ if individual_result.get("success") and optimized_result.get("success"):
+ individual_samples = individual_result["output_samples"]
+ optimized_samples = optimized_result["output_samples"]
+
+ validation["samples_match"] = individual_samples == optimized_samples
+ validation["sample_difference"] = abs(individual_samples - optimized_samples)
+ validation["validation_passed"] = validation["samples_match"]
+
+ if validation["validation_passed"]:
+ logger.info("โ
Validation passed: Both executions produced the same number of samples")
+ else:
+ logger.warning(
+ f"โ Validation failed: Sample count mismatch "
+ f"(individual: {individual_samples}, optimized: {optimized_samples})"
+ )
+ else:
+ logger.error("โ Validation failed: One or both executions failed")
+
+ return validation
+
+ def compare_performance(self, individual_result: Dict, optimized_result: Dict) -> Dict[str, Any]:
+ """Compare performance metrics between individual and optimized execution."""
+ logger.info("Comparing performance metrics...")
+
+ comparison = {
+ "individual_time": individual_result.get("wall_time", 0),
+ "optimized_time": optimized_result.get("wall_time", 0),
+ "speedup": 0,
+ "improvement_percent": 0,
+ "faster_mode": "none",
+ }
+
+ if individual_result.get("success") and optimized_result.get("success"):
+ individual_time = individual_result["wall_time"]
+ optimized_time = optimized_result["wall_time"]
+
+ if individual_time > 0:
+ comparison["speedup"] = individual_time / optimized_time
+ comparison["improvement_percent"] = ((individual_time - optimized_time) / individual_time) * 100
+
+ if optimized_time < individual_time:
+ comparison["faster_mode"] = "optimized"
+ elif individual_time < optimized_time:
+ comparison["faster_mode"] = "individual"
+ else:
+ comparison["faster_mode"] = "equal"
+
+ return comparison
+
+ def generate_report(self, results: Dict[str, Any]) -> str:
+ """Generate a comprehensive performance report."""
+ logger.info("Generating performance report...")
+
+ report_path = self.output_dir / "performance_report.md"
+
+ with open(report_path, "w") as f:
+ f.write("# Pipeline Performance Test Report\n\n")
+
+ # Summary
+ f.write("## Summary\n\n")
+ comparison = results["comparison"]
+ validation = results["validation"]
+
+ f.write(f"- **Test Date**: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
+ f.write(f"- **Recipe**: {results['metadata']['recipe_path']}\n")
+ f.write(f"- **Dataset**: {results['metadata']['dataset_path']}\n")
+ f.write(f"- **Validation**: {'โ
PASSED' if validation['validation_passed'] else 'โ FAILED'}\n")
+ f.write(f"- **Faster Mode**: {comparison['faster_mode'].title()}\n")
+ f.write(f"- **Speedup**: {comparison['speedup']:.2f}x\n")
+ f.write(f"- **Improvement**: {comparison['improvement_percent']:.1f}%\n\n")
+
+ # Detailed Results
+ f.write("## Detailed Results\n\n")
+
+ f.write("### Individual Pipeline\n")
+ f.write(f"- Execution Time: {results['individual']['wall_time']:.2f}s\n")
+ f.write(f"- Output Samples: {results['individual']['output_samples']:,}\n")
+ f.write(f"- Success: {results['individual']['success']}\n")
+ if not results["individual"]["success"]:
+ f.write(f"- Error: {results['individual']['error']}\n")
+ f.write("\n")
+
+ f.write("### Optimized Pipeline\n")
+ f.write(f"- Execution Time: {results['optimized']['wall_time']:.2f}s\n")
+ f.write(f"- Output Samples: {results['optimized']['output_samples']:,}\n")
+ f.write(f"- Success: {results['optimized']['success']}\n")
+ if not results["optimized"]["success"]:
+ f.write(f"- Error: {results['optimized']['error']}\n")
+ f.write("\n")
+
+ # Performance Comparison
+ f.write("### Performance Comparison\n")
+ f.write(f"- Individual Time: {comparison['individual_time']:.2f}s\n")
+ f.write(f"- Optimized Time: {comparison['optimized_time']:.2f}s\n")
+ f.write(f"- Speedup: {comparison['speedup']:.2f}x\n")
+ f.write(f"- Improvement: {comparison['improvement_percent']:.1f}%\n")
+ f.write(f"- Faster Mode: {comparison['faster_mode'].title()}\n\n")
+
+ # Validation Results
+ f.write("### Validation Results\n")
+ f.write(f"- Samples Match: {validation['samples_match']}\n")
+ f.write(f"- Individual Samples: {validation['individual_samples']:,}\n")
+ f.write(f"- Optimized Samples: {validation['optimized_samples']:,}\n")
+ f.write(f"- Sample Difference: {validation['sample_difference']}\n")
+ f.write(f"- Validation Passed: {validation['validation_passed']}\n\n")
+
+ # Recommendations
+ f.write("## Recommendations\n\n")
+ if validation["validation_passed"]:
+ if comparison["faster_mode"] == "optimized":
+ f.write(
+ "โ
**Use Optimized Pipeline**: The optimized pipeline is faster and produces correct results.\n"
+ )
+ elif comparison["faster_mode"] == "individual":
+ f.write(
+ "โ ๏ธ **Consider Individual Pipeline**: The individual pipeline is faster, but optimization may still be beneficial for larger datasets.\n"
+ )
+ else:
+ f.write(
+ "โน๏ธ **Both Modes Similar**: Performance is similar between individual and optimized modes.\n"
+ )
+ else:
+ f.write("โ **Investigation Required**: Results don't match between individual and optimized modes.\n")
+
+ return str(report_path)
+
+ def save_results(self, results: Dict[str, Any]) -> str:
+ """Save results to JSON file."""
+ results_path = self.output_dir / "results.json"
+
+ with open(results_path, "w") as f:
+ json.dump(results, f, indent=2, default=str)
+
+ return str(results_path)
+
+ def run_test(self, recipe_path: str, dataset_path: str, use_benchmark_framework: bool = False) -> Dict[str, Any]:
+ """Run the complete performance test."""
+ logger.info("Starting pipeline performance test...")
+ logger.info(f"Recipe: {recipe_path}")
+ logger.info(f"Dataset: {dataset_path}")
+
+ # Store metadata
+ self.results["metadata"] = {
+ "recipe_path": recipe_path,
+ "dataset_path": dataset_path,
+ "test_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
+ "use_benchmark_framework": use_benchmark_framework,
+ }
+
+ if use_benchmark_framework:
+ # Use the comprehensive benchmark framework
+ logger.info("Using comprehensive benchmark framework...")
+ benchmark_results = self.run_benchmark_comparison(recipe_path, dataset_path)
+ self.results.update(benchmark_results)
+ else:
+ # Use separate process execution
+ logger.info("Using separate process execution...")
+
+ # Run individual pipeline
+ individual_result = self.run_individual_pipeline(recipe_path, dataset_path)
+ self.results["individual"] = individual_result
+
+ # Run optimized pipeline
+ optimized_result = self.run_optimized_pipeline(recipe_path, dataset_path)
+ self.results["optimized"] = optimized_result
+
+ # Validate results
+ validation = self.validate_results(individual_result, optimized_result)
+ self.results["validation"] = validation
+
+ # Compare performance
+ comparison = self.compare_performance(individual_result, optimized_result)
+ self.results["comparison"] = comparison
+
+ # Save results
+ results_path = self.save_results(self.results)
+ logger.info(f"Results saved to: {results_path}")
+
+ # Generate report
+ report_path = self.generate_report(self.results)
+ logger.info(f"Report generated: {report_path}")
+
+ # Print summary
+ self._print_summary()
+
+ return self.results
+
+ def _print_summary(self):
+ """Print a summary of the test results."""
+ logger.info("\n" + "=" * 60)
+ logger.info("PERFORMANCE TEST SUMMARY")
+ logger.info("=" * 60)
+
+ comparison = self.results.get("comparison", {})
+ validation = self.results.get("validation", {})
+
+ logger.info(f"Individual Time: {comparison.get('individual_time', 0):.2f}s")
+ logger.info(f"Optimized Time: {comparison.get('optimized_time', 0):.2f}s")
+ logger.info(f"Speedup: {comparison.get('speedup', 0):.2f}x")
+ logger.info(f"Improvement: {comparison.get('improvement_percent', 0):.1f}%")
+ logger.info(f"Validation: {'โ
PASSED' if validation.get('validation_passed') else 'โ FAILED'}")
+ logger.info(f"Faster Mode: {comparison.get('faster_mode', 'none').title()}")
+ logger.info("=" * 60)
+
+
+def main():
+ """Main entry point."""
+ parser = argparse.ArgumentParser(description="Pipeline Performance Test: Compare individual vs optimized execution")
+ parser.add_argument("--recipe-path", type=str, required=True, help="Path to the recipe YAML file")
+ parser.add_argument("--dataset-path", type=str, required=True, help="Path to the dataset file")
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="./outputs/pipeline_perf_test",
+ help="Output directory for results and reports",
+ )
+ parser.add_argument(
+ "--use-benchmark-framework",
+ action="store_true",
+ help="Use the comprehensive benchmark framework instead of separate processes",
+ )
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
+
+ args = parser.parse_args()
+
+ # Setup logging
+ log_level = "DEBUG" if args.verbose else "INFO"
+ logger.remove()
+ logger.add(sys.stderr, level=log_level)
+
+ # Create tester and run test
+ tester = PipelinePerformanceTester(args.output_dir)
+
+ try:
+ results = tester.run_test(args.recipe_path, args.dataset_path, args.use_benchmark_framework)
+
+ # Exit with appropriate code
+ validation = results.get("validation", {})
+ if validation.get("validation_passed"):
+ logger.info("โ
Test completed successfully")
+ sys.exit(0)
+ else:
+ logger.error("โ Test failed validation")
+ sys.exit(1)
+
+ except Exception as e:
+ logger.error(f"Test failed with error: {e}")
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/optimizer_perf_test/run_pipeline_perf_test.sh b/tools/optimizer_perf_test/run_pipeline_perf_test.sh
new file mode 100755
index 0000000000..2f99bd4695
--- /dev/null
+++ b/tools/optimizer_perf_test/run_pipeline_perf_test.sh
@@ -0,0 +1,172 @@
+#!/bin/bash
+
+# Pipeline Performance Test Runner
+# This script runs the pipeline performance test with different configurations
+
+set -e
+
+# Colors for output
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+BLUE='\033[0;34m'
+NC='\033[0m' # No Color
+
+# Default values
+RECIPE_PATH=""
+DATASET_PATH=""
+OUTPUT_DIR="./outputs/pipeline_perf_test"
+USE_BENCHMARK_FRAMEWORK=false
+VERBOSE=false
+
+# Function to print usage
+print_usage() {
+ echo "Usage: $0 [OPTIONS]"
+ echo ""
+ echo "Options:"
+ echo " -r, --recipe-path PATH Path to the recipe YAML file (required)"
+ echo " -d, --dataset-path PATH Path to the dataset file (required)"
+ echo " -o, --output-dir PATH Output directory (default: ./outputs/pipeline_perf_test)"
+ echo " -b, --benchmark-framework Use comprehensive benchmark framework"
+ echo " -v, --verbose Enable verbose logging"
+ echo " -h, --help Show this help message"
+ echo ""
+ echo "Examples:"
+ echo " $0 -r configs/data_juicer_recipes/alpaca-cot-en-refine.yaml -d data/sample.jsonl"
+ echo " $0 --recipe-path configs/demo/analyzer.yaml --dataset-path demos/data/demo-dataset.jsonl --verbose"
+ echo " $0 -r configs/demo/analyzer.yaml -d demos/data/demo-dataset.jsonl -b"
+}
+
+# Function to print colored output
+print_info() {
+ echo -e "${BLUE}[INFO]${NC} $1"
+}
+
+print_success() {
+ echo -e "${GREEN}[SUCCESS]${NC} $1"
+}
+
+print_warning() {
+ echo -e "${YELLOW}[WARNING]${NC} $1"
+}
+
+print_error() {
+ echo -e "${RED}[ERROR]${NC} $1"
+}
+
+# Parse command line arguments
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ -r|--recipe-path)
+ RECIPE_PATH="$2"
+ shift 2
+ ;;
+ -d|--dataset-path)
+ DATASET_PATH="$2"
+ shift 2
+ ;;
+ -o|--output-dir)
+ OUTPUT_DIR="$2"
+ shift 2
+ ;;
+ -b|--benchmark-framework)
+ USE_BENCHMARK_FRAMEWORK=true
+ shift
+ ;;
+ -v|--verbose)
+ VERBOSE=true
+ shift
+ ;;
+ -h|--help)
+ print_usage
+ exit 0
+ ;;
+ *)
+ print_error "Unknown option: $1"
+ print_usage
+ exit 1
+ ;;
+ esac
+done
+
+# Validate required arguments
+if [[ -z "$RECIPE_PATH" ]]; then
+ print_error "Recipe path is required"
+ print_usage
+ exit 1
+fi
+
+if [[ -z "$DATASET_PATH" ]]; then
+ print_error "Dataset path is required"
+ print_usage
+ exit 1
+fi
+
+# Check if files exist
+if [[ ! -f "$RECIPE_PATH" ]]; then
+ print_error "Recipe file not found: $RECIPE_PATH"
+ exit 1
+fi
+
+if [[ ! -f "$DATASET_PATH" ]]; then
+ print_error "Dataset file not found: $DATASET_PATH"
+ exit 1
+fi
+
+# Get script directory
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
+
+# Change to project root
+cd "$PROJECT_ROOT"
+
+print_info "Starting Pipeline Performance Test"
+print_info "Project root: $PROJECT_ROOT"
+print_info "Recipe path: $RECIPE_PATH"
+print_info "Dataset path: $DATASET_PATH"
+print_info "Output directory: $OUTPUT_DIR"
+print_info "Use benchmark framework: $USE_BENCHMARK_FRAMEWORK"
+print_info "Verbose: $VERBOSE"
+
+# Build command
+CMD="python tests/benchmark_performance/perf_test_pipeline_comparison.py"
+CMD="$CMD --recipe-path \"$RECIPE_PATH\""
+CMD="$CMD --dataset-path \"$DATASET_PATH\""
+CMD="$CMD --output-dir \"$OUTPUT_DIR\""
+
+if [[ "$USE_BENCHMARK_FRAMEWORK" == true ]]; then
+ CMD="$CMD --use-benchmark-framework"
+fi
+
+if [[ "$VERBOSE" == true ]]; then
+ CMD="$CMD --verbose"
+fi
+
+print_info "Running command: $CMD"
+
+# Run the test
+if eval $CMD; then
+ print_success "Pipeline performance test completed successfully!"
+
+ # Show results summary
+ RESULTS_FILE="$OUTPUT_DIR/results.json"
+ REPORT_FILE="$OUTPUT_DIR/performance_report.md"
+
+ if [[ -f "$RESULTS_FILE" ]]; then
+ print_info "Results saved to: $RESULTS_FILE"
+ fi
+
+ if [[ -f "$REPORT_FILE" ]]; then
+ print_info "Report generated: $REPORT_FILE"
+ echo ""
+ print_info "Report preview:"
+ echo "=================="
+ head -20 "$REPORT_FILE"
+ echo "..."
+ echo "=================="
+ fi
+
+else
+ print_error "Pipeline performance test failed!"
+ exit 1
+fi