Skip to content

Commit

Permalink
🚧 add scikit-learn style interface with examples
Browse files Browse the repository at this point in the history
- minor updates in formatting in "sklearn" folder
  • Loading branch information
Henry committed Aug 17, 2023
1 parent d267f10 commit b377a45
Show file tree
Hide file tree
Showing 6 changed files with 760 additions and 11 deletions.
289 changes: 289 additions & 0 deletions project/04_1_train_pimms_models.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "eae0a078",
"metadata": {},
"source": [
"# Scikit-learn styple transformers of the data\n",
"\n",
"1. Load data into pandas dataframe\n",
"2. Fit transformer on training data\n",
"3. Impute only missing values with predictions from model\n",
"\n",
"Autoencoders need wide training data, i.e. a sample with all its features' intensities, whereas\n",
"Collaborative Filtering needs long training data, i.e. sample identifier a feature identifier and the intensity.\n",
"Both data formats can be transformed into each other, but models using long data format do not need to\n",
"take care of missing values."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a3edbdd",
"metadata": {},
"outputs": [],
"source": [
"from vaep.sklearn.ae_transformer import AETransformer\n",
"import vaep.sampling\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from vaep.sklearn.cf_transformer import CollaborativeFilteringTransformer\n",
"\n",
"fn_intensities: str = 'data/dev_datasets/HeLa_6070/protein_groups_wide_N50.csv'\n",
"df = pd.read_csv(fn_intensities, index_col=0)\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"id": "727b3ace",
"metadata": {},
"source": [
"We will need the data in long format. Naming both the row and column index assures\n",
"that the data can be transformed very easily into long format:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4fde25e9",
"metadata": {},
"outputs": [],
"source": [
"df.index.name = 'Sample ID' # already set\n",
"df.columns.name = 'protein group' # not set due to csv disk file format\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "646ea5bb",
"metadata": {},
"outputs": [],
"source": [
"df = df.stack().to_frame('intensity')\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "554d4fa7",
"metadata": {},
"outputs": [],
"source": [
"df = np.log2(df)\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"id": "7792ce6e",
"metadata": {},
"source": [
"The resulting DataFrame with one column has an `MulitIndex` with the sample and feature identifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "567854c0",
"metadata": {},
"outputs": [],
"source": [
"CollaborativeFilteringTransformer?"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b547a77",
"metadata": {},
"outputs": [],
"source": [
"cf_model = CollaborativeFilteringTransformer(\n",
" target_column='intensity',\n",
" sample_column='Sample ID',\n",
" item_column='protein group',\n",
" out_folder='runs/scikit_interface')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb5ac432",
"metadata": {},
"outputs": [],
"source": [
"cf_model.fit(df,\n",
" cuda=True,\n",
" epochs_max=5,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3dac537",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"df_imputed = cf_model.transform(df).unstack()\n",
"assert df_imputed.isna().sum().sum() == 0\n",
"df_imputed.head()"
]
},
{
"cell_type": "markdown",
"id": "a6d6552c",
"metadata": {
"lines_to_next_cell": 2
},
"source": [
"## AutoEncoder architectures"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7184c2e",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# Reload data (for demonstration)\n",
"\n",
"fn_intensities: str = 'data/dev_datasets/HeLa_6070/protein_groups_wide_N50.csv'\n",
"df = pd.read_csv(fn_intensities, index_col=0)\n",
"df.index.name = 'Sample ID' # already set\n",
"df.columns.name = 'protein group' # not set due to csv disk file format\n",
"df = np.log2(df) # log transform\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8bbd0017",
"metadata": {},
"outputs": [],
"source": [
"freq_feat = df.notna().sum()\n",
"freq_feat.head() # training data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99e690f9",
"metadata": {},
"outputs": [],
"source": [
"val_X, train_X = vaep.sampling.sample_data(df.stack(),\n",
" sample_index_to_drop=0,\n",
" weights=freq_feat,\n",
" frac=0.1,\n",
" random_state=42,)\n",
"val_X, train_X = val_X.unstack(), train_X.unstack()\n",
"val_X = pd.DataFrame(pd.NA, index=train_X.index,\n",
" columns=train_X.columns).fillna(val_X)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "873f3668",
"metadata": {},
"outputs": [],
"source": [
"val_X.shape, train_X.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62cd0721",
"metadata": {},
"outputs": [],
"source": [
"train_X.notna().sum().sum(), val_X.notna().sum().sum(),"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26a12a3e",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"model = AETransformer(\n",
" model='VAE',\n",
" # model='DAE',\n",
" hidden_layers=[512,],\n",
" latent_dim=50,\n",
" out_folder='runs/scikit_interface',\n",
" batch_size=10,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d3c7922",
"metadata": {},
"outputs": [],
"source": [
"model.fit(train_X, val_X,\n",
" epochs_max=5,\n",
" cuda=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24ca6c2c",
"metadata": {},
"outputs": [],
"source": [
"df_imputed = model.transform(train_X)\n",
"df_imputed"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f89799e8",
"metadata": {},
"outputs": [],
"source": [
"# replace predicted values with val_X values\n",
"df_imputed = df_imputed.replace(val_X)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a235f133",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit b377a45

Please sign in to comment.