Skip to content

Commit

Permalink
Merge pull request #42 from theislab/refactor
Browse files Browse the repository at this point in the history
API Update, Transform and ToIterable classes.
  • Loading branch information
selmanozleyen authored Nov 30, 2023
2 parents 4ce1212 + 485e76d commit 0dda80a
Show file tree
Hide file tree
Showing 27 changed files with 1,909 additions and 801 deletions.
322 changes: 322 additions & 0 deletions docs/notebooks/add_edge_weights_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Add Edge Weights Example"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"import anndata as ad\n",
"from geome.ann2data import Ann2DataByCategory\n",
"from utils.datasets import DatasetHartmann\n",
"from geome import transforms"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setting up the experiment"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"x:\n",
" -----------------------------------\n",
" obs/Cluster_preprocessed ('?',)\n",
" obs/donor ('?',)\n",
" -----------------------------------\n",
"edge_index:\n",
" -----------------------------------\n",
" uns/edge_index ('?',)\n",
" -----------------------------------\n",
"edge_weight:\n",
" -----------------------------------\n",
" uns/edge_weight ('?',)\n",
" -----------------------------------\n",
"y:\n",
" -----------------------------------\n",
" X ('?',)\n",
" -----------------------------------"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fields = {\n",
" 'x':['obs/Cluster_preprocessed','obs/donor'],\n",
" 'edge_index': ['uns/edge_index'],\n",
" 'edge_weight': ['uns/edge_weight'],\n",
" 'y':['X']\n",
"}\n",
"\n",
"\n",
"adj_matrix_loc = 'obsp/adjacency_matrix_connectivities'\n",
"\n",
"\n",
"preprocess = transforms.Compose([\n",
" transforms.Categorize(['Cluster_preprocessed', 'donor'],axis='obs'),\n",
"])\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.AddAdjMatrix(location=adj_matrix_loc),\n",
" transforms.AddEdgeIndex(adj_matrix_loc=adj_matrix_loc,edge_index_key='edge_index'),\n",
" transforms.AddEdgeWeight(edge_index_key='edge_index',edge_weight_key='edge_weight', weight_matrix_loc=adj_matrix_loc),\n",
"])\n",
"\n",
"\n",
"category_to_iterate = 'point'\n",
"\n",
"a2d = Ann2DataByCategory(\n",
" fields=fields,\n",
" category=category_to_iterate,\n",
" preprocess=preprocess,\n",
" transform=transform,\n",
")\n",
"\n",
"a2d # won't show any sizes until we call next(a2d(adata)) at least once"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading data from raw files\n",
"registering celldata\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/sel/mambaforge/envs/gnn/lib/python3.10/site-packages/anndata/_core/anndata.py:117: ImplicitModificationWarning: Transforming to str index.\n",
" warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"collecting image-wise celldata\n",
"adding graph-level covariates\n",
"Loaded 58 images with complete data from 4 patients over 63747 cells with 36 cell features and 8 distinct celltypes.\n"
]
}
],
"source": [
"#Mibitof\n",
"dataset = DatasetHartmann(data_path='./example_data/hartmann/')\n",
"adatas = list(dataset.img_celldata.values())\n",
"\n",
"# Merge the list of adatas and convert some string to categories as they should be\n",
"adata = ad.concat(adatas)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Checking some dimensions\n",
"This will show the sizes of the tensors it last see. If one field has more than one locations it will only show the last dimension"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"x:\n",
" -----------------------------------\n",
" obs/Cluster_preprocessed ('-', 8)\n",
" obs/donor ('-', 4)\n",
" -----------------------------------\n",
"edge_index:\n",
" -----------------------------------\n",
" uns/edge_index (2, 8028)\n",
" -----------------------------------\n",
"edge_weight:\n",
" -----------------------------------\n",
" uns/edge_weight (8028,)\n",
" -----------------------------------\n",
"y:\n",
" -----------------------------------\n",
" X (1338, 36)\n",
" -----------------------------------"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"next(a2d(adata)) # will show the sizes of the fields\n",
"a2d"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### Getting the whole list of data"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Data(x=[1338, 12], edge_index=[2, 8028], y=[1338, 36], edge_weight=[8028]),\n",
" Data(x=[311, 12], edge_index=[2, 1866], y=[311, 36], edge_weight=[1866]),\n",
" Data(x=[768, 12], edge_index=[2, 4608], y=[768, 36], edge_weight=[4608]),\n",
" Data(x=[1020, 12], edge_index=[2, 6120], y=[1020, 36], edge_weight=[6120]),\n",
" Data(x=[2100, 12], edge_index=[2, 12600], y=[2100, 36], edge_weight=[12600]),\n",
" Data(x=[1325, 12], edge_index=[2, 7950], y=[1325, 36], edge_weight=[7950]),\n",
" Data(x=[1091, 12], edge_index=[2, 6546], y=[1091, 36], edge_weight=[6546]),\n",
" Data(x=[1046, 12], edge_index=[2, 6276], y=[1046, 36], edge_weight=[6276]),\n",
" Data(x=[618, 12], edge_index=[2, 3708], y=[618, 36], edge_weight=[3708]),\n",
" Data(x=[61, 12], edge_index=[2, 366], y=[61, 36], edge_weight=[366]),\n",
" Data(x=[1316, 12], edge_index=[2, 7896], y=[1316, 36], edge_weight=[7896]),\n",
" Data(x=[1540, 12], edge_index=[2, 9240], y=[1540, 36], edge_weight=[9240]),\n",
" Data(x=[1822, 12], edge_index=[2, 10932], y=[1822, 36], edge_weight=[10932]),\n",
" Data(x=[863, 12], edge_index=[2, 5178], y=[863, 36], edge_weight=[5178]),\n",
" Data(x=[564, 12], edge_index=[2, 3384], y=[564, 36], edge_weight=[3384]),\n",
" Data(x=[1023, 12], edge_index=[2, 6138], y=[1023, 36], edge_weight=[6138]),\n",
" Data(x=[324, 12], edge_index=[2, 1944], y=[324, 36], edge_weight=[1944]),\n",
" Data(x=[287, 12], edge_index=[2, 1722], y=[287, 36], edge_weight=[1722]),\n",
" Data(x=[636, 12], edge_index=[2, 3816], y=[636, 36], edge_weight=[3816]),\n",
" Data(x=[890, 12], edge_index=[2, 5340], y=[890, 36], edge_weight=[5340]),\n",
" Data(x=[1235, 12], edge_index=[2, 7410], y=[1235, 36], edge_weight=[7410]),\n",
" Data(x=[1020, 12], edge_index=[2, 6120], y=[1020, 36], edge_weight=[6120]),\n",
" Data(x=[1241, 12], edge_index=[2, 7446], y=[1241, 36], edge_weight=[7446]),\n",
" Data(x=[1438, 12], edge_index=[2, 8628], y=[1438, 36], edge_weight=[8628]),\n",
" Data(x=[1021, 12], edge_index=[2, 6126], y=[1021, 36], edge_weight=[6126]),\n",
" Data(x=[1632, 12], edge_index=[2, 9792], y=[1632, 36], edge_weight=[9792]),\n",
" Data(x=[780, 12], edge_index=[2, 4680], y=[780, 36], edge_weight=[4680]),\n",
" Data(x=[524, 12], edge_index=[2, 3144], y=[524, 36], edge_weight=[3144]),\n",
" Data(x=[669, 12], edge_index=[2, 4014], y=[669, 36], edge_weight=[4014]),\n",
" Data(x=[241, 12], edge_index=[2, 1446], y=[241, 36], edge_weight=[1446]),\n",
" Data(x=[935, 12], edge_index=[2, 5610], y=[935, 36], edge_weight=[5610]),\n",
" Data(x=[347, 12], edge_index=[2, 2082], y=[347, 36], edge_weight=[2082]),\n",
" Data(x=[1499, 12], edge_index=[2, 8994], y=[1499, 36], edge_weight=[8994]),\n",
" Data(x=[601, 12], edge_index=[2, 3606], y=[601, 36], edge_weight=[3606]),\n",
" Data(x=[2268, 12], edge_index=[2, 13608], y=[2268, 36], edge_weight=[13608]),\n",
" Data(x=[1912, 12], edge_index=[2, 11472], y=[1912, 36], edge_weight=[11472]),\n",
" Data(x=[1678, 12], edge_index=[2, 10068], y=[1678, 36], edge_weight=[10068]),\n",
" Data(x=[1025, 12], edge_index=[2, 6150], y=[1025, 36], edge_weight=[6150]),\n",
" Data(x=[1306, 12], edge_index=[2, 7836], y=[1306, 36], edge_weight=[7836]),\n",
" Data(x=[852, 12], edge_index=[2, 5112], y=[852, 36], edge_weight=[5112]),\n",
" Data(x=[1664, 12], edge_index=[2, 9984], y=[1664, 36], edge_weight=[9984]),\n",
" Data(x=[1698, 12], edge_index=[2, 10188], y=[1698, 36], edge_weight=[10188]),\n",
" Data(x=[1672, 12], edge_index=[2, 10032], y=[1672, 36], edge_weight=[10032]),\n",
" Data(x=[777, 12], edge_index=[2, 4662], y=[777, 36], edge_weight=[4662]),\n",
" Data(x=[556, 12], edge_index=[2, 3336], y=[556, 36], edge_weight=[3336]),\n",
" Data(x=[554, 12], edge_index=[2, 3324], y=[554, 36], edge_weight=[3324]),\n",
" Data(x=[937, 12], edge_index=[2, 5622], y=[937, 36], edge_weight=[5622]),\n",
" Data(x=[1524, 12], edge_index=[2, 9144], y=[1524, 36], edge_weight=[9144]),\n",
" Data(x=[1528, 12], edge_index=[2, 9168], y=[1528, 36], edge_weight=[9168]),\n",
" Data(x=[721, 12], edge_index=[2, 4326], y=[721, 36], edge_weight=[4326]),\n",
" Data(x=[1395, 12], edge_index=[2, 8370], y=[1395, 36], edge_weight=[8370]),\n",
" Data(x=[611, 12], edge_index=[2, 3666], y=[611, 36], edge_weight=[3666]),\n",
" Data(x=[1872, 12], edge_index=[2, 11232], y=[1872, 36], edge_weight=[11232]),\n",
" Data(x=[1217, 12], edge_index=[2, 7302], y=[1217, 36], edge_weight=[7302]),\n",
" Data(x=[1531, 12], edge_index=[2, 9186], y=[1531, 36], edge_weight=[9186]),\n",
" Data(x=[1927, 12], edge_index=[2, 11562], y=[1927, 36], edge_weight=[11562]),\n",
" Data(x=[690, 12], edge_index=[2, 4140], y=[690, 36], edge_weight=[4140]),\n",
" Data(x=[1706, 12], edge_index=[2, 10236], y=[1706, 36], edge_weight=[10236])]"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datas = a2d(adata)\n",
"list(datas)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"vscode": {
"interpreter": {
"hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit 0dda80a

Please sign in to comment.