1
1
import logging
2
+ import re
2
3
from typing import Union , Tuple
3
4
4
5
import librosa
5
6
import numpy as np
6
7
import tensorflow as tf
8
+ import warnings
7
9
8
10
from src .feature_extractor .base_extractor import BaseExtractor
9
11
from src .input_pipeline .base_dataset import BaseDataset , DatasetType
@@ -24,6 +26,8 @@ def __init__(self,
24
26
:param params: parameters of the current experiment.
25
27
:param log: if the pipeline should log details about the data.
26
28
"""
29
+ # ignore warnings, such as the librosa warnings
30
+ warnings .filterwarnings ('ignore' )
27
31
28
32
self .dataset_path = Utils .check_if_path_exists (params .dcase_dataset_path )
29
33
self .dataset_name = params .dataset
@@ -32,7 +36,6 @@ def __init__(self,
32
36
33
37
self .num_parallel_calls = params .num_parallel_calls
34
38
self .gen_count = params .gen_count
35
- self .gen_index = 0
36
39
37
40
self .sample_rate = params .sample_rate
38
41
self .sample_size = params .sample_size
@@ -66,25 +69,25 @@ def __init__(self,
66
69
67
70
def reinitialise (self ):
68
71
self .logger .info ("Reinitialising the input pipeline" )
69
- self .gen_index = 0
70
72
self .dataset .initialise ()
71
73
72
74
def generate_samples (self , gen_name : str , trim : bool , return_labels : bool ) -> Tuple [np .ndarray , np .ndarray ,
73
75
np .ndarray , np .ndarray ]:
74
76
75
77
gen_name = gen_name .decode ("utf-8" )
78
+ gen_index = int (re .findall ('[0-9]+' , gen_name )[0 ])
76
79
77
- self .dataset .current_index = self . gen_index
78
- for index , anchor in enumerate ( self .dataset ) :
79
- self .dataset .current_index = self . gen_index
80
- if self .log and False :
81
- self .logger .debug ("{0}, index:{1 }" .format (gen_name , self . gen_index ))
80
+ self .dataset .current_index = gen_index
81
+ for anchor in self .dataset :
82
+ current_index = self .dataset .current_index - 1
83
+ if self .log :
84
+ self .logger .debug ("{0}, {1}, index:{2 }" .format (gen_name , gen_index , current_index ))
82
85
83
86
# fill the opposite sample buffer
84
- opposite_audios = self .dataset .fill_opposite_selection (index )
87
+ opposite_audios = self .dataset .fill_opposite_selection (current_index )
85
88
86
89
# load audio files from anchor
87
- anchor = self .dataset .df_train .iloc [index ]
90
+ anchor = self .dataset .df .iloc [current_index ]
88
91
if self .dataset_name == "MusicDataset" :
89
92
anchor_audio , _ = librosa .load (anchor .file_name , self .sample_rate )
90
93
anchor_audio , _ = librosa .effects .trim (anchor_audio )
@@ -96,7 +99,7 @@ def generate_samples(self, gen_name: str, trim: bool, return_labels: bool) -> Tu
96
99
anchor_audio_length = int (len (anchor_audio ) / self .sample_rate )
97
100
98
101
try :
99
- triplets = self .dataset .get_triplets (index , anchor_audio_length , trim = trim ,
102
+ triplets = self .dataset .get_triplets (current_index , anchor_audio_length , trim = trim ,
100
103
opposite_choices = opposite_audios )
101
104
except ValueError as err :
102
105
self .logger .debug ("Error during triplet creation: {}" .format (err ))
@@ -122,17 +125,15 @@ def generate_samples(self, gen_name: str, trim: bool, return_labels: bool) -> Tu
122
125
labels = [- 1 , - 1 , - 1 ]
123
126
labels = np .asarray (labels )
124
127
125
- if self . gen_index % 1000 == 0 and self . gen_index is not 0 and self .log :
128
+ if current_index % 1000 == 0 and current_index is not 0 and self .log :
126
129
self .logger .debug ("{0} yields sound segments {1}, a: {2}, n: {3}, o: {4}" .format (gen_name ,
127
- self . dataset . current_index ,
130
+ current_index ,
128
131
anchor_seg ,
129
132
neighbour_seg ,
130
133
opposite_seg ))
131
134
132
135
yield anchor_audio_seg , neighbour_audio_seg , opposite_audio_seg , labels
133
136
134
- self .gen_index += 1
135
-
136
137
def get_dataset (self , feature_extractor : Union [BaseExtractor , None ], dataset_type : DatasetType = DatasetType .TRAIN ,
137
138
shuffle : bool = True , trim : bool = True , return_labels : bool = False ):
138
139
0 commit comments