Skip to content

Commit

Permalink
Merge pull request #46 from theislab/update
Browse files Browse the repository at this point in the history
refactoring, updated the notebooks and tested them
  • Loading branch information
selmanozleyen committed Apr 25, 2024
2 parents 675dcc3 + 67fce3f commit 74b5795
Show file tree
Hide file tree
Showing 27 changed files with 1,403 additions and 1,864 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ __pycache__/
# IDEs
/.idea/
/.vscode/

**/lightning_logs/**
326 changes: 326 additions & 0 deletions docs/notebooks/1_iterables_and_iterators.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Basics: Iterables and Iterators\n",
"\n",
"\n",
"## Creating Iterables From AnnData\n",
"\n",
"The `Ann2DataBasic` class expects an anndata iterable (e.g., `Iterable[AnnData]`) when called upon. This in your case could be a list of anndata objects. However, if you have a single anndata object and want to split it into multiple anndata objects, you can use the implementations of `geome.iterables.ToIterable` classes. The signature of these classes is as follows:\n",
"\n",
"```python\n",
"class ToIterable(ABC):\n",
" @abstractmethod\n",
" def __call__(self, adata: AnnData) -> Iterable[AnnData]:\n",
" pass\n",
"\n",
"class Ann2DatAbstract(ABC):\n",
" \"\"\"Abstract class that transforms an iterable of AnnData to Pytorch Geometric Data objects.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" fields: dict[str, list[str]],\n",
" adata2iterable: Callable[[AnnData], Iterable[AnnData]] | None = None,\n",
" ...\n",
" ) -> None:\n",
" pass\n",
"\n",
" def __call__(self, adata: AnnData | Iterable[AnnData]) -> Iterable[Data]:\n",
" pass\n",
"```\n",
"\n",
"You can give an instance of this class to the `anndata2iter` parameter of the `geome.ann2data.Ann2Data` constructor or you can just call the instance with an anndata object to get an iterable of anndata objects. The advantage of giving the instance to the `anndata2iter` parameter is that you can use the same instance to split multiple anndata objects and specify preprocessing strategies before the split happens."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from geome import iterables, transforms, ann2data\n",
"import squidpy as sq\n",
"import numpy as np\n",
"from anndata import AnnData"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data\n",
"First, let's load the data and see what it looks like. In this example assume that we want to split by these categories specified in `adata.obs[\"Cluster\"]`.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['Endothelial', 'Epithelial', 'Fibroblast', 'Imm_other', 'Myeloid_CD11c',\n",
" 'Myeloid_CD68', 'Tcell_CD4', 'Tcell_CD8'],\n",
" dtype='object')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load squidpy dataset\n",
"adata = sq.datasets.mibitof()\n",
"adata.obs[\"Cluster\"].cat.categories"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create ToIterable instance: ToCategoryIterator\n",
"We will create an instance of `ToCategoryIterator` class which will split the anndata object by the categories specified in `adata.obs[\"Cluster\"]`. The signature of this class is as follows:\n",
"\n",
"```python\n",
"class ToCategoryIterator(ToIterable):\n",
" \"\"\"Iterates over `adata` by category on the given axis (either obs(0) or var(1)).\n",
"\n",
" Preserves the categories in the resulting AnnData obs and var Series.\n",
" \"\"\"\n",
"\n",
" def __init__(self, category: str, axis: Literal[0, 1, \"obs\", \"var\"] = \"obs\", preserve_categories: bool = True):\n",
" pass\n",
"\n",
" def __call__(self, adata: AnnData) -> Iterator[AnnData]:\n",
" pass\n",
"```\n",
"\n",
"#### Reminder: Iterator vs Iterable\n",
"- An `iterable` is an object that can be iterated over. It returns an iterator when `iter()` is called on it.\n",
"- An `iterator` is an object that produces the next value when `next()` is called on it.\n",
"- An `iterator` is an `iterable` but an `iterable` is not an `iterator`.\n",
"\n",
"The class `ToCategoryIterator` returns an `iterator` of anndata objects since an `iterator` is an `iterable` this complies with the `Ann2Data` class's requirements. Being able to return an `iterator` is useful because it allows us to lazily load the data and not store all the data in memory at once.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"to_iterable: iterables.ToIterable = iterables.ToCategoryIterator(\"Cluster\", axis=\"obs\", preserve_categories=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[AnnData object with n_obs × n_vars = 115 × 36\n",
" obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'\n",
" var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'\n",
" uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap'\n",
" obsm: 'X_scanorama', 'X_umap', 'spatial'\n",
" obsp: 'connectivities', 'distances',\n",
" AnnData object with n_obs × n_vars = 746 × 36\n",
" obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'\n",
" var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'\n",
" uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap'\n",
" obsm: 'X_scanorama', 'X_umap', 'spatial'\n",
" obsp: 'connectivities', 'distances',\n",
" AnnData object with n_obs × n_vars = 270 × 36\n",
" obs: 'row_num', 'point', 'cell_id', 'X1', 'center_rowcoord', 'center_colcoord', 'cell_size', 'category', 'donor', 'Cluster', 'batch', 'library_id'\n",
" var: 'mean-0', 'std-0', 'mean-1', 'std-1', 'mean-2', 'std-2'\n",
" uns: 'Cluster_colors', 'batch_colors', 'neighbors', 'spatial', 'umap'\n",
" obsm: 'X_scanorama', 'X_umap', 'spatial'\n",
" obsp: 'connectivities', 'distances']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"split_adatas = list(to_iterable(adata)) # split by cluster\n",
"assert len(split_adatas) == len(adata.obs[\"Cluster\"].cat.categories) # ensure all clusters have their own adata\n",
"split_adatas[:3] # show first 3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Important Note about `preserve_categories` parameter\n",
"\n",
"If preserve_categories is set to True, the categories of the original anndata object will be preserved in the split anndata objects. This means that if a category is not present in a split anndata object, it will still be present in the `obs` attribute of the split anndata object but with all values set to 0. This is useful when you want to keep track of the categories that were present in the original anndata object. However, if you want to remove the categories that are not present in a split anndata object, you can set preserve_categories to False. This will remove the categories that are not present in the split anndata object from the `obs` attribute of the split anndata object.\n",
"\n",
"This is important when you want to use one-hot encoding for the categories."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"assert all(len(ad.obs[\"Cluster\"].cat.categories) == len(adata.obs[\"Cluster\"].cat.categories) for ad in split_adatas) # ensure all splits have the same category"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The case for `preserve_categories=False` is shown in the example below."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1, 1, 1, 1, 1, 1, 1, 1]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"unpreserved_example = list(iterables.ToCategoryIterator(\"Cluster\", axis=\"obs\", preserve_categories=False)(adata))\n",
"assert all(len(ad.obs[\"Cluster\"].cat.categories) == 1 for ad in unpreserved_example)\n",
"# you can see that the categories are not preserved and each split has only one category\n",
"[len(ad.obs[\"Cluster\"].cat.categories) for ad in unpreserved_example]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### The role of AnnData2Iterable in Ann2Data\n",
"\n",
"In the following example, we will show three different ways of creating data objects that result in the same data objects.\n",
"1. `Ann2DataBasic` and giving the `ToCategoryIterator` instance to the `anndata2iter` parameter.\n",
"2. `Ann2DataBasic` and splitting the anndata object using the `ToCategoryIterator` instance and then passing the resulting iterable to the `Ann2DataBase` call.\n",
"3. `Ann2DataByCategory` which is a subclass of `Ann2DataBasic`, takes the category as a parameter, and uses the `ToCategoryIterator` instance internally."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. `Ann2DataBasic` and giving the `ToCategoryIterator` instance to the `anndata2iter` parameter."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"result1 = ann2data.Ann2DataBasic(\n",
" fields={\"x\": [\"X\"]},\n",
" adata2iter=iterables.ToCategoryIterator(\"Cluster\", axis=\"obs\"),\n",
").to_list(adata)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2. `Ann2DataBasic` and splitting the anndata object using the `ToCategoryIterator` instance and then passing the resulting iterable to the `Ann2DataBase` call."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"result2 = ann2data.Ann2DataBasic(\n",
" fields={\"x\": [\"X\"]},\n",
").to_list(iterables.ToCategoryIterator(\"Cluster\", axis=\"obs\")(adata))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3. `Ann2DataByCategory` which is a subclass of `Ann2DataBasic`, takes the category as a parameter, and uses the `ToCategoryIterator` instance internally."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"result3 = ann2data.Ann2DataByCategory(\n",
" fields={\"x\": [\"X\"]},\n",
" category=\"Cluster\",\n",
").to_list(adata)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below is the code that demonstrates the three ways of creating data objects that result in the same data objects."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"assert all(np.allclose(r1.x, r2.x) for r1, r2 in zip(result1, result2)) and all(np.allclose(r1.x, r3.x) for r1, r3 in zip(result1, result3))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "gnn",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 74b5795

Please sign in to comment.