Skip to content

Commit e84860b

Browse files
Merge pull request #45 from camlab-bioml/cklamann/bugfixes
Issues 42, 43, 44
2 parents a8e9266 + 1407286 commit e84860b

File tree

5 files changed

+71
-14
lines changed

5 files changed

+71
-14
lines changed

Dockerfile

+5-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ RUN python3 -m venv $VIRTUAL_ENV && \
3535

3636
USER $USERNAME
3737

38-
COPY --chown=${USER_UID}:${USER_GID} . .
38+
# prevent full rebuilds every time code changes
39+
COPY --chown=${USER_UID}:${USER_GID} pyproject.toml poetry.lock README.md /code/
40+
COPY --chown=${USER_UID}:${USER_GID} starling/__init__.py /code/starling/__init__.py
3941

4042
RUN poetry install --with docs,dev
43+
44+
COPY . .

starling/starling.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def prepare_data(self) -> None:
160160
self.adata.uns["init_cell_size_variances"] = np.array(init_sv)
161161
else:
162162
# init_cell_size_centroids = None; init_cell_size_variances = None
163-
self.adata.varm["init_cell_size_centroids"] = None
164-
self.adata.varm["init_cell_size_variances"] = None
163+
self.adata.uns["init_cell_size_centroids"] = None
164+
self.adata.uns["init_cell_size_variances"] = None
165165
self.train_df = utility.ConcatDataset([self.X, tr_fy, tr_fl])
166166

167167
# model_params = utility.model_paramters(self.init_e, self.init_v, self.init_s, self.init_sv)

starling/utility.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __len__(self):
3838
def init_clustering(
3939
initial_clustering_method: Literal["User", "KM", "GMM", "FS", "PG"],
4040
adata: AnnData,
41-
k: Union[int, None],
41+
k: Union[int, None] = None,
4242
labels: Optional[np.ndarray] = None,
4343
) -> AnnData:
4444
"""Compute initial cluster centroids, variances & labels
@@ -49,7 +49,8 @@ def init_clustering(
4949
``FS`` (FlowSOM), ``User`` (user-provided), or ``PG`` (PhenoGraph).
5050
:param k: The number of clusters, must be ``n_components`` when ``initial_clustering_method`` is ``GMM`` (required),
5151
``k`` when ``initial_clustering_method`` is ``KM`` (required), ``k`` when ``initial_clustering_method``
52-
is ``FS`` (required), ``?`` when ``initial_clustering_method`` is ``PG`` (optional)
52+
is ``FS`` (required), ``?`` when ``initial_clustering_method`` is ``PG`` (optional), and can be ommited when
53+
``initial_clustering_method`` is "User", because user will be passing in their own labels.
5354
:param labels: optional, user-provided labels
5455
5556
:raises: ValueError
@@ -67,6 +68,11 @@ def init_clustering(
6768
"k cannot be ommitted for KMeans, FlowSOM, or Gaussian Mixture"
6869
)
6970

71+
if initial_clustering_method == "User" and labels is None:
72+
raise ValueError(
73+
"labels must be provided when initial_clustering_method is set to 'User'"
74+
)
75+
7076
if initial_clustering_method == "KM":
7177
kms = KMeans(k).fit(adata.X)
7278
init_l = kms.labels_
@@ -90,12 +96,13 @@ def init_clustering(
9096
else:
9197
init_l = labels
9298

93-
k = len(np.unique(init_l))
99+
classes = np.unique(init_l)
100+
k = len(classes)
94101
init_e = np.zeros((k, adata.X.shape[1]))
95102
init_ev = np.zeros((k, adata.X.shape[1]))
96-
for c in range(k):
97-
init_e[c, :] = adata.X[init_l == c].mean(0)
98-
init_ev[c, :] = adata.X[init_l == c].var(0)
103+
for i, c in enumerate(classes):
104+
init_e[i, :] = adata.X[init_l == c].mean(0)
105+
init_ev[i, :] = adata.X[init_l == c].var(0)
99106

100107
elif initial_clustering_method == "FS":
101108
## needs to output to csv first

tests/test_sanity.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from os.path import dirname, join
22

33
import anndata as ad
4+
import numpy as np
45
import pandas as pd
5-
import pytorch_lightning as pl
66
from lightning_lite import seed_everything
77
from pytorch_lightning.callbacks import EarlyStopping
88

99
from starling import starling, utility
1010

1111

1212
def test_can_run_km(tmpdir):
13-
"""Temporary sanity check"""
13+
"""Test that we can run with the KM setting in init_clustering"""
1414
seed_everything(10, workers=True)
1515

1616
raw_adata = ad.read_h5ad(join(dirname(__file__), "fixtures", "sample_input.h5ad"))
@@ -51,7 +51,7 @@ def test_can_run_km(tmpdir):
5151

5252

5353
def test_can_run_gmm(tmpdir):
54-
"""Temporary sanity check"""
54+
"""Test that we can run with the GMM setting in init_clustering"""
5555
seed_everything(10, workers=True)
5656
adata = utility.init_clustering(
5757
"GMM",
@@ -89,7 +89,7 @@ def test_can_run_gmm(tmpdir):
8989

9090

9191
def test_can_run_pg(tmpdir):
92-
"""Temporary sanity check"""
92+
"""Test that we can run with the PG setting in init_clustering"""
9393
seed_everything(10, workers=True)
9494
adata = utility.init_clustering(
9595
"PG",
@@ -111,7 +111,6 @@ def test_can_run_pg(tmpdir):
111111
## initial expression centriods (p x c) matrix
112112
init_cent = pd.DataFrame(result.varm["init_exp_centroids"], index=result.var_names)
113113

114-
# j seems to vary here
115114
assert init_cent.shape[0] == 24
116115

117116
## starling expression centriods (p x c) matrix
@@ -125,3 +124,28 @@ def test_can_run_pg(tmpdir):
125124
)
126125

127126
assert prom_mat.shape[0] == 13685
127+
128+
129+
def test_can_run_pg_without_cell_size(tmpdir):
130+
"""Test that we can run the model with model_cell_size=False in ST"""
131+
seed_everything(10, workers=True)
132+
adata = utility.init_clustering(
133+
"PG",
134+
ad.read_h5ad(join(dirname(__file__), "fixtures", "sample_input.h5ad")),
135+
k=10,
136+
)
137+
st = starling.ST(adata, model_cell_size=False)
138+
cb_early_stopping = EarlyStopping(monitor="train_loss", mode="min", verbose=False)
139+
140+
## train ST
141+
st.train_and_fit(
142+
max_epochs=2,
143+
callbacks=[cb_early_stopping],
144+
default_root_dir=tmpdir,
145+
)
146+
147+
result = st.result()
148+
149+
exp_cent = pd.DataFrame(result.varm["st_exp_centroids"], index=result.var_names)
150+
151+
assert exp_cent.shape[0] == 24

tests/test_utility.py

+22
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from anndata import AnnData
23

34
from starling.utility import init_clustering, validate_starling_arguments
@@ -6,9 +7,11 @@
67
def assert_annotated(adata: AnnData, k):
78
assert "init_exp_centroids" in adata.varm
89
assert adata.varm["init_exp_centroids"].shape == (adata.X.shape[1], k)
10+
assert not np.any(np.isnan(adata.varm["init_exp_centroids"]))
911

1012
assert "init_exp_centroids" in adata.varm
1113
assert adata.varm["init_exp_variances"].shape == (adata.X.shape[1], k)
14+
assert not np.any(np.isnan(adata.varm["init_exp_variances"]))
1215

1316
assert "init_label" in adata.obs
1417
assert adata.obs["init_label"].shape == (adata.X.shape[0],)
@@ -32,6 +35,25 @@ def test_init_clustering_pg(simple_adata):
3235
assert_annotated(initialized, k)
3336

3437

38+
def test_init_clustering_user(simple_adata):
39+
k = 3
40+
initialized = init_clustering(
41+
"User", simple_adata, labels=np.random.randint(k, size=32)
42+
)
43+
assert_annotated(initialized, k)
44+
45+
46+
def test_init_clustering_user_string(simple_adata):
47+
k = 3
48+
initialized = init_clustering(
49+
"User",
50+
simple_adata,
51+
labels=np.random.choice(np.array(["a", "b", "c"]), size=32),
52+
)
53+
54+
assert_annotated(initialized, k)
55+
56+
3557
def test_validation_passes_with_no_size(simple_adata):
3658
validate_starling_arguments(
3759
adata=simple_adata,

0 commit comments

Comments
 (0)