Skip to content

Commit 9fede9b

Browse files
committed
polish(pu): polish resume_training in entry
1 parent f165c59 commit 9fede9b

11 files changed

+15
-21
lines changed

ding/entry/serial_entry.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def serial_pipeline(
5454
auto=True,
5555
create_cfg=create_cfg,
5656
save_cfg=True,
57-
renew_dir=not cfg.policy.learn.resume_training
57+
renew_dir=not cfg.policy.learn.get('resume_training', False)
5858
)
5959
# Create main components: env, policy
6060
if env_setting is None:
@@ -94,7 +94,7 @@ def serial_pipeline(
9494
# ==========
9595
# Learner's before_run hook.
9696
learner.call_hook('before_run')
97-
if cfg.policy.learn.resume_training:
97+
if cfg.policy.learn.get('resume_training', False):
9898
collector.envstep = learner.collector_envstep
9999

100100
# Accumulate plenty of data at the beginning of training.

ding/entry/serial_entry_mbrl.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def mbrl_entry_setup(
3737
auto=True,
3838
create_cfg=create_cfg,
3939
save_cfg=True,
40-
renew_dir=not cfg.policy.learn.resume_training
40+
renew_dir=not cfg.policy.learn.get('resume_training', False)
4141
)
4242

4343
if env_setting is None:
@@ -79,8 +79,7 @@ def mbrl_entry_setup(
7979
)
8080

8181
return (
82-
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger,
83-
resume_training
82+
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger
8483
)
8584

8685

@@ -125,13 +124,13 @@ def serial_pipeline_dyna(
125124
Returns:
126125
- policy (:obj:`Policy`): Converged policy.
127126
"""
128-
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, resume_training = \
127+
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
129128
mbrl_entry_setup(input_cfg, seed, env_setting, model)
130129

131130
img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger)
132131

133132
learner.call_hook('before_run')
134-
if cfg.policy.learn.resume_training:
133+
if cfg.policy.learn.get('resume_training', False):
135134
collector.envstep = learner.collector_envstep
136135

137136
if cfg.policy.get('random_collect_size', 0) > 0:
@@ -200,11 +199,11 @@ def serial_pipeline_dream(
200199
Returns:
201200
- policy (:obj:`Policy`): Converged policy.
202201
"""
203-
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, resume_training = \
202+
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
204203
mbrl_entry_setup(input_cfg, seed, env_setting, model)
205204

206205
learner.call_hook('before_run')
207-
if cfg.policy.learn.resume_training:
206+
if cfg.policy.learn.get('resume_training', False):
208207
collector.envstep = learner.collector_envstep
209208

210209
if cfg.policy.get('random_collect_size', 0) > 0:

ding/entry/serial_entry_ngu.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def serial_pipeline_ngu(
5454
auto=True,
5555
create_cfg=create_cfg,
5656
save_cfg=True,
57-
renew_dir=not cfg.policy.learn.resume_training
57+
renew_dir=not cfg.policy.learn.get('resume_training', False)
5858
)
5959
# Create main components: env, policy
6060
if env_setting is None:
@@ -97,7 +97,7 @@ def serial_pipeline_ngu(
9797
# ==========
9898
# Learner's before_run hook.
9999
learner.call_hook('before_run')
100-
if cfg.policy.learn.resume_training:
100+
if cfg.policy.learn.get('resume_training', False):
101101
collector.envstep = learner.collector_envstep
102102

103103
# Accumulate plenty of data at the beginning of training.

ding/entry/serial_entry_onpolicy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ def serial_pipeline_onpolicy(
5252
auto=True,
5353
create_cfg=create_cfg,
5454
save_cfg=True,
55-
renew_dir=not cfg.policy.learn.resume_training
55+
renew_dir=not cfg.policy.learn.get('resume_training', False)
5656
)
57+
5758
# Create main components: env, policy
5859
if env_setting is None:
5960
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
@@ -88,7 +89,7 @@ def serial_pipeline_onpolicy(
8889
# ==========
8990
# Learner's before_run hook.
9091
learner.call_hook('before_run')
91-
if cfg.policy.learn.resume_training:
92+
if cfg.policy.learn.get('resume_training', False):
9293
collector.envstep = learner.collector_envstep
9394

9495
while True:

ding/entry/serial_entry_onpolicy_ppg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def serial_pipeline_onpolicy_ppg(
5252
auto=True,
5353
create_cfg=create_cfg,
5454
save_cfg=True,
55-
renew_dir=not cfg.policy.learn.resume_training
55+
renew_dir=not cfg.policy.learn.get('resume_training', False)
5656
)
5757
# Create main components: env, policy
5858
if env_setting is None:
@@ -88,7 +88,7 @@ def serial_pipeline_onpolicy_ppg(
8888
# ==========
8989
# Learner's before_run hook.
9090
learner.call_hook('before_run')
91-
if cfg.policy.learn.resume_training:
91+
if cfg.policy.learn.get('resume_training', False):
9292
collector.envstep = learner.collector_envstep
9393

9494
while True:

dizoo/classic_control/cartpole/config/cartpole_a2c_config.py

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
learning_rate=0.001,
2222
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
2323
entropy_weight=0.01,
24-
resume_training=False,
2524
),
2625
collect=dict(
2726
# (int) collect n_sample data, train model n_iteration times

dizoo/classic_control/cartpole/config/cartpole_pg_config.py

-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
batch_size=64,
1919
learning_rate=0.001,
2020
entropy_weight=0.001,
21-
resume_training=False,
2221
),
2322
collect=dict(n_episode=80, unroll_len=1, discount_factor=0.9),
2423
eval=dict(evaluator=dict(eval_freq=100, ), ),

dizoo/classic_control/cartpole/config/cartpole_ppo_stdim_config.py

-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
entropy_weight=0.01,
3838
clip_ratio=0.2,
3939
learner=dict(hook=dict(save_ckpt_after_iter=100)),
40-
resume_training=False,
4140
),
4241
collect=dict(
4342
n_sample=256,

dizoo/classic_control/cartpole/config/cartpole_ppopg_config.py

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
batch_size=64,
2020
learning_rate=0.001,
2121
entropy_weight=0.001,
22-
resume_training=False,
2322
),
2423
collect=dict(n_episode=80, unroll_len=1, discount_factor=0.9, collector=dict(get_train_sample=True)),
2524
eval=dict(evaluator=dict(eval_freq=100, ), ),

dizoo/classic_control/pendulum/config/pendulum_ppo_config.py

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
adv_norm=True,
3636
value_norm=True,
3737
ignore_done=True,
38-
resume_training=False,
3938
),
4039
collect=dict(
4140
n_sample=5000,

dizoo/petting_zoo/config/ptz_simple_spread_mappo_config.py

-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
grad_clip_type='clip_norm',
5454
grad_clip_value=10,
5555
ignore_done=False,
56-
resume_training=False,
5756
),
5857
collect=dict(
5958
n_sample=3200,

0 commit comments

Comments
 (0)