Skip to content

Commit

Permalink
Make transforms illutration example use v2 instead of v1 (#7886)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 25, 2023
1 parent 7ebc3ad commit 47cd5ea
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 90 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, src_dir):

transforms_subsection_order = [
"plot_transforms_getting_started.py",
"plot_transforms_illustrations.py",
"plot_transforms_e2e.py",
"plot_cutmix_mixup.py",
"plot_custom_transforms.py",
Expand Down
2 changes: 1 addition & 1 deletion gallery/others/plot_scripted_tensor_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def show(imgs):
# --------------------------
# Most transforms natively support tensors on top of PIL images (to visualize
# the effect of the transforms, you may refer to see
# :ref:`sphx_glr_auto_examples_others_plot_transforms.py`).
# :ref:`sphx_glr_auto_examples_transforms_plot_transforms_illustrations.py`).
# Using tensor images, we can run the transforms on GPUs if cuda is available!

import torch.nn as nn
Expand Down
8 changes: 6 additions & 2 deletions gallery/transforms/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torchvision.transforms.v2 import functional as F


def plot(imgs):
def plot(imgs, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
Expand Down Expand Up @@ -40,7 +40,11 @@ def plot(imgs):
img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)

ax = axs[row_idx, col_idx]
ax.imshow(img.permute(1, 2, 0).numpy())
ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])

plt.tight_layout()
Loading

0 comments on commit 47cd5ea

Please sign in to comment.