Skip to content

Commit 6351027

Browse files
authored
Merge pull request #138 from mennetob/dev_master
Cleaning up multiprocessing forwarding
2 parents 329f2c9 + a70f1dc commit 6351027

File tree

2 files changed

+112
-106
lines changed

2 files changed

+112
-106
lines changed

core.py

+73-79
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,33 @@
2121
from data_io import read_lab_fea,open_or_fd,write_mat
2222
from utils import shift
2323

24-
def run_nn_refac01(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,cfg_file,processed_first,next_config_file,dry_run=False):
24+
def read_next_chunk_into_shared_list_with_subprocess(read_lab_fea, shared_list, cfg_file, is_production, output_folder, wait_for_process):
25+
p=threading.Thread(target=read_lab_fea, args=(cfg_file,is_production,shared_list,output_folder,))
26+
p.start()
27+
if wait_for_process:
28+
p.join()
29+
return None
30+
else:
31+
return p
32+
def extract_data_from_shared_list(shared_list):
33+
data_name = shared_list[0]
34+
data_end_index_fea = shared_list[1]
35+
data_end_index_lab = shared_list[2]
36+
fea_dict = shared_list[3]
37+
lab_dict = shared_list[4]
38+
arch_dict = shared_list[5]
39+
data_set = shared_list[6]
40+
return data_name, data_end_index_fea, data_end_index_lab, fea_dict, lab_dict, arch_dict, data_set
41+
def convert_numpy_to_torch(data_set_dict, save_gpumem, use_cuda):
42+
if not(save_gpumem) and use_cuda:
43+
data_set_inp=torch.from_numpy(data_set_dict['input']).float().cuda()
44+
data_set_ref=torch.from_numpy(data_set_dict['ref']).float().cuda()
45+
else:
46+
data_set_inp=torch.from_numpy(data_set_dict['input']).float()
47+
data_set_ref=torch.from_numpy(data_set_dict['ref']).float()
48+
data_set_ref = data_set_ref.view((data_set_ref.shape[0], 1))
49+
return data_set_inp, data_set_ref
50+
def run_nn_refac01(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,cfg_file,processed_first,next_config_file):
2551
def _read_chunk_specific_config(cfg_file):
2652
if not(os.path.exists(cfg_file)):
2753
sys.stderr.write('ERROR: The config file %s does not exist!\n'%(cfg_file))
@@ -30,23 +56,6 @@ def _read_chunk_specific_config(cfg_file):
3056
config = configparser.ConfigParser()
3157
config.read(cfg_file)
3258
return config
33-
def _read_next_chunk_into_shared_list_with_subprocess(read_lab_fea, shared_list, cfg_file, is_production, output_folder, wait_for_process):
34-
p=threading.Thread(target=read_lab_fea, args=(cfg_file,is_production,shared_list,output_folder,))
35-
p.start()
36-
if wait_for_process:
37-
p.join()
38-
return None
39-
else:
40-
return p
41-
def _extract_data_from_shared_list(shared_list):
42-
data_name = shared_list[0]
43-
data_end_index_fea = shared_list[1]
44-
data_end_index_lab = shared_list[2]
45-
fea_dict = shared_list[3]
46-
lab_dict = shared_list[4]
47-
arch_dict = shared_list[5]
48-
data_set = shared_list[6]
49-
return data_name, data_end_index_fea, data_end_index_lab, fea_dict, lab_dict, arch_dict, data_set
5059
def _get_batch_size_from_config(config, to_do):
5160
if to_do=='train':
5261
batch_size=int(config['batches']['batch_size_train'])
@@ -60,15 +69,6 @@ def _initialize_random_seed(config):
6069
torch.manual_seed(seed)
6170
random.seed(seed)
6271
np.random.seed(seed)
63-
def _convert_numpy_to_torch(data_set_dict, save_gpumem, use_cuda):
64-
if not(save_gpumem) and use_cuda:
65-
data_set_inp=torch.from_numpy(data_set_dict['input']).float().cuda()
66-
data_set_ref=torch.from_numpy(data_set_dict['ref']).float().cuda()
67-
else:
68-
data_set_inp=torch.from_numpy(data_set_dict['input']).float()
69-
data_set_ref=torch.from_numpy(data_set_dict['ref']).float()
70-
data_set_ref = data_set_ref.view((data_set_ref.shape[0], 1))
71-
return data_set_inp, data_set_ref
7272
def _load_model_and_optimizer(fea_dict,model,config,arch_dict,use_cuda,multi_gpu,to_do):
7373
inp_out_dict = fea_dict
7474
nns, costs = model_init(inp_out_dict,model,config,arch_dict,use_cuda,multi_gpu,to_do)
@@ -221,16 +221,18 @@ def _get_dim_from_data_set(data_set_inp, data_set_ref):
221221

222222
if processed_first:
223223
shared_list = list()
224-
p = _read_next_chunk_into_shared_list_with_subprocess(read_lab_fea, shared_list, cfg_file, is_production, output_folder, wait_for_process=True)
225-
data_name, data_end_index_fea, data_end_index_lab, fea_dict, lab_dict, arch_dict, data_set_dict = _extract_data_from_shared_list(shared_list)
226-
data_set_inp, data_set_ref = _convert_numpy_to_torch(data_set_dict, save_gpumem, use_cuda)
224+
p = read_next_chunk_into_shared_list_with_subprocess(read_lab_fea, shared_list, cfg_file, is_production, output_folder, wait_for_process=True)
225+
data_name, data_end_index_fea, data_end_index_lab, fea_dict, lab_dict, arch_dict, data_set_dict = extract_data_from_shared_list(shared_list)
226+
data_set_inp, data_set_ref = convert_numpy_to_torch(data_set_dict, save_gpumem, use_cuda)
227227
else:
228228
data_set_inp = data_set['input']
229229
data_set_ref = data_set['ref']
230230
data_end_index_fea = data_end_index['fea']
231231
data_end_index_lab = data_end_index['lab']
232232
shared_list = list()
233-
data_loading_process = _read_next_chunk_into_shared_list_with_subprocess(read_lab_fea, shared_list, next_config_file, is_production, output_folder, wait_for_process=False)
233+
data_loading_process = None
234+
if not next_config_file is None:
235+
data_loading_process = read_next_chunk_into_shared_list_with_subprocess(read_lab_fea, shared_list, next_config_file, is_production, output_folder, wait_for_process=False)
234236
nns, costs, optimizers, inp_out_dict = _load_model_and_optimizer(fea_dict,model,config,arch_dict,use_cuda,multi_gpu,to_do)
235237
if to_do=='forward':
236238
post_file = _open_forward_output_files_and_get_file_handles(forward_outs, require_decodings, info_file, output_folder)
@@ -253,25 +255,22 @@ def _get_dim_from_data_set(data_set_inp, data_set_ref):
253255
start_time = time.time()
254256
for i in range(N_batches):
255257
inp, ref, max_len_fea, max_len_lab, snt_len_fea, snt_len_lab, beg_snt_fea, beg_snt_lab, snt_index = _prepare_input(snt_index, batch_size, data_set_inp_dim, data_set_ref_dim, beg_snt_fea, beg_snt_lab, data_end_index_fea, data_end_index_lab, beg_batch, end_batch, seq_model, arr_snt_len_fea, arr_snt_len_lab, data_set_inp, data_set_ref, use_cuda)
256-
if dry_run:
257-
outs_dict = dict()
258+
if to_do=='train':
259+
outs_dict = forward_model(fea_dict, lab_dict, arch_dict, model, nns, costs, inp, ref, inp_out_dict, max_len_fea, max_len_lab, batch_size, to_do, forward_outs)
260+
_optimization_step(optimizers, outs_dict, config, arch_dict)
258261
else:
259-
if to_do=='train':
262+
with torch.no_grad():
260263
outs_dict = forward_model(fea_dict, lab_dict, arch_dict, model, nns, costs, inp, ref, inp_out_dict, max_len_fea, max_len_lab, batch_size, to_do, forward_outs)
261-
_optimization_step(optimizers, outs_dict, config, arch_dict)
262-
else:
263-
with torch.no_grad():
264-
outs_dict = forward_model(fea_dict, lab_dict, arch_dict, model, nns, costs, inp, ref, inp_out_dict, max_len_fea, max_len_lab, batch_size, to_do, forward_outs)
265-
if to_do == 'forward':
266-
for out_id in range(len(forward_outs)):
267-
out_save = outs_dict[forward_outs[out_id]].data.cpu().numpy()
268-
if forward_normalize_post[out_id]:
269-
counts = load_counts(forward_count_files[out_id])
270-
out_save=out_save-np.log(counts/np.sum(counts))
271-
write_mat(output_folder,post_file[forward_outs[out_id]], out_save, data_name[i])
272-
else:
273-
loss_sum=loss_sum+outs_dict['loss_final'].detach()
274-
err_sum=err_sum+outs_dict['err_final'].detach()
264+
if to_do == 'forward':
265+
for out_id in range(len(forward_outs)):
266+
out_save = outs_dict[forward_outs[out_id]].data.cpu().numpy()
267+
if forward_normalize_post[out_id]:
268+
counts = load_counts(forward_count_files[out_id])
269+
out_save=out_save-np.log(counts/np.sum(counts))
270+
write_mat(output_folder,post_file[forward_outs[out_id]], out_save, data_name[i])
271+
else:
272+
loss_sum=loss_sum+outs_dict['loss_final'].detach()
273+
err_sum=err_sum+outs_dict['err_final'].detach()
275274
beg_batch=end_batch
276275
end_batch=beg_batch+batch_size
277276
_update_progress_bar(to_do, i, N_batches, loss_sum)
@@ -286,13 +285,16 @@ def _get_dim_from_data_set(data_set_inp, data_set_ref):
286285
_write_info_file(info_file, to_do, loss_tot, err_tot, elapsed_time_chunk)
287286
if not data_loading_process is None:
288287
data_loading_process.join()
289-
data_name, data_end_index_fea, data_end_index_lab, fea_dict, lab_dict, arch_dict, data_set_dict = _extract_data_from_shared_list(shared_list)
290-
data_set_inp, data_set_ref = _convert_numpy_to_torch(data_set_dict, save_gpumem, use_cuda)
291-
data_set = {'input': data_set_inp, 'ref': data_set_ref}
292-
data_end_index = {'fea': data_end_index_fea,'lab': data_end_index_lab}
293-
return [data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict]
288+
data_name, data_end_index_fea, data_end_index_lab, fea_dict, lab_dict, arch_dict, data_set_dict = extract_data_from_shared_list(shared_list)
289+
data_set_inp, data_set_ref = convert_numpy_to_torch(data_set_dict, save_gpumem, use_cuda)
290+
data_set = {'input': data_set_inp, 'ref': data_set_ref}
291+
data_end_index = {'fea': data_end_index_fea,'lab': data_end_index_lab}
292+
return [data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict]
293+
else:
294+
return [None,None,None,None,None,None]
295+
294296

295-
def run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,cfg_file,processed_first,next_config_file,dry_run=False):
297+
def run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,cfg_file,processed_first,next_config_file):
296298

297299
# This function processes the current chunk using the information in cfg_file. In parallel, the next chunk is load into the CPU memory
298300

@@ -479,17 +481,13 @@ def run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,cfg_fil
479481

480482
if to_do=='train':
481483
# Forward input, with autograd graph active
482-
if not dry_run:
483-
outs_dict=forward_model(fea_dict,lab_dict,arch_dict,model,nns,costs,inp,inp_out_dict,max_len,batch_size,to_do,forward_outs)
484-
else:
485-
outs_dict = dict()
484+
outs_dict=forward_model(fea_dict,lab_dict,arch_dict,model,nns,costs,inp,inp_out_dict,max_len,batch_size,to_do,forward_outs)
486485

487486
for opt in optimizers.keys():
488487
optimizers[opt].zero_grad()
489488

490489

491-
if not dry_run:
492-
outs_dict['loss_final'].backward()
490+
outs_dict['loss_final'].backward()
493491

494492
# Gradient Clipping (th 0.1)
495493
#for net in nns.keys():
@@ -501,28 +499,24 @@ def run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,cfg_fil
501499
optimizers[opt].step()
502500
else:
503501
with torch.no_grad(): # Forward input without autograd graph (save memory)
504-
if not dry_run:
505-
outs_dict=forward_model(fea_dict,lab_dict,arch_dict,model,nns,costs,inp,inp_out_dict,max_len,batch_size,to_do,forward_outs)
506-
else:
507-
outs_dict = dict()
502+
outs_dict=forward_model(fea_dict,lab_dict,arch_dict,model,nns,costs,inp,inp_out_dict,max_len,batch_size,to_do,forward_outs)
508503

509504

510-
if not dry_run:
511-
if to_do=='forward':
512-
for out_id in range(len(forward_outs)):
513-
514-
out_save=outs_dict[forward_outs[out_id]].data.cpu().numpy()
505+
if to_do=='forward':
506+
for out_id in range(len(forward_outs)):
507+
508+
out_save=outs_dict[forward_outs[out_id]].data.cpu().numpy()
509+
510+
if forward_normalize_post[out_id]:
511+
# read the config file
512+
counts = load_counts(forward_count_files[out_id])
513+
out_save=out_save-np.log(counts/np.sum(counts))
515514

516-
if forward_normalize_post[out_id]:
517-
# read the config file
518-
counts = load_counts(forward_count_files[out_id])
519-
out_save=out_save-np.log(counts/np.sum(counts))
520-
521-
# save the output
522-
write_mat(output_folder,post_file[forward_outs[out_id]], out_save, data_name[i])
523-
else:
524-
loss_sum=loss_sum+outs_dict['loss_final'].detach()
525-
err_sum=err_sum+outs_dict['err_final'].detach()
515+
# save the output
516+
write_mat(output_folder,post_file[forward_outs[out_id]], out_save, data_name[i])
517+
else:
518+
loss_sum=loss_sum+outs_dict['loss_final'].detach()
519+
err_sum=err_sum+outs_dict['err_final'].detach()
526520

527521
# update it to the next batch
528522
beg_batch=end_batch

run_exp.py

+39-27
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@
1717
from utils import check_cfg,create_lists,create_configs, compute_avg_performance, \
1818
read_args_command_line, run_shell,compute_n_chunks, get_all_archs,cfg_item2sec, \
1919
dump_epoch_results, create_curves,change_lr_cfg,expand_str_ep
20+
from data_io import read_lab_fea_refac01 as read_lab_fea
2021
from shutil import copyfile
22+
from core import read_next_chunk_into_shared_list_with_subprocess, extract_data_from_shared_list, convert_numpy_to_torch
2123
import re
2224
from distutils.util import strtobool
2325
import importlib
2426
import math
2527
import multiprocessing
2628

29+
def _run_forwarding_in_subprocesses(config):
30+
use_cuda=strtobool(config['exp']['use_cuda'])
31+
if use_cuda:
32+
return False
33+
else:
34+
return True
35+
2736
# Reading global cfg file (first argument-mandatory file)
2837
cfg_file=sys.argv[1]
2938
if not(os.path.exists(cfg_file)):
@@ -309,7 +318,8 @@
309318
N_ck_forward=compute_n_chunks(out_folder,forward_data,ep,N_ep_str_format,'forward')
310319
N_ck_str_format='0'+str(max(math.ceil(np.log10(N_ck_forward)),1))+'d'
311320

312-
kwargs_list = list()
321+
processes = list()
322+
info_files = list()
313323
for ck in range(N_ck_forward):
314324

315325
if not is_production:
@@ -331,36 +341,38 @@
331341
next_config_file=cfg_file_list[op_counter]
332342

333343
# run chunk processing
334-
#[data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict]=run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,config_chunk_file,processed_first,next_config_file)
335-
kwargs = dict()
336-
for e in ['data_name','data_set','data_end_index','fea_dict','lab_dict','arch_dict','config_chunk_file','processed_first','next_config_file']:
337-
if e == "config_chunk_file":
338-
kwargs['cfg_file'] = eval(e)
339-
else:
340-
kwargs[e] = eval(e)
341-
kwargs_list.append(kwargs)
342-
[data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict]=run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,config_chunk_file,processed_first,next_config_file,dry_run=True)
343-
344+
if _run_forwarding_in_subprocesses(config):
345+
shared_list = list()
346+
output_folder = config['exp']['out_folder']
347+
save_gpumem = strtobool(config['exp']['save_gpumem'])
348+
use_cuda=strtobool(config['exp']['use_cuda'])
349+
p = read_next_chunk_into_shared_list_with_subprocess(read_lab_fea, shared_list, config_chunk_file, is_production, output_folder, wait_for_process=True)
350+
data_name, data_end_index_fea, data_end_index_lab, fea_dict, lab_dict, arch_dict, data_set_dict = extract_data_from_shared_list(shared_list)
351+
data_set_inp, data_set_ref = convert_numpy_to_torch(data_set_dict, save_gpumem, use_cuda)
352+
data_set = {'input': data_set_inp, 'ref': data_set_ref}
353+
data_end_index = {'fea': data_end_index_fea,'lab': data_end_index_lab}
354+
p = multiprocessing.Process(target=run_nn, kwargs={'data_name': data_name, 'data_set': data_set, 'data_end_index': data_end_index, 'fea_dict': fea_dict, 'lab_dict': lab_dict, 'arch_dict': arch_dict, 'cfg_file': config_chunk_file, 'processed_first': False, 'next_config_file': None})
355+
processes.append(p)
356+
p.start()
357+
else:
358+
[data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict]=run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,config_chunk_file,processed_first,next_config_file)
359+
processed_first=False
360+
if not(os.path.exists(info_file)):
361+
sys.stderr.write("ERROR: forward chunk %i of dataset %s not done! File %s does not exist.\nSee %s \n" % (ck,forward_data,info_file,log_file))
362+
sys.exit(0)
344363

345-
# update the first_processed variable
346-
processed_first=False
347-
348-
if not(os.path.exists(info_file)):
349-
sys.stderr.write("ERROR: forward chunk %i of dataset %s not done! File %s does not exist.\nSee %s \n" % (ck,forward_data,info_file,log_file))
350-
sys.exit(0)
351-
364+
info_files.append(info_file)
352365

353366
# update the operation counter
354367
op_counter+=1
355-
processes = list()
356-
for kwargs in kwargs_list:
357-
p = multiprocessing.Process(target=run_nn, kwargs=kwargs)
358-
processes.append(p)
359-
p.start()
360-
for process in processes:
361-
process.join()
362-
363-
368+
if _run_forwarding_in_subprocesses(config):
369+
for process in processes:
370+
process.join()
371+
for info_file in info_files:
372+
if not(os.path.exists(info_file)):
373+
sys.stderr.write("ERROR: File %s does not exist. Forwarding did not suceed.\nSee %s \n" % (info_file,log_file))
374+
sys.exit(0)
375+
364376

365377

366378
# --------DECODING--------#

0 commit comments

Comments
 (0)