Skip to content

Commit b09ffda

Browse files
committed
feature(pu): add load pretrained ckpt in serial_entry_onpolicy and serial_entry
1 parent d88ebe2 commit b09ffda

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

ding/entry/serial_entry.py

+9
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ def serial_pipeline(
6060
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
6161
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
6262

63+
# Load pretrained model if specified
64+
if cfg.policy.load_path is not None:
65+
logging.info(f'Loading model from {cfg.policy.load_path} begin...')
66+
if cfg.policy.cuda and torch.cuda.is_available():
67+
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cuda'))
68+
else:
69+
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
70+
logging.info(f'Loading model from {cfg.policy.load_path} end!')
71+
6372
# Create worker components: learner, collector, evaluator, replay buffer, commander.
6473
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
6574
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

ding/entry/serial_entry_onpolicy.py

+9
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ def serial_pipeline_onpolicy(
5858
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
5959
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
6060

61+
# Load pretrained model if specified
62+
if cfg.policy.load_path is not None:
63+
logging.info(f'Loading model from {cfg.policy.load_path} begin...')
64+
if cfg.policy.cuda and torch.cuda.is_available():
65+
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cuda'))
66+
else:
67+
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
68+
logging.info(f'Loading model from {cfg.policy.load_path} end!')
69+
6170
# Create worker components: learner, collector, evaluator, replay buffer, commander.
6271
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
6372
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

0 commit comments

Comments
 (0)