forked from NaJaeMin92/pytorch-DANN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransformations.py
109 lines (85 loc) · 4.01 KB
/
transformations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
class RandomSingleColorReplaceBlack():
def __init__(self, p=1.0):
self.probability = p
def __call__(self, img_tensor):
if torch.rand(1).item() > self.probability:
return img_tensor
random_color = torch.rand(3, device=img_tensor.device)
black_pixels_mask = torch.all(img_tensor <= 0.01, dim=0)
for c in range(3):
img_tensor[c][black_pixels_mask] = random_color[c]
return img_tensor
class RandomSingleColorReplaceNonBlack():
"""Replaces non-black pixels with a single random color."""
def __init__(self, p=1.0):
self.probability = p
def __call__(self, img_tensor):
if torch.rand(1).item() > self.probability:
return img_tensor
random_color = torch.rand(3, device=img_tensor.device)
non_black_pixels_mask = torch.any(img_tensor > 0.01, dim=0)
for c in range(3):
img_tensor[c][non_black_pixels_mask] = random_color[c]
return img_tensor
class RandomSingleColorReplaceAll():
"""Replaces all pixels with a single random color, ensuring output is always 3 channels."""
def __init__(self, p=1.0):
self.probability = p
def __call__(self, img_tensor):
# img_tensor is expected to be a PyTorch tensor with shape [1, 28, 28]
# Convert to 3-channel if it's a single-channel grayscale image
if img_tensor.size(0) == 1:
img_tensor = img_tensor.repeat(3, 1, 1) # Convert to 3 channels [3, 28, 28]
if torch.rand(1).item() <= self.probability:
# Apply the random color replacement
black_replacement_color = torch.rand(3, device=img_tensor.device)
non_black_replacement_color = torch.rand(3, device=img_tensor.device)
# Create masks for black and non-black pixels
black_pixels_mask = torch.all(img_tensor <= 0.01, dim=0)
non_black_pixels_mask = ~black_pixels_mask
# Replace black and non-black pixels with the respective colors
for c in range(3):
img_tensor[c][black_pixels_mask] = black_replacement_color[c]
img_tensor[c][non_black_pixels_mask] = non_black_replacement_color[c]
return img_tensor
class RandomColorsReplaceBlack():
"""Replaces each black pixel with a unique random color."""
def __init__(self, p=1.0):
self.probability = p
def __call__(self, img_tensor):
if torch.rand(1).item() > self.probability:
return img_tensor
black_pixels_mask = torch.all(img_tensor <= 0.01, dim=0)
random_colors = torch.rand_like(img_tensor)
img_tensor[:, black_pixels_mask] = random_colors[:, black_pixels_mask]
return img_tensor
class RandomColorsReplaceNonBlack():
"""Replaces each non-black pixel with a unique random color."""
def __init__(self, p=1.0):
self.probability = p
def __call__(self, img_tensor):
if torch.rand(1).item() > self.probability:
return img_tensor
non_black_pixels_mask = torch.any(img_tensor > 0.01, dim=0)
random_colors = torch.rand_like(img_tensor)
img_tensor[:, non_black_pixels_mask] = random_colors[:, non_black_pixels_mask]
return img_tensor
def _normalize_tensor(self, img_tensor):
"""Normalize tensor to float and scale [0, 1] if not already."""
if img_tensor.dtype == torch.uint8:
img_tensor = img_tensor.float() / 255.0
return img_tensor
class Identity():
"""Replaces each non-black pixel with a unique random color."""
def __init__(self, p=1.0):
self.probability = p
def __call__(self, img_tensor):
if torch.rand(1).item() > self.probability:
return img_tensor
return img_tensor
def _normalize_tensor(self, img_tensor):
"""Normalize tensor to float and scale [0, 1] if not already."""
if img_tensor.dtype == torch.uint8:
img_tensor = img_tensor.float() / 255.0
return img_tensor