Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 74 additions & 15 deletions examples/example_perturbation_expression_prediction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,50 @@
"""
Example script for running the Perturbation Expression Prediction Task.

CELL REPRESENTATION DATA (OPTIONAL):
By default, this script generates random cell representation data for testing.
You can provide your own cell representation data as an AnnData file:

For model_data_file (OPTIONAL):
# Create an AnnData object with your cell representations
import anndata as ad
import numpy as np

# Your cell representation matrix (cells x genes)
cell_representations = np.random.rand(1000, 500) # Replace with your actual data

# Gene names should be in .var.index
gene_names = ['GENE1', 'GENE2', 'GENE3', ...] # Your gene identifiers

# Cell/perturbation identifiers should be in .obs.index
cell_ids = ['CELL_1', 'CELL_2', 'CELL_3', ...] # Your cell identifiers

# Create and save AnnData
adata = ad.AnnData(X=cell_representations)
adata.var.index = gene_names
adata.obs.index = cell_ids
adata.write_h5ad('my_model_data.h5ad')

# Then run the script with:
# python example_perturbation_expression_prediction.py --model_data_file my_model_data.h5ad

IMPORTANT:
- If no model_data_file is provided, random data will be generated for testing
- The gene ordering (.var.index) and cell ordering (.obs.index) from your file will be used
- Your data dimensions must match the dataset dimensions
"""

import logging
import sys
import argparse
from czbenchmarks.datasets import load_dataset
import tempfile
import yaml
from pathlib import Path

import anndata as ad
import numpy as np

from czbenchmarks.datasets import load_dataset, SingleCellPerturbationDataset
from czbenchmarks.tasks.single_cell import (
PerturbationExpressionPredictionTask,
PerturbationExpressionPredictionTaskInput,
Expand All @@ -10,12 +53,7 @@
load_perturbation_task_input_from_saved_files,
)
from czbenchmarks.tasks.utils import print_metrics_summary
import numpy as np
from czbenchmarks.datasets import SingleCellPerturbationDataset
from czbenchmarks.tasks.types import CellRepresentation
import tempfile
import yaml
from pathlib import Path

if __name__ == "__main__":
"""Runs a task to calculate perturbation metrics.
Expand Down Expand Up @@ -86,6 +124,14 @@
default=0.55,
help="Minimum standardized mean difference for DE filtering (used when --metric=t-test)",
)
parser.add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the help under "metric", could we list the two possibilities?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, all default values should match what the default is set to in the respective method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure the values match the defaults

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Percent genes to mask does not mask (that's the one I was looking at when I wrote this). The rest are indeed idential. In the dataset class:
percent_genes_to_mask: float = 0.5

Copy link
Collaborator

@mlgill mlgill Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, metric should be removed as an arg from the script -- it's been deleted from the dataset/task since there is only one possibility right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

"--model_data_file",
help="[OPTIONAL] Path to AnnData file (.h5ad) containing your cell representation data. "
"If not provided, random data will be generated for testing. "
"The file should have: cell representations in .X, gene names in .var.index, "
"and cell identifiers in .obs.index. "
"The gene names and cell identifiers should match the task input, although the ordering does not need to be the same.",
)
Copy link
Collaborator

@mlgill mlgill Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 113 notes

TODO: Once PR 381 is merged, use the new load_local_dataset function

PR 381 has been merged. Can this be done or should this comment be removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, thanks for catching it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the resolution here? Is it not possible to use the new function?


args = parser.parse_args()

Expand Down Expand Up @@ -122,19 +168,32 @@
print("Task inputs loaded from saved files")
else:
print("Creating task input directly from dataset...")
# Create task input directly from dataset
# Create task input directly from dataset with separate fields
task_input = PerturbationExpressionPredictionTaskInput(
adata=dataset.control_matched_adata,
target_conditions_dict=dataset.target_conditions_dict,
de_results=dataset.de_results,
var_index=dataset.control_matched_adata.var.index,
masked_adata_obs=dataset.control_matched_adata.obs,
target_conditions_to_save=dataset.target_conditions_to_save,
row_index=dataset.adata.obs.index,
)

# Generate random model output
model_output: CellRepresentation = np.random.rand(
dataset.adata.shape[0], dataset.adata.shape[1]
)
# Load model data or generate random data
if args.model_data_file:
model_adata = ad.read_h5ad(args.model_data_file)

# Use the cell representation data from the file
# Handle both dense and sparse model_adata.X
if hasattr(model_adata.X, "toarray"):
model_output: CellRepresentation = model_adata.X.toarray()
else:
model_output: CellRepresentation = model_adata.X
# Apply the gene and cell ordering from the model data to the task input and validate dimensions
task_input.apply_model_ordering(model_adata)
else:
print("No model data file provided - generating random data for testing")

# Generate random model output for testing
model_output: CellRepresentation = np.random.rand(
dataset.adata.shape[0], dataset.adata.shape[1]
)

# Run task
task = PerturbationExpressionPredictionTask(metric=args.metric)
Expand Down
Loading