Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] #7894

Merged
merged 3 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(self, src_dir):
"plot_transforms_e2e.py",
"plot_cutmix_mixup.py",
"plot_custom_transforms.py",
"plot_datapoints.py",
"plot_custom_datapoints.py",
"plot_tv_tensors.py",
"plot_custom_tv_tensors.py",
]

def __call__(self, filename):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ architectures, and common image transformations for computer vision.
:caption: Package Reference

transforms
datapoints
tv_tensors
models
datasets
utils
Expand Down
8 changes: 4 additions & 4 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ tasks (image classification, detection, segmentation, video classification).
.. code:: python

# Detection (re-using imports and transforms from above)
from torchvision import datapoints
from torchvision import tv_tensors

img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
bboxes = torch.randint(0, H // 2, size=(3, 4))
bboxes[:, 2:] += bboxes[:, :2]
bboxes = datapoints.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W))
bboxes = tv_tensors.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W))

# The same transforms can be used!
img, bboxes = transforms(img, bboxes)
Expand Down Expand Up @@ -183,8 +183,8 @@ Transforms are available as classes like
This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`.

The functionals support PIL images, pure tensors, or :ref:`datapoints
<datapoints>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
The functionals support PIL images, pure tensors, or :ref:`tv_tensors
<tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
valid.

.. note::
Expand Down
12 changes: 6 additions & 6 deletions docs/source/datapoints.rst → docs/source/tv_tensors.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
.. _datapoints:
.. _tv_tensors:

Datapoints
TVTensors
==========

.. currentmodule:: torchvision.datapoints
.. currentmodule:: torchvision.tv_tensors

Datapoints are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
TVTensors are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
dispatch their inputs to the appropriate lower-level kernels. Most users do not
need to manipulate datapoints directly and can simply rely on dataset wrapping -
need to manipulate tv_tensors directly and can simply rely on dataset wrapping -
see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Occurrences like this "tv_tensor" one should probably become "TVTensor". But this can be done in a follow-up PR that cleans the docs


.. autosummary::
Expand All @@ -19,6 +19,6 @@ see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
BoundingBoxFormat
BoundingBoxes
Mask
Datapoint
TVTensor
set_return_type
wrap
4 changes: 2 additions & 2 deletions gallery/others/plot_video_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
print("PTS for first five frames ", ptss[:5])
print("Total number of frames: ", len(frames))
approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]
print("Approx total number of datapoints we can expect: ", approx_nf)
print("Approx total number of tv_tensors we can expect: ", approx_nf)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will revert

print("Read data size: ", frames[0].size(0) * len(frames))

# %%
Expand Down Expand Up @@ -170,7 +170,7 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au
return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()


# Total number of frames should be 327 for video and 523264 datapoints for audio
# Total number of frames should be 327 for video and 523264 tv_tensors for audio
vf, af, info, meta = example_read_video(video)
print(vf.size(), af.size())

Expand Down
4 changes: 2 additions & 2 deletions gallery/transforms/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F


Expand All @@ -22,7 +22,7 @@ def plot(imgs, row_title=None, **imshow_kwargs):
if isinstance(target, dict):
boxes = target.get("boxes")
masks = target.get("masks")
elif isinstance(target, datapoints.BoundingBoxes):
elif isinstance(target, tv_tensors.BoundingBoxes):
boxes = target
else:
raise ValueError(f"Unexpected target type: {type(target)}")
Expand Down
10 changes: 5 additions & 5 deletions gallery/transforms/plot_custom_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# %%
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import v2


Expand Down Expand Up @@ -62,7 +62,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured

H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = datapoints.BoundingBoxes(
bboxes = tv_tensors.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
Expand All @@ -74,9 +74,9 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
# %%
# .. note::
# While working with datapoint classes in your code, make sure to
# While working with tv_tensor classes in your code, make sure to
# familiarize yourself with this section:
# :ref:`datapoint_unwrapping_behaviour`
# :ref:`tv_tensor_unwrapping_behaviour`
#
# Supporting arbitrary input structures
# =====================================
Expand Down Expand Up @@ -111,7 +111,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all datapoints are
# based on the **class** of the entries, as all tv_tensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,62 +1,62 @@
"""
=====================================
How to write your own Datapoint class
How to write your own TVTensor class
=====================================

.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_datapoints.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_datapoints.py>` to download the full example code.
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.

This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
write your own tv_tensor class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`.
:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
"""

# %%
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import v2

# %%
# We will create a very simple class that just inherits from the base
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover
# :class:`~torchvision.tv_tensors.TVTensor` class. It will be enough to cover
# what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_.
# :class:`~torchvision.tv_tensors.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/tv_tensors/_bounding_box.py>`_.


class MyDatapoint(datapoints.Datapoint):
class MyTVTensor(tv_tensors.TVTensor):
pass


my_dp = MyDatapoint([1, 2, 3])
my_dp = MyTVTensor([1, 2, 3])
my_dp

# %%
# Now that we have defined our custom Datapoint class, we want it to be
# Now that we have defined our custom TVTensor class, we want it to be
# compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
#
# We illustrate this process below: we create a kernel for the "horizontal flip"
# operation of our MyDatapoint class, and register it to the functional API.
# operation of our MyTVTensor class, and register it to the functional API.

from torchvision.transforms.v2 import functional as F


@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint)
def hflip_my_datapoint(my_dp, *args, **kwargs):
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return datapoints.wrap(out, like=my_dp)
return tv_tensors.wrap(out, like=my_dp)


# %%
# To understand why :func:`~torchvision.datapoints.wrap` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# To understand why :func:`~torchvision.tv_tensors.wrap` is used, see
# :ref:`tv_tensor_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
# .. note::
Expand All @@ -67,9 +67,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# ``@register_kernel(functional=F.hflip, ...)``.
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:
# ``MyTVTensor`` instance:

my_dp = MyDatapoint(torch.rand(3, 256, 256))
my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)

# %%
Expand Down Expand Up @@ -102,10 +102,10 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as

def hflip_my_datapoint(my_dp): # noqa
def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return datapoints.wrap(out, like=my_dp)
return tv_tensors.wrap(out, like=my_dp)


# %%
Expand Down
6 changes: 3 additions & 3 deletions gallery/transforms/plot_transforms_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
import torch.utils.data

from torchvision import models, datasets, datapoints
from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2

torch.manual_seed(0)
Expand Down Expand Up @@ -72,7 +72,7 @@
# %%
# We used the ``target_keys`` parameter to specify the kind of output we're
# interested in. Our dataset now returns a target which is dict where the values
# are :ref:`Datapoints <what_are_datapoints>` (all are :class:`torch.Tensor`
# are :ref:`TVTensors <what_are_tv_tensors>` (all are :class:`torch.Tensor`
# subclasses). We're dropped all unncessary keys from the previous output, but
# if you need any of the original keys e.g. "image_id", you can still ask for
# it.
Expand Down Expand Up @@ -103,7 +103,7 @@
[
v2.ToImage(),
v2.RandomPhotometricDistort(p=1),
v2.RandomZoomOut(fill={datapoints.Image: (123, 117, 104), "others": 0}),
v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=1),
v2.SanitizeBoundingBoxes(),
Expand Down
Loading
Loading