@@ -80,15 +80,15 @@ def generator(self, utt_ids):
8080 for i , utt_id in enumerate (utt_ids ):
8181 audio_file = self .audio_files [i ]
8282 mel_file = self .mel_files [i ]
83-
83+
8484 items = {
8585 "utt_ids" : utt_id ,
8686 "audio_files" : audio_file ,
87- "mel_files" : mel_file
87+ "mel_files" : mel_file ,
8888 }
8989
9090 yield items
91-
91+
9292 @tf .function
9393 def _load_data (self , items ):
9494 audio = tf .numpy_function (np .load , [items ["audio_files" ]], tf .float32 )
@@ -101,7 +101,7 @@ def _load_data(self, items):
101101 "mel_lengths" : len (mel ),
102102 "audio_lengths" : len (audio ),
103103 }
104-
104+
105105 return items
106106
107107 def create (
@@ -120,8 +120,7 @@ def create(
120120
121121 # load dataset
122122 datasets = datasets .map (
123- lambda items : self ._load_data (items ),
124- tf .data .experimental .AUTOTUNE
123+ lambda items : self ._load_data (items ), tf .data .experimental .AUTOTUNE
125124 )
126125
127126 datasets = datasets .filter (
@@ -165,17 +164,19 @@ def create(
165164 }
166165
167166 datasets = datasets .padded_batch (
168- batch_size , padded_shapes = padded_shapes , padding_values = padding_values
167+ batch_size ,
168+ padded_shapes = padded_shapes ,
169+ padding_values = padding_values ,
170+ drop_remainder = True ,
169171 )
170172 datasets = datasets .prefetch (tf .data .experimental .AUTOTUNE )
171-
172173 return datasets
173174
174175 def get_output_dtypes (self ):
175176 output_types = {
176177 "utt_ids" : tf .string ,
177178 "audio_files" : tf .string ,
178- "mel_files" : tf .string
179+ "mel_files" : tf .string ,
179180 }
180181 return output_types
181182
0 commit comments