Skip to content

Commit

Permalink
Replaced ConvertImageDtype by ToDtype in reference scripts (#7862)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
vfdev-5 and NicolasHug authored Aug 24, 2023
1 parent 4491ca2 commit 9f0afd5
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 11 deletions.
4 changes: 2 additions & 2 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(

transforms.extend(
[
T.ConvertImageDtype(torch.float),
T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
transforms.append(T.PILToTensor())

transforms += [
T.ConvertImageDtype(torch.float),
T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]

Expand Down
4 changes: 2 additions & 2 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
# Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]

transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.ToDtype(torch.float, scale=True)]

if use_v2:
transforms += [
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(self, backend="pil", use_v2=False):
else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.ToDtype(torch.float, scale=True)]

if use_v2:
transforms += [T.ToPureTensor()]
Expand Down
7 changes: 5 additions & 2 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ def forward(
return image, target


class ConvertImageDtype(nn.Module):
def __init__(self, dtype: torch.dtype) -> None:
class ToDtype(nn.Module):
def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
super().__init__()
self.dtype = dtype
self.scale = scale

def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype)
return image, target

Expand Down
4 changes: 2 additions & 2 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
]
else:
# No need to explicitly convert masks as they're magically int64 already
transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.ToDtype(torch.float, scale=True)]

transforms += [T.Normalize(mean=mean, std=std)]
if use_v2:
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]

transforms += [
T.ConvertImageDtype(torch.float),
T.ToDtype(torch.float, scale=True),
T.Normalize(mean=mean, std=std),
]
if use_v2:
Expand Down
7 changes: 5 additions & 2 deletions references/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,14 @@ def __call__(self, image, target):
return image, target


class ConvertImageDtype:
def __init__(self, dtype):
class ToDtype:
def __init__(self, dtype, scale=False):
self.dtype = dtype
self.scale = scale

def __call__(self, image, target):
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype)
return image, target

Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,6 @@ def _coco_detection_masks_to_voc_segmentation_mask(self, target):
def forward(self, image, target):
segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target)
if segmentation_mask is None:
segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8)
segmentation_mask = torch.zeros(v2.functional.get_size(image), dtype=torch.uint8)

return image, datapoints.Mask(segmentation_mask)

0 comments on commit 9f0afd5

Please sign in to comment.