Skip to content

Commit

Permalink
only check for size column if model_cell_size is True
Browse files Browse the repository at this point in the history
  • Loading branch information
cklamann committed Apr 4, 2024
1 parent 48d90f5 commit 5f79c5b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
3 changes: 2 additions & 1 deletion starling/starling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class ST(pl.LightningModule):
:param dist_option: The distribution to use, one of 'T' for Student-T (df=2) or 'N' for Normal (Gaussian), defaults to T
:param singlet_prop: The proportion of anticipated segmentation error free cells
:param model_cell_size: Whether STARLING should incoporate cell size in the model
:param cell_size_col_name: The column name in ``AnnData`` (anndata.obs)
:param cell_size_col_name: The column name in ``AnnData`` (anndata.obs). Required only if ``model_cell_size`` is ``True``,
otherwise ignored.
:param model_zplane_overlap: If cell size is modelled, should STARLING model z-plane overlap
:param model_regularizer: Regularizer term impose on synethic doublet loss (BCE)
:param learning_rate: Learning rate of ADAM optimizer for STARLING
Expand Down
2 changes: 1 addition & 1 deletion starling/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def validate_starling_arguments(
f"Argument `model_cell_size` must be boolean, received {type(model_cell_size)}"
)

if cell_size_col_name not in adata.obs:
if model_cell_size and cell_size_col_name not in adata.obs:
raise ValueError(
f"Argument `cell_size_col_name` must be a valid column in `adata.obs`"
)
Expand Down
15 changes: 14 additions & 1 deletion tests/test_utility.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from anndata import AnnData

from starling.utility import init_clustering
from starling.utility import init_clustering, validate_starling_arguments


def assert_annotated(adata: AnnData, k):
Expand Down Expand Up @@ -30,3 +30,16 @@ def test_init_clustering_pg(simple_adata):
k = 2
initialized = init_clustering("PG", simple_adata, k)
assert_annotated(initialized, k)


def test_validation_passes_with_no_size(simple_adata):
validate_starling_arguments(
adata=simple_adata,
cell_size_col_name="nonexistent",
dist_option="T",
singlet_prop=0.5,
model_cell_size=False,
model_zplane_overlap=False,
model_regularizer=0.1,
learning_rate=1e-3,
)

0 comments on commit 5f79c5b

Please sign in to comment.