Skip to content

Commit cb77ff1

Browse files
committed
error logging and better format of the example
1 parent eb7f135 commit cb77ff1

7 files changed

+434
-33
lines changed

examples/example_perturbation_expression_prediction.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
"If not provided, random data will be generated for testing. "
131131
"The file should have: cell representations in .X, gene names in .var.index, "
132132
"and cell identifiers in .obs.index. "
133-
"Example: adata.write_h5ad('my_model_data.h5ad')",
133+
"The gene names and cell identifiers should match the task input, although the ordering does not need to be the same.",
134134
)
135135

136136
args = parser.parse_args()
@@ -178,20 +178,15 @@
178178
# Load model data or generate random data
179179
if args.model_data_file:
180180
model_adata = ad.read_h5ad(args.model_data_file)
181-
# Validate dimensions
182-
assert model_adata.shape == dataset.adata.shape, (
183-
f"Model data shape {model_adata.shape} does not match dataset shape {dataset.adata.shape}"
184-
)
185181

186182
# Use the cell representation data from the file
187-
model_output: CellRepresentation = model_adata.X
188-
189-
# Apply the gene and cell ordering from the model data to the task input
190-
task_input.adata.var.index = model_adata.var.index
191-
task_input.adata.uns["cell_barcode_index"] = model_adata.obs.index.astype(
192-
str
193-
).values
194-
183+
# Handle both dense and sparse model_adata.X
184+
if hasattr(model_adata.X, "toarray"):
185+
model_output: CellRepresentation = model_adata.X.toarray()
186+
else:
187+
model_output: CellRepresentation = model_adata.X
188+
# Apply the gene and cell ordering from the model data to the task input and validate dimensions
189+
task_input.apply_model_ordering(model_adata)
195190
else:
196191
print("No model data file provided - generating random data for testing")
197192

examples/test_equivalency_perturbation_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,6 @@ def run_new_code(
329329
filter &= df_csv["standardized_mean_diff"].abs() >= args.min_smd
330330

331331
df_csv = df_csv[filter]
332-
333332
assert_de_results_equivalent(df_csv, new_dataset.de_results, col_map)
334333
logger.info("DE results matched")
335334

examples/test_equivalency_perturbation_task.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,30 @@ def generate_model_predictions(masked_notebook_adata, args):
255255
) as f:
256256
target_genes_to_save = json.load(f)
257257
print("Generating model predictions matrix...")
258+
# Generate random model data
259+
print("Generating random model data for testing")
258260
model_output: CellRepresentation = np.random.rand(
259261
masked_notebook_adata.shape[0], masked_notebook_adata.shape[1]
260262
)
263+
264+
# Create and save h5ad file if save_model_data is specified
265+
if args.save_model_data:
266+
# Auto-generate filename based on test parameters
267+
filename = (
268+
f"generated_model_data_{args.metric_type}_{args.percent_genes_to_mask}.h5ad"
269+
)
270+
print(f"Creating and saving model data to {filename}")
271+
272+
# Create AnnData object with the generated model data
273+
model_adata = ad.AnnData(
274+
X=model_output,
275+
obs=masked_notebook_adata.obs.copy(),
276+
var=masked_notebook_adata.var.copy(),
277+
)
278+
279+
# Save to h5ad file
280+
model_adata.write_h5ad(filename)
281+
print(f"Saved model data with shape {model_output.shape} to {filename}")
261282
obs_index = masked_notebook_adata.obs.index
262283

263284
# Speed up by using numpy and pandas vectorized lookups instead of repeated .index() calls
@@ -355,6 +376,14 @@ def generate_model_predictions(masked_notebook_adata, args):
355376
type=str,
356377
default="notebook_task_inputs_{metric_type}_{percent_genes_to_mask}",
357378
)
379+
parser.add_argument(
380+
"--save_model_data",
381+
action="store_true",
382+
help="[OPTIONAL] Save the generated random model data as an AnnData file (.h5ad). "
383+
"The file will be automatically named based on the test parameters and saved in the current directory. "
384+
"The saved file will contain: cell representations in .X, gene names in .var.index, "
385+
"and cell identifiers in .obs.index.",
386+
)
358387

359388
args = parser.parse_args()
360389
mask_portion = (

src/czbenchmarks/datasets/single_cell_perturbation.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def _create_adata(self) -> Tuple[ad.AnnData, dict]:
218218
percent_genes_to_mask=self.percent_genes_to_mask,
219219
min_de_genes_to_mask=self.min_de_genes_to_mask,
220220
condition_col=self.condition_key,
221-
gene_col=self.de_gene_col,
221+
gene_col="gene_id", # Column was renamed to gene_id during optimization
222222
)
223223

224224
target_conditions = list(target_condition_dict.keys())
@@ -270,6 +270,9 @@ def _create_adata(self) -> Tuple[ad.AnnData, dict]:
270270
adata_final.obs[self.condition_key]
271271
)
272272

273+
# Optimize: Keep only necessary columns in obs (only condition_key is used in task)
274+
adata_final.obs = adata_final.obs[[self.condition_key]]
275+
273276
# Add task-related data to uns for easy access
274277
adata_final.uns["target_conditions_dict"] = target_condition_dict
275278
adata_final.uns["de_results"] = {
@@ -341,6 +344,24 @@ def load_data(
341344
self.de_results = self.load_and_filter_deg_results()
342345
logger.info(f"Using {len(self.de_results)} differential expression values")
343346

347+
# Optimize: Keep only necessary columns in de_results
348+
# Task only uses: condition_key, "gene_id", and metric_column (logfoldchange or standardized_mean_diff)
349+
metric_column = (
350+
"logfoldchange"
351+
if self.deg_test_name == "wilcoxon"
352+
else "standardized_mean_diff"
353+
)
354+
necessary_columns = [self.condition_key, self.de_gene_col, metric_column]
355+
356+
# Ensure we have gene_id column for compatibility with task
357+
if self.de_gene_col != "gene_id":
358+
self.de_results = self.de_results.rename(
359+
columns={self.de_gene_col: "gene_id"}
360+
)
361+
necessary_columns = [self.condition_key, "gene_id", metric_column]
362+
363+
self.de_results = self.de_results[necessary_columns]
364+
344365
# Compare conditions and throw warning or error for unmatched conditions
345366
unique_conditions_adata = set(self.adata.obs[self.condition_key])
346367
unique_conditions_control_cells_ids = set(self.control_cells_ids.keys())
@@ -392,11 +413,10 @@ def store_task_inputs(self) -> Path:
392413
Store all task inputs as separate files.
393414
394415
This method saves all task-related data as separate files:
395-
- control_matched_adata.h5ad: The main AnnData object
416+
- control_matched_adata.h5ad: The main AnnData object (includes cell_barcode_index in uns)
396417
- control_cells_ids.json: Control cell IDs mapping
397418
- target_conditions_dict.json: Target conditions dictionary
398419
- de_results.csv: Differential expression results
399-
- cell_barcode_index.npy: Original cell barcode indices
400420
401421
Returns:
402422
Path: Path to the task inputs directory.

src/czbenchmarks/tasks/single_cell/perturbation_expression_prediction.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,24 @@ class PerturbationExpressionPredictionTaskInput(TaskInput):
2525
target_conditions_dict: dict
2626
de_results: pd.DataFrame
2727

28-
class Config:
29-
arbitrary_types_allowed = True
28+
def apply_model_ordering(self, model_adata: ad.AnnData) -> None:
29+
"""
30+
Apply gene and cell ordering from model data to match the task input.
31+
32+
Args:
33+
model_adata: AnnData object containing the desired gene and cell ordering
34+
"""
35+
36+
# Apply gene ordering
37+
# Assert that the same values are in both gene and cell indices before re-assigning
38+
if set(self.adata.var.index) != set(model_adata.var.index):
39+
raise ValueError("Gene indices in task input and model data do not match.")
40+
if set(self.adata.obs.index) != set(model_adata.obs.index):
41+
raise ValueError("Cell indices in task input and model data do not match.")
42+
self.adata.var.index = model_adata.var.index
43+
44+
# Apply cell barcode ordering
45+
self.adata.uns["cell_barcode_index"] = model_adata.obs.index.astype(str).values
3046

3147

3248
def load_perturbation_task_input_from_saved_files(
@@ -42,7 +58,7 @@ def load_perturbation_task_input_from_saved_files(
4258
PerturbationExpressionPredictionTaskInput: The loaded task input.
4359
"""
4460

45-
# Load the main AnnData object
61+
# Load the main AnnData object (contains cell_barcode_index in uns)
4662
adata_file = task_inputs_dir / "control_matched_adata.h5ad"
4763
task_adata = ad.read_h5ad(adata_file)
4864

@@ -133,6 +149,7 @@ def _run_task(
133149

134150
for condition in perturbation_conditions:
135151
# Get target genes for this condition
152+
136153
target_genes = target_conditions_dict.get(condition, [])
137154
valid_genes = [g for g in target_genes if g in adata.var.index]
138155

@@ -170,7 +187,6 @@ def _run_task(
170187
.index.str.split("_")
171188
.str[0]
172189
)
173-
174190
condition_idx = np.where(base_cell_ids.isin(condition_cells))[0]
175191
control_idx = np.where(base_cell_ids.isin(control_cells))[0]
176192

tests/datasets/test_single_cell_perturbation_dataset.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def test_perturbation_dataset_store_task_inputs(
259259
# Check that all required files exist
260260
expected_files = [
261261
"control_matched_adata.h5ad",
262+
"control_cells_ids.json",
262263
"target_conditions_dict.json",
263264
"de_results.csv",
264265
]
@@ -284,15 +285,16 @@ def test_perturbation_dataset_store_task_inputs(
284285
target_conditions_dict = json.load(f)
285286
assert isinstance(target_conditions_dict, dict)
286287

287-
# Load and validate DE results CSV
288+
# Load and validate DE results CSV (should only have optimized columns)
288289
de_df = pd.read_csv(task_inputs_dir / "de_results.csv")
289290
assert not de_df.empty
290-
base_cols = {"condition", "gene", "pval_adj"}
291-
assert base_cols.issubset(set(de_df.columns))
291+
# Only the necessary columns should be present
292+
expected_cols = {"condition", "gene_id"}
292293
if deg_test_name == "wilcoxon":
293-
assert "logfoldchange" in de_df.columns
294+
expected_cols.add("logfoldchange")
294295
else:
295-
assert "standardized_mean_diff" in de_df.columns
296+
expected_cols.add("standardized_mean_diff")
297+
assert set(de_df.columns) == expected_cols
296298

297299
# Load and validate cell barcode index
298300
cell_barcode_index = task_adata.uns["cell_barcode_index"]
@@ -415,16 +417,17 @@ def test_control_matched_adata_contains_task_data(self, deg_test_name, tmp_path)
415417
assert len(uns["control_cells_ids"]) > 0
416418
assert uns["control_cells_ids"] == dataset.control_cells_ids
417419

418-
# Check de_results can be reconstructed as DataFrame
420+
# Check de_results can be reconstructed as DataFrame (should only have optimized columns)
419421
assert isinstance(uns["de_results"], dict)
420422
de_df = pd.DataFrame(uns["de_results"])
421423
assert not de_df.empty
422-
base_cols = {"condition", "gene", "pval_adj"}
423-
assert base_cols.issubset(set(de_df.columns))
424+
# Only the necessary columns should be present
425+
expected_cols = {"condition", "gene_id"}
424426
if deg_test_name == "wilcoxon":
425-
assert "logfoldchange" in de_df.columns
427+
expected_cols.add("logfoldchange")
426428
else:
427-
assert "standardized_mean_diff" in de_df.columns
429+
expected_cols.add("standardized_mean_diff")
430+
assert set(de_df.columns) == expected_cols
428431

429432
# Check cell_barcode_index
430433
assert isinstance(uns["cell_barcode_index"], np.ndarray)

0 commit comments

Comments
 (0)