diff --git a/optimize.py b/optimize.py index 97bd86b..fb5634b 100644 --- a/optimize.py +++ b/optimize.py @@ -1,6 +1,6 @@ +import multiprocessing import numpy as np -import multiprocessing from lib.RLTrader import RLTrader np.warnings.filterwarnings('ignore') @@ -13,19 +13,17 @@ def optimize_code(params): if __name__ == '__main__': n_process = multiprocessing.cpu_count() - params = {'n_cpu': n_process} + params = {'n_envs': n_process} - # processes = [] - # for i in range(n_process): - # processes.append(multiprocessing.Process(target=optimize_code, args=(params,))) + processes = [] + for i in range(n_process): + processes.append(multiprocessing.Process(target=optimize_code, args=(params,))) - # for p in processes: - # p.start() + for p in processes: + p.start() - # for p in processes: - # p.join() + for p in processes: + p.join() trader = RLTrader(**params) - # trader.train(test_trained_model=True, render_trained_model=True) - - trader.test(model_epoch=10) + trader.train(test_trained_model=True, render_trained_model=True)