-
Notifications
You must be signed in to change notification settings - Fork 0
/
generated_midi_distance_aimc.py
263 lines (211 loc) · 11.9 KB
/
generated_midi_distance_aimc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import copy
import itertools
import numpy
import pickle
from itertools import combinations
import seaborn as sns
import numpy as np
import torch
from matplotlib import pyplot as plt
from scipy.spatial.distance import pdist, cdist
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from DDPM.latent_diffusion import LatentDiffusion
from DDPM.main_latent_space import load_or_process_dataset, load_config
from DDPM.model import ConditionalUNet
from Midi_Encoder.model import EncoderDecoder
def get_models(ddpm_model_path, ae_model_path):
device = "cuda" if torch.cuda.is_available() else "cpu"
ddpm_model = ConditionalUNet(time_encoding_dim=16).to(device)
model_state_path = ddpm_model_path
if torch.cuda.is_available():
ddpm_model.load_state_dict(torch.load(model_state_path))
else:
ddpm_model.load_state_dict(torch.load(model_state_path, map_location=torch.device('cpu')))
ddpm_model.eval()
autoencoder_config_path = "Midi_Encoder/config.yaml"
autoencoder_model_path = ae_model_path
midi_encoder_decoder = EncoderDecoder(autoencoder_config_path).to(device)
midi_encoder_decoder.eval()
if torch.cuda.is_available():
midi_encoder_decoder.load_state_dict(torch.load(autoencoder_model_path))
else:
midi_encoder_decoder.load_state_dict(torch.load(autoencoder_model_path, map_location=torch.device('cpu')))
return ddpm_model, midi_encoder_decoder
def compute_stats(distances):
return np.mean(distances), np.std(distances), np.min(distances), np.max(distances)
def compute_midi_distance(batch_of_midi_files):
distance_between_midis = []
# Flatten the last two dimensions (128x9) of each MIDI file for distance calculation
flattened_midis = batch_of_midi_files.reshape(batch_of_midi_files.shape[0], -1)
# Iterate over all unique pairs of MIDI files
for i, j in combinations(range(len(flattened_midis)), 2):
# Compute the mean Euclidean distance between the flattened MIDI representations
distance = np.linalg.norm(flattened_midis[i] - flattened_midis[j])/flattened_midis.shape[-1]
# Add the computed distance to the list
distance_between_midis.append(distance)
return distance_between_midis
def compute_midi_hamming_distance(batch_of_midi_files):
distance_between_midis = []
# Flatten the last two dimensions (128x9) of each MIDI file for distance calculation
flattened_midis = batch_of_midi_files.reshape(batch_of_midi_files.shape[0], -1)
# Iterate over all unique pairs of MIDI files
for i, j in combinations(range(len(flattened_midis)), 2):
# Compute the mean Euclidean distance between the flattened MIDI representations
distance = torch.sum(torch.logical_xor(flattened_midis[i], flattened_midis[j]))
# Add the computed distance to the list
distance_between_midis.append(distance)
return distance_between_midis
def compute_midi_distance_between(midi_pianorolls_one, midi_pianorolls_two):
all_distances = []
# Flatten the last two dimensions (128x9) of each MIDI file for distance calculation
flattened_one = midi_pianorolls_one.reshape(midi_pianorolls_one.shape[0], -1)
flattened_two = midi_pianorolls_two.reshape(midi_pianorolls_two.shape[0], -1)
# Compute pairwise distances between every midi in set one and every midi in set two
for midi_one in flattened_one:
for midi_two in flattened_two:
distance = np.linalg.norm(midi_one - midi_two)/midi_one.shape[-1]
all_distances.append(distance)
# Sort distances to find the 1% smallest ones
all_distances.sort()
one_percent_index = int(len(all_distances) * 0.01)
one_percent_least_distances = all_distances[:max(1, one_percent_index)] # Ensure at least one distance is included
return one_percent_least_distances
def compute_midi_hamming_distance_between(midi_pianorolls_one, midi_pianorolls_two):
all_distances = []
# Flatten the last two dimensions (128x9) of each MIDI file for distance calculation
flattened_one = midi_pianorolls_one.reshape(midi_pianorolls_one.shape[0], -1).detach().numpy()
flattened_two = midi_pianorolls_two.reshape(midi_pianorolls_two.shape[0], -1)
# Compute pairwise distances between every midi in set one and every midi in set two
for midi_one in flattened_one:
for midi_two in flattened_two:
distance = np.sum(np.logical_xor(midi_one, midi_two))
all_distances.append(distance)
# Sort distances to find the 1% smallest ones
all_distances.sort()
one_percent_index = int(len(all_distances) * 0.01)
one_percent_least_distances = all_distances[:max(1, one_percent_index)] # Ensure at least one distance is included
return one_percent_least_distances
def compute_1_percent_least_distances(generated_midi_embeddings, dataset_midi_embeddings):
# Compute the pairwise Euclidean distances between generated MIDIs and dataset MIDIs
distances = cdist(generated_midi_embeddings, dataset_midi_embeddings, 'euclidean')
# Flatten the distance matrix to sort all distances together
flattened_distances = distances.flatten()
# Sort the flattened array of distances
sorted_distances = np.sort(flattened_distances)
# Select the 1% smallest distances
one_percent_size = int(len(sorted_distances) * 0.01)
one_percent_least_distances = sorted_distances[
:max(one_percent_size, 1)] # Ensure at least one distance is included
# The final list should be of length 1134 if the total number of distances is 113400
return one_percent_least_distances
def create_and_save_combined_plot(distances_dict, space):
"""
Creates a combined KDE plot for distances from multiple models, indicating the distance measure,
and saves the figure.
Parameters:
- distances_dict: A dictionary where keys are model names and values are the distances arrays.
- space: The name of the space (e.g., 'midi', 'latent') for labeling and saving the plot.
"""
plt.figure(figsize=(10, 6))
distance_measure = "Hamming" if "midi" in space else "Euclidean"
for model_name, distances in distances_dict.items():
sns.kdeplot(distances, bw_adjust=0.5, label=model_name)
plt.title(f'{space.capitalize()} Space - Distance Distribution\n(Distance Measure: {distance_measure})')
plt.xlabel('Distance')
plt.ylabel('Density')
plt.legend(title='Model')
plt.savefig(f"AIMC results/Combined_{space}_distance_distribution.png")
plt.show()
plt.close()
diffusion = LatentDiffusion(latent_dimension=128)
prompts = ['latin triplet', '4-4 electronic', 'funky 16th', 'rock fill 8th',
'blues shuffle', 'pop ride', 'funky blues', 'latin rock']
low_noise_ddpm, low_noise_ae = get_models("AIMC results/Base Model Results/ddpm_model/model_final.pth",
"AIMC results/Base Model Results/enc_dec_model/final_model.pt")
high_noise_ddpm, high_noise_ae = get_models("AIMC results/High Noise/ddpm_model/model_final.pth",
"AIMC results/High Noise/enc_dec_model/final_model.pt")
no_noise_ddpm, no_noise_ae = get_models("AIMC results/No Noise/ddpm_model/model_final.pth",
"AIMC results/No Noise/enc_dec_model/final_model.pt")
models = [(low_noise_ddpm, low_noise_ae), (high_noise_ddpm, high_noise_ae), (no_noise_ddpm, no_noise_ae)]
config_path = 'DDPM/config.yaml'
config = load_config(config_path)
train_dataset = load_or_process_dataset(dataset_dir=config['dataset_dir'])
# subset_size = 100 # For example, to use only 100 samples from your dataset
# train_dataset = Subset(train_dataset, list(range(subset_size)))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
total_stats = {}
dataset_midi_space = []
for drum_beats, _ in tqdm(train_loader):
dataset_midi_space.extend(drum_beats.numpy())
dataset_midi_space = (np.array(dataset_midi_space) * 255).astype(int)
dataset_midi_space[dataset_midi_space <= 5] = 0
dataset_midi_space[dataset_midi_space > 5] = 1
for model_name, (ddpm_model, enc_dec_model) in zip(["Low Noise", "High Noise", "No Noise"], models):
model_stats = {
"midi": [],
"latent": [],
"dataset_midi": [],
"dataset_latent": []
}
# Compute distances from the dataset in both spaces
# This would involve comparing each of the 10 generated samples against the entire dataset
# Assuming there's a method to get the entire dataset's embeddings in latent space for this comparison
dataset_latent_embeddings = []
# For MIDI space, assuming a function that converts dataset MIDI to the same space as sampled_midi
for drum_beats, _ in tqdm(train_loader):
dataset_batch_z = enc_dec_model.encode(drum_beats)
dataset_latent_embeddings.extend(dataset_batch_z.detach().numpy())
dataset_latent_embeddings = np.array(dataset_latent_embeddings)
for prompt in prompts:
prompt_repeated = [prompt] * 10 # Copy the same prompt 10 times in a list
# Generate 10 MIDI files
# Assume diffusion.sample_conditional returns a batch of generated MIDI files based on the prompt
sampled_midi = diffusion.sample_conditional(ddpm_model, n=10, text_keywords=prompt_repeated, midi_decoder=enc_dec_model)
sampled_midi_binary = copy.deepcopy(sampled_midi)
sampled_midi_binary[sampled_midi_binary <= 5] = 0
sampled_midi_binary[sampled_midi_binary > 5] = 1
# Compute distances in the MIDI space
# Assuming sampled_midi is in the correct format to compute distances directly
distances_midi = compute_midi_hamming_distance(sampled_midi_binary)
model_stats["midi"].append(distances_midi)
# Embed the MIDI to get 10 Z vectors in the latent space
z = enc_dec_model.encode(sampled_midi.permute(0, 2, 1)/255).detach().numpy()
distances_latent = pdist(z, 'euclidean')
model_stats["latent"].append(distances_latent)
# Compute the distance from all the dataset and choose the top 1% of datapoints for analysis
distances_from_dataset_midi = compute_midi_hamming_distance_between(sampled_midi_binary, dataset_midi_space)
distances_from_dataset_latent = compute_1_percent_least_distances(z, dataset_latent_embeddings)
# Assuming methods to compute stats for these selected top 1% distances
model_stats["dataset_midi"].append(distances_from_dataset_midi)
model_stats["dataset_latent"].append(distances_from_dataset_latent)
# Compute and publish stats for the model
with open("AIMC results/distance_results_hamming.txt", "a") as file:
for space in ["midi", "latent", "dataset_midi", "dataset_latent"]:
distances = np.concatenate(model_stats[space])
mean, std, min_dis, max_dis = compute_stats(distances)
output_text = f"{model_name} {space.capitalize()} Space - Mean distance: {mean:.2f}, Std: {std:.2f}, Min: {min_dis:.2f}, Max: {max_dis:.2f}\n"
# Print to console
print(output_text.strip())
# Write to file
file.write(output_text)
# Add model stats to total_stats for use later
total_stats[model_name] = model_stats
stats_by_space = {
"midi": {},
"latent": {},
"dataset_midi": {},
"dataset_latent": {}
}
for _k_model, _v_model in total_stats.items():
for _k_space, _v_space in _v_model.items():
stats_by_space[_k_space][_k_model] = _v_space
for k in stats_by_space.keys():
for _key_space, _value_space in stats_by_space[k].items():
stats_by_space[k][_key_space] = list(itertools.chain.from_iterable(_value_space))
stats_by_space[k][_key_space] = list(map(lambda x: x if type(x) is int or type(x) is float else x.item(), stats_by_space[k][_key_space]))
for space_name in stats_by_space.keys():
create_and_save_combined_plot(stats_by_space[space_name], space_name)
# Save the total_stats in a file
with open("AIMC results/distance_stats_hamming.pickle", "wb") as f:
pickle.dump(total_stats, f)