@@ -37,7 +37,7 @@ def mbrl_entry_setup(
37
37
auto = True ,
38
38
create_cfg = create_cfg ,
39
39
save_cfg = True ,
40
- renew_dir = not cfg .policy .learn .resume_training
40
+ renew_dir = not cfg .policy .learn .get ( ' resume_training' , False )
41
41
)
42
42
43
43
if env_setting is None :
@@ -79,8 +79,7 @@ def mbrl_entry_setup(
79
79
)
80
80
81
81
return (
82
- cfg , policy , world_model , env_buffer , learner , collector , collector_env , evaluator , commander , tb_logger ,
83
- resume_training
82
+ cfg , policy , world_model , env_buffer , learner , collector , collector_env , evaluator , commander , tb_logger
84
83
)
85
84
86
85
@@ -125,13 +124,13 @@ def serial_pipeline_dyna(
125
124
Returns:
126
125
- policy (:obj:`Policy`): Converged policy.
127
126
"""
128
- cfg , policy , world_model , env_buffer , learner , collector , collector_env , evaluator , commander , tb_logger , resume_training = \
127
+ cfg , policy , world_model , env_buffer , learner , collector , collector_env , evaluator , commander , tb_logger = \
129
128
mbrl_entry_setup (input_cfg , seed , env_setting , model )
130
129
131
130
img_buffer = create_img_buffer (cfg , input_cfg , world_model , tb_logger )
132
131
133
132
learner .call_hook ('before_run' )
134
- if cfg .policy .learn .resume_training :
133
+ if cfg .policy .learn .get ( ' resume_training' , False ) :
135
134
collector .envstep = learner .collector_envstep
136
135
137
136
if cfg .policy .get ('random_collect_size' , 0 ) > 0 :
@@ -200,11 +199,11 @@ def serial_pipeline_dream(
200
199
Returns:
201
200
- policy (:obj:`Policy`): Converged policy.
202
201
"""
203
- cfg , policy , world_model , env_buffer , learner , collector , collector_env , evaluator , commander , tb_logger , resume_training = \
202
+ cfg , policy , world_model , env_buffer , learner , collector , collector_env , evaluator , commander , tb_logger = \
204
203
mbrl_entry_setup (input_cfg , seed , env_setting , model )
205
204
206
205
learner .call_hook ('before_run' )
207
- if cfg .policy .learn .resume_training :
206
+ if cfg .policy .learn .get ( ' resume_training' , False ) :
208
207
collector .envstep = learner .collector_envstep
209
208
210
209
if cfg .policy .get ('random_collect_size' , 0 ) > 0 :
0 commit comments