diff --git a/docs/source/developer_guides/datasets.md b/docs/source/developer_guides/datasets.md index 79df5fd6..ffa37c64 100644 --- a/docs/source/developer_guides/datasets.md +++ b/docs/source/developer_guides/datasets.md @@ -91,7 +91,7 @@ labels_series = dataset.labels # Labels from the specified obs column ```python control_cells_ids = dataset.control_cells_ids # List of control cell IDs -target_conditions_to_save = dataset.target_conditions_to_save # Conditions to be saved for benchmarking +target_conditions_dict = dataset.target_conditions_dict # Conditions to be saved for benchmarking de_results = dataset.de_results # Differential expression results control_matched_adata = dataset.control_matched_adata # AnnData object for matched controls ``` diff --git a/docs/source/how_to_guides/guide_to_perturbation_expression_prediction_benchmark.md b/docs/source/how_to_guides/guide_to_perturbation_expression_prediction_benchmark.md index 7d7f1b22..7e0db3cc 100644 --- a/docs/source/how_to_guides/guide_to_perturbation_expression_prediction_benchmark.md +++ b/docs/source/how_to_guides/guide_to_perturbation_expression_prediction_benchmark.md @@ -52,10 +52,10 @@ This task evaluates perturbation-induced expression predictions against their gr The following parameters are used by the task, via the `PerturbationExpressionPredictionTask` class: - `de_results`: DE results used by the dataset class (`SingleCellPerturbationDataset`). -- `masked_data_obs`: The `obs` column from the control-matched and masked data. -- `row_index`: Sequence of cell barcodes vertically aligned with `cell_representation` matrix. -- `var_index`: Sequence of gene names horizontally aligned with `cell_representation` matrix. -- `target_conditions_to_save`: Dictionary of target conditions whose genes were randomly selected for masking. +- `adata`: The complete AnnData object containing control-matched and masked data. +- `cell_index`: Sequence of cell barcodes vertically aligned with `cell_representation` matrix. +- `gene_index`: Sequence of gene names horizontally aligned with `cell_representation` matrix. +- `target_conditions_dict`: Dictionary of target conditions whose genes were randomly selected for masking. ## Metrics diff --git a/examples/example_perturbation_expression_prediction.py b/examples/example_perturbation_expression_prediction.py index 4d05de53..2888d136 100644 --- a/examples/example_perturbation_expression_prediction.py +++ b/examples/example_perturbation_expression_prediction.py @@ -1,21 +1,48 @@ import logging import sys import argparse -from czbenchmarks.datasets import load_dataset +import tempfile +import yaml +from pathlib import Path +import anndata as ad +import numpy as np +from czbenchmarks.datasets import SingleCellPerturbationDataset, load_dataset from czbenchmarks.tasks.single_cell import ( PerturbationExpressionPredictionTask, PerturbationExpressionPredictionTaskInput, ) + from czbenchmarks.tasks.single_cell.perturbation_expression_prediction import ( load_perturbation_task_input_from_saved_files, ) from czbenchmarks.tasks.utils import print_metrics_summary -import numpy as np -from czbenchmarks.datasets import SingleCellPerturbationDataset from czbenchmarks.tasks.types import CellRepresentation -import tempfile -import yaml -from pathlib import Path + +logging.basicConfig(level=logging.INFO, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def generate_random_model_predictions(n_cells, n_genes): + """This demonstrates the expected format for the model predictions. + This should be an anndata file where the obs.index contains the cell + barcodes and the var.index contains the genes. These should be the same or a + subset of the genes and cells in the dataset. The X matrix should be the + model predictions. + """ + + model_predictions: CellRepresentation = np.random.rand(n_cells, n_genes) + # Put the predictions in an anndata object + model_adata = ad.AnnData(X=model_predictions) + + # The same genes and cells (or a subset of them) should be in the model adata. + model_adata.obs.index = ( + dataset.adata.obs.index.to_series().sample(frac=1, random_state=42).values + ) + model_adata.var.index = ( + dataset.adata.var.index.to_series().sample(frac=1, random_state=42).values + ) + return model_adata + if __name__ == "__main__": """Runs a task to calculate perturbation metrics. @@ -23,7 +50,7 @@ As input, this uses a SingleCellPerturbationDataset. Currently, this assumes data from the Replogle et al. 2022 dataset. Addtionally, this contains differentially expressed genes for each perturbation. The extent of the - perturbation is merged with the willcoxon test or t-test. + perturbation is merged with the wilcoxon test. The dataset is filtered based on the type of statistical test, along with the minimum number of differentially expressed genes, maximum p-value, and the minimum @@ -39,6 +66,9 @@ differentially expressed genes between perturbed and non-targeting groups. It then calculates the correlation between ground truth and predicted log fold change for each condition using a variety of metrics. + + Before running this example, make sure the replogle_k562_essential_perturbpredict_path + is set to the path where the replogle_k562_essential_perturbpredict.h5ad file is saved. """ # Parse command line arguments @@ -50,23 +80,17 @@ action="store_true", help="Save dataset task inputs to disk and load them back (demonstrates save/load functionality)", ) - parser.add_argument( - "--metric", - type=str, - default="wilcoxon", # Set this to correspond to the type of statistical test used to determine differentially expressed genes - help="Metric to use for DE analysis", - ) parser.add_argument( "--percent_genes_to_mask", type=float, - default=1.0, + default=0.5, help="Percentage of genes to mask", ) parser.add_argument( "--min_logfoldchange", type=float, default=1.0, - help="Minimum absolute log-fold change for DE filtering (used when --metric=wilcoxon)", + help="Minimum absolute log-fold change for DE filtering", ) parser.add_argument( "--pval_threshold", @@ -80,16 +104,9 @@ default=5, help="Minimum number of DE genes required to mask a condition", ) - parser.add_argument( - "--min_smd", - type=float, - default=0.55, - help="Minimum standardized mean difference for DE filtering (used when --metric=t-test)", - ) - args = parser.parse_args() - logging.basicConfig(level=logging.INFO, stream=sys.stdout) + # Instantiate a config and load the input data cfg = { @@ -103,38 +120,62 @@ } } - with tempfile.TemporaryDirectory() as d: - cfg_path = Path(d) / "config.yaml" - cfg_path.write_text(yaml.safe_dump(cfg)) + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as tmp_cfg: + yaml.dump(cfg, tmp_cfg) + tmp_cfg_path = tmp_cfg.name + dataset: SingleCellPerturbationDataset = load_dataset( - "replogle_k562_essential_perturbpredict", config_path=str(cfg_path) - ) # TODO: Once PR 381 is merged, use the new load_local_dataset function + "replogle_k562_essential_perturbpredict", + config_path=Path(tmp_cfg_path), + ) + + # Optional: validate the dataset + dataset.validate() + # This generates a sample model anndata file. In applications, + # this should contain the model predictions and should be provided by the user. + model_adata = generate_random_model_predictions( + dataset.adata.shape[0], dataset.adata.shape[1] + ) # Choose approach based on flag if args.save_inputs: - print("Using save/load approach...") + logger.info("Using save/load approach...") # Save and load dataset task inputs task_inputs_dir = dataset.store_task_inputs() - print(f"Task inputs saved to: {task_inputs_dir}") + logger.info(f"Task inputs saved to: {task_inputs_dir}") task_input = load_perturbation_task_input_from_saved_files(task_inputs_dir) - print("Task inputs loaded from saved files") + logger.info("Task inputs loaded from saved files") + + # Update with the model ordering of the genes and of the cells + task_input.gene_index = model_adata.var.index + task_input.cell_index = model_adata.obs.index else: - print("Creating task input directly from dataset...") - # Create task input directly from dataset + logger.info("Creating task input directly from dataset...") + # Create task input directly from dataset with separate fields task_input = PerturbationExpressionPredictionTaskInput( + adata=dataset.control_matched_adata, + target_conditions_dict=dataset.target_conditions_dict, de_results=dataset.de_results, - var_index=dataset.control_matched_adata.var.index, - masked_adata_obs=dataset.control_matched_adata.obs, - target_conditions_to_save=dataset.target_conditions_to_save, - row_index=dataset.adata.obs.index, + gene_index=model_adata.var.index, + cell_index=model_adata.obs.index, ) - - # Generate random model output - model_output: CellRepresentation = np.random.rand( - dataset.adata.shape[0], dataset.adata.shape[1] - ) + # Convert model adata to cell representation + model_output = model_adata.X # Run task - task = PerturbationExpressionPredictionTask(metric=args.metric) - metrics_dict = task.run(model_output, task_input) + task = PerturbationExpressionPredictionTask( + condition_key=dataset.condition_key, control_name=dataset.control_name + ) + metrics_dict = task.run(cell_representation=model_output, task_input=task_input) + logger.info("Model metrics:") print_metrics_summary(metrics_dict) + + # # Compute baseline -- pseudocode -- this throws an error because of non-log-normalized data + # baseline_model = task.compute_baseline( + # cell_representation=dataset.adata.X, baseline_type="median" + # ) + # baseline_metrics_dict = task.run( + # cell_representation=baseline_model, task_input=task_input + # ) + # logger.info("Baseline metrics:") + # print_metrics_summary(baseline_metrics_dict) diff --git a/pyproject.toml b/pyproject.toml index 3fb4790c..4862e94b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "hnswlib>=0.8.0", "tomli>=2.2.1", "boto3-stubs-lite[s3]>=1.38.0", + "pyarrow>=17.0.0", ] [project.optional-dependencies] diff --git a/src/czbenchmarks/datasets/single_cell_perturbation.py b/src/czbenchmarks/datasets/single_cell_perturbation.py index 28c1deec..d3ab94f2 100644 --- a/src/czbenchmarks/datasets/single_cell_perturbation.py +++ b/src/czbenchmarks/datasets/single_cell_perturbation.py @@ -1,5 +1,3 @@ -import io -import json import logging from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -7,9 +5,8 @@ import anndata as ad import numpy as np import pandas as pd -import scipy.sparse as sparse from tqdm import tqdm - +import json from czbenchmarks.constants import RANDOM_SEED from czbenchmarks.datasets.single_cell import SingleCellDataset from czbenchmarks.datasets.utils_single_cell import create_adata_for_condition @@ -78,20 +75,20 @@ class SingleCellPerturbationDataset(SingleCellDataset): control_cells_ids (dict): Dictionary mapping each condition to a dictionary of treatment cell barcodes (keys) to matched control cell barcodes (values). de_results (pd.DataFrame): Differential expression results calculated on ground truth data using matched controls. - target_conditions_to_save (dict): Dictionary of target conditions for each cell. + target_conditions_dict (dict): Dictionary of masked genes for each condition. """ control_matched_adata: ad.AnnData control_cells_ids: dict de_results: pd.DataFrame - target_conditions_to_save: dict + target_conditions_dict: dict def __init__( self, path: Path, organism: Organism, condition_key: str = "condition", - control_name: str = "ctrl", + control_name: str = "non-targeting", de_gene_col: str = "gene", percent_genes_to_mask: float = 0.5, min_de_genes_to_mask: int = 5, @@ -108,7 +105,7 @@ def __init__( condition_key (str): Key for the column in `adata.obs` specifying conditions. Defaults to "condition". control_name (str): Name of the control condition. Defaults to - "ctrl". + "non-targeting". de_gene_col (str): Column name for the names of genes which are differentially expressed in the differential expression results. Defaults to "gene". @@ -197,7 +194,7 @@ def _create_adata(self) -> Tuple[ad.AnnData, dict]: percent_genes_to_mask=self.percent_genes_to_mask, min_de_genes_to_mask=self.min_de_genes_to_mask, condition_col=self.condition_key, - gene_col=self.de_gene_col, + gene_col="gene_id", # Column was renamed to gene_id during optimization ) target_conditions = list(target_condition_dict.keys()) @@ -229,7 +226,6 @@ def _create_adata(self) -> Tuple[ad.AnnData, dict]: } all_merged_data = [] - target_conditions_to_save = {} with tqdm( total=total_conditions, desc="Processing conditions", unit="item" @@ -245,24 +241,32 @@ def _create_adata(self) -> Tuple[ad.AnnData, dict]: ) all_merged_data.append(adata_merged) - # target_conditions_to_save.update(result[1]) - for idx in adata_merged.obs.index: - target_conditions_to_save[idx] = target_condition_dict[ - selected_condition - ] pbar.set_postfix_str(f"Completed {pbar.n + 1}/{total_conditions}") pbar.update(1) # Combine all adata objects logger.info( - f"Collected {len(all_merged_data)} datasets for the sampled control-matched conditions." + f"Merged datasets for {len(all_merged_data)} control-matched conditions." ) adata_final = ad.concat(all_merged_data, index_unique=None) adata_final.obs[self.condition_key] = pd.Categorical( adata_final.obs[self.condition_key] ) - return adata_final, target_conditions_to_save + # Optimize: Keep only necessary columns in obs (only condition_key is used in task) + adata_final.obs = adata_final.obs[[self.condition_key]] + + # Add task-related data to uns for easy access + adata_final.uns["target_conditions_dict"] = target_condition_dict + adata_final.uns["de_results"] = { + col: self.de_results[col].values for col in self.de_results.columns + } + adata_final.uns["cell_barcode_condition_index"] = self.adata.obs.index.astype( + str + ).values + adata_final.uns["control_cells_ids"] = self.control_cells_ids + + return adata_final, target_condition_dict def load_data( self, @@ -278,7 +282,6 @@ def load_data( ValueErrors or FileNotFoundErrors based on required data structure. """ super().load_data() - if self.condition_key not in self.adata.obs.columns: raise ValueError( f"Condition key '{self.condition_key}' not found in adata.obs" @@ -307,6 +310,24 @@ def load_data( self.de_results = self.load_and_filter_deg_results() logger.info(f"Using {len(self.de_results)} differential expression values") + # Optimize: Keep only necessary columns in de_results + # Task only uses: condition_key, "gene_id", and metric_column (logfoldchange or standardized_mean_diff) + metric_column = ( + "logfoldchange" + if self.deg_test_name == "wilcoxon" + else "standardized_mean_diff" + ) + necessary_columns = [self.condition_key, self.de_gene_col, metric_column] + + # Ensure we have gene_id column for compatibility with task + if self.de_gene_col != "gene_id": + self.de_results = self.de_results.rename( + columns={self.de_gene_col: "gene_id"} + ) + necessary_columns = [self.condition_key, "gene_id", metric_column] + + self.de_results = self.de_results[necessary_columns] + # Compare conditions and throw warning or error for unmatched conditions unique_conditions_adata = set(self.adata.obs[self.condition_key]) unique_conditions_control_cells_ids = set(self.control_cells_ids.keys()) @@ -352,59 +373,42 @@ def load_data( logger.info( f"Creating control-matched adata for {len(self.control_cells_ids)} conditions" ) - adata_final, target_conditions_to_save = self._create_adata() + adata_final, target_conditions_dict = self._create_adata() self.control_matched_adata = adata_final - self.target_conditions_to_save = target_conditions_to_save + self.target_conditions_dict = target_conditions_dict def store_task_inputs(self) -> Path: """ - Store auxiliary data files. + Store all task inputs as separate files. - This method saves the IDs of the control cells and the target conditions dictionary - to JSON files. + This method saves all task-related data as separate files: + - control_matched_adata.h5ad: The main AnnData object (includes cell_barcode_condition_index, control_cells_ids, target_conditions_dict, and de_results in uns) + - target_conditions_dict.json: Target conditions dictionary + - de_results.parquet: Differential expression results Returns: - Path: Path to the directory storing the task input files. + Path: Path to the task inputs directory. """ - # TODO: Might be better as a single adata, pending future design on how - # Task instantiation is performed by benchmarking pipelines - inputs_to_store = { - "control_cells_ids": self.control_cells_ids, - "target_conditions_to_save": self.target_conditions_to_save, - "de_results": self.de_results, - "control_matched_adata/obs": self.control_matched_adata.obs, - "control_matched_adata/var": self.control_matched_adata.var, - "control_matched_adata/X": self.control_matched_adata.X, - "original_adata/obs/index": self.adata.obs.index.astype(str).to_numpy(), - } - - for key, item in inputs_to_store.items(): - if hasattr(item, "to_json"): - # For pandas DataFrames. Preserve index for obs/var by using orient="split". - buffer = io.StringIO() - if key in {"control_matched_adata/obs", "control_matched_adata/var"}: - item.to_json(buffer, orient="split") - else: - item.to_json(buffer) - self._store_task_input(f"{key}.json", buffer.getvalue()) - - elif isinstance(item, np.ndarray): - output_dir = self.task_inputs_dir / Path(key).parent - output_dir.mkdir(parents=True, exist_ok=True) - output_file = self.task_inputs_dir / (key + ".npy") - np.save(output_file, item) - - elif isinstance(item, sparse.csr_matrix): - output_dir = self.task_inputs_dir / Path(key).parent - output_dir.mkdir(parents=True, exist_ok=True) - output_file = self.task_inputs_dir / (key + ".npz") - sparse.save_npz(output_file, item) - - else: - # For dictionaries and other JSON-serializable objects - json_string = json.dumps(item) - self._store_task_input(f"{key}.json", json_string) + # Ensure the task inputs directory exists + self.task_inputs_dir.mkdir(parents=True, exist_ok=True) + adata_to_save = self.control_matched_adata.copy() + adata_to_save.uns["cell_barcode_condition_index"] = self.adata.obs.index.astype( + str + ).values + + # Save the main AnnData object + adata_file = self.task_inputs_dir / "control_matched_adata.h5ad" + adata_to_save.write_h5ad(adata_file) + + # Save target conditions dict as JSON + target_conditions_file = self.task_inputs_dir / "target_conditions_dict.json" + with open(target_conditions_file, "w") as f: + json.dump(self.target_conditions_dict, f) + + # Save DE results as Parquet using PyArrow + de_results_file = self.task_inputs_dir / "de_results.parquet" + self.de_results.to_parquet(de_results_file, engine="pyarrow", index=False) return self.task_inputs_dir @@ -414,7 +418,7 @@ def _validate(self) -> None: Validates the following: - Condition format must be one of: - - ``{control_name}`` or ``{control_name}_{perturb}`` for matched control samples. + - ``{control_name}`` or ``{control_name}_{perturb}`` for unmatched or matched control samples. - ``{perturb}`` for single perturbations. - Combinatorial perturbations are not currently supported. @@ -423,26 +427,43 @@ def _validate(self) -> None: """ super()._validate() - # Validate condition format - conditions = set(self.control_matched_adata.obs[self.condition_key]) - target_conditions = set( - x.split("_")[1] for x in self.target_conditions_to_save.keys() - ) # Update for multiple perturbations + if self.condition_key not in self.adata.obs.columns: + raise ValueError( + f"Condition key '{self.condition_key}' not found in adata.obs" + ) + if self.condition_key not in self.control_matched_adata.obs.columns: + raise ValueError( + f"Condition key '{self.condition_key}' not found in control_matched_adata.obs" + ) - for condition in conditions: + # Validate matched_adata condition format by checking the ORIGINAL conditions before processing + original_conditions = set(self.control_matched_adata.obs[self.condition_key]) + target_conditions = set(self.target_conditions_dict.keys()) + for condition in original_conditions: + # Check if it's a valid perturbation condition (just the perturbation name) if condition in target_conditions: continue - elif condition.startswith(self.control_name): - control_matched_condition = condition.split("_")[1] - if control_matched_condition not in target_conditions: + # Check if it's a control condition: just control_name + elif condition == self.control_name: + continue + # Check if it's a matched control condition: control_name_perturb + elif condition.startswith(f"{self.control_name}_"): + # Extract the perturbation part after control_name_ + perturb_part = condition[len(f"{self.control_name}_") :] + if perturb_part in target_conditions: + continue + else: raise ValueError( - f"Invalid control matched condition format: {condition}. " - f"Must be ``{self.control_name}`` or ``{self.control_name}_{{perturb}}``" + f"Invalid matched control condition format: {condition}. " + f"The perturbation part '{perturb_part}' is not a valid target condition. " + f"Valid target conditions: {list(target_conditions)}" ) else: - # Update for multiple perturbations + # Invalid condition format raise ValueError( - f"Invalid perturbation condition format: {condition}. " - f"Must be ``{self.control_name}`` or ``{self.control_name}_{{perturb}}`` for control samples," - "or ``{perturb}`` for perturbations." + f"Invalid condition format: {condition}. " + f"Must be one of:\n" + f" - ``{self.control_name}`` for unmatched control samples\n" + f" - ``{self.control_name}_{{perturb}}`` for matched control samples\n" + f" - ``{{perturb}}`` for single perturbations (where perturb is one of {list(target_conditions)})" ) diff --git a/src/czbenchmarks/tasks/single_cell/perturbation_expression_prediction.py b/src/czbenchmarks/tasks/single_cell/perturbation_expression_prediction.py index f98c2fad..e1e3a3f7 100644 --- a/src/czbenchmarks/tasks/single_cell/perturbation_expression_prediction.py +++ b/src/czbenchmarks/tasks/single_cell/perturbation_expression_prediction.py @@ -1,18 +1,19 @@ -import json import logging from pathlib import Path from typing import Dict, List, Literal +import anndata as ad import numpy as np import pandas as pd from scipy import sparse as sp_sparse +import json from ...constants import RANDOM_SEED from ...metrics import metrics_registry from ...metrics.types import MetricResult, MetricType from ...tasks.types import CellRepresentation from ..task import Task, TaskInput, TaskOutput -from ..utils import binarize_values +from ..utils import binarize_values, looks_like_lognorm logger = logging.getLogger(__name__) @@ -20,11 +21,11 @@ class PerturbationExpressionPredictionTaskInput(TaskInput): """Pydantic model for PerturbationTask inputs.""" + adata: ad.AnnData + target_conditions_dict: Dict[str, List[str]] de_results: pd.DataFrame - masked_adata_obs: pd.DataFrame - var_index: pd.Index - target_conditions_to_save: Dict[str, List[str]] - row_index: pd.Index + gene_index: pd.Index + cell_index: pd.Index def load_perturbation_task_input_from_saved_files( @@ -46,29 +47,29 @@ def load_perturbation_task_input_from_saved_files( inputs_dir = Path(task_inputs_dir) - # Load DE results - de_results_path = inputs_dir / "de_results.json" - de_results = pd.read_json(de_results_path) - - # Load target conditions to save - target_genes_path = inputs_dir / "target_conditions_to_save.json" - with target_genes_path.open("r") as f: - target_conditions_to_save = json.load(f) - - # Rebuild AnnData obs and var - adata_dir = inputs_dir / "control_matched_adata" - obs = pd.read_json(adata_dir / "obs.json", orient="split") - var = pd.read_json(adata_dir / "var.json", orient="split") - row_index = pd.Index( - np.load(inputs_dir / "original_adata/obs/index.npy", allow_pickle=True) - ) + # Load DE results from parquet + de_results_path = inputs_dir / "de_results.parquet" + de_results = pd.read_parquet(de_results_path) + + # Load target conditions dict + target_conditions_path = inputs_dir / "target_conditions_dict.json" + with target_conditions_path.open("r") as f: + target_conditions_dict = json.load(f) + + # Load the main AnnData object + adata_path = inputs_dir / "control_matched_adata.h5ad" + adata = ad.read_h5ad(adata_path) + + # Extract gene_index and cell_index + gene_index = adata.var.index + cell_index = pd.Index(adata.uns["cell_barcode_condition_index"]) return PerturbationExpressionPredictionTaskInput( + adata=adata, + target_conditions_dict=target_conditions_dict, de_results=de_results, - masked_adata_obs=obs, - var_index=var.index, - target_conditions_to_save=target_conditions_to_save, - row_index=row_index, + gene_index=gene_index, + cell_index=cell_index, ) @@ -86,24 +87,22 @@ class PerturbationExpressionPredictionTask(Task): def __init__( self, - metric: str = "wilcoxon", - control_prefix: str = "non-targeting", + condition_key: str = "condition", + control_name: str = "non-targeting", *, random_seed: int = RANDOM_SEED, ): """ Args: - control_prefix (str): Prefix for control conditions. + condition_key (str): Key for the column in `adata.obs` specifying + conditions. Defaults to "condition". + control_name (str): Prefix for control conditions. Defaults to "non-targeting". random_seed (int): Random seed for reproducibility. """ super().__init__(random_seed=random_seed) - if metric == "wilcoxon": - self.metric_column = "logfoldchange" - elif metric == "t-test": - self.metric_column = "standardized_mean_diff" - else: - raise ValueError(f"Metric {metric} not supported") - self.control_prefix = control_prefix + self.metric_column = "logfoldchange" # TODO: logfoldchange only for now + self.control_name = control_name + self.condition_key = condition_key def _run_task( self, @@ -113,78 +112,99 @@ def _run_task( """ Runs the perturbation evaluation task. - This method computes predicted and ground truth log fold changes for each perturbation - condition in the dataset, using the provided cell representations and differential - expression results. It aligns predictions and ground truth values for masked genes, - and prepares data for downstream metric computation. - Args: - cell_representation (CellRepresentation): A numpy matrix of shape (n_cells, n_genes) - task_input (PerturbationExpressionPredictionTaskInput): Input object containing: - - de_results (pd.DataFrame): DataFrame with differential expression results, - including log fold changes/standard mean deviation and gene names. + cell_representation: Cell expression matrix of shape (n_cells, n_genes) + task_input: Task input containing AnnData with all necessary data Returns: - PerturbationExpressionPredictionOutput: Output object containing dictionaries of predicted and true log fold changes - for each perturbation condition. + PerturbationExpressionPredictionOutput: Predicted and true log fold changes """ - + self._validate(task_input, cell_representation) pred_log_fc_dict = {} true_log_fc_dict = {} + adata = task_input.adata + + # Extract data from AnnData + obs = adata.obs de_results = task_input.de_results + target_conditions_dict = task_input.target_conditions_dict - condition_series = task_input.masked_adata_obs["condition"].astype(str) - condition_list = np.unique( - condition_series[~condition_series.str.startswith(self.control_prefix)] + # Get perturbation conditions (non-control) + conditions = obs["condition"].astype(str) + perturbation_conditions = np.unique( + conditions[~conditions.str.startswith(self.control_name)] ) - row_index = task_input.row_index.str.split("_").str[0] - - for condition in condition_list: - condition_de_df = de_results[de_results["condition"] == condition] - masked_genes = np.array( - task_input.target_conditions_to_save[ - task_input.masked_adata_obs.index[ - task_input.masked_adata_obs["condition"] == condition - ][0] - ] - ) - # Filter masked_genes to only those present in var.index - masked_genes = np.array( - [g for g in masked_genes if g in task_input.var_index] - ) + # Extract base cell IDs for matching + base_cell_ids = task_input.cell_index.str.split("_").str[0] - if len(masked_genes) == 0: - print("Skipping condition because it has no masked genes.") + for condition in perturbation_conditions: + # Get target genes for this condition + target_genes = target_conditions_dict.get(condition, []) + valid_genes = [g for g in target_genes if g in task_input.gene_index] + if not valid_genes: + logger.warning( + "Skipping condition %s - no valid target genes", condition + ) continue - true_log_fc = ( - condition_de_df.set_index("gene_id") - .reindex(masked_genes)[self.metric_column] + # This is where the true and predicted log fold changes are computed for each condition + # This outputs an array of true log fold changes for each cell in the condition + # and a corresponding array of predicted log fold changes for each cell in the condition + # Get the true DE results for this condition + condition_de = de_results[de_results[self.condition_key] == condition] + + # Get true log fold changes from DE results + true_lfc = ( + condition_de.set_index("gene_id") + .reindex(valid_genes)[self.metric_column] .values ) - valid = ~np.isnan(true_log_fc) - masked_genes = masked_genes[valid] - true_log_fc = true_log_fc[valid] - col_indices = task_input.var_index.get_indexer(masked_genes) - condition_adata = task_input.masked_adata_obs[ - task_input.masked_adata_obs["condition"] == condition - ].index - condition_col_ids = condition_adata.to_series().str.split("_").str[0] - condition_idx = np.where(row_index.isin(condition_col_ids))[0] - control_adata = task_input.masked_adata_obs[ - task_input.masked_adata_obs["condition"] - == f"{self.control_prefix}_{condition}" - ].index - control_col_ids = control_adata.to_series().str.split("_").str[0] - - control_idx = np.where(row_index.isin(control_col_ids))[0] - condition_vals = cell_representation[np.ix_(condition_idx, col_indices)] - control_vals = cell_representation[np.ix_(control_idx, col_indices)] - ctrl_mean = np.mean(control_vals, axis=0) - cond_mean = np.mean(condition_vals, axis=0) - pred_log_fc = cond_mean - ctrl_mean - pred_log_fc_dict[condition] = pred_log_fc - true_log_fc_dict[condition] = true_log_fc + # Mask out genes with NaN true log fold change values + valid_mask = ~np.isnan(true_lfc) + n_filtered = (~valid_mask).sum() + if n_filtered: + logger.warning( + f"Filtered out {n_filtered} NaN true log fold changes for {condition}" + ) + # Only keep genes with valid (non-NaN) true log fold change values + final_genes = np.array(valid_genes)[valid_mask] + true_lfc = true_lfc[valid_mask] + # true_lfc could be float, so convert to string for join + + # If no valid genes remain for this condition, skip to next + if len(final_genes) == 0: + logger.warning( + f"Skipping condition {condition} - no valid genes remain after filtering" + ) + continue + + # Get indices of the valid genes in task_input.gene_index for slicing the cell_representation matrix + gene_indices = task_input.gene_index.get_indexer(final_genes) + # Find cell barcodes for the current perturbation condition + # This extracts the base cell IDs (before the underscore) for all cells in the current condition + condition_cells = ( + obs[obs[self.condition_key] == condition].index.str.split("_").str[0] + ) + # Find cell barcodes for the corresponding control cells + # Control cells are expected to have a condition label like "controlPrefix_condition" + control_cells = ( + obs[obs[self.condition_key] == f"{self.control_name}_{condition}"] + .index.str.split("_") + .str[0] + ) + # Get indices of the condition and control cells in the cell_representation matrix + condition_idx = np.where(base_cell_ids.isin(condition_cells))[0] + control_idx = np.where(base_cell_ids.isin(control_cells))[0] + # Compute predicted log fold change for each gene: + # - Take the mean expression of each gene across all condition cells + # - Subtract the mean expression of the same gene across all control cells + pred_lfc = cell_representation[np.ix_(condition_idx, gene_indices)].mean( + axis=0 + ) - cell_representation[np.ix_(control_idx, gene_indices)].mean(axis=0) + # Store the predicted and true log fold changes for this condition + pred_log_fc_dict[condition] = pred_lfc + true_log_fc_dict[condition] = true_lfc + return PerturbationExpressionPredictionOutput( pred_log_fc_dict=pred_log_fc_dict, true_log_fc_dict=true_log_fc_dict, @@ -340,3 +360,35 @@ def compute_baseline( # Store the baseline prediction in the dataset for evaluation return perturb_baseline_pred + + def _validate( + self, + task_input: PerturbationExpressionPredictionTaskInput, + cell_representation: CellRepresentation, + ) -> None: + if not looks_like_lognorm(cell_representation): + raise ValueError( + "Task input likelihood contains non-log-normalized data. Please provide a log-normalized cell representation." + ) + + if "cell_barcode_condition_index" not in task_input.adata.uns: + raise ValueError("Task input contains no cell barcode index.") + # Assert that the same values are in both gene and cell indices before re-assigning + if not set(task_input.gene_index).issubset(set(task_input.adata.var.index)): + raise ValueError( + "Model data contains genes that are not in the task input." + ) + if not set(task_input.cell_index).issubset( + set(task_input.adata.uns["cell_barcode_condition_index"]) + ): + raise ValueError( + "Model data contains cells that are not in the task input." + ) + + if set(task_input.gene_index) != set(task_input.adata.var.index): + logger.warning("Task input contains genes that are not in the model input.") + + if set(task_input.cell_index) != set( + task_input.adata.uns["cell_barcode_condition_index"] + ): + logger.warning("Task input contains cells that are not in the model input.") diff --git a/src/czbenchmarks/tasks/utils.py b/src/czbenchmarks/tasks/utils.py index 196009ae..9112eec5 100644 --- a/src/czbenchmarks/tasks/utils.py +++ b/src/czbenchmarks/tasks/utils.py @@ -313,3 +313,31 @@ def run_standard_scrna_workflow( sc.pp.pca(adata, n_comps=n_pcs, key_added=obsm_key, random_state=random_state) return adata.obsm[obsm_key] + + +def looks_like_lognorm( + matrix: CellRepresentation, + sample_size: int | float = 1_000, + tol: float = 1e-2, +) -> bool: + """ + Guess if a matrix contains log-normalized (non-integer) values by inspecting random cell sums. + + This function randomly picks a subset of rows (cells), sums their values, and checks if any + of those sums are not close to integers, which would indicate the data is not raw counts. + + Args: + matrix: Expression matrix (cells x genes). + sample_size: How many cells to check (default: 500 or all if fewer). + tol: Allowed deviation from integer for sum to be considered integer-like. + + Returns: + bool: True if at least one sampled cell sum is non-integer (suggesting log-normalized data). + """ + total_cells = matrix.shape[0] + n = int(min(sample_size, total_cells)) + indices = np.random.choice(total_cells, n, replace=False) + row_totals = matrix[indices].sum(axis=1) + if np.any(np.abs(row_totals - np.round(row_totals)) > tol): + return True + return False diff --git a/tests/conftest.py b/tests/conftest.py index 78d8622d..97c7e8bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,35 +1,77 @@ import pytest +import pandas as pd +import anndata as ad +from czbenchmarks.tasks.types import CellRepresentation +from czbenchmarks.datasets.types import Organism +from tests.utils import create_dummy_anndata -def pytest_addoption(parser): - parser.addoption( - "--tolerance-percent", - type=float, - default=0.2, - help="Percentage tolerance for metric comparison (default: 0.2 = 20%)", - ) - parser.addoption( - "--run-model-tests", - action="store_true", - default=False, - help="Run model regression tests", +@pytest.fixture +def dummy_anndata(): + n_cells: int = 500 + n_genes: int = 200 + organism: Organism = Organism.HUMAN + obs_columns: list[str] = ["cell_type", "batch"] + var_columns: list[str] = ["feature_name"] + anndata: ad.AnnData = create_dummy_anndata( + n_cells=n_cells, + n_genes=n_genes, + organism=organism, + obs_columns=obs_columns, + var_columns=var_columns, ) + expression_matrix: CellRepresentation = anndata.X.copy() + obs: pd.DataFrame = anndata.obs.copy() + var: pd.DataFrame = anndata.var.copy() -def pytest_configure(config): - # Register custom markers - config.addinivalue_line( - "markers", - "integration: marks tests as integration tests (deselect with '-m \"not integration\"')", - ) - pytest.run_model_tests = config.getoption("--run-model-tests") + # TODO perform PCA on expression matrix to get true embedding + embedding_matrix: CellRepresentation = expression_matrix.toarray() + + return { + "anndata": anndata, + "expression_matrix": expression_matrix, + "obs": obs, + "var": var, + "embedding_matrix": embedding_matrix, + } + + +@pytest.fixture +def expression_matrix(dummy_anndata): + return dummy_anndata["expression_matrix"] + + +@pytest.fixture +def embedding_matrix(dummy_anndata): + return dummy_anndata["embedding_matrix"] + + +@pytest.fixture +def obs(dummy_anndata): + return dummy_anndata["obs"] @pytest.fixture -def tolerance_percent(request): - return request.config.getoption("--tolerance-percent", default=0.2) +def var(dummy_anndata): + return dummy_anndata["var"] @pytest.fixture -def run_model_tests(request): - return request.config.getoption("--run-model-tests", default=False) +def fixture_data(request): + # Enables lazy generation of fixture data so fixtures can be used as + # parameters + valid_fixture_names = ["expression_matrix", "embedding_matrix", "obs", "var"] + fixture_name, other_data = request.param + if isinstance(fixture_name, str): + fixture_data = ( + request.getfixturevalue(fixture_name) + if fixture_name in valid_fixture_names + else fixture_name + ) + else: + fixture_data = [ + request.getfixturevalue(f) if f in valid_fixture_names else f + for f in fixture_name + ] + return fixture_data, other_data diff --git a/tests/datasets/test_single_cell_perturbation_dataset.py b/tests/datasets/test_single_cell_perturbation_dataset.py index c7f16c5c..e0c4bebf 100644 --- a/tests/datasets/test_single_cell_perturbation_dataset.py +++ b/tests/datasets/test_single_cell_perturbation_dataset.py @@ -1,9 +1,15 @@ +import json +import numpy as np from pathlib import Path import pandas as pd import pytest +import anndata as ad from czbenchmarks.datasets.single_cell_perturbation import SingleCellPerturbationDataset from czbenchmarks.datasets.types import Organism +from czbenchmarks.tasks.single_cell.perturbation_expression_prediction import ( + PerturbationExpressionPredictionTaskInput, +) from tests.datasets.test_single_cell_dataset import SingleCellDatasetTests from tests.utils import create_dummy_anndata @@ -18,7 +24,7 @@ def valid_dataset(self, tmp_path) -> SingleCellPerturbationDataset: path=self.valid_dataset_file(tmp_path), organism=Organism.HUMAN, condition_key="condition", - control_name="ctrl", + control_name="non-targeting", percent_genes_to_mask=0.5, min_de_genes_to_mask=5, pval_threshold=1e-4, @@ -35,8 +41,8 @@ def valid_dataset_file(self, tmp_path) -> Path: organism=Organism.HUMAN, ) adata.obs["condition"] = [ - "ctrl", - "ctrl", + "non-targeting", + "non-targeting", "test1", "test1", "test2", @@ -44,8 +50,8 @@ def valid_dataset_file(self, tmp_path) -> Path: ] # Set indices so that splitting on '_' and taking token [1] yields the condition adata.obs_names = [ - "ctrl_test1_a", # control cell 1 - "ctrl_test2_b", # control cell 2 + "non-targeting_test1_a", # control cell 1 + "non-targeting_test2_b", # control cell 2 "cond_test1_a", "cond_test1_b", "cond_test2_a", @@ -55,12 +61,12 @@ def valid_dataset_file(self, tmp_path) -> Path: # treatment cell ids -> matched control cell id adata.uns["control_cells_ids"] = { "test1": { - "cond_test1_a": "ctrl_test1_a", - "cond_test1_b": "ctrl_test2_b", + "cell_2": "cell_0", + "cell_3": "cell_1", }, "test2": { - "cond_test2_a": "ctrl_test1_a", - "cond_test2_b": "ctrl_test2_b", + "cell_4": "cell_0", + "cell_5": "cell_1", }, } # Provide sufficient DE results to pass internal filtering and sampling @@ -103,8 +109,8 @@ def perturbation_invalid_condition_h5ad(self, tmp_path) -> Path: organism=Organism.HUMAN, ) adata.obs["condition"] = [ - "BADctrl", - "BADctrl", + "BADnon-targeting", + "BADnon-targeting", "test1", "test1", "test2", @@ -114,14 +120,8 @@ def perturbation_invalid_condition_h5ad(self, tmp_path) -> Path: # Map treatment cells to control cells using default obs_names from create_dummy_anndata # BADctrl cells correspond to cell_0 and cell_1; test1 -> cell_2, cell_3; test2 -> cell_4, cell_5 adata.uns["control_cells_ids"] = { - "test1": { - "cell_2": "cell_0", - "cell_3": "cell_1", - }, - "test2": { - "cell_4": "cell_0", - "cell_5": "cell_1", - }, + "test1": ["non-targeting_cell_0", "non-targeting_cell_1"], + "test2": ["non-targeting_cell_0", "non-targeting_cell_1"], } de_conditions = ["test1"] * 10 + ["test2"] * 10 de_genes = [f"ENSG000000000{str(i).zfill(2)}" for i in range(20)] @@ -148,12 +148,12 @@ def test_perturbation_dataset_load_data( pval_threshold, ): """Tests the loading of perturbation dataset data across parameter combinations.""" - + condition_key = "condition" dataset = SingleCellPerturbationDataset( path=self.valid_dataset_file(tmp_path), organism=Organism.HUMAN, - condition_key="condition", - control_name="ctrl", + condition_key=condition_key, + control_name="non-targeting", percent_genes_to_mask=percent_genes_to_mask, min_de_genes_to_mask=min_de_genes_to_mask, pval_threshold=pval_threshold, @@ -166,18 +166,28 @@ def test_perturbation_dataset_load_data( # Expect 2 conditions (test1, test2), each with 2 perturbed + 2 control cells -> 8 total assert dataset.control_matched_adata.shape == (8, 3) # Target genes should be stored per cell (for each unique cell index) - assert hasattr(dataset, "target_conditions_to_save") - unique_obs_count = len(set(dataset.control_matched_adata.obs.index.tolist())) - assert len(dataset.target_conditions_to_save) == unique_obs_count + assert hasattr(dataset, "target_conditions_dict") + unique_condition_count = len( + np.unique( + dataset.control_matched_adata.obs[condition_key][ + ~dataset.control_matched_adata.obs[condition_key].str.startswith( + "non-targeting" + ) + ] + ) + ) + + assert len(dataset.target_conditions_dict) == unique_condition_count # With 10 DE genes per condition in fixtures expected_sampled = int(10 * percent_genes_to_mask) - sampled_lengths = {len(v) for v in dataset.target_conditions_to_save.values()} + sampled_lengths = {len(v) for v in dataset.target_conditions_dict.values()} assert sampled_lengths == {expected_sampled} def test_perturbation_dataset_load_data_missing_condition_key( self, perturbation_missing_condition_column_h5ad, ): + condition_key = "condition" """Tests that loading data fails when the condition column is missing.""" invalid_dataset = SingleCellPerturbationDataset( perturbation_missing_condition_column_h5ad, @@ -190,17 +200,19 @@ def test_perturbation_dataset_load_data_missing_condition_key( ) with pytest.raises( - ValueError, match="Condition key 'condition' not found in adata.obs" + ValueError, match=f"Condition key '{condition_key}' not found in adata.obs" ): invalid_dataset.load_data() - def test_perturbation_dataset_validate_invalid_condition( + def test_perturbation_dataset_store_task_inputs( self, - perturbation_invalid_condition_h5ad, + tmp_path, ): - """Test that validation fails with invalid condition format.""" + """Tests that the store_task_inputs method writes expected separate files.""" + condition_key = "condition" + dataset = SingleCellPerturbationDataset( - perturbation_invalid_condition_h5ad, + path=self.valid_dataset_file(tmp_path), organism=Organism.HUMAN, condition_key="condition", percent_genes_to_mask=0.5, @@ -209,41 +221,98 @@ def test_perturbation_dataset_validate_invalid_condition( min_logfoldchange=1.0, ) dataset.load_data() - with pytest.raises(ValueError): - dataset.validate() - def test_perturbation_dataset_store_task_inputs( - self, - tmp_path, - ): - """Tests that the store_task_inputs method writes expected files.""" + task_inputs_dir = dataset.store_task_inputs() + assert task_inputs_dir.exists() + assert task_inputs_dir.is_dir() + + # Check that all required files exist + expected_files = [ + "control_matched_adata.h5ad", + "target_conditions_dict.json", + "de_results.parquet", + ] + + for filename in expected_files: + filepath = task_inputs_dir / filename + assert filepath.exists(), f"Expected file {filename} not found" + + # Load and validate the main AnnData file + + task_adata = ad.read_h5ad(task_inputs_dir / "control_matched_adata.h5ad") + assert isinstance(task_adata, ad.AnnData) + + # Load and validate JSON files + with open(task_inputs_dir / "target_conditions_dict.json", "r") as f: + target_conditions_dict = json.load(f) + assert isinstance(target_conditions_dict, dict) + + # Load and validate DE results Parquet (should only have optimized columns) + de_df = pd.read_parquet( + task_inputs_dir / "de_results.parquet", engine="pyarrow" + ) + assert not de_df.empty + # Only the necessary columns should be present + expected_cols = {condition_key, "gene_id"} + expected_cols.add("logfoldchange") + assert set(de_df.columns) == expected_cols + + # Load and validate cell barcode index + cell_barcode_condition_index = task_adata.uns["cell_barcode_condition_index"] + assert isinstance(cell_barcode_condition_index, np.ndarray) + assert len(cell_barcode_condition_index) == dataset.adata.shape[0] + + def test_control_matched_adata_contains_task_data(self, tmp_path): + """Test that control_matched_adata contains all required task data in uns.""" dataset = SingleCellPerturbationDataset( path=self.valid_dataset_file(tmp_path), organism=Organism.HUMAN, condition_key="condition", - control_name="ctrl", + control_name="non-targeting", percent_genes_to_mask=0.5, - min_de_genes_to_mask=5, + min_de_genes_to_mask=2, pval_threshold=1e-4, min_logfoldchange=1.0, ) dataset.load_data() - out_dir = dataset.store_task_inputs() - control_file = out_dir / "control_cells_ids.json" - target_conditions_file = out_dir / "target_conditions_to_save.json" - de_results_file = out_dir / "de_results.json" + # Verify that control_matched_adata exists and has the required keys in uns + assert hasattr(dataset, "control_matched_adata") + assert dataset.control_matched_adata is not None + + required_uns_keys = { + "target_conditions_dict", + "de_results", + "cell_barcode_condition_index", + "control_cells_ids", + } + + actual_uns_keys = set(dataset.control_matched_adata.uns.keys()) + assert required_uns_keys.issubset(actual_uns_keys), ( + f"Missing required keys in control_matched_adata.uns. " + f"Required: {required_uns_keys}, Found: {actual_uns_keys}" + ) + + # Verify the contents of each key + uns = dataset.control_matched_adata.uns + + # Check target_conditions_dict + assert isinstance(uns["target_conditions_dict"], dict) + assert len(uns["target_conditions_dict"]) > 0 + assert uns["target_conditions_dict"] == dataset.target_conditions_dict - assert control_file.exists() - assert target_conditions_file.exists() - assert de_results_file.exists() + # Check control_cells_ids + assert isinstance(uns["control_cells_ids"], dict) + assert len(uns["control_cells_ids"]) > 0 + assert uns["control_cells_ids"] == dataset.control_cells_ids - # Validate that DE results JSON is readable and has expected columns - de_df = pd.read_json(de_results_file) + # Check de_results can be reconstructed as DataFrame (should only have optimized columns) + assert isinstance(uns["de_results"], dict) + de_df = pd.DataFrame(uns["de_results"]) assert not de_df.empty - base_cols = {"condition", "gene", "pval_adj"} - assert base_cols.issubset(set(de_df.columns)) - assert "logfoldchange" in de_df.columns + # The optimized DE results should only contain the necessary columns for the task + expected_cols = {"condition", "gene_id", "logfoldchange"} + assert set(de_df.columns) == expected_cols @pytest.mark.parametrize("percent_genes_to_mask", [0.5, 1.0]) @pytest.mark.parametrize("min_de_genes_to_mask", [1, 5]) @@ -257,7 +326,7 @@ def test_perturbation_dataset_load_de_results_from_csv( ): """Tests loading DE results from an external CSV via de_results_path.""" # Create the base AnnData file using existing helper to ensure obs/uns layout - h5ad_path = self.valid_dataset_file(tmp_path) + self.valid_dataset_file(tmp_path) # Create a DE results CSV with required columns for both tests # Include two conditions that match the AnnData: test1 and test2, 10 genes each @@ -276,27 +345,56 @@ def test_perturbation_dataset_load_de_results_from_csv( # Construct dataset pointing to the CSV and with permissive thresholds dataset = SingleCellPerturbationDataset( - path=h5ad_path, + path=self.valid_dataset_file(tmp_path), organism=Organism.HUMAN, condition_key="condition", - control_name="ctrl", + control_name="non-targeting", percent_genes_to_mask=percent_genes_to_mask, min_de_genes_to_mask=min_de_genes_to_mask, pval_threshold=pval_threshold, min_logfoldchange=0.0, ) - dataset.load_data() - # Expect 2 conditions (test1, test2), each with 2 perturbed + 2 control cells -> 8 total - assert dataset.control_matched_adata.shape == (8, 3) + # This should work without any errors since control_matched_adata contains all required data + task_input = PerturbationExpressionPredictionTaskInput( + adata=dataset.control_matched_adata, + target_conditions_dict=dataset.target_conditions_dict, + de_results=dataset.de_results, + gene_index=dataset.control_matched_adata.var.index, + cell_index=pd.Index( + dataset.control_matched_adata.uns["cell_barcode_condition_index"] + ), + ) - # Target genes should be stored per cell (for each unique cell index) - assert hasattr(dataset, "target_conditions_to_save") - unique_obs_count = len(set(dataset.control_matched_adata.obs.index.tolist())) - assert len(dataset.target_conditions_to_save) == unique_obs_count + # Verify the task input was created successfully + assert task_input is not None + assert hasattr(task_input, "adata") + assert hasattr(task_input, "target_conditions_dict") + assert hasattr(task_input, "de_results") + assert task_input.adata is not None - # With 10 genes per condition and percent as parameter - expected_sampled = int(10 * percent_genes_to_mask) - sampled_lengths = {len(v) for v in dataset.target_conditions_to_save.values()} - assert sampled_lengths == {expected_sampled} + # Verify that required data in uns is accessible + required_uns_keys = { + "cell_barcode_condition_index", + "control_cells_ids", + } + actual_uns_keys = set(task_input.adata.uns.keys()) + assert required_uns_keys.issubset(actual_uns_keys) + + # Verify data integrity - the data should match the original dataset + assert ( + task_input.adata.uns["target_conditions_dict"] + == dataset.target_conditions_dict + ) + assert task_input.adata.uns["control_cells_ids"] == dataset.control_cells_ids + + # Check that DE results can be reconstructed + de_df = pd.DataFrame(task_input.adata.uns["de_results"]) + assert len(de_df) == len(dataset.de_results) + + # Check cell barcode index + np.testing.assert_array_equal( + task_input.adata.uns["cell_barcode_condition_index"], + dataset.adata.obs.index.astype(str).values, + ) diff --git a/tests/tasks/test_cross_species_integration.py b/tests/tasks/test_cross_species_integration.py new file mode 100644 index 00000000..f5f50049 --- /dev/null +++ b/tests/tasks/test_cross_species_integration.py @@ -0,0 +1,37 @@ +import pytest +from czbenchmarks.tasks.single_cell import ( + CrossSpeciesIntegrationTask, + CrossSpeciesIntegrationTaskInput, +) +from czbenchmarks.datasets.types import Organism +from czbenchmarks.metrics.types import MetricResult + + +def test_cross_species_task(embedding_matrix, obs): + """Test that CrossSpeciesIntegrationTask executes without errors.""" + task = CrossSpeciesIntegrationTask() + embedding_list = [embedding_matrix, embedding_matrix] + labels = obs["cell_type"] + labels_list = [labels, labels] + organism_list = [Organism.HUMAN, Organism.MOUSE] + task_input = CrossSpeciesIntegrationTaskInput( + labels=labels_list, organism_list=organism_list + ) + + try: + # Test regular task execution + results = task.run( + cell_representation=embedding_list, + task_input=task_input, + ) + + # Verify results structure + assert isinstance(results, list) + assert all(isinstance(r, MetricResult) for r in results) + + # Test that baseline raises NotImplementedError + with pytest.raises(NotImplementedError): + task.compute_baseline() + + except Exception as e: + pytest.fail(f"CrossSpeciesIntegrationTask failed unexpectedly: {e}") diff --git a/tests/tasks/test_general_tasks.py b/tests/tasks/test_general_tasks.py new file mode 100644 index 00000000..ebcff582 --- /dev/null +++ b/tests/tasks/test_general_tasks.py @@ -0,0 +1,144 @@ +import pytest +from czbenchmarks.tasks import ( + ClusteringTask, + ClusteringTaskInput, + EmbeddingTask, + EmbeddingTaskInput, + BatchIntegrationTask, + BatchIntegrationTaskInput, + MetadataLabelPredictionTask, + MetadataLabelPredictionTaskInput, +) +from czbenchmarks.metrics.types import MetricResult +from tests.utils import DummyTask, DummyTaskInput + + +@pytest.mark.parametrize( + "fixture_data", + [ + ("expression_matrix", False), + (["expression_matrix", "expression_matrix"], True), + ], + indirect=True, +) +def test_embedding_valid_input_output(fixture_data): + """Test that embedding is accepted and List[MetricResult] is returned.""" + embedding, requires_multiple_datasets = fixture_data + task = DummyTask(requires_multiple_datasets=requires_multiple_datasets) + results = task.run( + cell_representation=embedding, + task_input=DummyTaskInput(), + ) + + assert isinstance(results, list) + assert all(isinstance(r, MetricResult) for r in results) + + +@pytest.mark.parametrize( + "fixture_data", + [ + ( + "abcd", + [False, "This task requires a single cell representation for input"], + ), + ( + ["embedding_matrix"], + [False, "This task requires a single cell representation for input"], + ), + ( + ["embedding_matrix", "embedding_matrix"], + [False, "This task requires a single cell representation for input"], + ), + ( + "embedding_matrix", + [True, "This task requires a list of cell representations"], + ), + ( + ["abcd", "embedding_matrix"], + [True, "This task requires a list of cell representations"], + ), + ( + ["embedding_matrix"], + [ + True, + "This task requires a list of cell representations", + ], + ), + ], + indirect=True, +) +def test_embedding_invalid_input(fixture_data): + """Test ValueError for mismatch with requires_multiple_datasets.""" + embedding_list, (requires_multiple_datasets, error_message) = fixture_data + task = DummyTask(requires_multiple_datasets=requires_multiple_datasets) + with pytest.raises(ValueError, match=error_message): + task.run( + cell_representation=embedding_list, + task_input=DummyTaskInput(), + ) + + +@pytest.mark.parametrize( + "task_class,task_input_builder", + [ + ( + ClusteringTask, + lambda obs: ClusteringTaskInput(obs=obs, input_labels=obs["cell_type"]), + ), + ( + EmbeddingTask, + lambda obs: EmbeddingTaskInput(input_labels=obs["cell_type"]), + ), + ( + BatchIntegrationTask, + lambda obs: BatchIntegrationTaskInput( + labels=obs["cell_type"], batch_labels=obs["batch"] + ), + ), + ( + MetadataLabelPredictionTask, + lambda obs: MetadataLabelPredictionTaskInput(labels=obs["cell_type"]), + ), + ], +) +def test_task_execution( + task_class, + task_input_builder, + embedding_matrix, + expression_matrix, + obs, +): + """Test that each task executes without errors on compatible data.""" + + task_input = task_input_builder(obs) + + task = task_class() + + try: + # Test regular task execution + results = task.run( + cell_representation=embedding_matrix, + task_input=task_input, + ) + assert isinstance(results, list) + assert all(isinstance(r, MetricResult) for r in results) + + # Test baseline execution if implemented + try: + n_pcs = min(50, expression_matrix.shape[1] - 1) + baseline_embedding = task.compute_baseline(expression_matrix, n_pcs=n_pcs) + if hasattr(task_input, "var"): + task_input.var = task_input.var.iloc[:n_pcs] + + baseline_results = task.run( + cell_representation=baseline_embedding, + task_input=task_input, + ) + assert isinstance(baseline_results, list) + assert all(isinstance(r, MetricResult) for r in baseline_results) + except NotImplementedError: + # Some tasks may not implement compute_baseline + pass + + except Exception as e: + pytest.fail(f"Task {task_class.__name__} failed unexpectedly: {e}") diff --git a/tests/tasks/test_perturbation_expression_prediction.py b/tests/tasks/test_perturbation_expression_prediction.py new file mode 100644 index 00000000..852d21fe --- /dev/null +++ b/tests/tasks/test_perturbation_expression_prediction.py @@ -0,0 +1,686 @@ +import pytest +import pandas as pd +import numpy as np +import tempfile +from pathlib import Path +import anndata as ad +from czbenchmarks.tasks.single_cell import ( + PerturbationExpressionPredictionTask, + PerturbationExpressionPredictionTaskInput, +) +from czbenchmarks.tasks.single_cell.perturbation_expression_prediction import ( + load_perturbation_task_input_from_saved_files, +) +from czbenchmarks.datasets.types import Organism +from czbenchmarks.datasets.single_cell_perturbation import ( + SingleCellPerturbationDataset, +) +from czbenchmarks.metrics.types import MetricResult, MetricType +from tests.utils import create_dummy_anndata, create_dummy_perturbation_anndata + + +def test_perturbation_task(): + """Test that PerturbationExpressionPredictionTask executes without errors.""" + # Create dummy perturbation data + perturbation_data: dict = create_dummy_perturbation_anndata( + n_cells=500, + n_genes=200, + organism=Organism.HUMAN, + condition_column="condition", + split_column="split", + ) + gene_pert = perturbation_data["gene_pert"] + # Convert sparse matrix to dense array to avoid matrix object issues + cell_representation = perturbation_data["adata"].X.toarray() + # Log-normalize the data to pass validation (add small constant to avoid log(0)) + cell_representation = np.log1p(cell_representation) + var_names = perturbation_data["adata"].var_names + + # Task and argument setup + task = PerturbationExpressionPredictionTask() + + # Create DE results DataFrame matching expected structure + de_results = pd.DataFrame( + { + "condition": [gene_pert] * len(var_names), + "gene_id": var_names, + "logfoldchange": np.random.randn(len(var_names)), + "pval_adj": np.random.uniform(0, 0.01, len(var_names)), + } + ) + + # Create masked_adata_obs DataFrame and fix condition naming + adata = perturbation_data["adata"] + masked_adata_obs = adata.obs.copy() + + # Fix condition naming to match task expectations + # Task expects control cells to be named: {control_prefix}_{condition} + control_condition_name = f"non-targeting_{gene_pert}" + masked_adata_obs.loc[masked_adata_obs["condition"] == "ctrl", "condition"] = ( + control_condition_name + ) + + # Create target_conditions_dict dict - map condition names to lists of genes to mask + target_conditions_dict = {} + # Map each perturbation condition to genes to mask + unique_conditions = np.unique( + masked_adata_obs["condition"][ + ~masked_adata_obs["condition"].str.startswith("non-targeting") + ] + ) + for condition in unique_conditions: + # Sample some genes to mask for each condition + n_genes_to_mask = min(10, len(var_names) // 2) + target_conditions_dict[condition] = list( + np.random.choice(var_names, n_genes_to_mask, replace=False) + ) + + # Create AnnData with required data + test_adata = ad.AnnData( + X=adata.X, obs=masked_adata_obs, var=pd.DataFrame(index=var_names) + ) + test_adata.uns["cell_barcode_condition_index"] = adata.obs.index.astype(str).values + + task_input = PerturbationExpressionPredictionTaskInput( + adata=test_adata, + target_conditions_dict=target_conditions_dict, + de_results=de_results, + gene_index=test_adata.var.index, + cell_index=pd.Index(test_adata.uns["cell_barcode_condition_index"]), + ) + + # Five metrics per condition: accuracy, precision, recall, f1, correlation + # We have one perturbed condition, so 5 metrics total + num_metrics = 5 + + try: + # Test regular task execution + results = task.run( + cell_representation, + task_input, + ) + + # Verify results structure - the method returns a list of MetricResult + assert isinstance(results, list) + assert all(isinstance(r, MetricResult) for r in results) + assert len(results) == num_metrics + + # Test baseline with both mean and median + for baseline_type in ["mean", "median"]: + baseline_embedding = task.compute_baseline( + cell_representation=cell_representation, + baseline_type=baseline_type, + ) + # Create a new task input with the baseline embedding + baseline_results = task.run(baseline_embedding, task_input) + assert isinstance(baseline_results, list) + assert all(isinstance(r, MetricResult) for r in baseline_results) + assert len(baseline_results) == num_metrics + except Exception as e: + pytest.fail(f"Test failed with exception: {e}") + + +def test_perturbation_expression_prediction_task_wilcoxon(): + """Test Wilcoxon path computes correct vectors and metrics.""" + # Deterministic gene set and per-condition true/predicted effects + gene_names = ["G0", "G1", "G2", "G3"] + true_lfc_gene_A = np.array([1.0, 0.5, -0.5, -1.0]) + true_lfc_gene_B = np.array([2.0, 1.0, -1.0, -2.0]) + + # Build DE results matching the designed true effects for both conditions + de_res_wilcoxon_df = pd.concat( + [ + pd.DataFrame( + { + "logfoldchange": true_lfc_gene_A, + "target_gene": ["gene_A"] * len(gene_names), + "names": gene_names, + "pval": [0.001] * len(gene_names), + "pval_adj": [0.001] * len(gene_names), + "condition": ["gene_A"] * len(gene_names), + "condition_ensembl_id": ["ENSG_A"] * len(gene_names), + } + ), + pd.DataFrame( + { + "logfoldchange": true_lfc_gene_B, + "target_gene": ["gene_B"] * len(gene_names), + "names": gene_names, + "pval": [0.001] * len(gene_names), + "pval_adj": [0.001] * len(gene_names), + "condition": ["gene_B"] * len(gene_names), + "condition_ensembl_id": ["ENSG_B"] * len(gene_names), + } + ), + ], + ignore_index=True, + ) + + # Build AnnData with 4 groups: condition/control for A and B + n_per_group = 4 + conditions = ( + ["gene_A"] * n_per_group + + ["ctrl_gene_A"] * n_per_group + + ["gene_B"] * n_per_group + + ["ctrl_gene_B"] * n_per_group + ) + # Create base cell names without underscores (so str.split("_").str[0] works correctly) + base_cell_names = [f"cellbarcode{i}" for i in range(len(conditions))] + # Create extended obs names like real dataset: base_name + "_" + condition + obs_names = [ + f"{base_name}_{cond}" for base_name, cond in zip(base_cell_names, conditions) + ] + + X = np.zeros((len(conditions), len(gene_names)), dtype=float) + # Set group means so that pred_log_fc equals the designed true_lfc + X[0:n_per_group, :] = true_lfc_gene_A # gene_A group + X[n_per_group : 2 * n_per_group, :] = 0.0 # ctrl_gene_A group + X[2 * n_per_group : 3 * n_per_group, :] = true_lfc_gene_B # gene_B group + X[3 * n_per_group : 4 * n_per_group, :] = 0.0 # ctrl_gene_B group + + adata = ad.AnnData( + X=X, + obs=pd.DataFrame({"condition": conditions}, index=obs_names), + var=pd.DataFrame(index=gene_names), + ) + target_conditions_dict = {"gene_A": list(gene_names), "gene_B": list(gene_names)} + + # Create log-normalized data that will pass validation and gives expected differences + # Since the task computes: mean(treatment) - mean(control), we create data where this equals true_lfc + base_value = 1.0 # Base value that will become 0 after log transformation + cell_representation = np.zeros_like(X, dtype=float) + + # For treatment groups: set to base_value + true_lfc so log difference will be true_lfc + cell_representation[0:n_per_group, :] = base_value + true_lfc_gene_A # gene_A group + cell_representation[n_per_group : 2 * n_per_group, :] = ( + base_value # ctrl_gene_A group + ) + cell_representation[2 * n_per_group : 3 * n_per_group, :] = ( + base_value + true_lfc_gene_B + ) # gene_B group + cell_representation[3 * n_per_group : 4 * n_per_group, :] = ( + base_value # ctrl_gene_B group + ) + + # Add small constant to ensure fractional cell sums for validation (must exceed epsilon=1e-2) + cell_representation += 0.003 + + # Ensure de_results has the expected gene identifier column + de_res_wilcoxon_df["gene_id"] = de_res_wilcoxon_df["names"] + + task = PerturbationExpressionPredictionTask( + control_name="ctrl", + ) + # Create AnnData with required data + test_adata = adata.copy() + test_adata.uns["cell_barcode_condition_index"] = ( + pd.Index(base_cell_names).astype(str).values + ) + + task_input = PerturbationExpressionPredictionTaskInput( + adata=test_adata, + target_conditions_dict=target_conditions_dict, + de_results=de_res_wilcoxon_df, + gene_index=test_adata.var.index, + cell_index=pd.Index(test_adata.uns["cell_barcode_condition_index"]), + ) + + # First, check that true/pred vectors produced by _run_task match expectations + task_output = task._run_task(cell_representation, task_input) + assert set(task_output.pred_log_fc_dict.keys()) == {"gene_A", "gene_B"} + assert set(task_output.true_log_fc_dict.keys()) == {"gene_A", "gene_B"} + assert np.allclose(task_output.pred_log_fc_dict["gene_A"], true_lfc_gene_A) + assert np.allclose(task_output.pred_log_fc_dict["gene_B"], true_lfc_gene_B) + assert np.allclose(task_output.true_log_fc_dict["gene_A"], true_lfc_gene_A) + assert np.allclose(task_output.true_log_fc_dict["gene_B"], true_lfc_gene_B) + + # Then, run the full task and validate metrics are perfect + results = task.run(cell_representation, task_input) + + assert isinstance(results, list) + assert all(isinstance(r, MetricResult) for r in results) + # Expect results for both conditions x 5 metric types = 10 total results + assert len(results) == 10 + + # Each result should have perfect scores + for r in results: + assert np.isclose(r.value, 1.0) + + # Verify metric types present + metric_types = {result.metric_type for result in results} + expected_types = { + MetricType.ACCURACY_CALCULATION, + MetricType.PRECISION_CALCULATION, + MetricType.RECALL_CALCULATION, + MetricType.F1_CALCULATION, + MetricType.SPEARMAN_CORRELATION_CALCULATION, + } + assert expected_types.issubset(metric_types) + + +def test_perturbation_expression_prediction_task_load_from_task_inputs(tmp_path): + """Test that the task can load inputs from stored task files.""" + + # Create a dummy dataset and store its task inputs + file_path = tmp_path / "dummy_perturbation.h5ad" + adata = create_dummy_anndata( + n_cells=6, + n_genes=3, + obs_columns=["condition"], + organism=Organism.HUMAN, + ) + adata.obs["condition"] = ["ctrl", "ctrl", "test1", "test1", "test2", "test2"] + adata.obs_names = [ + "ctrl_test1_a", + "ctrl_test2_b", + "cond_test1_a", + "cond_test1_b", + "cond_test2_a", + "cond_test2_b", + ] + # Provide matched control cell IDs and DE results + adata.uns["control_cells_ids"] = { + "test1": { + "cond_test1_a": "non-targeting_test1_a", + "cond_test1_b": "non-targeting_test2_b", + }, + "test2": { + "cond_test2_a": "non-targeting_test1_a", + "cond_test2_b": "non-targeting_test2_b", + }, + } + de_conditions = ["test1"] * 10 + ["test2"] * 10 + de_genes = [f"ENSG000000000{str(i).zfill(2)}" for i in range(20)] + adata.uns["de_results_wilcoxon"] = pd.DataFrame( + { + "condition": de_conditions, + "gene": de_genes, + "pval_adj": [1e-6] * 20, + "logfoldchange": [2.0] * 20, + } + ) + adata.write_h5ad(file_path) + + # Create dataset and store task inputs + dataset = SingleCellPerturbationDataset( + path=file_path, + organism=Organism.HUMAN, + condition_key="condition", + control_name="ctrl", + ) + dataset.load_data() + task_inputs_dir = dataset.store_task_inputs() + + # Test loading task inputs using the standalone function + + task_input = load_perturbation_task_input_from_saved_files(task_inputs_dir) + + # Verify the loaded task input has the expected structure + assert isinstance(task_input, PerturbationExpressionPredictionTaskInput) + assert hasattr(task_input, "adata") + assert hasattr(task_input, "target_conditions_dict") + assert hasattr(task_input, "de_results") + assert isinstance(task_input.adata, ad.AnnData) + assert isinstance(task_input.target_conditions_dict, dict) + assert isinstance(task_input.de_results, pd.DataFrame) + + # Verify AnnData contains required data in uns + assert "cell_barcode_condition_index" in task_input.adata.uns + assert "control_cells_ids" in task_input.adata.uns + + # Verify data integrity + assert task_input.de_results.shape[0] > 0 + assert task_input.adata.obs.shape[0] > 0 + assert len(task_input.adata.var.index) > 0 + assert len(task_input.target_conditions_dict) > 0 + + # Verify cell barcode index matches adata size + assert ( + len(task_input.adata.uns["cell_barcode_condition_index"]) + == dataset.adata.shape[0] + ) + + +def test_perturbation_expression_prediction_task_with_shuffled_input(): + """Test that the perturbation task works with shuffled AnnData input.""" + + # Create a dummy dataset + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = Path(tmp_dir) / "dummy_perturbation.h5ad" + adata = create_dummy_anndata( + n_cells=6, + n_genes=4, # Use 4 genes for easier testing + obs_columns=["condition"], + organism=Organism.HUMAN, + ) + adata.obs["condition"] = ["ctrl", "ctrl", "test1", "test1", "test2", "test2"] + adata.obs_names = [ + "ctrl_test1_a", + "ctrl_test2_b", + "cond_test1_a", + "cond_test1_b", + "cond_test2_a", + "cond_test2_b", + ] + # Provide matched control cell IDs and DE results + adata.uns["control_cells_ids"] = { + "test1": { + "cond_test1_a": "non-targeting_test1_a", + "cond_test1_b": "non-targeting_test2_b", + }, + "test2": { + "cond_test2_a": "non-targeting_test1_a", + "cond_test2_b": "non-targeting_test2_b", + }, + } + + # Create DE results that match the actual gene names in the dataset + # The create_dummy_anndata creates genes with names like "ENSG000000000{i:02d}" + gene_names = ( + adata.var.index.tolist() + ) # Get actual gene names from the created dataset + + # Create sufficient DE results for both conditions using actual gene names + de_data = [] + for condition in ["test1", "test2"]: + for gene in gene_names: + de_data.append( + { + "condition": condition, + "gene_id": gene, # Use gene_id instead of gene + "pval_adj": 1e-6, + "logfoldchange": 2.0, + } + ) + + adata.uns["de_results_wilcoxon"] = pd.DataFrame(de_data) + adata.write_h5ad(file_path) + + # Create dataset and load data with relaxed parameters + dataset = SingleCellPerturbationDataset( + path=file_path, + organism=Organism.HUMAN, + condition_key="condition", + control_name="ctrl", + percent_genes_to_mask=1.0, # Use all genes to avoid filtering + min_de_genes_to_mask=1, # Minimum threshold + pval_threshold=1.0, # Accept all p-values + min_logfoldchange=0.0, # Accept all log fold changes + ) + dataset.load_data() + + # Create task input with original ordering + original_task_input = PerturbationExpressionPredictionTaskInput( + adata=dataset.control_matched_adata, + target_conditions_dict=dataset.target_conditions_dict, + de_results=dataset.de_results, + gene_index=dataset.control_matched_adata.var.index, + cell_index=pd.Index( + dataset.control_matched_adata.uns["cell_barcode_condition_index"] + ), + ) + + # Create identical model output for both (using original dimensions) + model_output = np.random.rand( + dataset.control_matched_adata.shape[0], + dataset.control_matched_adata.shape[1], + ) + # Shuffle obs and var with fixed random seed for reproducibility + np.random.seed(42) + obs_shuffled = dataset.adata.obs.sample(frac=1, random_state=42) + var_shuffled = dataset.adata.var.sample(frac=1, random_state=42) + + # Get the indices for shuffling + obs_order = [dataset.adata.obs.index.get_loc(i) for i in obs_shuffled.index] + var_order = [dataset.adata.var.index.get_loc(i) for i in var_shuffled.index] + + # For shuffled input, we need to reorder the model output to match + model_output_shuffled = model_output[np.ix_(obs_order, var_order)] + + # Initialize task + task = PerturbationExpressionPredictionTask() + + # Run task with both inputs + original_result = task._run_task(model_output, original_task_input) + shuffled_result = task._run_task(model_output_shuffled, original_task_input) + + # Results should be identical (same conditions, same relative data) + assert set(original_result.pred_log_fc_dict.keys()) == set( + shuffled_result.pred_log_fc_dict.keys() + ) + assert set(original_result.true_log_fc_dict.keys()) == set( + shuffled_result.true_log_fc_dict.keys() + ) + + # The predicted values should be the same since we used correspondingly shuffled model output + for condition in original_result.pred_log_fc_dict.keys(): + np.testing.assert_allclose( + original_result.pred_log_fc_dict[condition], + shuffled_result.pred_log_fc_dict[condition], + rtol=1e-10, + err_msg=f"Predicted log fold changes differ for condition {condition}", + ) + np.testing.assert_allclose( + original_result.true_log_fc_dict[condition], + shuffled_result.true_log_fc_dict[condition], + rtol=1e-10, + err_msg=f"True log fold changes differ for condition {condition}", + ) + + +def test_perturbation_task_apply_model_ordering(): + """Test the apply_model_ordering method for PerturbationExpressionPredictionTaskInput.""" + + # Create a dummy dataset + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = Path(tmp_dir) / "dummy_perturbation.h5ad" + adata = create_dummy_anndata( + n_cells=6, + n_genes=4, + obs_columns=["condition"], + organism=Organism.HUMAN, + ) + adata.obs["condition"] = ["ctrl", "ctrl", "test1", "test1", "test2", "test2"] + adata.obs_names = [ + "ctrl_test1_a", + "ctrl_test2_b", + "cond_test1_a", + "cond_test1_b", + "cond_test2_a", + "cond_test2_b", + ] + + # Provide matched control cell IDs and DE results + adata.uns["control_cells_ids"] = { + "test1": { + "cond_test1_a": "non-targeting_test1_a", + "cond_test1_b": "non-targeting_test2_b", + }, + "test2": { + "cond_test2_a": "non-targeting_test1_a", + "cond_test2_b": "non-targeting_test2_b", + }, + } + + # Create DE results + gene_names = adata.var.index.tolist() + de_data = [] + for condition in ["test1", "test2"]: + for gene in gene_names: + de_data.append( + { + "condition": condition, + "gene_id": gene, + "pval_adj": 1e-6, + "logfoldchange": 2.0, + } + ) + + adata.uns["de_results_wilcoxon"] = pd.DataFrame(de_data) + adata.write_h5ad(file_path) + + # Create dataset and get task input + dataset = SingleCellPerturbationDataset( + path=file_path, + organism=Organism.HUMAN, + condition_key="condition", + control_name="ctrl", + min_de_genes_to_mask=1, # Lower threshold so genes get sampled + percent_genes_to_mask=1.0, # Use all genes + ) + dataset.load_data() + + # Create task input + task_input = PerturbationExpressionPredictionTaskInput( + adata=dataset.control_matched_adata, + target_conditions_dict=dataset.target_conditions_dict, + de_results=dataset.de_results, + gene_index=dataset.control_matched_adata.var.index, + cell_index=pd.Index( + dataset.control_matched_adata.uns["cell_barcode_condition_index"] + ), + ) + + # Create model data with shuffled ordering (same content, different order) + # We need to create model data that has the same cells as in cell_barcode_condition_index + np.random.seed(42) # For reproducible test + + # Get the cell barcodes that should match between model and task input + task_cell_barcodes = task_input.adata.uns["cell_barcode_condition_index"] + task_genes = task_input.adata.var.index + + # Create model AnnData with the same cells and genes but in shuffled order + gene_order = np.random.permutation(task_genes) + cell_order = np.random.permutation(task_cell_barcodes) + + # Create a model AnnData with the shuffled ordering + model_adata = ad.AnnData( + X=np.random.rand(len(cell_order), len(gene_order)), + obs=pd.DataFrame(index=cell_order), + var=pd.DataFrame(index=gene_order), + ) + + # Store original orderings + original_gene_order = task_input.adata.var.index.copy() + original_cell_barcode_index = task_input.adata.uns[ + "cell_barcode_condition_index" + ].copy() + + # Apply model ordering + task_input.gene_index = model_adata.var.index + task_input.cell_index = model_adata.obs.index + + # Verify that orderings have changed to match model data + pd.testing.assert_index_equal(task_input.gene_index, model_adata.var.index) + np.testing.assert_array_equal( + task_input.cell_index, + model_adata.obs.index.astype(str).values, + ) + + # Verify orderings are different from original (unless by chance they're the same) + assert not task_input.gene_index.equals(original_gene_order) + assert not np.array_equal(task_input.cell_index, original_cell_barcode_index) + + +def test_perturbation_task_apply_model_ordering_validation(): + """Test that apply_model_ordering validates matching gene and cell sets.""" + + # Create a dummy dataset + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = Path(tmp_dir) / "dummy_perturbation.h5ad" + adata = create_dummy_anndata( + n_cells=4, + n_genes=3, + obs_columns=["condition"], + organism=Organism.HUMAN, + ) + adata.obs["condition"] = ["ctrl", "ctrl", "test1", "test1"] + adata.obs_names = ["ctrl_a", "ctrl_b", "cond_a", "cond_b"] + + # Provide matched control cell IDs and DE results + + adata.uns["control_cells_ids"] = { + "test1": {"cond_a": "ctrl_a", "cond_b": "ctrl_b"} + } + + # Create DE results + gene_names = adata.var.index.tolist() + de_data = [] + for gene in gene_names: + de_data.append( + { + "condition": "test1", + "gene_id": gene, + "pval_adj": 1e-6, + "logfoldchange": 2.0, + } + ) + + adata.uns["de_results_wilcoxon"] = pd.DataFrame(de_data) + adata.write_h5ad(file_path) + + # Create dataset and get task input + dataset = SingleCellPerturbationDataset( + path=file_path, + organism=Organism.HUMAN, + condition_key="condition", + control_name="ctrl", + min_de_genes_to_mask=1, # Lower threshold so genes get sampled + percent_genes_to_mask=1.0, # Use all genes + ) + dataset.load_data() + + # Create task input + task_input = PerturbationExpressionPredictionTaskInput( + adata=dataset.control_matched_adata, + target_conditions_dict=dataset.target_conditions_dict, + de_results=dataset.de_results, + gene_index=dataset.control_matched_adata.var.index, + cell_index=pd.Index( + dataset.control_matched_adata.uns["cell_barcode_condition_index"] + ), + ) + + # Test with mismatched genes + model_adata_bad_genes = create_dummy_anndata( + n_cells=4, + n_genes=3, + obs_columns=["condition"], + organism=Organism.HUMAN, + ) + model_adata_bad_genes.obs_names = task_input.adata.obs.index # Same cells + model_adata_bad_genes.var_names = [ + "different_gene1", + "different_gene2", + "different_gene3", + ] # Different genes + task = PerturbationExpressionPredictionTask( + control_name="ctrl", + ) + + with pytest.raises( + ValueError, match="Model data contains genes that are not in the task input" + ): + task_input.gene_index = model_adata_bad_genes.var.index + task_input.cell_index = model_adata_bad_genes.obs.index + task._run_task( + np.random.rand(len(task_input.cell_index), len(task_input.gene_index)), + task_input, + ) + + # Test with mismatched cells + model_adata_bad_cells = task_input.adata.copy() + model_adata_bad_cells.obs_names = [ + "different_cell1", + "different_cell2", + "different_cell3", + "different_cell4", + ] # Different cells + task_input.cell_index = model_adata_bad_cells.obs.index + with pytest.raises( + ValueError, match="Model data contains genes that are not in the task input" + ): + task._run_task( + np.random.rand(len(task_input.cell_index), len(task_input.gene_index)), + task_input, + ) diff --git a/tests/tasks/test_task_utils.py b/tests/tasks/test_task_utils.py new file mode 100644 index 00000000..3fd650be --- /dev/null +++ b/tests/tasks/test_task_utils.py @@ -0,0 +1,190 @@ +import numpy as np +import scipy.sparse as sp +from czbenchmarks.tasks.utils import looks_like_lognorm + + +class TestLooksLikeLognorm: + """Test suite for the looks_like_lognorm function.""" + + def test_raw_count_data_returns_false(self): + """Test that raw count data (integers) returns False.""" + # Create mock raw count data (integers) + raw_data = np.random.randint(0, 1000, size=(100, 50)) + + result = looks_like_lognorm(raw_data) + + assert result is False + + def test_log_normalized_data_returns_true(self): + """Test that log-normalized data (with fractional values) returns True.""" + # Create mock log-normalized data with fractional values + log_data = np.random.lognormal(0, 1, size=(100, 50)) + + result = looks_like_lognorm(log_data) + + assert result is True + + def test_normalized_non_integer_data_returns_true(self): + """Test that any non-integer data returns True.""" + # Create data with fractional values (simulating normalized but not necessarily log-transformed) + normalized_data = np.random.rand(100, 50) * 10 # Random floats between 0-10 + + result = looks_like_lognorm(normalized_data) + + assert result is True + + def test_sparse_raw_count_data_returns_false(self): + """Test that sparse raw count data returns False.""" + # Create sparse raw count data + raw_data = np.random.randint(0, 100, size=(100, 50)) + sparse_data = sp.csr_matrix(raw_data) + + result = looks_like_lognorm(sparse_data) + + assert result is False + + def test_sparse_log_normalized_data_returns_true(self): + """Test that sparse log-normalized data returns True.""" + # Create sparse log-normalized data + log_data = np.random.lognormal(0, 1, size=(100, 50)) + sparse_log_data = sp.csr_matrix(log_data) + + result = looks_like_lognorm(sparse_log_data) + + assert result is True + + def test_custom_n_cells_parameter(self): + """Test that the n_cells parameter works correctly.""" + # Create log-normalized data + log_data = np.random.lognormal(0, 1, size=(1000, 50)) + + # Test with different n_cells values + result_50 = looks_like_lognorm(log_data, sample_size=50) + result_100 = looks_like_lognorm(log_data, sample_size=100) + + # Both should return True for log-normalized data + assert result_50 is True + assert result_100 is True + + def test_n_cells_larger_than_data_size(self): + """Test behavior when n_cells is larger than the actual number of cells.""" + # Create small dataset + log_data = np.random.lognormal(0, 1, size=(10, 50)) + + # Request more cells than available + result = looks_like_lognorm(log_data, sample_size=100) + + # Should still work by using all available cells + assert result is True + + def test_custom_epsilon_parameter(self): + """Test that the epsilon parameter affects detection sensitivity.""" + # Create data that's almost integer but with tiny fractional parts + almost_integer_data = np.random.randint(0, 100, size=(100, 50)) + 1e-4 + + # With default epsilon (1e-2), should return False + result_default = looks_like_lognorm(almost_integer_data) + assert result_default is False + + # With very small tol (1e-5), should return True + result_small_tol = looks_like_lognorm(almost_integer_data, tol=1e-5) + assert result_small_tol is True + + def test_all_zero_data(self): + """Test behavior with all-zero data.""" + zero_data = np.zeros((100, 50)) + + result = looks_like_lognorm(zero_data) + + # All zeros should be considered as integer data (raw counts) + assert result is False + + def test_mixed_integer_and_float_data(self): + """Test data that's mostly integer but has some fractional values.""" + # Create mostly integer data + mixed_data = np.random.randint(0, 100, size=(100, 50)).astype(float) + # Add a fractional value to ensure the sum is not an integer + mixed_data[0, 0] += 0.3 # Make first cell have fractional sum + + result = looks_like_lognorm(mixed_data) + + # Should return True since some cells have fractional sums + assert result is True + + def test_single_cell_data(self): + """Test behavior with single cell data.""" + # Single cell with integer values + single_cell_int = np.array([[1, 2, 3, 4, 5]]) + result_int = looks_like_lognorm(single_cell_int) + assert result_int is False + + # Single cell with fractional values + single_cell_float = np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]) + result_float = looks_like_lognorm(single_cell_float) + assert result_float is True + + def test_deterministic_behavior_with_seed(self): + """Test that results are consistent when sampling the same data.""" + # Create data + log_data = np.random.lognormal(0, 1, size=(1000, 50)) + + # Set seed for reproducible sampling + np.random.seed(42) + result1 = looks_like_lognorm(log_data, sample_size=100) + + np.random.seed(42) + result2 = looks_like_lognorm(log_data, sample_size=100) + + # Results should be the same with same seed + assert result1 == result2 + + def test_edge_case_very_small_dataset(self): + """Test with very small dataset (fewer cells than default sampling).""" + # Create tiny dataset with sums that are actually fractional + tiny_data = np.array( + [[1.3, 2.5], [3.2, 4.1]] + ) # 2 cells, 2 genes (sums: 3.8, 7.3) + + result = looks_like_lognorm( + tiny_data, sample_size=500 + ) # Request more cells than available + + # Should still work and return True for fractional data + assert result is True + + def test_fractional_values_integer_sums(self): + """Test that fractional values with integer sums return False.""" + # Create data where individual values are fractional but sums are integers + # For example: [0.5, 0.5] sums to 1.0 (integer) + data_with_integer_sums = np.array([[0.5, 0.5], [1.5, 2.5]]) # sums: [1.0, 4.0] + + result = looks_like_lognorm(data_with_integer_sums) + + # Should return False because cell sums are integers (within epsilon) + assert result is False + + def test_explains_function_behavior(self): + """Test that demonstrates the function checks cell sums, not individual values.""" + # Create data where individual values are fractional but cell sums are integers + integer_sum_data = np.array( + [ + [0.25, 0.25, 0.25, 0.25], # sum = 1.0 + [0.5, 0.5, 1.0, 1.0], # sum = 3.0 + [1.1, 1.9, 2.0, 3.0], # sum = 8.0 + ] + ) + + result_integer_sums = looks_like_lognorm(integer_sum_data) + assert result_integer_sums is False # Integer sums + + # Create data where cell sums are fractional + fractional_sum_data = np.array( + [ + [0.3, 0.4], # sum = 0.7 + [1.1, 2.2], # sum = 3.3 + [0.9, 1.5], # sum = 2.4 + ] + ) + + result_fractional_sums = looks_like_lognorm(fractional_sum_data) + assert result_fractional_sums is True # Fractional sums diff --git a/tests/tasks/test_tasks.py b/tests/tasks/test_tasks.py deleted file mode 100644 index 3711e313..00000000 --- a/tests/tasks/test_tasks.py +++ /dev/null @@ -1,667 +0,0 @@ -import pytest -import pandas as pd -import numpy as np - -import anndata as ad -from czbenchmarks.tasks import ( - ClusteringTask, - ClusteringTaskInput, - EmbeddingTask, - EmbeddingTaskInput, - BatchIntegrationTask, - BatchIntegrationTaskInput, - MetadataLabelPredictionTask, - MetadataLabelPredictionTaskInput, -) -from czbenchmarks.tasks.single_cell import ( - CrossSpeciesIntegrationTask, - CrossSpeciesIntegrationTaskInput, - PerturbationExpressionPredictionTask, - PerturbationExpressionPredictionTaskInput, -) -from czbenchmarks.tasks.types import CellRepresentation -from czbenchmarks.datasets.types import Organism -from czbenchmarks.metrics.types import MetricResult, MetricType - -from tests.utils import ( - create_dummy_anndata, - DummyTask, - create_dummy_perturbation_anndata, - DummyTaskInput, -) - - -# FIXME these tests could be split into multiple files and fixtures moved to -# conftest.py - - -@pytest.fixture -def dummy_anndata(): - n_cells: int = 500 - n_genes: int = 200 - organism: Organism = Organism.HUMAN - obs_columns: list[str] = ["cell_type", "batch"] - var_columns: list[str] = ["feature_name"] - anndata: ad.AnnData = create_dummy_anndata( - n_cells=n_cells, - n_genes=n_genes, - organism=organism, - obs_columns=obs_columns, - var_columns=var_columns, - ) - - expression_matrix: CellRepresentation = anndata.X.copy() - obs: pd.DataFrame = anndata.obs.copy() - var: pd.DataFrame = anndata.var.copy() - - # TODO perform PCA on expression matrix to get true embedding - embedding_matrix: CellRepresentation = expression_matrix.toarray() - - return { - "anndata": anndata, - "expression_matrix": expression_matrix, - "obs": obs, - "var": var, - "embedding_matrix": embedding_matrix, - } - - -@pytest.fixture -def expression_matrix(dummy_anndata): - return dummy_anndata["expression_matrix"] - - -@pytest.fixture -def embedding_matrix(dummy_anndata): - return dummy_anndata["embedding_matrix"] - - -@pytest.fixture -def obs(dummy_anndata): - return dummy_anndata["obs"] - - -@pytest.fixture -def var(dummy_anndata): - return dummy_anndata["var"] - - -@pytest.fixture -def fixture_data(request): - # Enables lazy generation of fixture data so fixtures can be used as - # parameters - valid_fixture_names = ["expression_matrix", "embedding_matrix", "obs", "var"] - fixture_name, other_data = request.param - if isinstance(fixture_name, str): - fixture_data = ( - request.getfixturevalue(fixture_name) - if fixture_name in valid_fixture_names - else fixture_name - ) - else: - fixture_data = [ - request.getfixturevalue(f) if f in valid_fixture_names else f - for f in fixture_name - ] - return fixture_data, other_data - - -@pytest.mark.parametrize( - "fixture_data", - [ - ("expression_matrix", False), - (["expression_matrix", "expression_matrix"], True), - ], - indirect=True, -) -def test_embedding_valid_input_output(fixture_data): - """Test that embedding is accepted and List[MetricResult] is returned.""" - embedding, requires_multiple_datasets = fixture_data - task = DummyTask(requires_multiple_datasets=requires_multiple_datasets) - results = task.run( - cell_representation=embedding, - task_input=DummyTaskInput(), - ) - - assert isinstance(results, list) - assert all(isinstance(r, MetricResult) for r in results) - - -@pytest.mark.parametrize( - "fixture_data", - [ - ( - "abcd", - [False, "This task requires a single cell representation for input"], - ), - ( - ["embedding_matrix"], - [False, "This task requires a single cell representation for input"], - ), - ( - ["embedding_matrix", "embedding_matrix"], - [False, "This task requires a single cell representation for input"], - ), - ( - "embedding_matrix", - [True, "This task requires a list of cell representations"], - ), - ( - ["abcd", "embedding_matrix"], - [True, "This task requires a list of cell representations"], - ), - ( - ["embedding_matrix"], - [ - True, - "This task requires a list of cell representations but only one " - "was provided", - ], - ), - ], - indirect=True, -) -def test_embedding_invalid_input(fixture_data): - """Test ValueError for mismatch with requires_multiple_datasets.""" - embedding_list, (requires_multiple_datasets, error_message) = fixture_data - task = DummyTask(requires_multiple_datasets=requires_multiple_datasets) - with pytest.raises(ValueError, match=error_message): - task.run( - cell_representation=embedding_list, - task_input=DummyTaskInput(), - ) - - -@pytest.mark.parametrize( - "task_class,task_input_builder", - [ - ( - ClusteringTask, - lambda obs: ClusteringTaskInput(obs=obs, input_labels=obs["cell_type"]), - ), - ( - EmbeddingTask, - lambda obs: EmbeddingTaskInput(input_labels=obs["cell_type"]), - ), - ( - BatchIntegrationTask, - lambda obs: BatchIntegrationTaskInput( - labels=obs["cell_type"], batch_labels=obs["batch"] - ), - ), - ( - MetadataLabelPredictionTask, - lambda obs: MetadataLabelPredictionTaskInput(labels=obs["cell_type"]), - ), - ], -) -def test_task_execution( - task_class, - task_input_builder, - embedding_matrix, - expression_matrix, - obs, -): - """Test that each task executes without errors on compatible data.""" - - task_input = task_input_builder(obs) - - task = task_class() - - try: - # Test regular task execution - results = task.run( - cell_representation=embedding_matrix, - task_input=task_input, - ) - assert isinstance(results, list) - assert all(isinstance(r, MetricResult) for r in results) - - # Test baseline execution if implemented - try: - n_pcs = min(50, expression_matrix.shape[1] - 1) - baseline_embedding = task.compute_baseline(expression_matrix, n_pcs=n_pcs) - if hasattr(task_input, "var"): - task_input.var = task_input.var.iloc[:n_pcs] - - baseline_results = task.run( - cell_representation=baseline_embedding, - task_input=task_input, - ) - assert isinstance(baseline_results, list) - assert all(isinstance(r, MetricResult) for r in baseline_results) - except NotImplementedError: - # Some tasks may not implement compute_baseline - pass - - except Exception as e: - pytest.fail(f"Task {task_class.__name__} failed unexpectedly: {e}") - - -def test_cross_species_task(embedding_matrix, obs): - """Test that CrossSpeciesIntegrationTask executes without errors.""" - task = CrossSpeciesIntegrationTask() - embedding_list = [embedding_matrix, embedding_matrix] - labels = obs["cell_type"] - labels_list = [labels, labels] - organism_list = [Organism.HUMAN, Organism.MOUSE] - task_input = CrossSpeciesIntegrationTaskInput( - labels=labels_list, organism_list=organism_list - ) - - try: - # Test regular task execution - results = task.run( - cell_representation=embedding_list, - task_input=task_input, - ) - - # Verify results structure - assert isinstance(results, list) - assert all(isinstance(r, MetricResult) for r in results) - - # Test that baseline raises NotImplementedError - with pytest.raises(NotImplementedError): - task.compute_baseline() - - except Exception as e: - pytest.fail(f"CrossSpeciesIntegrationTask failed unexpectedly: {e}") - - -def test_perturbation_task(): - """Test that PerturbationExpressionPredictionTask executes without errors.""" - # Create dummy perturbation data - perturbation_data: dict = create_dummy_perturbation_anndata( - n_cells=500, - n_genes=200, - organism=Organism.HUMAN, - condition_column="condition", - split_column="split", - ) - gene_pert = perturbation_data["gene_pert"] - # Convert sparse matrix to dense array to avoid matrix object issues - cell_representation = perturbation_data["adata"].X.toarray() - var_names = perturbation_data["adata"].var_names - - # Task and argument setup - task = PerturbationExpressionPredictionTask() - - # Create DE results DataFrame matching expected structure - de_results = pd.DataFrame( - { - "condition": [gene_pert] * len(var_names), - "gene_id": var_names, - "logfoldchange": np.random.randn(len(var_names)), - "pval_adj": np.random.uniform(0, 0.01, len(var_names)), - } - ) - - # Create masked_adata_obs DataFrame and fix condition naming - adata = perturbation_data["adata"] - masked_adata_obs = adata.obs.copy() - - # Fix condition naming to match task expectations - # Task expects control cells to be named: {control_prefix}_{condition} - control_condition_name = f"non-targeting_{gene_pert}" - masked_adata_obs.loc[masked_adata_obs["condition"] == "ctrl", "condition"] = ( - control_condition_name - ) - - # Create target_conditions_to_save dict - map cell IDs to lists of genes to mask - target_conditions_to_save = {} - for cell_id in adata.obs_names: - # Sample some genes to mask for each cell - n_genes_to_mask = min(10, len(var_names) // 2) - target_conditions_to_save[cell_id] = list( - np.random.choice(var_names, n_genes_to_mask, replace=False) - ) - - task_input = PerturbationExpressionPredictionTaskInput( - de_results=de_results, - masked_adata_obs=masked_adata_obs, - var_index=var_names, - target_conditions_to_save=target_conditions_to_save, - row_index=adata.obs.index, - ) - - # Five metrics per condition: accuracy, precision, recall, f1, correlation - # We have one perturbed condition, so 5 metrics total - num_metrics = 5 - - try: - # Test regular task execution - results = task.run( - cell_representation, - task_input, - ) - - # Verify results structure - the method returns a list of MetricResult - assert isinstance(results, list) - assert all(isinstance(r, MetricResult) for r in results) - assert len(results) == num_metrics - - # Test baseline with both mean and median - for baseline_type in ["mean", "median"]: - baseline_embedding = task.compute_baseline( - cell_representation=cell_representation, - baseline_type=baseline_type, - ) - # Create a new task input with the baseline embedding - baseline_results = task.run(baseline_embedding, task_input) - assert isinstance(baseline_results, list) - assert all(isinstance(r, MetricResult) for r in baseline_results) - assert len(baseline_results) == num_metrics - except Exception as e: - pytest.fail(f"Test failed with exception: {e}") - - -def test_perturbation_expression_prediction_task_wilcoxon(): - """Test Wilcoxon path computes correct vectors and metrics.""" - # Deterministic gene set and per-condition true/predicted effects - gene_names = ["G0", "G1", "G2", "G3"] - true_lfc_gene_A = np.array([1.0, 0.5, -0.5, -1.0]) - true_lfc_gene_B = np.array([2.0, 1.0, -1.0, -2.0]) - - # Build DE results matching the designed true effects for both conditions - de_res_wilcoxon_df = pd.concat( - [ - pd.DataFrame( - { - "logfoldchange": true_lfc_gene_A, - "target_gene": ["gene_A"] * len(gene_names), - "names": gene_names, - "pval": [0.001] * len(gene_names), - "pval_adj": [0.001] * len(gene_names), - "condition": ["gene_A"] * len(gene_names), - "condition_ensembl_id": ["ENSG_A"] * len(gene_names), - } - ), - pd.DataFrame( - { - "logfoldchange": true_lfc_gene_B, - "target_gene": ["gene_B"] * len(gene_names), - "names": gene_names, - "pval": [0.001] * len(gene_names), - "pval_adj": [0.001] * len(gene_names), - "condition": ["gene_B"] * len(gene_names), - "condition_ensembl_id": ["ENSG_B"] * len(gene_names), - } - ), - ], - ignore_index=True, - ) - - # Build AnnData with 4 groups: condition/control for A and B - n_per_group = 4 - conditions = ( - ["gene_A"] * n_per_group - + ["ctrl_gene_A"] * n_per_group - + ["gene_B"] * n_per_group - + ["ctrl_gene_B"] * n_per_group - ) - # Create base cell names without underscores (so str.split("_").str[0] works correctly) - base_cell_names = [f"cellbarcode{i}" for i in range(len(conditions))] - # Create extended obs names like real dataset: base_name + "_" + condition - obs_names = [ - f"{base_name}_{cond}" for base_name, cond in zip(base_cell_names, conditions) - ] - - X = np.zeros((len(conditions), len(gene_names)), dtype=float) - # Set group means so that pred_log_fc equals the designed true_lfc - X[0:n_per_group, :] = true_lfc_gene_A # gene_A group - X[n_per_group : 2 * n_per_group, :] = 0.0 # ctrl_gene_A group - X[2 * n_per_group : 3 * n_per_group, :] = true_lfc_gene_B # gene_B group - X[3 * n_per_group : 4 * n_per_group, :] = 0.0 # ctrl_gene_B group - - adata = ad.AnnData( - X=X, - obs=pd.DataFrame({"condition": conditions}, index=obs_names), - var=pd.DataFrame(index=gene_names), - ) - target_conditions_to_save = {obs_name: list(gene_names) for obs_name in obs_names} - cell_representation = X - - # Ensure de_results has the expected gene identifier column - de_res_wilcoxon_df["gene_id"] = de_res_wilcoxon_df["names"] - - task = PerturbationExpressionPredictionTask( - control_prefix="ctrl", - ) - task_input = PerturbationExpressionPredictionTaskInput( - de_results=de_res_wilcoxon_df, - masked_adata_obs=adata.obs, - var_index=adata.var_names, - target_conditions_to_save=target_conditions_to_save, - row_index=pd.Index(base_cell_names), # Full dataset uses base names - ) - - # First, check that true/pred vectors produced by _run_task match expectations - task_output = task._run_task(cell_representation, task_input) - assert set(task_output.pred_log_fc_dict.keys()) == {"gene_A", "gene_B"} - assert set(task_output.true_log_fc_dict.keys()) == {"gene_A", "gene_B"} - assert np.allclose(task_output.pred_log_fc_dict["gene_A"], true_lfc_gene_A) - assert np.allclose(task_output.pred_log_fc_dict["gene_B"], true_lfc_gene_B) - assert np.allclose(task_output.true_log_fc_dict["gene_A"], true_lfc_gene_A) - assert np.allclose(task_output.true_log_fc_dict["gene_B"], true_lfc_gene_B) - - # Then, run the full task and validate metrics are perfect - results = task.run(cell_representation, task_input) - - assert isinstance(results, list) - assert all(isinstance(r, MetricResult) for r in results) - # Expect results for both conditions x 5 metric types = 10 total results - assert len(results) == 10 - - # Each result should have perfect scores - for r in results: - assert np.isclose(r.value, 1.0) - - # Verify metric types present - metric_types = {result.metric_type for result in results} - expected_types = { - MetricType.ACCURACY_CALCULATION, - MetricType.PRECISION_CALCULATION, - MetricType.RECALL_CALCULATION, - MetricType.F1_CALCULATION, - MetricType.SPEARMAN_CORRELATION_CALCULATION, - } - assert expected_types.issubset(metric_types) - - -def test_perturbation_expression_prediction_task_ttest(): - """Test t-test path computes correct vectors and metrics.""" - # Deterministic gene set and per-condition true/predicted effects - gene_names = ["G0", "G1", "G2", "G3"] - true_smd_gene_A = np.array([0.2, 0.5, -0.5, -0.2]) - true_smd_gene_B = np.array([1.0, 0.7, -0.7, -1.0]) - - # Build DE results matching the designed true effects for both conditions - de_res_ttest_df = pd.concat( - [ - pd.DataFrame( - { - "standardized_mean_diff": true_smd_gene_A, - "logfoldchange": true_smd_gene_A, # Add logfoldchange column for compatibility - "target_gene": ["gene_A"] * len(gene_names), - "names": gene_names, - "pval": [0.001] * len(gene_names), - "pval_adj": [0.001] * len(gene_names), - "condition": ["gene_A"] * len(gene_names), - "condition_ensembl_id": ["ENSG_A"] * len(gene_names), - } - ), - pd.DataFrame( - { - "standardized_mean_diff": true_smd_gene_B, - "logfoldchange": true_smd_gene_B, # Add logfoldchange column for compatibility - "target_gene": ["gene_B"] * len(gene_names), - "names": gene_names, - "pval": [0.001] * len(gene_names), - "pval_adj": [0.001] * len(gene_names), - "condition": ["gene_B"] * len(gene_names), - "condition_ensembl_id": ["ENSG_B"] * len(gene_names), - } - ), - ], - ignore_index=True, - ) - - # Build AnnData with 4 groups: condition/control for A and B - n_per_group = 4 - conditions = ( - ["gene_A"] * n_per_group - + ["ctrl_gene_A"] * n_per_group - + ["gene_B"] * n_per_group - + ["ctrl_gene_B"] * n_per_group - ) - # Create base cell names without underscores (so str.split("_").str[0] works correctly) - base_cell_names = [f"cellbarcode{i}" for i in range(len(conditions))] - # Create extended obs names like real dataset: base_name + "_" + condition - obs_names = [ - f"{base_name}_{cond}" for base_name, cond in zip(base_cell_names, conditions) - ] - - X = np.zeros((len(conditions), len(gene_names)), dtype=float) - # Set group means so that pred_log_fc equals the designed "true" standardized_mean_diff - X[0:n_per_group, :] = true_smd_gene_A # gene_A group - X[n_per_group : 2 * n_per_group, :] = 0.0 # ctrl_gene_A group - X[2 * n_per_group : 3 * n_per_group, :] = true_smd_gene_B # gene_B group - X[3 * n_per_group : 4 * n_per_group, :] = 0.0 # ctrl_gene_B group - - adata = ad.AnnData( - X=X, - obs=pd.DataFrame({"condition": conditions}, index=obs_names), - var=pd.DataFrame(index=gene_names), - ) - target_conditions_to_save = {obs_name: list(gene_names) for obs_name in obs_names} - cell_representation = X - - # Ensure de_results has the expected gene identifier column - de_res_ttest_df["gene_id"] = de_res_ttest_df["names"] - - task = PerturbationExpressionPredictionTask( - control_prefix="ctrl", - ) - task_input = PerturbationExpressionPredictionTaskInput( - de_results=de_res_ttest_df, - masked_adata_obs=adata.obs, - var_index=adata.var_names, - target_conditions_to_save=target_conditions_to_save, - row_index=pd.Index(base_cell_names), # Full dataset uses base names - ) - - # First, check that true/pred vectors produced by _run_task match expectations - task_output = task._run_task(cell_representation, task_input) - assert set(task_output.pred_log_fc_dict.keys()) == {"gene_A", "gene_B"} - assert set(task_output.true_log_fc_dict.keys()) == {"gene_A", "gene_B"} - assert np.allclose(task_output.pred_log_fc_dict["gene_A"], true_smd_gene_A) - assert np.allclose(task_output.pred_log_fc_dict["gene_B"], true_smd_gene_B) - assert np.allclose(task_output.true_log_fc_dict["gene_A"], true_smd_gene_A) - assert np.allclose(task_output.true_log_fc_dict["gene_B"], true_smd_gene_B) - - # Then, run the full task and validate metrics are perfect - try: - results = task.run(cell_representation, task_input) - - assert isinstance(results, list) - assert all(isinstance(r, MetricResult) for r in results) - # Expect results for both conditions x 5 metric types = 10 total results - assert len(results) == 10 - - # Each result should have perfect scores - for r in results: - assert np.isclose(r.value, 1.0) - - metric_types = {result.metric_type for result in results} - expected_types = { - MetricType.ACCURACY_CALCULATION, - MetricType.PRECISION_CALCULATION, - MetricType.RECALL_CALCULATION, - MetricType.F1_CALCULATION, - MetricType.SPEARMAN_CORRELATION_CALCULATION, - } - assert expected_types.issubset(metric_types) - except Exception as e: - pytest.fail( - f"PerturbationExpressionPredictionTask (t-test) failed unexpectedly: {e}" - ) - - -def test_perturbation_expression_prediction_task_load_from_task_inputs(tmp_path): - """Test that the task can load inputs from stored task files.""" - from czbenchmarks.datasets.single_cell_perturbation import ( - SingleCellPerturbationDataset, - ) - from czbenchmarks.datasets.types import Organism - from tests.utils import create_dummy_anndata - - # Create a dummy dataset and store its task inputs - file_path = tmp_path / "dummy_perturbation.h5ad" - adata = create_dummy_anndata( - n_cells=6, - n_genes=3, - obs_columns=["condition"], - organism=Organism.HUMAN, - ) - adata.obs["condition"] = ["ctrl", "ctrl", "test1", "test1", "test2", "test2"] - adata.obs_names = [ - "ctrl_test1_a", - "ctrl_test2_b", - "cond_test1_a", - "cond_test1_b", - "cond_test2_a", - "cond_test2_b", - ] - # Provide matched control cell IDs and DE results - adata.uns["control_cells_ids"] = { - "test1": { - "cond_test1_a": "ctrl_test1_a", - "cond_test1_b": "ctrl_test2_b", - }, - "test2": { - "cond_test2_a": "ctrl_test1_a", - "cond_test2_b": "ctrl_test2_b", - }, - } - de_conditions = ["test1"] * 10 + ["test2"] * 10 - de_genes = [f"ENSG000000000{str(i).zfill(2)}" for i in range(20)] - adata.uns["de_results_wilcoxon"] = pd.DataFrame( - { - "condition": de_conditions, - "gene": de_genes, - "pval_adj": [1e-6] * 20, - "logfoldchange": [2.0] * 20, - } - ) - adata.write_h5ad(file_path) - - # Create dataset and store task inputs - dataset = SingleCellPerturbationDataset( - path=file_path, - organism=Organism.HUMAN, - condition_key="condition", - control_name="ctrl", - ) - dataset.load_data() - stored_dir = dataset.store_task_inputs() - - # Test loading task inputs using the standalone function - from czbenchmarks.tasks.single_cell.perturbation_expression_prediction import ( - load_perturbation_task_input_from_saved_files, - ) - - task_input = load_perturbation_task_input_from_saved_files(stored_dir) - - # Verify the loaded task input has the expected structure - assert isinstance(task_input, PerturbationExpressionPredictionTaskInput) - assert isinstance(task_input.de_results, pd.DataFrame) - assert isinstance(task_input.masked_adata_obs, pd.DataFrame) - assert len(task_input.target_conditions_to_save) > 0 - assert all( - isinstance(v, list) for v in task_input.target_conditions_to_save.values() - ) - - # Verify data integrity - assert task_input.de_results.shape[0] > 0 - assert task_input.masked_adata_obs.shape[0] > 0 - assert len(task_input.var_index) > 0 diff --git a/tests/test_integration_end_to_end.py b/tests/test_integration_end_to_end.py index 2b36316f..32afe815 100644 --- a/tests/test_integration_end_to_end.py +++ b/tests/test_integration_end_to_end.py @@ -1,7 +1,10 @@ import json import numpy as np +import os +import pandas as pd import pytest - +import tempfile +import anndata as ad from czbenchmarks.constants import RANDOM_SEED from czbenchmarks.datasets.single_cell_labeled import SingleCellLabeledDataset from czbenchmarks.datasets import SingleCellPerturbationDataset @@ -199,12 +202,16 @@ def test_end_to_end_perturbation_expression_prediction(): ) # Build task input directly from dataset + # Create AnnData with required data in uns + adata = dataset.control_matched_adata.copy() + adata.uns["cell_barcode_index"] = dataset.control_matched_adata.obs.index.astype( + str + ).values + task_input = PerturbationExpressionPredictionTaskInput( + adata=adata, + target_conditions_dict=dataset.target_conditions_dict, de_results=dataset.de_results, - var_index=dataset.control_matched_adata.var.index, - masked_adata_obs=dataset.control_matched_adata.obs, - target_conditions_to_save=dataset.target_conditions_to_save, - row_index=dataset.adata.obs.index, ) # Create random model output matching dataset dimensions @@ -278,3 +285,113 @@ def test_end_to_end_perturbation_expression_prediction(): assert "perturbation" in parsed assert "model" in parsed["perturbation"] assert "baseline" in parsed["perturbation"] + + +@pytest.mark.integration +def test_end_to_end_perturbation_with_model_anndata_file(): + """Integration test for perturbation task with model data from AnnData file. + + This test demonstrates the workflow where a user provides model predictions + in an AnnData file and uses apply_model_ordering to align the data. + """ + + # Load dataset (requires cloud access) + dataset: SingleCellPerturbationDataset = load_dataset( + "replogle_k562_essential_perturbpredict" + ) + + # Create task input from dataset + task_input = PerturbationExpressionPredictionTaskInput( + adata=dataset.control_matched_adata, + target_conditions_dict=dataset.target_conditions_dict, + de_results=dataset.de_results, + ) + + # Simulate a model AnnData file with different ordering + # This mimics the case where a user has predictions from a model in AnnData format + model_adata = dataset.control_matched_adata.copy() + + # Shuffle the ordering to simulate different gene/cell order from model + + np.random.seed(123) # For reproducible test + + # Shuffle gene order + gene_order = np.random.permutation(model_adata.var.index) + model_adata = model_adata[:, gene_order] + + # Shuffle cell order + cell_order = np.random.permutation(model_adata.obs.index) + model_adata = model_adata[cell_order, :] + + # Save model data to temporary file (simulating user providing model file) + with tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False) as tmp_file: + model_file_path = tmp_file.name + model_adata.write_h5ad(model_file_path) + + try: + # Load model data from file (as user would do) + loaded_model_adata = ad.read_h5ad(model_file_path) + + # Apply model ordering to align task input with model data + # This is the key functionality we're testing + task_input.apply_model_ordering(loaded_model_adata) + + # Verify that task input now matches model ordering + pd.testing.assert_index_equal( + task_input.adata.var.index, loaded_model_adata.var.index + ) + np.testing.assert_array_equal( + task_input.adata.uns["cell_barcode_index"], + loaded_model_adata.obs.index.astype(str).values, + ) + + # Create model output matching the new ordering + aligned_model_output = np.random.rand( + task_input.adata.shape[0], task_input.adata.shape[1] + ) + + # Initialize and run task + task = PerturbationExpressionPredictionTask() + results = task.run(aligned_model_output, task_input) + + # Validate results structure + assert isinstance(results, list) + assert len(results) > 0 + + for result in results: + assert hasattr(result, "metric_type") + assert hasattr(result, "value") + assert hasattr(result, "params") + assert not np.isnan(result.value), f"Got NaN for {result.metric_type}" + + # Verify we have the expected metric types + metric_types = {r.metric_type.value for r in results} + expected_metrics = { + "accuracy_calculation", + "precision_calculation", + "recall_calculation", + "f1_calculation", + "spearman_correlation_calculation", + } + assert expected_metrics.issubset(metric_types) + + # Test JSON serialization + results_dict = [ + { + "metric_type": result.metric_type.value, + "value": float(result.value), + "params": result.params, + } + for result in results + ] + + json_output = json.dumps(results_dict, indent=2) + assert isinstance(json_output, str) + parsed = json.loads(json_output) + assert len(parsed) == len(results) + + finally: + # Clean up temporary file + + if os.path.exists(model_file_path): + os.unlink(model_file_path)