diff --git a/mjrl/samplers/core.py b/mjrl/samplers/core.py index be4a988..b8d5408 100644 --- a/mjrl/samplers/core.py +++ b/mjrl/samplers/core.py @@ -134,7 +134,7 @@ def sample_paths( start_time = timer.time() print("####### Gathering Samples #######") - results = _try_multiprocess(do_rollout, input_dict_list, + results = _try_multiprocess_cf(do_rollout, input_dict_list, num_cpu, max_process_time, max_timeouts) paths = [] # result is a paths type and results is list of paths @@ -186,7 +186,7 @@ def sample_data_batch( return paths -def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts): +def _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts): # Base case if max_timeouts == 0: @@ -202,9 +202,29 @@ def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_time pool.close() pool.terminate() pool.join() - return _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1) + return _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1) pool.close() pool.terminate() pool.join() return results + +def _try_multiprocess_cf(func, input_dict_list, num_cpu, max_process_time, max_timeouts): + import concurrent.futures + results = None + if max_timeouts != 0: + with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpu) as executor: + submit_futures = [executor.submit(func, **input_dict) for input_dict in input_dict_list] + try: + results = [f.result() for f in submit_futures] + except TimeoutError as e: + print(str(e)) + print("Timeout Error raised...") + except concurrent.futures.CancelledError as e: + print(str(e)) + print("Future Cancelled Error raised...") + except Exception as e: + print(str(e)) + print("Error raised...") + raise e + return results diff --git a/mjrl/utils/tensor_utils.py b/mjrl/utils/tensor_utils.py index 8b0002a..1372fca 100644 --- a/mjrl/utils/tensor_utils.py +++ b/mjrl/utils/tensor_utils.py @@ -61,7 +61,7 @@ def high_res_normalize(probs): def stack_tensor_list(tensor_list): - return np.array(tensor_list) + return np.array(tensor_list, dtype='object') # tensor_shape = np.array(tensor_list[0]).shape # if tensor_shape is tuple(): # return np.array(tensor_list)