|
6 | 6 |
|
7 | 7 | from ding.envs import BaseEnvManager
|
8 | 8 | from ding.worker import VectorEvalMonitor, InteractionSerialEvaluator
|
9 |
| -from ding.torch_utils import to_tensor, to_ndarray |
| 9 | +from ding.torch_utils import to_tensor, to_ndarray, to_item |
10 | 10 | from ding.utils import SERIAL_EVALUATOR_REGISTRY, import_module
|
11 | 11 |
|
12 | 12 |
|
@@ -66,15 +66,14 @@ def eval(
|
66 | 66 | - n_episode (:obj:`int`): Number of evaluation episodes.
|
67 | 67 | Returns:
|
68 | 68 | - stop_flag (:obj:`bool`): Whether this training program can be ended.
|
69 |
| - - return_info (:obj:`dict`): Current evaluation return information. |
| 69 | + - episode_info (:obj:`dict`): Current evaluation return information. |
70 | 70 | '''
|
71 | 71 |
|
72 | 72 | if n_episode is None:
|
73 | 73 | n_episode = self._default_n_episode
|
74 | 74 | assert n_episode is not None, "please indicate eval n_episode"
|
75 | 75 | envstep_count = 0
|
76 | 76 | info = {}
|
77 |
| - return_info = [] |
78 | 77 | eval_monitor = TradingEvalMonitor(self._env.env_num, n_episode)
|
79 | 78 | self._env.reset()
|
80 | 79 | self._policy.reset()
|
@@ -105,10 +104,8 @@ def eval(
|
105 | 104 | # Env reset is done by env_manager automatically.
|
106 | 105 | self._policy.reset([env_id])
|
107 | 106 | reward = t.info['eval_episode_return']
|
108 |
| - if 'episode_info' in t.info: |
109 |
| - eval_monitor.update_info(env_id, t.info['episode_info']) |
| 107 | + eval_monitor.update_info(env_id, t.info) |
110 | 108 | eval_monitor.update_reward(env_id, reward)
|
111 |
| - return_info.append(t.info) |
112 | 109 |
|
113 | 110 | #========== only used by anytrading =======
|
114 | 111 | if 'max_possible_profit' in t.info:
|
@@ -185,7 +182,8 @@ def eval(
|
185 | 182 | "Current episode_return: {} is greater than stop_value: {}".format(episode_return, self._stop_value) +
|
186 | 183 | ", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
|
187 | 184 | )
|
188 |
| - return stop_flag, return_info |
| 185 | + episode_info = to_item(episode_info) |
| 186 | + return stop_flag, episode_info |
189 | 187 |
|
190 | 188 |
|
191 | 189 | class TradingEvalMonitor(VectorEvalMonitor):
|
|
0 commit comments