Skip to content

Commit

Permalink
Merge branch 'main' into hmdb51-output-format
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Aug 24, 2023
2 parents fc07a1c + b82d883 commit af781f9
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
7 changes: 3 additions & 4 deletions gallery/v2_transforms/plot_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,9 @@
format="XYXY", canvas_size=img.shape[-2:])

transforms = v2.Compose([
v2.RandomPhotometricDistort(),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=0.5),
v2.SanitizeBoundingBoxes(),
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomPhotometricDistort(p=1),
v2.RandomHorizontalFlip(p=1),
])
out_img, out_bboxes = transforms(img, bboxes)

Expand Down
14 changes: 14 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,20 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_labels.tolist() == valid_indices


def test_sanitize_bounding_boxes_no_label():
# Non-regression test for https://github.com/pytorch/vision/issues/7878

img = make_image()
boxes = make_bounding_boxes()

with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
transforms.SanitizeBoundingBoxes()(img, boxes)

out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes)
assert isinstance(out_img, datapoints.Image)
assert isinstance(out_boxes, datapoints.BoundingBoxes)


def test_sanitize_bounding_boxes_errors():

good_bbox = datapoints.BoundingBoxes(
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class MixUp(_BaseMixUpCutMix):
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
Expand Down Expand Up @@ -279,7 +279,7 @@ class CutMix(_BaseMixUpCutMix):
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
inputs = inputs[1]

# MixUp, CutMix
if isinstance(inputs, torch.Tensor):
if is_pure_tensor(inputs):
return inputs

if not isinstance(inputs, collections.abc.Mapping):
Expand Down

0 comments on commit af781f9

Please sign in to comment.