Skip to content

Commit cf163b7

Browse files
committedOct 8, 2023
refactor nasbenchasr
1 parent 56b5beb commit cf163b7

File tree

1 file changed

+45
-6
lines changed

1 file changed

+45
-6
lines changed
 

‎hyperbox/networks/nasbenchasr/model.py

+45-6
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,42 @@ def build_query(
224224
self._query = from_folder(folder, max_epochs, seeds, devices, include_static_info, validate_data)
225225
return self._query
226226

227+
def is_dataset_ready(self):
228+
flag1 = os.path.exists(self.NASBenchASR_DATAPATH)
229+
if flag1:
230+
flag2 = len(os.listdir(self.NASBenchASR_DATAPATH)) > 0
231+
return flag1 and flag2
232+
return flag1
233+
234+
def prepare_dataset(self):
235+
if not os.path.exists(self.NASBenchASR_DATAPATH):
236+
os.makedirs(self.NASBenchASR_DATAPATH)
237+
if len(os.listdir(self.NASBenchASR_DATAPATH)) <= 0:
238+
crt_file_path = os.path.abspath(__file__)
239+
crt_folder_path = os.path.dirname(crt_file_path)
240+
download_shell = os.path.join(crt_folder_path, 'download_nasbenchasr.sh')
241+
print('Downloading NASBenchASR dataset...')
242+
os.system(f"bash {download_shell}")
243+
print('Done')
244+
245+
def query_by_key(self, key: str, **kwargs):
246+
if not self.is_dataset_ready():
247+
self.prepare_dataset()
248+
if key == 'full':
249+
return self.query_full_info(**kwargs)
250+
elif key == 'test_acc':
251+
return self.query_test_acc(**kwargs)
252+
elif key == 'val_acc':
253+
return self.query_val_acc(**kwargs)
254+
elif key == 'latency':
255+
return self.query_latency(**kwargs)
256+
elif key == 'params':
257+
return self.query_params(**kwargs)
258+
elif key == 'flops':
259+
return self.query_flops(**kwargs)
260+
else:
261+
raise NotImplementedError(f'{key} not supported.')
262+
227263
def query_full_info(self, **kwargs):
228264
default_kwargs = {
229265
"arch": self.arch,
@@ -377,9 +413,12 @@ def dict_mask_to_list_desc(self, mask: dict):
377413
y = model(x)
378414
print(NASBenchASR.dict_mask_to_list_desc(mask))
379415

380-
print(model2.query_full_info(max_epochs=5))
381-
print(model2.query_flops())
382-
print(model2.query_latency())
383-
print(model2.query_params())
384-
print(model2.query_test_acc())
385-
print(model2.query_val_acc())
416+
# print(model2.query_full_info(max_epochs=5))
417+
# print(model2.query_flops())
418+
# print(model2.query_latency())
419+
# print(model2.query_params())
420+
# print(model2.query_test_acc())
421+
# print(model2.query_val_acc())
422+
423+
for key in ['full', 'flops', 'test_acc', 'params', 'val_acc', 'latency']:
424+
print(key, model2.query_by_key(key))

0 commit comments

Comments
 (0)
Please sign in to comment.