Skip to content

Commit 49fc489

Browse files
committed
fix(nyz): fix evaluator return episode_info compatibility bug
1 parent 521284b commit 49fc489

File tree

5 files changed

+19
-24
lines changed

5 files changed

+19
-24
lines changed

ding/entry/application_entry.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ def eval(
7272

7373
# Evaluate
7474
_, episode_info = evaluator.eval()
75-
reward = [e['eval_episode_return'] for e in episode_info]
76-
episode_return = np.mean(to_ndarray(reward))
75+
episode_return = np.mean(episode_info['eval_episode_return'])
7776
print('Eval is over! The performance of your RL policy is {}'.format(episode_return))
7877
return episode_return
7978

ding/entry/serial_entry_reward_model_offpolicy.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,16 @@ def serial_pipeline_reward_model_offpolicy(
8989
if cfg.policy.get('random_collect_size', 0) > 0:
9090
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
9191
count = 0
92-
best_reward = -np.inf
92+
best_return = -np.inf
9393
while True:
9494
collect_kwargs = commander.step()
9595
# Evaluate policy performance
9696
if evaluator.should_eval(learner.train_iter):
97-
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
98-
reward_mean = np.array([r['eval_episode_return'] for r in reward]).mean()
99-
if reward_mean >= best_reward:
97+
stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
98+
eval_return_mean = np.mean(eval_info['eval_episode_return'])
99+
if eval_return_mean >= best_return:
100100
reward_model.save(path=cfg.exp_name, name='best')
101-
best_reward = reward_mean
101+
best_return = eval_return_mean
102102
if stop:
103103
break
104104
new_data_count, target_new_data_count = 0, cfg.reward_model.get('target_new_data_count', 1)

ding/entry/serial_entry_reward_model_onpolicy.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,16 @@ def serial_pipeline_reward_model_onpolicy(
8989
if cfg.policy.get('random_collect_size', 0) > 0:
9090
random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
9191
count = 0
92-
best_reward = -np.inf
92+
best_return = -np.inf
9393
while True:
9494
collect_kwargs = commander.step()
9595
# Evaluate policy performance
9696
if evaluator.should_eval(learner.train_iter):
97-
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
98-
reward_mean = np.array([r['eval_episode_return'] for r in reward]).mean()
99-
if reward_mean >= best_reward:
97+
stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
98+
eval_return_mean = np.mean(eval_info['eval_episode_return'])
99+
if eval_return_mean >= best_return:
100100
reward_model.save(path=cfg.exp_name, name='best')
101-
best_reward = reward_mean
101+
best_return = eval_return_mean
102102
if stop:
103103
break
104104
new_data_count, target_new_data_count = 0, cfg.reward_model.get('target_new_data_count', 1)

ding/worker/collector/interaction_serial_evaluator.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ class InteractionSerialEvaluator(ISerialEvaluator):
3131
),
3232
# (str) File path for visualize environment information.
3333
figure_path=None,
34-
# (bool) Whether to return env info in termination step.
35-
return_env_info=True,
3634
)
3735

3836
def __init__(
@@ -253,10 +251,10 @@ def eval(
253251
self._env.enable_save_figure(env_id, self._cfg.figure_path)
254252
self._policy.reset([env_id])
255253
reward = t.info['eval_episode_return']
254+
saved_info = {'eval_episode_return': t.info['eval_episode_return']}
256255
if 'episode_info' in t.info:
257-
eval_monitor.update_info(env_id, t.info['episode_info'])
258-
elif self._cfg.return_env_info:
259-
eval_monitor.update_info(env_id, t.info)
256+
saved_info.update(t.info['episode_info'])
257+
eval_monitor.update_info(env_id, saved_info)
260258
eval_monitor.update_reward(env_id, reward)
261259
self._logger.info(
262260
"[EVALUATOR]env {} finish episode, final reward: {:.4f}, current episode: {}".format(

dizoo/gym_anytrading/worker/trading_serial_evaluator.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ding.envs import BaseEnvManager
88
from ding.worker import VectorEvalMonitor, InteractionSerialEvaluator
9-
from ding.torch_utils import to_tensor, to_ndarray
9+
from ding.torch_utils import to_tensor, to_ndarray, to_item
1010
from ding.utils import SERIAL_EVALUATOR_REGISTRY, import_module
1111

1212

@@ -66,15 +66,14 @@ def eval(
6666
- n_episode (:obj:`int`): Number of evaluation episodes.
6767
Returns:
6868
- stop_flag (:obj:`bool`): Whether this training program can be ended.
69-
- return_info (:obj:`dict`): Current evaluation return information.
69+
- episode_info (:obj:`dict`): Current evaluation return information.
7070
'''
7171

7272
if n_episode is None:
7373
n_episode = self._default_n_episode
7474
assert n_episode is not None, "please indicate eval n_episode"
7575
envstep_count = 0
7676
info = {}
77-
return_info = []
7877
eval_monitor = TradingEvalMonitor(self._env.env_num, n_episode)
7978
self._env.reset()
8079
self._policy.reset()
@@ -105,10 +104,8 @@ def eval(
105104
# Env reset is done by env_manager automatically.
106105
self._policy.reset([env_id])
107106
reward = t.info['eval_episode_return']
108-
if 'episode_info' in t.info:
109-
eval_monitor.update_info(env_id, t.info['episode_info'])
107+
eval_monitor.update_info(env_id, t.info)
110108
eval_monitor.update_reward(env_id, reward)
111-
return_info.append(t.info)
112109

113110
#========== only used by anytrading =======
114111
if 'max_possible_profit' in t.info:
@@ -185,7 +182,8 @@ def eval(
185182
"Current episode_return: {} is greater than stop_value: {}".format(episode_return, self._stop_value) +
186183
", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
187184
)
188-
return stop_flag, return_info
185+
episode_info = to_item(episode_info)
186+
return stop_flag, episode_info
189187

190188

191189
class TradingEvalMonitor(VectorEvalMonitor):

0 commit comments

Comments
 (0)