Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update pytorch-lightning examples #430

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
"\n",
"# Transfer Learning Using PyTorch Lightning ⚡️\n",
"\n",
"In this colab, we will extend the pipeline [here](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb) to perform transfer learning with PyTorch Lightning. \n",
"In this colab, we will extend the pipeline [here](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb) to perform transfer learning with PyTorch Lightning.\n",
"\n",
"Transfer Learning is a technique where the knowledge learned while training a model for \"task\" A and can be used for \"task\" B. Here A and B can be the same deep learning tasks but on a different dataset. \n"
"Transfer Learning is a technique where the knowledge learned while training a model for \"task\" A and can be used for \"task\" B. Here A and B can be the same deep learning tasks but on a different dataset.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setting up PyTorch Lightning and W&B \n",
"## Setting up PyTorch Lightning and W&B\n",
"\n",
"For this tutorial, we need PyTorch Lightning and Weights and Biases."
]
Expand All @@ -38,7 +38,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install wandb pytorch-lightning -qqq"
"!pip install -U wandb pytorch-lightning -qqq"
]
},
{
Expand Down Expand Up @@ -68,7 +68,7 @@
"from torchmetrics import Accuracy\n",
"\n",
"from torchvision import transforms\n",
"from torchvision.datasets import StanfordCars\n",
"from torchvision.datasets import Flowers102\n",
"from torchvision.datasets.utils import download_url\n",
"import torchvision.models as models\n",
"\n",
Expand Down Expand Up @@ -98,7 +98,9 @@
"source": [
"## The Dataset 💿\n",
"\n",
"We will be using the StanfordCars dataset to train our image classifier. It contains 16,185 images of 196 classes of cars. The data is split into 8,144 training images and 8,041 testing images, where each class has been split roughly in a 50-50 split. Classes are typically at the level of Make, Model, Year, e.g. 2012 Tesla Model S or 2012 BMW M3 coupe."
"We will be using the Oxford 102 Flower dataset to train our image classifier. The dataset consists of 102 flower categories. The flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of between 40 and 258 images.\n",
"\n",
"The images have large scale, pose and light variations. In addition, there are categories that have large variations within the category, and several very similar categories."
]
},
{
Expand All @@ -107,7 +109,7 @@
"metadata": {},
"outputs": [],
"source": [
"class StanfordCarsDataModule(pl.LightningDataModule):\n",
"class Flowers102DataModule(pl.LightningDataModule):\n",
" def __init__(self, batch_size, data_dir: str = './'):\n",
" super().__init__()\n",
" self.data_dir = data_dir\n",
Expand All @@ -129,26 +131,26 @@
" transforms.ToTensor(),\n",
" transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])\n",
" ])\n",
" \n",
"\n",
" self.num_classes = 196\n",
"\n",
" def prepare_data(self):\n",
" pass\n",
"\n",
" def setup(self, stage=None):\n",
" # build dataset\n",
" dataset = StanfordCars(root=self.data_dir, download=True, split=\"train\")\n",
" dataset = Flowers102(root=self.data_dir, download=True, split=\"train\")\n",
" # split dataset\n",
" self.train, self.val = random_split(dataset, [6500, 1644])\n",
" self.train, self.val = random_split(dataset, [816, 204])\n",
"\n",
" self.test = Flowers102(root=self.data_dir, download=True, split=\"test\")\n",
"\n",
" self.test = StanfordCars(root=self.data_dir, download=True, split=\"test\")\n",
" \n",
" self.test = random_split(self.test, [len(self.test)])[0]\n",
"\n",
" self.train.dataset.transform = self.augmentation\n",
" self.val.dataset.transform = self.transform\n",
" self.test.dataset.transform = self.transform\n",
" \n",
"\n",
" def train_dataloader(self):\n",
" return DataLoader(self.train, batch_size=self.batch_size, shuffle=True, num_workers=2)\n",
"\n",
Expand All @@ -166,9 +168,9 @@
"## LightingModule - Define the System\n",
"\n",
"Let us look at the model definition to see how transfer learning can be used with PyTorch Lightning.\n",
"In the `LitModel` class, we can use the pre-trained model provided by Torchvision as a feature extractor for our classification model. Here we are using ResNet-18. A list of pre-trained models provided by PyTorch Lightning can be found here.\n",
"In the `LitModel` class, we can use the pre-trained model provided by Torchvision as a feature extractor for our classification model. Here we are using ResNet-18. A list of pre-trained models provided by PyTorch Lightning can be found [here](https://pytorch.org/vision/stable/models.html#classification).\n",
"- When `pretrained=True`, we use the pre-trained weights; otherwise, the weights are initialized randomly.\n",
"- If `.eval()` is used, then the layers are frozen. \n",
"- If `.eval()` is used, then the layers are frozen.\n",
"- A single `Linear` layer is used as the output layer. We can have multiple layers stacked over the `feature_extractor`.\n",
"\n",
"Setting the `transfer` argument to `True` will enable transfer learning."
Expand All @@ -183,13 +185,13 @@
"class LitModel(pl.LightningModule):\n",
" def __init__(self, input_shape, num_classes, learning_rate=2e-4, transfer=False):\n",
" super().__init__()\n",
" \n",
"\n",
" # log hyperparameters\n",
" self.save_hyperparameters()\n",
" self.learning_rate = learning_rate\n",
" self.dim = input_shape\n",
" self.num_classes = num_classes\n",
" \n",
"\n",
" # transfer learning if pretrained=True\n",
" self.feature_extractor = models.resnet18(pretrained=transfer)\n",
"\n",
Expand All @@ -199,36 +201,39 @@
" # freeze params\n",
" for param in self.feature_extractor.parameters():\n",
" param.requires_grad = False\n",
" \n",
"\n",
" n_sizes = self._get_conv_output(input_shape)\n",
"\n",
" self.dropout = nn.Dropout(0.5)\n",
" self.classifier = nn.Linear(n_sizes, num_classes)\n",
"\n",
" self.criterion = nn.CrossEntropyLoss()\n",
" self.accuracy = Accuracy()\n",
" \n",
" self.accuracy = Accuracy(task=\"multiclass\", num_classes=num_classes)\n",
" self.test_step_outputs = []\n",
"\n",
" # returns the size of the output tensor going into the Linear layer from the conv block.\n",
" def _get_conv_output(self, shape):\n",
" batch_size = 1\n",
" tmp_input = torch.autograd.Variable(torch.rand(batch_size, *shape))\n",
"\n",
" output_feat = self._forward_features(tmp_input) \n",
" output_feat = self._forward_features(tmp_input)\n",
" n_size = output_feat.data.view(batch_size, -1).size(1)\n",
" return n_size\n",
" \n",
"\n",
" # returns the feature tensor from the conv block\n",
" def _forward_features(self, x):\n",
" x = self.feature_extractor(x)\n",
" return x\n",
" \n",
"\n",
" # will be used during inference\n",
" def forward(self, x):\n",
" x = self._forward_features(x)\n",
" x = x.view(x.size(0), -1)\n",
" x = self.dropout(x)\n",
" x = self.classifier(x)\n",
" \n",
"\n",
" return x\n",
" \n",
"\n",
" def training_step(self, batch):\n",
" batch, gt = batch[0], batch[1]\n",
" out = self.forward(batch)\n",
Expand All @@ -240,7 +245,7 @@
" self.log(\"train/acc\", acc)\n",
"\n",
" return loss\n",
" \n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" batch, gt = batch[0], batch[1]\n",
" out = self.forward(batch)\n",
Expand All @@ -252,27 +257,29 @@
" self.log(\"val/acc\", acc)\n",
"\n",
" return loss\n",
" \n",
"\n",
" def test_step(self, batch, batch_idx):\n",
" batch, gt = batch[0], batch[1]\n",
" out = self.forward(batch)\n",
" loss = self.criterion(out, gt)\n",
" \n",
" return {\"loss\": loss, \"outputs\": out, \"gt\": gt}\n",
" \n",
" def test_epoch_end(self, outputs):\n",
" self.test_step_outputs.append({\"loss\": loss, \"outputs\": out, \"gt\": gt})\n",
" return out\n",
"\n",
" def on_test_epoch_end(self):\n",
" outputs = self.test_step_outputs\n",
" loss = torch.stack([x['loss'] for x in outputs]).mean()\n",
" output = torch.cat([x['outputs'] for x in outputs], dim=0)\n",
" \n",
"\n",
" gts = torch.cat([x['gt'] for x in outputs], dim=0)\n",
" \n",
"\n",
" self.log(\"test/loss\", loss)\n",
" acc = self.accuracy(output, gts)\n",
" self.log(\"test/acc\", acc)\n",
" \n",
"\n",
" self.test_gts = gts\n",
" self.test_output = output\n",
" \n",
" self.test_step_outputs.clear()\n",
"\n",
" def configure_optimizers(self):\n",
" return torch.optim.Adam(self.parameters(), lr=self.learning_rate)"
]
Expand All @@ -288,7 +295,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"To train the model, we instantiate the `StanfordCarsDataModule` and the `LitModel` along with the PyTorch Lightning Trainer. To the `Trainer`, we will pass the `WandbLogger` as the logger to use W&B to track the metrics during model training!"
"To train the model, we instantiate the `StanfordCarsDataModule` and the `LitModel` along with the PyTorch Lightning Trainer. To the `Trainer`, we will pass the `WandbLogger` as the logger to use W&B to track the metrics during model training!."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Model logging checkpointing\n",
"\n",
"We will use the `ModelCheckpoint` callback to save the best model. We can easily save our model alongside our wandb run as a model artifact by passing `log_model=True` to our `WandbLogger`. This will not only facilitate easy reproducibility of our models from the experiments but also allow us to reuse the model from [Weights & Biases Models](https://docs.wandb.ai/guides/models) during inference."
]
},
{
Expand All @@ -297,9 +313,17 @@
"metadata": {},
"outputs": [],
"source": [
"dm = StanfordCarsDataModule(batch_size=32)\n",
"dm = Flowers102DataModule(batch_size=32)\n",
"model = LitModel((3, 300, 300), 196, transfer=True)\n",
"trainer = pl.Trainer(logger=WandbLogger(project=\"TransferLearning\"), max_epochs=10, accelerator=\"gpu\")"
"checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor=\"val/acc\", mode=\"max\", save_top_k=3)\n",
"wandb_logger = WandbLogger(project=\"TransferLearning\", log_model=\"all\")\n",
"trainer = pl.Trainer(\n",
" max_epochs=25,\n",
" accelerator=\"gpu\",\n",
" log_every_n_steps=26,\n",
" callbacks=[checkpoint_callback],\n",
" logger=wandb_logger\n",
" )"
]
},
{
Expand All @@ -322,7 +346,35 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that the model is trained, let's see how it performs on the test set"
"## Model Evaluation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the best model using Weights & Biases Artifacts\n",
"\n",
"Although we ran our training for 15 epochs, we want to test against the best model. We can use the model artifact stored in wandb using the `wandb_logger` to retrieve the checkpoints for the best model and reload the model from it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"artifact_path = f'{wandb.run.entity}/{wandb.run.project}/model-{wandb.run.id}:best'\n",
"artifact = wandb_logger.use_artifact(artifact_path, artifact_type='model',)\n",
"artifact_dir = artifact.download()\n",
"model = LitModel.load_from_checkpoint(f\"{artifact_dir}/model.ckpt\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have our best model, let's see how it performs on the test set by evaluating the model on the test set."
]
},
{
Expand All @@ -347,14 +399,15 @@
"metadata": {},
"outputs": [],
"source": [
"display(wandb.run)\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The workspace generated to compare training the model from scratch vs using transfer learning is available [here](https://wandb.ai/manan-goel/StanfordCars). The conclusions that can be drawn from this are explained in detail in [this report](https://wandb.ai/wandb/wandb-lightning/reports/Transfer-Learning-Using-PyTorch-Lightning--VmlldzoyMzMxMzk4/edit)."
"The workspace generated to compare training the model from scratch vs using transfer learning is available [here](https://wandb.ai/parambharat/TransferLearning)"
]
},
{
Expand All @@ -363,7 +416,7 @@
"source": [
"## Conclusion\n",
"\n",
"I will encourage you to play with the code and train an image classifier with a dataset of your choice from scratch and using transfer learning. \n"
"We encourage you to play with the code and train an image classifier with a dataset of your choice from scratch and using transfer learning.\n"
]
},
{
Expand All @@ -385,9 +438,7 @@
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"provenance": [],
"toc_visible": true
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
Expand Down