21
21
from data_io import read_lab_fea ,open_or_fd ,write_mat
22
22
from utils import shift
23
23
24
- def run_nn_refac01 (data_name ,data_set ,data_end_index ,fea_dict ,lab_dict ,arch_dict ,cfg_file ,processed_first ,next_config_file ,dry_run = False ):
24
+ def read_next_chunk_into_shared_list_with_subprocess (read_lab_fea , shared_list , cfg_file , is_production , output_folder , wait_for_process ):
25
+ p = threading .Thread (target = read_lab_fea , args = (cfg_file ,is_production ,shared_list ,output_folder ,))
26
+ p .start ()
27
+ if wait_for_process :
28
+ p .join ()
29
+ return None
30
+ else :
31
+ return p
32
+ def extract_data_from_shared_list (shared_list ):
33
+ data_name = shared_list [0 ]
34
+ data_end_index_fea = shared_list [1 ]
35
+ data_end_index_lab = shared_list [2 ]
36
+ fea_dict = shared_list [3 ]
37
+ lab_dict = shared_list [4 ]
38
+ arch_dict = shared_list [5 ]
39
+ data_set = shared_list [6 ]
40
+ return data_name , data_end_index_fea , data_end_index_lab , fea_dict , lab_dict , arch_dict , data_set
41
+ def convert_numpy_to_torch (data_set_dict , save_gpumem , use_cuda ):
42
+ if not (save_gpumem ) and use_cuda :
43
+ data_set_inp = torch .from_numpy (data_set_dict ['input' ]).float ().cuda ()
44
+ data_set_ref = torch .from_numpy (data_set_dict ['ref' ]).float ().cuda ()
45
+ else :
46
+ data_set_inp = torch .from_numpy (data_set_dict ['input' ]).float ()
47
+ data_set_ref = torch .from_numpy (data_set_dict ['ref' ]).float ()
48
+ data_set_ref = data_set_ref .view ((data_set_ref .shape [0 ], 1 ))
49
+ return data_set_inp , data_set_ref
50
+ def run_nn_refac01 (data_name ,data_set ,data_end_index ,fea_dict ,lab_dict ,arch_dict ,cfg_file ,processed_first ,next_config_file ):
25
51
def _read_chunk_specific_config (cfg_file ):
26
52
if not (os .path .exists (cfg_file )):
27
53
sys .stderr .write ('ERROR: The config file %s does not exist!\n ' % (cfg_file ))
@@ -30,23 +56,6 @@ def _read_chunk_specific_config(cfg_file):
30
56
config = configparser .ConfigParser ()
31
57
config .read (cfg_file )
32
58
return config
33
- def _read_next_chunk_into_shared_list_with_subprocess (read_lab_fea , shared_list , cfg_file , is_production , output_folder , wait_for_process ):
34
- p = threading .Thread (target = read_lab_fea , args = (cfg_file ,is_production ,shared_list ,output_folder ,))
35
- p .start ()
36
- if wait_for_process :
37
- p .join ()
38
- return None
39
- else :
40
- return p
41
- def _extract_data_from_shared_list (shared_list ):
42
- data_name = shared_list [0 ]
43
- data_end_index_fea = shared_list [1 ]
44
- data_end_index_lab = shared_list [2 ]
45
- fea_dict = shared_list [3 ]
46
- lab_dict = shared_list [4 ]
47
- arch_dict = shared_list [5 ]
48
- data_set = shared_list [6 ]
49
- return data_name , data_end_index_fea , data_end_index_lab , fea_dict , lab_dict , arch_dict , data_set
50
59
def _get_batch_size_from_config (config , to_do ):
51
60
if to_do == 'train' :
52
61
batch_size = int (config ['batches' ]['batch_size_train' ])
@@ -60,15 +69,6 @@ def _initialize_random_seed(config):
60
69
torch .manual_seed (seed )
61
70
random .seed (seed )
62
71
np .random .seed (seed )
63
- def _convert_numpy_to_torch (data_set_dict , save_gpumem , use_cuda ):
64
- if not (save_gpumem ) and use_cuda :
65
- data_set_inp = torch .from_numpy (data_set_dict ['input' ]).float ().cuda ()
66
- data_set_ref = torch .from_numpy (data_set_dict ['ref' ]).float ().cuda ()
67
- else :
68
- data_set_inp = torch .from_numpy (data_set_dict ['input' ]).float ()
69
- data_set_ref = torch .from_numpy (data_set_dict ['ref' ]).float ()
70
- data_set_ref = data_set_ref .view ((data_set_ref .shape [0 ], 1 ))
71
- return data_set_inp , data_set_ref
72
72
def _load_model_and_optimizer (fea_dict ,model ,config ,arch_dict ,use_cuda ,multi_gpu ,to_do ):
73
73
inp_out_dict = fea_dict
74
74
nns , costs = model_init (inp_out_dict ,model ,config ,arch_dict ,use_cuda ,multi_gpu ,to_do )
@@ -221,16 +221,18 @@ def _get_dim_from_data_set(data_set_inp, data_set_ref):
221
221
222
222
if processed_first :
223
223
shared_list = list ()
224
- p = _read_next_chunk_into_shared_list_with_subprocess (read_lab_fea , shared_list , cfg_file , is_production , output_folder , wait_for_process = True )
225
- data_name , data_end_index_fea , data_end_index_lab , fea_dict , lab_dict , arch_dict , data_set_dict = _extract_data_from_shared_list (shared_list )
226
- data_set_inp , data_set_ref = _convert_numpy_to_torch (data_set_dict , save_gpumem , use_cuda )
224
+ p = read_next_chunk_into_shared_list_with_subprocess (read_lab_fea , shared_list , cfg_file , is_production , output_folder , wait_for_process = True )
225
+ data_name , data_end_index_fea , data_end_index_lab , fea_dict , lab_dict , arch_dict , data_set_dict = extract_data_from_shared_list (shared_list )
226
+ data_set_inp , data_set_ref = convert_numpy_to_torch (data_set_dict , save_gpumem , use_cuda )
227
227
else :
228
228
data_set_inp = data_set ['input' ]
229
229
data_set_ref = data_set ['ref' ]
230
230
data_end_index_fea = data_end_index ['fea' ]
231
231
data_end_index_lab = data_end_index ['lab' ]
232
232
shared_list = list ()
233
- data_loading_process = _read_next_chunk_into_shared_list_with_subprocess (read_lab_fea , shared_list , next_config_file , is_production , output_folder , wait_for_process = False )
233
+ data_loading_process = None
234
+ if not next_config_file is None :
235
+ data_loading_process = read_next_chunk_into_shared_list_with_subprocess (read_lab_fea , shared_list , next_config_file , is_production , output_folder , wait_for_process = False )
234
236
nns , costs , optimizers , inp_out_dict = _load_model_and_optimizer (fea_dict ,model ,config ,arch_dict ,use_cuda ,multi_gpu ,to_do )
235
237
if to_do == 'forward' :
236
238
post_file = _open_forward_output_files_and_get_file_handles (forward_outs , require_decodings , info_file , output_folder )
@@ -253,25 +255,22 @@ def _get_dim_from_data_set(data_set_inp, data_set_ref):
253
255
start_time = time .time ()
254
256
for i in range (N_batches ):
255
257
inp , ref , max_len_fea , max_len_lab , snt_len_fea , snt_len_lab , beg_snt_fea , beg_snt_lab , snt_index = _prepare_input (snt_index , batch_size , data_set_inp_dim , data_set_ref_dim , beg_snt_fea , beg_snt_lab , data_end_index_fea , data_end_index_lab , beg_batch , end_batch , seq_model , arr_snt_len_fea , arr_snt_len_lab , data_set_inp , data_set_ref , use_cuda )
256
- if dry_run :
257
- outs_dict = dict ()
258
+ if to_do == 'train' :
259
+ outs_dict = forward_model (fea_dict , lab_dict , arch_dict , model , nns , costs , inp , ref , inp_out_dict , max_len_fea , max_len_lab , batch_size , to_do , forward_outs )
260
+ _optimization_step (optimizers , outs_dict , config , arch_dict )
258
261
else :
259
- if to_do == 'train' :
262
+ with torch . no_grad () :
260
263
outs_dict = forward_model (fea_dict , lab_dict , arch_dict , model , nns , costs , inp , ref , inp_out_dict , max_len_fea , max_len_lab , batch_size , to_do , forward_outs )
261
- _optimization_step (optimizers , outs_dict , config , arch_dict )
262
- else :
263
- with torch .no_grad ():
264
- outs_dict = forward_model (fea_dict , lab_dict , arch_dict , model , nns , costs , inp , ref , inp_out_dict , max_len_fea , max_len_lab , batch_size , to_do , forward_outs )
265
- if to_do == 'forward' :
266
- for out_id in range (len (forward_outs )):
267
- out_save = outs_dict [forward_outs [out_id ]].data .cpu ().numpy ()
268
- if forward_normalize_post [out_id ]:
269
- counts = load_counts (forward_count_files [out_id ])
270
- out_save = out_save - np .log (counts / np .sum (counts ))
271
- write_mat (output_folder ,post_file [forward_outs [out_id ]], out_save , data_name [i ])
272
- else :
273
- loss_sum = loss_sum + outs_dict ['loss_final' ].detach ()
274
- err_sum = err_sum + outs_dict ['err_final' ].detach ()
264
+ if to_do == 'forward' :
265
+ for out_id in range (len (forward_outs )):
266
+ out_save = outs_dict [forward_outs [out_id ]].data .cpu ().numpy ()
267
+ if forward_normalize_post [out_id ]:
268
+ counts = load_counts (forward_count_files [out_id ])
269
+ out_save = out_save - np .log (counts / np .sum (counts ))
270
+ write_mat (output_folder ,post_file [forward_outs [out_id ]], out_save , data_name [i ])
271
+ else :
272
+ loss_sum = loss_sum + outs_dict ['loss_final' ].detach ()
273
+ err_sum = err_sum + outs_dict ['err_final' ].detach ()
275
274
beg_batch = end_batch
276
275
end_batch = beg_batch + batch_size
277
276
_update_progress_bar (to_do , i , N_batches , loss_sum )
@@ -286,13 +285,16 @@ def _get_dim_from_data_set(data_set_inp, data_set_ref):
286
285
_write_info_file (info_file , to_do , loss_tot , err_tot , elapsed_time_chunk )
287
286
if not data_loading_process is None :
288
287
data_loading_process .join ()
289
- data_name , data_end_index_fea , data_end_index_lab , fea_dict , lab_dict , arch_dict , data_set_dict = _extract_data_from_shared_list (shared_list )
290
- data_set_inp , data_set_ref = _convert_numpy_to_torch (data_set_dict , save_gpumem , use_cuda )
291
- data_set = {'input' : data_set_inp , 'ref' : data_set_ref }
292
- data_end_index = {'fea' : data_end_index_fea ,'lab' : data_end_index_lab }
293
- return [data_name ,data_set ,data_end_index ,fea_dict ,lab_dict ,arch_dict ]
288
+ data_name , data_end_index_fea , data_end_index_lab , fea_dict , lab_dict , arch_dict , data_set_dict = extract_data_from_shared_list (shared_list )
289
+ data_set_inp , data_set_ref = convert_numpy_to_torch (data_set_dict , save_gpumem , use_cuda )
290
+ data_set = {'input' : data_set_inp , 'ref' : data_set_ref }
291
+ data_end_index = {'fea' : data_end_index_fea ,'lab' : data_end_index_lab }
292
+ return [data_name ,data_set ,data_end_index ,fea_dict ,lab_dict ,arch_dict ]
293
+ else :
294
+ return [None ,None ,None ,None ,None ,None ]
295
+
294
296
295
- def run_nn (data_name ,data_set ,data_end_index ,fea_dict ,lab_dict ,arch_dict ,cfg_file ,processed_first ,next_config_file , dry_run = False ):
297
+ def run_nn (data_name ,data_set ,data_end_index ,fea_dict ,lab_dict ,arch_dict ,cfg_file ,processed_first ,next_config_file ):
296
298
297
299
# This function processes the current chunk using the information in cfg_file. In parallel, the next chunk is load into the CPU memory
298
300
@@ -479,17 +481,13 @@ def run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,cfg_fil
479
481
480
482
if to_do == 'train' :
481
483
# Forward input, with autograd graph active
482
- if not dry_run :
483
- outs_dict = forward_model (fea_dict ,lab_dict ,arch_dict ,model ,nns ,costs ,inp ,inp_out_dict ,max_len ,batch_size ,to_do ,forward_outs )
484
- else :
485
- outs_dict = dict ()
484
+ outs_dict = forward_model (fea_dict ,lab_dict ,arch_dict ,model ,nns ,costs ,inp ,inp_out_dict ,max_len ,batch_size ,to_do ,forward_outs )
486
485
487
486
for opt in optimizers .keys ():
488
487
optimizers [opt ].zero_grad ()
489
488
490
489
491
- if not dry_run :
492
- outs_dict ['loss_final' ].backward ()
490
+ outs_dict ['loss_final' ].backward ()
493
491
494
492
# Gradient Clipping (th 0.1)
495
493
#for net in nns.keys():
@@ -501,28 +499,24 @@ def run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,cfg_fil
501
499
optimizers [opt ].step ()
502
500
else :
503
501
with torch .no_grad (): # Forward input without autograd graph (save memory)
504
- if not dry_run :
505
- outs_dict = forward_model (fea_dict ,lab_dict ,arch_dict ,model ,nns ,costs ,inp ,inp_out_dict ,max_len ,batch_size ,to_do ,forward_outs )
506
- else :
507
- outs_dict = dict ()
502
+ outs_dict = forward_model (fea_dict ,lab_dict ,arch_dict ,model ,nns ,costs ,inp ,inp_out_dict ,max_len ,batch_size ,to_do ,forward_outs )
508
503
509
504
510
- if not dry_run :
511
- if to_do == 'forward' :
512
- for out_id in range (len (forward_outs )):
513
-
514
- out_save = outs_dict [forward_outs [out_id ]].data .cpu ().numpy ()
505
+ if to_do == 'forward' :
506
+ for out_id in range (len (forward_outs )):
507
+
508
+ out_save = outs_dict [forward_outs [out_id ]].data .cpu ().numpy ()
509
+
510
+ if forward_normalize_post [out_id ]:
511
+ # read the config file
512
+ counts = load_counts (forward_count_files [out_id ])
513
+ out_save = out_save - np .log (counts / np .sum (counts ))
515
514
516
- if forward_normalize_post [out_id ]:
517
- # read the config file
518
- counts = load_counts (forward_count_files [out_id ])
519
- out_save = out_save - np .log (counts / np .sum (counts ))
520
-
521
- # save the output
522
- write_mat (output_folder ,post_file [forward_outs [out_id ]], out_save , data_name [i ])
523
- else :
524
- loss_sum = loss_sum + outs_dict ['loss_final' ].detach ()
525
- err_sum = err_sum + outs_dict ['err_final' ].detach ()
515
+ # save the output
516
+ write_mat (output_folder ,post_file [forward_outs [out_id ]], out_save , data_name [i ])
517
+ else :
518
+ loss_sum = loss_sum + outs_dict ['loss_final' ].detach ()
519
+ err_sum = err_sum + outs_dict ['err_final' ].detach ()
526
520
527
521
# update it to the next batch
528
522
beg_batch = end_batch
0 commit comments