diff --git a/src/geome/ann2data/basic.py b/src/geome/ann2data/basic.py index 0e169e0..efe86a9 100644 --- a/src/geome/ann2data/basic.py +++ b/src/geome/ann2data/basic.py @@ -87,7 +87,7 @@ def _convert_to_tensor(self, obj): if obj.dtype.name == "category": return torch.from_numpy(pd.get_dummies(obj).to_numpy()).to(torch.float) if not np.issubdtype(obj.dtype, np.number): - return torch.from_numpy(obj.astype(np.float)).to(torch.float) + return torch.from_numpy(obj.astype(np.float64)).to(torch.float) if isinstance(obj, np.ndarray): return torch.from_numpy(obj).to(torch.float) else: diff --git a/src/geome/utils.py b/src/geome/utils.py index fb953ec..d09a624 100644 --- a/src/geome/utils.py +++ b/src/geome/utils.py @@ -21,6 +21,11 @@ def get_from_loc(adata: AnnData, location: str) -> Any: """ if location == "X": return adata.X + elif location == "obs_names": + return adata.obs_names.to_numpy() + elif location == "var_names": + return adata.var_names.to_numpy() + assert len(location.split("/")) == 2, f"Location must have only one delimiter {location}" axis, key = location.split("/") diff --git a/tests/ann2data/test_ann2data_by_category.py b/tests/ann2data/test_ann2data_by_category.py index f06c8fa..b085397 100644 --- a/tests/ann2data/test_ann2data_by_category.py +++ b/tests/ann2data/test_ann2data_by_category.py @@ -11,14 +11,28 @@ def test_sample_case_ann2data_basic(): # so that the resulting splits number of edges will be the same # as the sum of the number of edges in each cluster func_args = {"radius": 4.0, "coord_type": "generic"} - coordinates[:25, 0] += 100 + cell_type = ["a"] * 20 + ["b"] * 20 + ["c"] * 5 + ["d"] * 5 + image_id = list("xy" * 20) + ["z"] * 10 + # make clusters for each cell type + for i, ct in enumerate(set(cell_type)): + idx = np.where(np.array(cell_type) == ct)[0] + coordinates[idx, 0] += 100 * i + coordinates[idx, 1] += 100 * i + adata_gt = ad.AnnData( np.random.rand(50, 2), - obs={"cell_type": ["a"] * 25 + ["b"] * 25, "image_id": list("cd" * 25)}, + obs={"cell_type": cell_type, "image_id": image_id}, obsm={"spatial_init": coordinates}, ) a2d = ann2data.Ann2DataByCategory( - fields={"x": ["X"], "edge_index": ["uns/edge_index"], "edge_weight": ["uns/edge_weight"]}, + fields={ + "x": ["X"], + "obs_names": ["obs_names"], + "var_names": ["var_names"], + "edge_index": ["uns/edge_index"], + "edge_weight": ["uns/edge_weight"], + "y": ["obs/cell_type"], + }, category="cell_type", preprocess=transforms.Categorize(keys=["cell_type", "image_id"]), transform=transforms.AddEdgeIndex( @@ -30,7 +44,7 @@ def test_sample_case_ann2data_basic(): ), ) datas = list(a2d(adata_gt.copy())) - assert len(datas) == 2 + assert len(datas) == 4 big_adata_tf = transforms.Compose( [ transforms.Categorize(keys=["cell_type", "image_id"]), @@ -48,11 +62,13 @@ def test_sample_case_ann2data_basic(): assert torch.allclose(torch.cat([d.x for d in datas]), torch.from_numpy(big_adata.X).to(torch.float)) assert sum([d.edge_index.shape[1] for d in datas]) == big_adata.uns["edge_index"].shape[1] adatas = list(iterables.ToCategoryIterator(category="cell_type")(big_adata)) - assert np.allclose( - np.array(adatas[0].obsp["graph_distances"].todense()), - np.array(big_adata.obsp["graph_distances"][0:25, 0:25].todense()), - ) - assert np.allclose( - np.array(adatas[1].obsp["graph_distances"].todense()), - np.array(big_adata.obsp["graph_distances"][25:, 25:].todense()), - ) + assert len(adatas) == 4 + # this line is the for loop version of the last two assertions + + for a in adatas: + ct = a.obs["cell_type"].values[0] + ct_idx = np.where(np.array(cell_type) == ct)[0] + np.allclose( + a.obsp["graph_distances"].todense(), + big_adata.obsp["graph_distances"][ct_idx, :][:, ct_idx].todense(), + ) diff --git a/tests/transforms/test_add_edge_index.py b/tests/transforms/test_add_edge_index.py index c33b9da..0993dc5 100644 --- a/tests/transforms/test_add_edge_index.py +++ b/tests/transforms/test_add_edge_index.py @@ -62,7 +62,7 @@ def test_add_edge_index(): tf = transforms.AddEdgeIndex( spatial_key="spatial_init", key_added="pred", - func_args={"radius": median_dist, "n_neighs": 4}, + func_args={"radius": median_dist, "n_neighs": 4, "coord_type": "generic"}, edge_index_key="edge_index", edge_weight_key="edge_weight", gets_connectivities=False, # gets distances