Skip to content

Commit

Permalink
Merge pull request #53 from FrancescaDr/update_basic
Browse files Browse the repository at this point in the history
obs_names and var_names added to fields
  • Loading branch information
selmanozleyen committed Jun 2, 2024
2 parents f2a58ed + ee3a68a commit 22ce1fd
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/geome/ann2data/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _convert_to_tensor(self, obj):
if obj.dtype.name == "category":
return torch.from_numpy(pd.get_dummies(obj).to_numpy()).to(torch.float)
if not np.issubdtype(obj.dtype, np.number):
return torch.from_numpy(obj.astype(np.float)).to(torch.float)
return torch.from_numpy(obj.astype(np.float64)).to(torch.float)
if isinstance(obj, np.ndarray):
return torch.from_numpy(obj).to(torch.float)
else:
Expand Down
5 changes: 5 additions & 0 deletions src/geome/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def get_from_loc(adata: AnnData, location: str) -> Any:
"""
if location == "X":
return adata.X
elif location == "obs_names":
return adata.obs_names.to_numpy()
elif location == "var_names":
return adata.var_names.to_numpy()

assert len(location.split("/")) == 2, f"Location must have only one delimiter {location}"
axis, key = location.split("/")

Expand Down
40 changes: 28 additions & 12 deletions tests/ann2data/test_ann2data_by_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,28 @@ def test_sample_case_ann2data_basic():
# so that the resulting splits number of edges will be the same
# as the sum of the number of edges in each cluster
func_args = {"radius": 4.0, "coord_type": "generic"}
coordinates[:25, 0] += 100
cell_type = ["a"] * 20 + ["b"] * 20 + ["c"] * 5 + ["d"] * 5
image_id = list("xy" * 20) + ["z"] * 10
# make clusters for each cell type
for i, ct in enumerate(set(cell_type)):
idx = np.where(np.array(cell_type) == ct)[0]
coordinates[idx, 0] += 100 * i
coordinates[idx, 1] += 100 * i

adata_gt = ad.AnnData(
np.random.rand(50, 2),
obs={"cell_type": ["a"] * 25 + ["b"] * 25, "image_id": list("cd" * 25)},
obs={"cell_type": cell_type, "image_id": image_id},
obsm={"spatial_init": coordinates},
)
a2d = ann2data.Ann2DataByCategory(
fields={"x": ["X"], "edge_index": ["uns/edge_index"], "edge_weight": ["uns/edge_weight"]},
fields={
"x": ["X"],
"obs_names": ["obs_names"],
"var_names": ["var_names"],
"edge_index": ["uns/edge_index"],
"edge_weight": ["uns/edge_weight"],
"y": ["obs/cell_type"],
},
category="cell_type",
preprocess=transforms.Categorize(keys=["cell_type", "image_id"]),
transform=transforms.AddEdgeIndex(
Expand All @@ -30,7 +44,7 @@ def test_sample_case_ann2data_basic():
),
)
datas = list(a2d(adata_gt.copy()))
assert len(datas) == 2
assert len(datas) == 4
big_adata_tf = transforms.Compose(
[
transforms.Categorize(keys=["cell_type", "image_id"]),
Expand All @@ -48,11 +62,13 @@ def test_sample_case_ann2data_basic():
assert torch.allclose(torch.cat([d.x for d in datas]), torch.from_numpy(big_adata.X).to(torch.float))
assert sum([d.edge_index.shape[1] for d in datas]) == big_adata.uns["edge_index"].shape[1]
adatas = list(iterables.ToCategoryIterator(category="cell_type")(big_adata))
assert np.allclose(
np.array(adatas[0].obsp["graph_distances"].todense()),
np.array(big_adata.obsp["graph_distances"][0:25, 0:25].todense()),
)
assert np.allclose(
np.array(adatas[1].obsp["graph_distances"].todense()),
np.array(big_adata.obsp["graph_distances"][25:, 25:].todense()),
)
assert len(adatas) == 4
# this line is the for loop version of the last two assertions

for a in adatas:
ct = a.obs["cell_type"].values[0]
ct_idx = np.where(np.array(cell_type) == ct)[0]
np.allclose(
a.obsp["graph_distances"].todense(),
big_adata.obsp["graph_distances"][ct_idx, :][:, ct_idx].todense(),
)
2 changes: 1 addition & 1 deletion tests/transforms/test_add_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_add_edge_index():
tf = transforms.AddEdgeIndex(
spatial_key="spatial_init",
key_added="pred",
func_args={"radius": median_dist, "n_neighs": 4},
func_args={"radius": median_dist, "n_neighs": 4, "coord_type": "generic"},
edge_index_key="edge_index",
edge_weight_key="edge_weight",
gets_connectivities=False, # gets distances
Expand Down

0 comments on commit 22ce1fd

Please sign in to comment.