Skip to content

Commit

Permalink
Redo torchscript example (#7889)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 29, 2023
1 parent 6472a5c commit 655ebdb
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 62 deletions.
10 changes: 7 additions & 3 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ Torchscript support
-------------------

Most transform classes and functionals support torchscript. For composing
transforms, use :class:`torch.nn.Sequential` instead of ``Compose``:
transforms, use :class:`torch.nn.Sequential` instead of
:class:`~torchvision.transforms.v2.Compose`:

.. code:: python
Expand All @@ -232,7 +233,7 @@ transforms, use :class:`torch.nn.Sequential` instead of ``Compose``:
scripted and eager executions due to implementation differences between v1
and v2.

If you really need torchscript support for the v2 tranforms, we recommend
If you really need torchscript support for the v2 transforms, we recommend
scripting the **functionals** from the
``torchvision.transforms.v2.functional`` namespace to avoid surprises.

Expand All @@ -242,7 +243,10 @@ are always treated as images. If you need torchscript support for other types
like bounding boxes or masks, you can rely on the :ref:`low-level kernels
<functional_transforms>`.

For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``.
For any custom transformations to be used with ``torch.jit.script``, they should
be derived from ``torch.nn.Module``.

See also: :ref:`sphx_glr_auto_examples_others_plot_scripted_tensor_transforms.py`.

V2 API reference - Recommended
------------------------------
Expand Down
109 changes: 50 additions & 59 deletions gallery/others/plot_scripted_tensor_transforms.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,77 @@
"""
=========================
Tensor transforms and JIT
=========================
===================
Torchscript support
===================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_scripted_tensor_transforms.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_scripted_tensor_transforms.py>` to download the full example code.
This example illustrates various features that are now supported by the
:ref:`image transformations <transforms>` on Tensor images. In particular, we
show how image transforms can be performed on GPU, and how one can also script
them using JIT compilation.
Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric
and presented multiple limitations due to that. Now, since v0.8.0, transforms
implementations are Tensor and PIL compatible, and we can achieve the following
new features:
- transform multi-band torch tensor images (with more than 3-4 channels)
- torchscript transforms together with your model for deployment
- support for GPU acceleration
- batched transformation such as for videos
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)
.. note::
These features are only possible with **Tensor** images.
This example illustrates `torchscript
<https://pytorch.org/docs/stable/jit.html>`_ support of the torchvision
:ref:`transforms <transforms>` on Tensor images.
"""

# %%
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T
from torchvision.io import read_image
import torch.nn as nn

import torchvision.transforms as v1
from torchvision.io import read_image

plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)


def show(imgs):
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = T.ToPILImage()(img.to('cpu'))
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
import sys
sys.path += ["../transforms"]
from helpers import plot
ASSETS_PATH = Path('../assets')


# %%
# The :func:`~torchvision.io.read_image` function allows to read an image and
# directly load it as a tensor
# Most transforms support torchscript. For composing transforms, we use
# :class:`torch.nn.Sequential` instead of
# :class:`~torchvision.transforms.v2.Compose`:

dog1 = read_image(str(Path('../assets') / 'dog1.jpg'))
dog2 = read_image(str(Path('../assets') / 'dog2.jpg'))
show([dog1, dog2])

# %%
# Transforming images on GPU
# --------------------------
# Most transforms natively support tensors on top of PIL images (to visualize
# the effect of the transforms, you may refer to see
# :ref:`sphx_glr_auto_examples_transforms_plot_transforms_illustrations.py`).
# Using tensor images, we can run the transforms on GPUs if cuda is available!

import torch.nn as nn
dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg'))

transforms = torch.nn.Sequential(
T.RandomCrop(224),
T.RandomHorizontalFlip(p=0.3),
v1.RandomCrop(224),
v1.RandomHorizontalFlip(p=0.3),
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dog1 = dog1.to(device)
dog2 = dog2.to(device)
scripted_transforms = torch.jit.script(transforms)

plot([dog1, scripted_transforms(dog1), dog2, scripted_transforms(dog2)])

transformed_dog1 = transforms(dog1)
transformed_dog2 = transforms(dog2)
show([transformed_dog1, transformed_dog2])

# %%
# Scriptable transforms for easier deployment via torchscript
# -----------------------------------------------------------
# We now show how to combine image transformations and a model forward pass,
# while using ``torch.jit.script`` to obtain a single scripted module.
# .. warning::
#
# Above we have used transforms from the ``torchvision.transforms``
# namespace, i.e. the "v1" transforms. The v2 transforms from the
# ``torchvision.transforms.v2`` namespace are the :ref:`recommended
# <v1_or_v2>` way to use transforms in your code.
#
# The v2 transforms also support torchscript, but if you call
# ``torch.jit.script()`` on a v2 **class** transform, you'll actually end up
# with its (scripted) v1 equivalent. This may lead to slightly different
# results between the scripted and eager executions due to implementation
# differences between v1 and v2.
#
# If you really need torchscript support for the v2 transforms, **we
# recommend scripting the functionals** from the
# ``torchvision.transforms.v2.functional`` namespace to avoid surprises.
#
# Below we now show how to combine image transformations and a model forward
# pass, while using ``torch.jit.script`` to obtain a single scripted module.
#
# Let's define a ``Predictor`` module that transforms the input tensor and then
# applies an ImageNet model on it.
Expand All @@ -98,7 +85,7 @@ def __init__(self):
super().__init__()
weights = ResNet18_Weights.DEFAULT
self.resnet18 = resnet18(weights=weights, progress=False).eval()
self.transforms = weights.transforms()
self.transforms = weights.transforms(antialias=True)

def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
Expand All @@ -111,6 +98,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Now, let's define scripted and non-scripted instances of ``Predictor`` and
# apply it on multiple tensor images of the same size

device = "cuda" if torch.cuda.is_available() else "cpu"

predictor = Predictor().to(device)
scripted_predictor = torch.jit.script(predictor).to(device)

Expand Down Expand Up @@ -143,3 +132,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
dumped_scripted_predictor = torch.jit.load(f.name)
res_scripted_dumped = dumped_scripted_predictor(batch)
assert (res_scripted_dumped == res_scripted).all()

# %%
1 change: 1 addition & 0 deletions gallery/transforms/plot_transforms_getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)

# sphinx_gallery_thumbnail_number = 4
plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
print(f"{out_target['this_is_ignored']}")

Expand Down

0 comments on commit 655ebdb

Please sign in to comment.