Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Jun 2, 2024
1 parent f111a43 commit ee3a68a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
29 changes: 19 additions & 10 deletions tests/ann2data/test_ann2data_by_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@ def test_sample_case_ann2data_basic():
# so that the resulting splits number of edges will be the same
# as the sum of the number of edges in each cluster
func_args = {"radius": 4.0, "coord_type": "generic"}
coordinates[:25, 0] += 100
cell_type = ["a"] * 20 + ["b"] * 20 + ["c"] * 5 + ["d"] * 5
image_id = list("xy" * 20) + ["z"] * 10
# make clusters for each cell type
for i, ct in enumerate(set(cell_type)):
idx = np.where(np.array(cell_type) == ct)[0]
coordinates[idx, 0] += 100 * i
coordinates[idx, 1] += 100 * i

adata_gt = ad.AnnData(
np.random.rand(50, 2),
obs={"cell_type": ["a"] * 20 + ["b"] * 20 + ["c"] * 5 + ["d"] * 5, "image_id": list("xy" * 20) + ["z"] * 10},
obs={"cell_type": cell_type, "image_id": image_id},
obsm={"spatial_init": coordinates},
)
a2d = ann2data.Ann2DataByCategory(
Expand Down Expand Up @@ -55,11 +62,13 @@ def test_sample_case_ann2data_basic():
assert torch.allclose(torch.cat([d.x for d in datas]), torch.from_numpy(big_adata.X).to(torch.float))
assert sum([d.edge_index.shape[1] for d in datas]) == big_adata.uns["edge_index"].shape[1]
adatas = list(iterables.ToCategoryIterator(category="cell_type")(big_adata))
assert np.allclose(
np.array(adatas[0].obsp["graph_distances"].todense()),
np.array(big_adata.obsp["graph_distances"][0:25, 0:25].todense()),
)
assert np.allclose(
np.array(adatas[1].obsp["graph_distances"].todense()),
np.array(big_adata.obsp["graph_distances"][25:, 25:].todense()),
)
assert len(adatas) == 4
# this line is the for loop version of the last two assertions

for a in adatas:
ct = a.obs["cell_type"].values[0]
ct_idx = np.where(np.array(cell_type) == ct)[0]
np.allclose(
a.obsp["graph_distances"].todense(),
big_adata.obsp["graph_distances"][ct_idx, :][:, ct_idx].todense(),
)
2 changes: 1 addition & 1 deletion tests/transforms/test_add_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_add_edge_index():
tf = transforms.AddEdgeIndex(
spatial_key="spatial_init",
key_added="pred",
func_args={"radius": median_dist, "n_neighs": 4},
func_args={"radius": median_dist, "n_neighs": 4, "coord_type": "generic"},
edge_index_key="edge_index",
edge_weight_key="edge_weight",
gets_connectivities=False, # gets distances
Expand Down

0 comments on commit ee3a68a

Please sign in to comment.