-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
193 lines (138 loc) · 6.05 KB
/
utils.py
File metadata and controls
193 lines (138 loc) · 6.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
import importlib
def get_dataset_features(dataset, device='cuda', autoencoder=None,):
assert autoencoder is not None, "autoencoder must be provided"
autoencoder.eval()
autoencoder.to(device)
dataset_features = []
image_set = []
for (images, labels) in dataset:
images = images.to(device)
with torch.no_grad():
posterior = autoencoder.encode(images.unsqueeze(0))
z = posterior.mode()
z = z.view(z.shape[0], -1)
dataset_features.append(z.cpu().numpy())
dataset_features = np.concatenate(dataset_features, axis=0)
return dataset_features
def get_k_nearest_neighbor(image, k=5, device='cuda', dataset=None, dataset_features=None, autoencoder=None):
if dataset_features is None:
assert dataset is not None, "Either dataset_features or dataset must be provided"
dataset_features = get_dataset_features(dataset, device, autoencoder)
nn = NearestNeighbors(n_neighbors=k, metric='cosine')
nn.fit(dataset_features)
posterior = autoencoder.encode(image)
z = posterior.mode()
if z.dim() == 3:
z = z[None, :, :, :]
z = z.view(z.shape[0], -1)
z = z.cpu().numpy()
distance, index = nn.kneighbors(z)
nearest_neighbors = dataset[index][0]
return nearest_neighbors
def image_translation(
image: np.ndarray,
dims: int = 64,
translation: int = -1,
direction: list = ['down', 'up', 'left', 'right'],
**kwargs
):
assert image.shape[-1] >= dims and image.shape[-2] >= dims, "image must be of shape (c, h, w) where w >= dims and h >= dims"
h, w = image.shape[-2], image.shape[-1]
h_offset = (h - dims) // 2
w_offset = (w - dims) // 2
centre_dims = ((h_offset, w_offset), (h_offset + dims, w_offset + dims))
centre_img = image[:, centre_dims[0][0]:centre_dims[1][0], centre_dims[0][1]:centre_dims[1][1]]
if translation == -1:
return centre_img, None
else:
translated_images = []
if 'down' in direction:
down = image[:, centre_dims[0][0] + translation:centre_dims[1][0] + translation, centre_dims[0][1]:centre_dims[1][1]]
translated_images.append(down)
if 'up' in direction:
up = image[:, centre_dims[0][0] - translation:centre_dims[1][0] - translation, centre_dims[0][1]:centre_dims[1][1]]
translated_images.append(up)
if 'left' in direction:
left = image[:, centre_dims[0][0]:centre_dims[1][0], centre_dims[0][1] - translation:centre_dims[1][1] - translation]
translated_images.append(left)
if 'right' in direction:
right = image[:, centre_dims[0][0]:centre_dims[1][0], centre_dims[0][1] + translation:centre_dims[1][1] + translation]
translated_images.append(right)
return centre_img, translated_images
def image_rotation(
image: np.ndarray,
dims: int = 64,
angles: list = [90, 180, 270],
**kwargs
):
from scipy.ndimage import rotate
assert image.shape[-1] >= dims and image.shape[-2] >= dims, "image must be of shape (c, h, w) where w >= dims and h >= dims"
h, w = image.shape[-2], image.shape[-1]
h_offset = (h - dims) // 2
w_offset = (w - dims) // 2
centre_dims = ((h_offset, w_offset), (h_offset + dims, w_offset + dims))
rotated_images = []
for angle in angles:
# Rotate each channel separately
rotated_channels = []
for c in range(image.shape[0]):
rotated_channel = rotate(image[c], angle, reshape=False, order=1, mode='constant', cval=0)
rotated_channels.append(rotated_channel)
# Stack channels back
rotated_img = np.stack(rotated_channels, axis=0)
# Crop center region from rotated image
rotated_centre = rotated_img[:, centre_dims[0][0]:centre_dims[1][0], centre_dims[0][1]:centre_dims[1][1]]
rotated_images.append(rotated_centre)
return rotated_images
def butter_bandpass_filter(data, low, high, order):
from scipy.signal import butter, filtfilt
# Design the filter
b, a = butter(order, [low, high], btype='band')
# Apply the filter
filtered_data = filtfilt(b, a, data, axis=-1) # or appropriate axis
return filtered_data
def band_pass_filter(
image,
low=0.01,
high=0.25,
order=2,
normalize=True,
dims=64,
**kwargs,
):
original_image = image.copy()
# Normalize image to 0-1
image = (image + 1) / 2
# Apply band-pass filter
filtered_image = butter_bandpass_filter(image, low, high, order)
filtered_image = np.clip(filtered_image, 0, 1)
# Option 2: Normalize to [0,1] range (uncomment if you prefer this)
# filtered_image = (filtered_image - filtered_image.min()) / (filtered_image.max() - filtered_image.min())
# Normalize back to -1 to 1
filtered_image = (filtered_image - 0.5) * 2
# Get the centre of the image
h, w = image.shape[-2], image.shape[-1]
h_offset = (h - dims) // 2
w_offset = (w - dims) // 2
centre_dims = ((h_offset, w_offset), (h_offset + dims, w_offset + dims))
filtered_image = filtered_image[:, centre_dims[0][0]:centre_dims[1][0], centre_dims[0][1]:centre_dims[1][1]]
return [filtered_image], original_image - filtered_image
def visualize_image(image: np.ndarray, augmented_image: np.ndarray, save_path: str = "augmented_image.png"):
h, w = image.shape[-2], image.shape[-1]
h_a, w_a = augmented_image.shape[-2], augmented_image.shape[-1]
fig, axes = plt.subplots(1, 2, figsize=(10, 10))
image = image.transpose(1, 2, 0)
augmented_image = augmented_image.transpose(1, 2, 0)
image = (image + 1) / 2
augmented_image = (augmented_image + 1) / 2
axes[0].imshow(image)
axes[0].set_title("Original Image")
axes[0].axis("off")
axes[1].imshow(augmented_image)
axes[1].set_title("Augmented Image")
axes[1].axis("off")
plt.savefig(save_path)