forked from chendaichao/Hilbert-Huang-transform
-
Notifications
You must be signed in to change notification settings - Fork 1
/
visualization.py
170 lines (148 loc) · 7.02 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
import matplotlib.pyplot as plt
def plot_IMFs(x, imfs, fs, index = None, align_yticks = True, time_scale = 1, time_range = None, save_fig = None, title = None):
'''
Visualize the IMFs extracted from the original signal by empirical model decomposition.
Parameters:
------------
x (Tensor) :
Signal data. The last dimension of `x` is considered as time.
imfs (Tensor) :
IMFs obtained from `emd`.
fs (real) :
Sampling frequencies in Hz.
index (list of integers, Optional) :
The indices of the IMF component to be shown.
align_yticks (bool, optional) :
Whether to align the y-axises of all IMFs.
Default: True.
time_scale (int, optional) :
The scale for the time axis for ploting.
time_range (an (L, R)-tuple, Optional) :
The range of time axis to be shown.
save_fig (list of string or string, Optional) :
Save the image as `save_fig` if specified.
title (str, optional) :
Specifying the title of the figure.
Returns: None
'''
assert(type(time_scale) is int)
assert( (save_fig is None) or (type(save_fig) in [list, str]) )
all_x = torch.as_tensor(x).view(-1, x.shape[-1])
all_imfs = torch.as_tensor(imfs).view(-1, imfs.shape[-2], imfs.shape[-1])
for batch_idx in range(all_x.shape[0]):
imfs, x = all_imfs[batch_idx], all_x[batch_idx]
# scaling the time axis
if (time_scale > 1):
N = x.shape[0]
fs /= time_scale
M = N // time_scale
x = x[:M*time_scale].view(M, time_scale).mean(dim = 1)
imfs = imfs[:, :M*time_scale].view(imfs.shape[0], M, time_scale).mean(dim = 2)
# the IMF indexes to be shown
num_imfs = len(imfs)
if (index is None):
index = list(range(num_imfs))
else:
num_imfs = len(index)
# the time range of interest
t = torch.arange(x.shape[0], dtype = torch.double) / fs
if (time_range):
L, R = time_range
L, R = min(int(L * fs), len(t)-1), min(int(R * fs)+1, len(t))
t = t[L:R]
# initialize pyplot
plt.figure(figsize=(10, num_imfs+2))
plt.subplots_adjust(hspace = 0.3, left = 0.3)
# plot signals
ax = plt.subplot(num_imfs+2, 1, 1)
scale = max(torch.abs(imfs[:, L:R] if time_range else imfs).max(), (x[L:R] if time_range else x).max())
scale = scale.cpu()
ax.plot(t, x[L:R].cpu() if time_range else x.cpu(), linewidth = 1)
ax.set_title("Empirical Mode Decomposition" if title is None else title)
ax.set_ylabel("signal")
ax.set_ylim(-scale, scale)
ax.set_xticks([])
# plot IMFs
for i_ in range(len(index)):
i = index[i_]
imf = imfs[i]
if not align_yticks:
scale = torch.abs(imf[L:R] if time_range else imf).max().cpu()
ax = plt.subplot(num_imfs+2, 1, i_+2)
ax.plot(t, imf[L:R].cpu() if time_range else imf.cpu(), linewidth = 1)
ax.set_ylim(-scale, scale)
ax.set_ylabel("IMF %d" % (i))
ax.set_xticks([])
# plot the residual
x = x - imfs.sum(axis = 0)
if not align_yticks:
scale = torch.abs(x[L:R] if time_range else x).max().cpu()
ax = plt.subplot(num_imfs+2, 1, num_imfs+2)
ax.plot(t, x[L:R].cpu() if time_range else x.cpu(), linewidth = 1)
ax.set_ylabel("residual")
ax.set_ylim(-scale, scale)
ax.set_xlabel("time")
if (save_fig is not None):
plt.savefig(save_fig if type(save_fig) is str else save_fig[batch_idx])
plt.show()
def plot_HilbertSpectrum(spectrum, time_axis, freq_axis, energy_scale = "log", save_spectrum = None,
save_marginal = None, title = None):
'''
Visualize the Hilbert spectrum, by plotting all the IMFs on a time-frequency plane.
Parameters:
------------
spectrum : Tensor, of shape ( ..., # time_bins, # freq_bins ).
A pytorch tensor, representing the Hilbert spectrum H(t, f).
The tensor will be on the same device as `imfs_env` and `imfs_freq`.
time_axis : Tensor, 1D, of shape ( # time_bins )
The label for the time axis of the spectrum.
freq_axis : Tensor, 1D, of shape ( # freq_bins )
The label for the frequency axis (in `freq_unit`) of the spectrum.
energy_scale : string ('linear' or 'log')
Specifying whether to visualize the energy in linear/log scale.
save_spectrum : string or list of string, Optional.
If specified, the Hilbert spectrum will be saved as a file named `save_spectrum`.
save_maeginal : string or list of string, Optional.
If specified, the Hilbert marginal spectrum will be saved as a file named `save_marginal`.
title : string, optional.
Specifying the title of the figure.
'''
spectrum = spectrum.view(-1, spectrum.shape[-2], spectrum.shape[-1])
time_axis = time_axis.cpu()
freq_axis = freq_axis.cpu()
for batch_idx in range(spectrum.shape[0]):
if (energy_scale == "log"):
eps = spectrum[batch_idx, :, :].max() * (1e-5)
coutour = plt.pcolormesh(time_axis.cpu(), freq_axis.cpu(),
10*torch.log10(spectrum[batch_idx, :, :].T + eps).cpu(),
shading='auto',
cmap = plt.cm.jet)
plt.colorbar(coutour, label = "energy (dB)")
else:
coutour = plt.pcolormesh(time_axis.cpu(), freq_axis.cpu(),
spectrum[batch_idx, :, :].T.cpu(),
shading='auto',
cmap = plt.cm.jet)
plt.colorbar(coutour, label = "energy")
plt.xlabel("time")
plt.ylabel("frequency (Hz)")
plt.title("Hilbert spectrum" if title is None else title)
if (save_spectrum is not None):
plt.savefig(save_spectrum if type(save_spectrum) is str else save_spectrum[batch_idx], dpi = 600)
plt.show()
if (save_marginal is None) :
return
marginal = spectrum.sum(dim = -2)
for batch_idx in range(spectrum.shape[0]):
if (energy_scale == "log"):
eps = marginal[batch_idx, :].max() * (1e-5)
plt.plot(freq_axis.cpu(), 10*torch.log10(marginal[batch_idx, :].cpu() + eps) )
plt.ylabel("energy (dB)")
else:
plt.plot(freq_axis.cpu(), marginal[batch_idx, :].cpu())
plt.ylabel("energy")
plt.xlabel("frequency (Hz)")
plt.title("Marginal spectrum" if title is None else title + " (marginal)")
plt.savefig(save_marginal if type(save_marginal) is str else save_marginal[batch_idx], dpi = 600)
plt.show()