Skip to content

Commit 8d3bbf5

Browse files
fixed generator bug
1 parent 2fb4c21 commit 8d3bbf5

6 files changed

+23
-19
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,5 @@ data/DCASE18-Task5-development/*
223223
data/DCASE18-Task5-evaluation/*
224224
data/dj-set/*
225225
data/dj-set/MusicDataset/*
226+
*.mp3
227+
*.wav

src/input_pipeline/music_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def fill_opposite_selection(self, audio_id):
9797
def get_triplets(self, audio_id, audio_length, opposite_choices, trim: bool = True) -> np.ndarray:
9898
try:
9999
triplets = []
100-
for anchor_id in range(0, audio_length, self.sample_tile_size):
100+
for anchor_id in range(0, audio_length - self.sample_tile_size, self.sample_tile_size):
101101
a_seg = [audio_id, anchor_id]
102102
n_seg = self.get_neighbour(audio_id, anchor_sample_id=anchor_id, audio_length=audio_length)
103103
o_seg = self.get_opposite(audio_id, anchor_sample_id=anchor_id, audio_length=audio_length,

src/input_pipeline/triplet_input_pipeline.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
2+
import re
23
from typing import Union, Tuple
34

45
import librosa
56
import numpy as np
67
import tensorflow as tf
8+
import warnings
79

810
from src.feature_extractor.base_extractor import BaseExtractor
911
from src.input_pipeline.base_dataset import BaseDataset, DatasetType
@@ -24,6 +26,8 @@ def __init__(self,
2426
:param params: parameters of the current experiment.
2527
:param log: if the pipeline should log details about the data.
2628
"""
29+
# ignore warnings, such as the librosa warnings
30+
warnings.filterwarnings('ignore')
2731

2832
self.dataset_path = Utils.check_if_path_exists(params.dcase_dataset_path)
2933
self.dataset_name = params.dataset
@@ -32,7 +36,6 @@ def __init__(self,
3236

3337
self.num_parallel_calls = params.num_parallel_calls
3438
self.gen_count = params.gen_count
35-
self.gen_index = 0
3639

3740
self.sample_rate = params.sample_rate
3841
self.sample_size = params.sample_size
@@ -66,25 +69,25 @@ def __init__(self,
6669

6770
def reinitialise(self):
6871
self.logger.info("Reinitialising the input pipeline")
69-
self.gen_index = 0
7072
self.dataset.initialise()
7173

7274
def generate_samples(self, gen_name: str, trim: bool, return_labels: bool) -> Tuple[np.ndarray, np.ndarray,
7375
np.ndarray, np.ndarray]:
7476

7577
gen_name = gen_name.decode("utf-8")
78+
gen_index = int(re.findall('[0-9]+', gen_name)[0])
7679

77-
self.dataset.current_index = self.gen_index
78-
for index, anchor in enumerate(self.dataset):
79-
self.dataset.current_index = self.gen_index
80-
if self.log and False:
81-
self.logger.debug("{0}, index:{1}".format(gen_name, self.gen_index))
80+
self.dataset.current_index = gen_index
81+
for anchor in self.dataset:
82+
current_index = self.dataset.current_index - 1
83+
if self.log:
84+
self.logger.debug("{0}, {1}, index:{2}".format(gen_name, gen_index, current_index))
8285

8386
# fill the opposite sample buffer
84-
opposite_audios = self.dataset.fill_opposite_selection(index)
87+
opposite_audios = self.dataset.fill_opposite_selection(current_index)
8588

8689
# load audio files from anchor
87-
anchor = self.dataset.df_train.iloc[index]
90+
anchor = self.dataset.df.iloc[current_index]
8891
if self.dataset_name == "MusicDataset":
8992
anchor_audio, _ = librosa.load(anchor.file_name, self.sample_rate)
9093
anchor_audio, _ = librosa.effects.trim(anchor_audio)
@@ -96,7 +99,7 @@ def generate_samples(self, gen_name: str, trim: bool, return_labels: bool) -> Tu
9699
anchor_audio_length = int(len(anchor_audio) / self.sample_rate)
97100

98101
try:
99-
triplets = self.dataset.get_triplets(index, anchor_audio_length, trim=trim,
102+
triplets = self.dataset.get_triplets(current_index, anchor_audio_length, trim=trim,
100103
opposite_choices=opposite_audios)
101104
except ValueError as err:
102105
self.logger.debug("Error during triplet creation: {}".format(err))
@@ -122,17 +125,15 @@ def generate_samples(self, gen_name: str, trim: bool, return_labels: bool) -> Tu
122125
labels = [-1, -1, -1]
123126
labels = np.asarray(labels)
124127

125-
if self.gen_index % 1000 == 0 and self.gen_index is not 0 and self.log:
128+
if current_index % 1000 == 0 and current_index is not 0 and self.log:
126129
self.logger.debug("{0} yields sound segments {1}, a: {2}, n: {3}, o: {4}".format(gen_name,
127-
self.dataset.current_index,
130+
current_index,
128131
anchor_seg,
129132
neighbour_seg,
130133
opposite_seg))
131134

132135
yield anchor_audio_seg, neighbour_audio_seg, opposite_audio_seg, labels
133136

134-
self.gen_index += 1
135-
136137
def get_dataset(self, feature_extractor: Union[BaseExtractor, None], dataset_type: DatasetType = DatasetType.TRAIN,
137138
shuffle: bool = True, trim: bool = True, return_labels: bool = False):
138139

src/train_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def evaluate():
160160
model_name=experiment_name)
161161

162162
# set logger
163-
Utils.set_logger(log_path, params.log_level)
163+
Utils.set_logger(__name__, log_path, params.log_level)
164164
logger = logging.getLogger("Main ({})".format(params.experiment_name))
165165

166166
# set the folder for the summary writer

src/train_triplet_loss.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def main():
9595
saved_model_path=params.saved_model_path)
9696

9797
# set logger
98-
Utils.set_logger(log_path, params.log_level)
98+
Utils.set_logger(__name__, log_path, params.log_level)
9999
logger = logging.getLogger("Main ({})".format(params.experiment_name))
100100

101101
# print params

src/utils/utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,18 @@ def check_if_path_exists(path: Union[str, pathlib.Path]):
4747
return path
4848

4949
@staticmethod
50-
def set_logger(log_path, log_level: str = "INFO"):
50+
def set_logger(logger_name, log_path, log_level: str = "INFO"):
5151
"""
5252
Sets the logger to log info in terminal and file `log_path`.
5353
5454
In general, it is useful to have a logger so that every output to the terminal is saved
5555
in a permanent file. Here we save it to `log_path/experiment.log`.
5656
57+
:param logger_name: (string) name of the default logger
5758
:param log_path: (string) where to log
5859
:param log_level: sets the log level
5960
"""
60-
logger = logging.getLogger()
61+
logger = logging.getLogger(logger_name)
6162
logger.setLevel(log_level)
6263

6364
if not logger.handlers:

0 commit comments

Comments
 (0)