@@ -224,6 +224,42 @@ def build_query(
224
224
self ._query = from_folder (folder , max_epochs , seeds , devices , include_static_info , validate_data )
225
225
return self ._query
226
226
227
+ def is_dataset_ready (self ):
228
+ flag1 = os .path .exists (self .NASBenchASR_DATAPATH )
229
+ if flag1 :
230
+ flag2 = len (os .listdir (self .NASBenchASR_DATAPATH )) > 0
231
+ return flag1 and flag2
232
+ return flag1
233
+
234
+ def prepare_dataset (self ):
235
+ if not os .path .exists (self .NASBenchASR_DATAPATH ):
236
+ os .makedirs (self .NASBenchASR_DATAPATH )
237
+ if len (os .listdir (self .NASBenchASR_DATAPATH )) <= 0 :
238
+ crt_file_path = os .path .abspath (__file__ )
239
+ crt_folder_path = os .path .dirname (crt_file_path )
240
+ download_shell = os .path .join (crt_folder_path , 'download_nasbenchasr.sh' )
241
+ print ('Downloading NASBenchASR dataset...' )
242
+ os .system (f"bash { download_shell } " )
243
+ print ('Done' )
244
+
245
+ def query_by_key (self , key : str , ** kwargs ):
246
+ if not self .is_dataset_ready ():
247
+ self .prepare_dataset ()
248
+ if key == 'full' :
249
+ return self .query_full_info (** kwargs )
250
+ elif key == 'test_acc' :
251
+ return self .query_test_acc (** kwargs )
252
+ elif key == 'val_acc' :
253
+ return self .query_val_acc (** kwargs )
254
+ elif key == 'latency' :
255
+ return self .query_latency (** kwargs )
256
+ elif key == 'params' :
257
+ return self .query_params (** kwargs )
258
+ elif key == 'flops' :
259
+ return self .query_flops (** kwargs )
260
+ else :
261
+ raise NotImplementedError (f'{ key } not supported.' )
262
+
227
263
def query_full_info (self , ** kwargs ):
228
264
default_kwargs = {
229
265
"arch" : self .arch ,
@@ -377,9 +413,12 @@ def dict_mask_to_list_desc(self, mask: dict):
377
413
y = model (x )
378
414
print (NASBenchASR .dict_mask_to_list_desc (mask ))
379
415
380
- print (model2 .query_full_info (max_epochs = 5 ))
381
- print (model2 .query_flops ())
382
- print (model2 .query_latency ())
383
- print (model2 .query_params ())
384
- print (model2 .query_test_acc ())
385
- print (model2 .query_val_acc ())
416
+ # print(model2.query_full_info(max_epochs=5))
417
+ # print(model2.query_flops())
418
+ # print(model2.query_latency())
419
+ # print(model2.query_params())
420
+ # print(model2.query_test_acc())
421
+ # print(model2.query_val_acc())
422
+
423
+ for key in ['full' , 'flops' , 'test_acc' , 'params' , 'val_acc' , 'latency' ]:
424
+ print (key , model2 .query_by_key (key ))
0 commit comments