Skip to content

Commit

Permalink
Fix TF hanging issue in multiprocessing pools.
Browse files Browse the repository at this point in the history
  • Loading branch information
notadamking committed Jul 8, 2019
1 parent e8c8eab commit a2b8610
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 29 deletions.
32 changes: 15 additions & 17 deletions cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import multiprocessing

from multiprocessing.pool import ThreadPool

from lib.RLTrader import RLTrader
from lib.cli.RLTraderCLI import RLTraderCLI
Expand All @@ -11,32 +12,29 @@
args = trader_cli.get_args()


def run_concurrent_optimize():
trader = RLTrader(**vars(args))
trader.optimize(args.trials)

def run_optimize(params):
trader_args, logger = params

def concurrent_optimize():
processes = []
for i in range(args.parallel_jobs):
processes.append(multiprocessing.Process(target=run_concurrent_optimize, args=()))
trader = RLTrader(**vars(trader_args), logger=logger)
trader.optimize(trader_args.trials)

print(processes)

for p in processes:
p.start()
def optimize_concurrent(trader_args, logger):
n_processes = trader_args.parallel_jobs

for p in processes:
p.join()
opt_pool = ThreadPool(processes=n_processes)
opt_pool.map(run_optimize, [((trader_args, logger)) for _ in range(n_processes)])


if __name__ == '__main__':
logger = init_logger(__name__, show_debug=args.debug)
trader = RLTrader(**vars(args), logger=logger)

if args.command == 'optimize':
concurrent_optimize()
elif args.command == 'train':
optimize_concurrent(args, logger)

trader = RLTrader(**vars(args), logger=logger)

if args.command == 'train':
trader.train(n_epochs=args.epochs)
elif args.command == 'test':
trader.test(model_epoch=args.model_epoch, should_render=args.no_render)
Expand Down
19 changes: 7 additions & 12 deletions optimize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import multiprocessing
import os
import numpy as np

from multiprocessing.pool import ThreadPool

from lib.RLTrader import RLTrader

np.warnings.filterwarnings('ignore')
Expand All @@ -12,18 +14,11 @@ def optimize_code(params):


if __name__ == '__main__':
n_process = multiprocessing.cpu_count()
params = {'n_envs': n_process}

processes = []
for i in range(n_process):
processes.append(multiprocessing.Process(target=optimize_code, args=(params,)))

for p in processes:
p.start()
n_processes = 6 # os.cpu_count()
params = {'n_envs': n_processes}

for p in processes:
p.join()
opt_pool = ThreadPool(processes=n_processes)
opt_pool.map(optimize_code, [params for _ in range(n_processes)])

trader = RLTrader(**params)
trader.train(test_trained_model=True, render_trained_model=True)

0 comments on commit a2b8610

Please sign in to comment.