diff --git a/clocks/notebooks/dnamfitage.ipynb b/clocks/notebooks/dnamfitage.ipynb new file mode 100644 index 0000000..6321d71 --- /dev/null +++ b/clocks/notebooks/dnamfitage.ipynb @@ -0,0 +1,737 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2f04eee0-5928-4e74-a754-6dc2e528810c", + "metadata": {}, + "source": [ + "# DNAmFitAge" + ] + }, + { + "cell_type": "markdown", + "id": "a3f514a3-772c-4a14-afdf-5a8376851ff4", + "metadata": {}, + "source": [ + "## Index\n", + "1. [Instantiate model class](#Instantiate-model-class)\n", + "2. [Define clock metadata](#Define-clock-metadata)\n", + "3. [Download clock dependencies](#Download-clock-dependencies)\n", + "5. [Load features](#Load-features)\n", + "6. [Load weights into base model](#Load-weights-into-base-model)\n", + "7. [Load reference values](#Load-reference-values)\n", + "8. [Load preprocess and postprocess objects](#Load-preprocess-and-postprocess-objects)\n", + "10. [Check all clock parameters](#Check-all-clock-parameters)\n", + "10. [Basic test](#Basic-test)\n", + "11. [Save torch model](#Save-torch-model)\n", + "12. [Clear directory](#Clear-directory)" + ] + }, + { + "cell_type": "markdown", + "id": "d95fafdc-643a-40ea-a689-200bd132e90c", + "metadata": {}, + "source": [ + "Let's first import some packages:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4adfb4de-cd79-4913-a1af-9e23e9e236c9", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import inspect\n", + "import shutil\n", + "import json\n", + "import torch\n", + "import pandas as pd\n", + "import numpy as np\n", + "import pyaging as pya" + ] + }, + { + "cell_type": "markdown", + "id": "145082e5-ced4-47ae-88c0-cb69773e3c5a", + "metadata": {}, + "source": [ + "## Instantiate model class" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8aa77372-7ed3-4da7-abc9-d30372106139", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class DNAmFitAge(pyagingModel):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " self.GaitF = None\n", + " self.GripF = None\n", + " self.GaitM = None\n", + " self.GripM = None\n", + " self.VO2Max = None\n", + "\n", + " self.features_GaitF = None\n", + " self.features_GripF = None\n", + " self.features_GaitM = None\n", + " self.features_GripM = None\n", + " self.features_VO2Max = None\n", + "\n", + " def forward(self, x):\n", + " \n", + " Female = x[:, -3]#.unsqueeze(1)\n", + " Age = x[:, -2]#.unsqueeze(1)\n", + " GrimAge = x[:, -1].unsqueeze(1)\n", + "\n", + " is_female = Female == 1\n", + " is_male = Female == 0\n", + "\n", + " x_f = x[is_female]\n", + " x_m = x[is_male]\n", + "\n", + " GaitF = self.GaitF(x_f[:, self.features_GaitF])\n", + " GripF = self.GripF(x_f[:, self.features_GripF])\n", + " VO2MaxF = self.VO2Max(x_f[:, self.features_VO2Max])\n", + " GrimAgeF = GrimAge[is_female, :]\n", + "\n", + " GaitM = self.GaitM(x_m[:, self.features_GaitM])\n", + " GripM = self.GripM(x_m[:, self.features_GripM])\n", + " VO2MaxM = self.VO2Max(x_m[:, self.features_VO2Max])\n", + " GrimAgeM = GrimAge[is_male, :]\n", + "\n", + " x_f = torch.concat(\n", + " [\n", + " (VO2MaxF - 46.825091)/(-0.13620215),\n", + " (GripF - 39.857718) / (-0.22074456),\n", + " (GaitF - 2.508547) / (-0.01245682),\n", + " (GrimAgeF - 7.978487) / (0.80928530)\n", + " ],\n", + " dim=1,\n", + " )\n", + "\n", + " x_m = torch.concat(\n", + " [\n", + " (VO2MaxM - 49.836389) / (-0.141862925),\n", + " (GripM - 57.514016) / (-0.253179827),\n", + " (GaitM - 2.349080) / (-0.009380061),\n", + " (GrimAgeM - 9.549733) / (0.835120557) \n", + " ],\n", + " dim=1,\n", + " )\n", + "\n", + " y_f = self.base_model_f(x_f)\n", + " y_m = self.base_model_m(x_m)\n", + "\n", + " y = torch.zeros((x.size(0), 1), dtype=x.dtype, device=x.device)\n", + " y[is_female] = y_f\n", + " y[is_male] = y_m\n", + "\n", + " return y\n", + " \n", + " def preprocess(self, x):\n", + " return x\n", + "\n", + " def postprocess(self, x):\n", + " return x\n", + "\n" + ] + } + ], + "source": [ + "def print_entire_class(cls):\n", + " source = inspect.getsource(cls)\n", + " print(source)\n", + "\n", + "print_entire_class(pya.models.DNAmFitAge)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "914a94cf-bf6c-4b9d-862a-a2787842e07e", + "metadata": {}, + "outputs": [], + "source": [ + "model = pya.models.DNAmFitAge()" + ] + }, + { + "cell_type": "markdown", + "id": "51f8615e-01fa-4aa5-b196-3ee2b35d261c", + "metadata": {}, + "source": [ + "## Define clock metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6609d6dc-c0a0-4137-bdf5-9fb31ea85281", + "metadata": {}, + "outputs": [], + "source": [ + "model.metadata[\"clock_name\"] = 'dnamfitage'\n", + "model.metadata[\"data_type\"] = 'methylation'\n", + "model.metadata[\"species\"] = 'Homo sapiens'\n", + "model.metadata[\"year\"] = 2023\n", + "model.metadata[\"approved_by_author\"] = '⌛'\n", + "model.metadata[\"citation\"] = \"McGreevy, Kristen M., et al. \\\"DNAmFitAge: biological age indicator incorporating physical fitness.\\\" Aging (Albany NY) 15.10 (2023): 3904.\"\n", + "model.metadata[\"doi\"] = 'https://doi.org/10.18632/aging.204538'\n", + "model.metadata[\"notes\"] = None" + ] + }, + { + "cell_type": "markdown", + "id": "74492239-5aae-4026-9d90-6bc9c574c110", + "metadata": {}, + "source": [ + "## Download clock dependencies" + ] + }, + { + "cell_type": "markdown", + "id": "7bec474f-80ce-4884-9472-30c193327117", + "metadata": {}, + "source": [ + "#### Download GitHub repository" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "aa4a1b59-dda3-4ea8-8f34-b3c53ecbc310", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "github_url = \"https://github.com/kristenmcgreevy/DNAmFitAge.git\"\n", + "github_folder_name = github_url.split('/')[-1].split('.')[0]\n", + "os.system(f\"git clone {github_url}\")" + ] + }, + { + "cell_type": "markdown", + "id": "6bd15521-363f-4029-99ff-9f0b2ae0ed2e", + "metadata": {}, + "source": [ + "#### Download from R package" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f1f9bbe4-cfc8-494c-b910-c96da88afb2b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing download.r\n" + ] + } + ], + "source": [ + "%%writefile download.r\n", + "\n", + "options(repos = c(CRAN = \"https://cloud.r-project.org/\"))\n", + "library(jsonlite)\n", + "\n", + "DNAmFitnessModels <- readRDS(\"DNAmFitAge/DNAmFitnessModelsandFitAge_Oct2022.rds\")\n", + "\n", + "AllCpGs <- DNAmFitnessModels$AllCpGs\n", + "write_json(AllCpGs, \"AllCpGs.json\")\n", + "\n", + "MaleMedians <- DNAmFitnessModels$Male_Medians_All\n", + "write.csv(MaleMedians, \"MaleMedians.csv\")\n", + "FemaleMedians <- DNAmFitnessModels$Female_Medians_All\n", + "write.csv(FemaleMedians, \"FemaleMedians.csv\")\n", + "\n", + "Gait_noAge_Females <- DNAmFitnessModels$Gait_noAge_Females\n", + "Gait_noAge_Males <- DNAmFitnessModels$Gait_noAge_Males\n", + "Grip_noAge_Females <- DNAmFitnessModels$Grip_noAge_Females\n", + "Grip_noAge_Males <- DNAmFitnessModels$Grip_noAge_Males\n", + "VO2maxModel <- DNAmFitnessModels$VO2maxModel\n", + "write.csv(Gait_noAge_Females, \"Gait_noAge_Females.csv\")\n", + "write.csv(Gait_noAge_Males, \"Gait_noAge_Males.csv\")\n", + "write.csv(Grip_noAge_Females, \"Grip_noAge_Females.csv\")\n", + "write.csv(Grip_noAge_Males, \"Grip_noAge_Males.csv\")\n", + "write.csv(VO2maxModel, \"VO2maxModel.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f1965587-a6ac-47ce-bd7a-bb98ca1d91b5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "os.system(\"Rscript download.r\")" + ] + }, + { + "cell_type": "markdown", + "id": "5035b180-3d1b-4432-8ebe-b9c92bd93a7f", + "metadata": {}, + "source": [ + "## Load features" + ] + }, + { + "cell_type": "markdown", + "id": "d8025ed7-0013-419b-8cb5-1a2db98f9eba", + "metadata": {}, + "source": [ + "#### From JSON file" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a1d6ee3c-c028-43fa-97aa-2631ada774e8", + "metadata": {}, + "outputs": [], + "source": [ + "with open('AllCpGs.json', 'r') as f:\n", + " features_list = json.load(f)\n", + "model.features = features_list + ['female'] + ['age'] + ['grimage']" + ] + }, + { + "cell_type": "markdown", + "id": "ee6d8fa0-4767-4c45-9717-eb1c95e2ddc0", + "metadata": {}, + "source": [ + "## Load weights into base model" + ] + }, + { + "cell_type": "markdown", + "id": "d79e5690-e284-4de6-8460-d3545a8192af", + "metadata": {}, + "source": [ + "#### From CSV file" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7f6187ed-fcff-4ff2-bcb1-b5bcef8190e8", + "metadata": {}, + "outputs": [], + "source": [ + "gaitf_df = pd.read_csv('Gait_noAge_Females.csv', index_col=0)\n", + "gaitm_df = pd.read_csv('Gait_noAge_Males.csv', index_col=0)\n", + "gripf_df = pd.read_csv('Grip_noAge_Females.csv', index_col=0)\n", + "gripm_df = pd.read_csv('Grip_noAge_Males.csv', index_col=0)\n", + "vo2max_df = pd.read_csv('VO2maxModel.csv', index_col=0)" + ] + }, + { + "cell_type": "markdown", + "id": "69901c2b-9584-4de3-a642-ddb6b43d923a", + "metadata": {}, + "source": [ + "#### Linear model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "5fb10110-a89a-4caa-a62a-59899ebccd23", + "metadata": {}, + "outputs": [], + "source": [ + "all_features = features_list + ['Female'] + ['Age'] + ['GrimAge']\n", + "\n", + "model.GaitF = pya.models.LinearModel(input_dim=len(gaitf_df))\n", + "model.GaitF.linear.weight.data = torch.tensor(np.array(gaitf_df['estimate'][1:])).unsqueeze(0).float()\n", + "model.GaitF.linear.bias.data = torch.tensor(np.array(gaitf_df['estimate'].iloc[0])).float()\n", + "model.features_GaitF = torch.tensor([all_features.index(item) for item in np.array(gaitf_df['term'][1:]) if item in all_features]).long()\n", + "\n", + "model.GaitM = pya.models.LinearModel(input_dim=len(gaitm_df))\n", + "model.GaitM.linear.weight.data = torch.tensor(np.array(gaitm_df['estimate'][1:])).unsqueeze(0).float()\n", + "model.GaitM.linear.bias.data = torch.tensor(np.array(gaitm_df['estimate'].iloc[0])).float()\n", + "model.features_GaitM = torch.tensor([all_features.index(item) for item in np.array(gaitm_df['term'][1:]) if item in all_features]).long()\n", + "\n", + "model.GripF = pya.models.LinearModel(input_dim=len(gripf_df))\n", + "model.GripF.linear.weight.data = torch.tensor(np.array(gripf_df['estimate'][1:])).unsqueeze(0).float()\n", + "model.GripF.linear.bias.data = torch.tensor(np.array(gripf_df['estimate'].iloc[0])).float()\n", + "model.features_GripF = torch.tensor([all_features.index(item) for item in np.array(gripf_df['term'][1:]) if item in all_features]).long()\n", + "\n", + "model.GaitM = pya.models.LinearModel(input_dim=len(gaitm_df))\n", + "model.GaitM.linear.weight.data = torch.tensor(np.array(gaitm_df['estimate'][1:])).unsqueeze(0).float()\n", + "model.GaitM.linear.bias.data = torch.tensor(np.array(gaitm_df['estimate'].iloc[0])).float()\n", + "model.features_GaitM = torch.tensor([all_features.index(item) for item in np.array(gaitm_df['term'][1:]) if item in all_features]).long()\n", + "\n", + "model.GripM = pya.models.LinearModel(input_dim=len(gripm_df))\n", + "model.GripM.linear.weight.data = torch.tensor(np.array(gripm_df['estimate'][1:])).unsqueeze(0).float()\n", + "model.GripM.linear.bias.data = torch.tensor(np.array(gripm_df['estimate'].iloc[0])).float()\n", + "model.features_GripM = torch.tensor([all_features.index(item) for item in np.array(gripm_df['term'][1:]) if item in all_features]).long()\n", + "\n", + "model.VO2Max = pya.models.LinearModel(input_dim=len(vo2max_df))\n", + "model.VO2Max.linear.weight.data = torch.tensor(np.array(vo2max_df['estimate'][1:])).unsqueeze(0).float()\n", + "model.VO2Max.linear.bias.data = torch.tensor(np.array(vo2max_df['estimate'].iloc[0])).float()\n", + "model.features_VO2Max = torch.tensor([all_features.index(item) for item in np.array(vo2max_df['term'][1:]) if item in all_features]).long()" + ] + }, + { + "cell_type": "markdown", + "id": "ad261636-5b00-4979-bb1d-67a851f7aa19", + "metadata": {}, + "source": [ + "#### Linear model" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d7f43b99-26f2-4622-9a76-316712058877", + "metadata": {}, + "outputs": [], + "source": [ + "base_model_m = pya.models.LinearModel(input_dim=4)\n", + "\n", + "base_model_m.linear.weight.data = torch.tensor(np.array([0.1390346, 0.1787371, 0.1593873, 0.5228411])).unsqueeze(0).float()\n", + "base_model_m.linear.bias.data = torch.tensor(np.array([0.0])).float()\n", + "\n", + "model.base_model_m = base_model_m\n", + "\n", + "base_model_f = pya.models.LinearModel(input_dim=4)\n", + "\n", + "base_model_f.linear.weight.data = torch.tensor(np.array([0.1044232, 0.1742083, 0.2278776, 0.4934908])).unsqueeze(0).float()\n", + "base_model_f.linear.bias.data = torch.tensor(np.array([0.0])).float()\n", + "\n", + "model.base_model_f = base_model_f" + ] + }, + { + "cell_type": "markdown", + "id": "ad8b4c1d-9d57-48b7-9a30-bcfea7b747b1", + "metadata": {}, + "source": [ + "## Load reference values" + ] + }, + { + "cell_type": "markdown", + "id": "f7fdae64-096a-4640-ade7-6a17b78a01d5", + "metadata": {}, + "source": [ + "#### From CSV file" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e1dc004f-06b7-4e24-a937-00736e93765f", + "metadata": {}, + "outputs": [], + "source": [ + "reference_df_f = pd.read_csv('FemaleMedians.csv', index_col=0)\n", + "reference_f = reference_df_f.loc[1, model.features[:-3]]\n", + "reference_df_m = pd.read_csv('MaleMedians.csv', index_col=0)\n", + "reference_m = reference_df_m.loc[1, model.features[:-3]]\n", + "reference = (reference_f + reference_m)/2\n", + "model.reference_values = list(reference) + [1] + [65] + [65] #65yo F with 65GrimAge" + ] + }, + { + "cell_type": "markdown", + "id": "af3bcf7b-74a8-4d21-9ccb-4de0c2b0516b", + "metadata": {}, + "source": [ + "## Load preprocess and postprocess objects" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "79a1b3a2-00f1-42b1-9fcd-f919343391d7", + "metadata": {}, + "outputs": [], + "source": [ + "model.preprocess_name = None\n", + "model.preprocess_dependencies = None" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "ff4a21cb-cf41-44dc-9ed1-95cf8aa15772", + "metadata": {}, + "outputs": [], + "source": [ + "model.postprocess_name = None\n", + "model.postprocess_dependencies = None" + ] + }, + { + "cell_type": "markdown", + "id": "86e3d6b1-e67e-4f3d-bd39-0ebec5726c3c", + "metadata": {}, + "source": [ + "## Check all clock parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "2168355c-47d9-475d-b816-49f65e74887c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "%==================================== Model Details ====================================%\n", + "Model Attributes:\n", + "\n", + "training: True\n", + "metadata: {'approved_by_author': '⌛',\n", + " 'citation': 'McGreevy, Kristen M., et al. \"DNAmFitAge: biological age '\n", + " 'indicator incorporating physical fitness.\" Aging (Albany NY) '\n", + " '15.10 (2023): 3904.',\n", + " 'clock_name': 'dnamfitage',\n", + " 'data_type': 'methylation',\n", + " 'doi': 'https://doi.org/10.18632/aging.204538',\n", + " 'notes': None,\n", + " 'species': 'Homo sapiens',\n", + " 'version': None,\n", + " 'year': 2023}\n", + "reference_values: [0.521913219528255, 0.28125954210819004, 0.9275230085489266, 0.01455467410155745, 0.041014116925727054, 0.12647568639954998, 0.7148617994059816, 0.6786637301838809, 0.909376031310397, 0.1136806555747305, 0.45398237911395245, 0.0544492346267719, 0.7738429377348031, 0.8480746411296824, 0.7667083937960659, 0.0159858833215953, 0.7183128068669931, 0.06813828137044395, 0.939547714031041, 0.8290646522154059, 0.01727972597475225, 0.0697125677059708, 0.366626793673691, 0.588925514102081, 0.02786566159606685, 0.8252930680510391, 0.211681997462417, 0.01269953071843695, 0.7886011964686286, 0.871255311509148]... [Total elements: 630]\n", + "preprocess_name: None\n", + "preprocess_dependencies: None\n", + "postprocess_name: None\n", + "postprocess_dependencies: None\n", + "features: ['cg25137787', 'cg08911391', 'cg24685778', 'cg25551287', 'cg15322207', 'cg24604749', 'cg03890680', 'cg14601038', 'cg25587481', 'cg23922134', 'cg18561976', 'cg08175029', 'cg23202468', 'cg06181470', 'cg14422932', 'cg15128470', 'cg13587180', 'cg25440680', 'cg16995193', 'cg12864235', 'cg23715435', 'cg02805890', 'cg06381959', 'cg00779476', 'cg25489467', 'cg02376916', 'cg17741339', 'cg20219159', 'cg19629631', 'cg25577212']... [Total elements: 630]\n", + "base_model_features: None\n", + "base_model: None\n", + "features_GaitF: [272, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 281, 514, 515, 516, 517, 518, 301, 519, 305, 520, 521, 522, 523, 524, 525, 526, 314, 318, 527]... [Tensor of shape torch.Size([53])]\n", + "features_GripF: [272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301]... [Tensor of shape torch.Size([91])]\n", + "features_GaitM: [545, 546, 481, 547, 548, 483, 549, 550, 485, 551, 486, 148, 552, 553, 554, 488, 149, 555, 556, 557, 558, 559, 234, 560, 561, 562, 563, 564, 465, 565]... [Tensor of shape torch.Size([59])]\n", + "features_GripM: [361, 362, 363, 209, 364, 365, 366, 367, 211, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 225]... [Tensor of shape torch.Size([93])]\n", + "features_VO2Max: [587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616]... [Tensor of shape torch.Size([41])]\n", + "\n", + "%==================================== Model Details ====================================%\n", + "Model Structure:\n", + "\n", + "GaitF: LinearModel(\n", + " (linear): Linear(in_features=54, out_features=1, bias=True)\n", + ")\n", + "GaitM: LinearModel(\n", + " (linear): Linear(in_features=60, out_features=1, bias=True)\n", + ")\n", + "GripF: LinearModel(\n", + " (linear): Linear(in_features=92, out_features=1, bias=True)\n", + ")\n", + "GripM: LinearModel(\n", + " (linear): Linear(in_features=94, out_features=1, bias=True)\n", + ")\n", + "VO2Max: LinearModel(\n", + " (linear): Linear(in_features=42, out_features=1, bias=True)\n", + ")\n", + "base_model_m: LinearModel(\n", + " (linear): Linear(in_features=4, out_features=1, bias=True)\n", + ")\n", + "base_model_f: LinearModel(\n", + " (linear): Linear(in_features=4, out_features=1, bias=True)\n", + ")\n", + "\n", + "%==================================== Model Details ====================================%\n", + "Model Parameters and Weights:\n", + "\n", + "GaitF.linear.weight: [-0.05867812782526016, -0.05637867748737335, -0.10371068120002747, 0.01624305173754692, -0.053210534155368805, -0.07633326947689056, -0.01514248363673687, -0.049918416887521744, -0.013779371976852417, 0.14702405035495758, 0.22061608731746674, -0.6326411366462708, -0.40431228280067444, 0.06633924692869186, -0.2228449434041977, -0.03177845478057861, -0.35903501510620117, 0.4153103232383728, -0.837234616279602, 0.056484829634428024, -0.13299566507339478, -0.058516617864370346, 0.04777200520038605, 0.13982263207435608, -0.1280703842639923, -0.03444225341081619, -0.05433110147714615, -0.4258767366409302, 0.0011224570916965604, 0.01846371404826641]... [Tensor of shape torch.Size([1, 53])]\n", + "GaitF.linear.bias: tensor(3.9701)\n", + "GaitM.linear.weight: [0.10235995799303055, 0.08753516525030136, 0.3124901354312897, -0.28120002150535583, -0.3208324611186981, 0.24479524791240692, 0.05682919919490814, 0.21363066136837006, -0.3853186368942261, -0.038501303642988205, -0.0023554968647658825, -0.17415688931941986, 0.05159717798233032, -0.5185700058937073, -0.04655730724334717, -0.19074112176895142, -0.21096128225326538, 0.011959427036345005, 0.1078566312789917, 0.0770212784409523, 0.18820391595363617, 0.43347951769828796, -0.13240143656730652, 0.021351546049118042, -0.12319610267877579, -0.010150707326829433, -0.007736711762845516, 0.13240836560726166, -1.1829639673233032, -0.10984379798173904]... [Tensor of shape torch.Size([1, 59])]\n", + "GaitM.linear.bias: tensor(3.1825)\n", + "GripF.linear.weight: [-2.3457672595977783, -2.94146728515625, 2.8132119178771973, -0.9427763223648071, 3.2280826568603516, -0.5796002745628357, -5.075088977813721, -1.3516250848770142, 3.539742946624756, -6.468724727630615, -3.5424692630767822, -6.332897663116455, 4.4002580642700195, 10.170988082885742, -0.5222252011299133, -2.993544101715088, -0.7089398503303528, -3.3968186378479004, 0.9145923852920532, 1.0081183910369873, -2.5558736324310303, -1.6970638036727905, 2.0081098079681396, 0.2233070731163025, -3.5272421836853027, -4.740792274475098, -2.4629898071289062, 0.7111413478851318, -11.599475860595703, 3.976231575012207]... [Tensor of shape torch.Size([1, 91])]\n", + "GripF.linear.bias: tensor(53.8206)\n", + "GripM.linear.weight: [-0.892844021320343, 1.6379542350769043, 1.5265462398529053, -0.3347032964229584, -1.9029316902160645, -0.2647155225276947, -6.30814266204834, 14.954833984375, 1.178484320640564, 3.5211784839630127, -0.1861504316329956, -1.6255935430526733, 4.550158977508545, -1.587499976158142, -0.449662446975708, -8.599822998046875, 25.895660400390625, 4.368823051452637, 3.992393970489502, 1.3252184391021729, -2.2360410690307617, -0.6896253228187561, 1.5932470560073853, 1.5443568229675293, -0.7052236795425415, -3.0787854194641113, -0.2242996096611023, -0.23673297464847565, 2.1442930698394775, -0.3954241871833801]... [Tensor of shape torch.Size([1, 93])]\n", + "GripM.linear.bias: tensor(43.0198)\n", + "VO2Max.linear.weight: [5.249129772186279, 3.0901758670806885, -7.551166534423828, -5.796545028686523, -1.094834804534912, -2.3806116580963135, -0.0022889631800353527, 1.0938740968704224, -1.4775551557540894, 1.4427802562713623, 1.268430471420288, 5.4764933586120605, -8.934550285339355, -1.9918478727340698, -5.6620774269104, -6.2174201011657715, -0.6082701086997986, -7.513339996337891, -1.4299590587615967, -3.6723220348358154, 14.669830322265625, 0.5884844660758972, -0.9597266912460327, -1.0253041982650757, -1.802089810371399, -4.9922356605529785, -0.6746888160705566, -10.973499298095703, -0.6614307761192322, 2.365175247192383]... [Tensor of shape torch.Size([1, 41])]\n", + "VO2Max.linear.bias: tensor(69.6523)\n", + "base_model_m.linear.weight: tensor([[0.1390, 0.1787, 0.1594, 0.5228]])\n", + "base_model_m.linear.bias: tensor([0.])\n", + "base_model_f.linear.weight: tensor([[0.1044, 0.1742, 0.2279, 0.4935]])\n", + "base_model_f.linear.bias: tensor([0.])\n", + "\n", + "%==================================== Model Details ====================================%\n", + "\n" + ] + } + ], + "source": [ + "pya.utils.print_model_details(model)" + ] + }, + { + "cell_type": "markdown", + "id": "986d0262-e0c7-4036-b687-dee53ba392fb", + "metadata": {}, + "source": [ + "## Basic test" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "352cffb0-c5a8-4c82-8f61-fce35baf5a22", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.],\n", + " [0.],\n", + " [0.],\n", + " [0.],\n", + " [0.],\n", + " [0.],\n", + " [0.],\n", + " [0.],\n", + " [0.],\n", + " [0.]], dtype=torch.float64, grad_fn=)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.manual_seed(42)\n", + "input = torch.randn(10, len(model.features), dtype=float)\n", + "model.eval()\n", + "model.to(float)\n", + "pred = model(input)\n", + "pred" + ] + }, + { + "cell_type": "markdown", + "id": "fe8299d7-9285-4e22-82fd-b664434b4369", + "metadata": {}, + "source": [ + "## Save torch model" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "0c3a2d80-1b5f-458a-926c-cbc0aa9416e1", + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(model, f\"../weights/{model.metadata['clock_name']}.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "bac6257b-8d08-4a90-8d0b-7f745dc11ac1", + "metadata": {}, + "source": [ + "## Clear directory\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "11aeaa70-44c0-42f9-86d7-740e3849a7a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deleted file: Grip_noAge_Females.csv\n", + "Deleted file: Grip_noAge_Males.csv\n", + "Deleted file: Gait_noAge_Females.csv\n", + "Deleted file: VO2maxModel.csv\n", + "Deleted file: AllCpGs.json\n", + "Deleted file: Gait_noAge_Males.csv\n", + "Deleted folder: DNAmFitAge\n", + "Deleted file: download.r\n", + "Deleted file: FemaleMedians.csv\n", + "Deleted file: MaleMedians.csv\n" + ] + } + ], + "source": [ + "# Function to remove a folder and all its contents\n", + "def remove_folder(path):\n", + " try:\n", + " shutil.rmtree(path)\n", + " print(f\"Deleted folder: {path}\")\n", + " except Exception as e:\n", + " print(f\"Error deleting folder {path}: {e}\")\n", + "\n", + "# Get a list of all files and folders in the current directory\n", + "all_items = os.listdir('.')\n", + "\n", + "# Loop through the items\n", + "for item in all_items:\n", + " # Check if it's a file and does not end with .ipynb\n", + " if os.path.isfile(item) and not item.endswith('.ipynb'):\n", + " os.remove(item)\n", + " print(f\"Deleted file: {item}\")\n", + " # Check if it's a folder\n", + " elif os.path.isdir(item):\n", + " remove_folder(item)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.17" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/clock_implementation.rst b/docs/source/clock_implementation.rst index bee4f49..9771e1c 100644 --- a/docs/source/clock_implementation.rst +++ b/docs/source/clock_implementation.rst @@ -16,6 +16,7 @@ The following collection of Jupyter Notebooks provides a comprehensive guide to clock_notebooks/camilloh3k9ac clock_notebooks/camilloh3k9me3 clock_notebooks/camillopanhistone + clock_notebooks/dnamfitage clock_notebooks/dnamphenoage clock_notebooks/dnamtl clock_notebooks/dunedinpace diff --git a/pyaging/models/_models.py b/pyaging/models/_models.py index bd7eb49..328356a 100644 --- a/pyaging/models/_models.py +++ b/pyaging/models/_models.py @@ -1307,3 +1307,77 @@ def preprocess(self, x): def postprocess(self, x): return x + + +class DNAmFitAge(pyagingModel): + def __init__(self): + super().__init__() + + self.GaitF = None + self.GripF = None + self.GaitM = None + self.GripM = None + self.VO2Max = None + + self.features_GaitF = None + self.features_GripF = None + self.features_GaitM = None + self.features_GripM = None + self.features_VO2Max = None + + def forward(self, x): + + Female = x[:, -3]#.unsqueeze(1) + Age = x[:, -2]#.unsqueeze(1) + GrimAge = x[:, -1].unsqueeze(1) + + is_female = Female == 1 + is_male = Female == 0 + + x_f = x[is_female] + x_m = x[is_male] + + GaitF = self.GaitF(x_f[:, self.features_GaitF]) + GripF = self.GripF(x_f[:, self.features_GripF]) + VO2MaxF = self.VO2Max(x_f[:, self.features_VO2Max]) + GrimAgeF = GrimAge[is_female, :] + + GaitM = self.GaitM(x_m[:, self.features_GaitM]) + GripM = self.GripM(x_m[:, self.features_GripM]) + VO2MaxM = self.VO2Max(x_m[:, self.features_VO2Max]) + GrimAgeM = GrimAge[is_male, :] + + x_f = torch.concat( + [ + (VO2MaxF - 46.825091)/(-0.13620215), + (GripF - 39.857718) / (-0.22074456), + (GaitF - 2.508547) / (-0.01245682), + (GrimAgeF - 7.978487) / (0.80928530) + ], + dim=1, + ) + + x_m = torch.concat( + [ + (VO2MaxM - 49.836389) / (-0.141862925), + (GripM - 57.514016) / (-0.253179827), + (GaitM - 2.349080) / (-0.009380061), + (GrimAgeM - 9.549733) / (0.835120557) + ], + dim=1, + ) + + y_f = self.base_model_f(x_f) + y_m = self.base_model_m(x_m) + + y = torch.zeros((x.size(0), 1), dtype=x.dtype, device=x.device) + y[is_female] = y_f + y[is_male] = y_m + + return y + + def preprocess(self, x): + return x + + def postprocess(self, x): + return x diff --git a/pyproject.toml b/pyproject.toml index 3367448..a5e7071 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pyaging" -version = "v0.1.2" +version = "v0.1.3" description = "A Python-based compendium of GPU-optimized aging clocks." authors = ["Lucas Paulo de Lima Camillo "] license = "BSD" diff --git a/tests/predict/test_gold_standard.py b/tests/predict/test_gold_standard.py index 9c59830..df5acc8 100644 --- a/tests/predict/test_gold_standard.py +++ b/tests/predict/test_gold_standard.py @@ -57,10 +57,13 @@ 'thompson': 164.57995856164365, 'zhangblup': 78.76779185124363, 'zhangen': 37.404900683228966, - 'zhangmortality': 2.8135717975793475 + 'zhangmortality': 2.8135717975793475, + 'dnamfitage': 91.03008383895092, + 'yingcausage': 195.3013578758023, + 'yingadaptage': 173.48314231920278, + 'yingdamage': -53.509282005508, } - def test_all_clocks(): all_clocks = list(gold_standard_dict.keys())