-
Notifications
You must be signed in to change notification settings - Fork 6
Pbinder/task feat changes #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
d4a8804
a06f892
1a28368
3391383
e6bb224
e7b2a0f
eb7f135
cb77ff1
e39b748
19de04c
c54d8dd
b20b4b5
654a0f1
b02ec56
71d1bbc
1be30aa
4dcf186
9f40130
3c539e1
72c0816
93de3e3
f8d18d1
16c00de
3ec4797
c795846
202239e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,50 @@ | ||
""" | ||
Example script for running the Perturbation Expression Prediction Task. | ||
|
||
CELL REPRESENTATION DATA (OPTIONAL): | ||
By default, this script generates random cell representation data for testing. | ||
You can provide your own cell representation data as an AnnData file: | ||
|
||
For model_data_file (OPTIONAL): | ||
# Create an AnnData object with your cell representations | ||
import anndata as ad | ||
import numpy as np | ||
|
||
# Your cell representation matrix (cells x genes) | ||
cell_representations = np.random.rand(1000, 500) # Replace with your actual data | ||
|
||
# Gene names should be in .var.index | ||
gene_names = ['GENE1', 'GENE2', 'GENE3', ...] # Your gene identifiers | ||
|
||
# Cell/perturbation identifiers should be in .obs.index | ||
cell_ids = ['CELL_1', 'CELL_2', 'CELL_3', ...] # Your cell identifiers | ||
|
||
# Create and save AnnData | ||
adata = ad.AnnData(X=cell_representations) | ||
adata.var.index = gene_names | ||
adata.obs.index = cell_ids | ||
adata.write_h5ad('my_model_data.h5ad') | ||
|
||
# Then run the script with: | ||
# python example_perturbation_expression_prediction.py --model_data_file my_model_data.h5ad | ||
|
||
IMPORTANT: | ||
- If no model_data_file is provided, random data will be generated for testing | ||
- The gene ordering (.var.index) and cell ordering (.obs.index) from your file will be used | ||
- Your data dimensions must match the dataset dimensions | ||
""" | ||
|
||
import logging | ||
import sys | ||
import argparse | ||
from czbenchmarks.datasets import load_dataset | ||
import tempfile | ||
import yaml | ||
from pathlib import Path | ||
|
||
import anndata as ad | ||
import numpy as np | ||
|
||
from czbenchmarks.datasets import load_dataset, SingleCellPerturbationDataset | ||
from czbenchmarks.tasks.single_cell import ( | ||
PerturbationExpressionPredictionTask, | ||
PerturbationExpressionPredictionTaskInput, | ||
|
@@ -10,12 +53,7 @@ | |
load_perturbation_task_input_from_saved_files, | ||
) | ||
from czbenchmarks.tasks.utils import print_metrics_summary | ||
import numpy as np | ||
from czbenchmarks.datasets import SingleCellPerturbationDataset | ||
from czbenchmarks.tasks.types import CellRepresentation | ||
import tempfile | ||
import yaml | ||
from pathlib import Path | ||
|
||
if __name__ == "__main__": | ||
"""Runs a task to calculate perturbation metrics. | ||
|
@@ -86,6 +124,14 @@ | |
default=0.55, | ||
help="Minimum standardized mean difference for DE filtering (used when --metric=t-test)", | ||
) | ||
parser.add_argument( | ||
"--model_data_file", | ||
help="[OPTIONAL] Path to AnnData file (.h5ad) containing your cell representation data. " | ||
"If not provided, random data will be generated for testing. " | ||
"The file should have: cell representations in .X, gene names in .var.index, " | ||
"and cell identifiers in .obs.index. " | ||
"The gene names and cell identifiers should match the task input, although the ordering does not need to be the same.", | ||
) | ||
|
||
|
||
args = parser.parse_args() | ||
|
||
|
@@ -122,19 +168,32 @@ | |
print("Task inputs loaded from saved files") | ||
mlgill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
else: | ||
print("Creating task input directly from dataset...") | ||
# Create task input directly from dataset | ||
# Create task input directly from dataset with separate fields | ||
task_input = PerturbationExpressionPredictionTaskInput( | ||
adata=dataset.control_matched_adata, | ||
target_conditions_dict=dataset.target_conditions_dict, | ||
de_results=dataset.de_results, | ||
var_index=dataset.control_matched_adata.var.index, | ||
masked_adata_obs=dataset.control_matched_adata.obs, | ||
target_conditions_to_save=dataset.target_conditions_to_save, | ||
row_index=dataset.adata.obs.index, | ||
) | ||
|
||
# Generate random model output | ||
model_output: CellRepresentation = np.random.rand( | ||
dataset.adata.shape[0], dataset.adata.shape[1] | ||
) | ||
# Load model data or generate random data | ||
if args.model_data_file: | ||
mlgill marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
model_adata = ad.read_h5ad(args.model_data_file) | ||
|
||
# Use the cell representation data from the file | ||
# Handle both dense and sparse model_adata.X | ||
if hasattr(model_adata.X, "toarray"): | ||
model_output: CellRepresentation = model_adata.X.toarray() | ||
else: | ||
model_output: CellRepresentation = model_adata.X | ||
# Apply the gene and cell ordering from the model data to the task input and validate dimensions | ||
task_input.apply_model_ordering(model_adata) | ||
else: | ||
print("No model data file provided - generating random data for testing") | ||
|
||
# Generate random model output for testing | ||
model_output: CellRepresentation = np.random.rand( | ||
dataset.adata.shape[0], dataset.adata.shape[1] | ||
) | ||
|
||
# Run task | ||
task = PerturbationExpressionPredictionTask(metric=args.metric) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the help under "metric", could we list the two possibilities?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, all default values should match what the default is set to in the respective method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure the values match the defaults
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Percent genes to mask does not mask (that's the one I was looking at when I wrote this). The rest are indeed idential. In the dataset class:
percent_genes_to_mask: float = 0.5
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, metric should be removed as an arg from the script -- it's been deleted from the dataset/task since there is only one possibility right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!