Skip to content

Commit 370506f

Browse files
puyuan1996‘whl’dyyounggpuyuanPaParaZz1
authored
feature(whl/pu): add jericho environments and related configs (#307)
* init commit * add bert encoding * debug * debug and polish env * polish config * update * polish(pu): polish jericho config and pipeline * polish(pu): polish HFLanguageRepresentationNetwork * fix(pu): fix padded action_mask * sync code * feature(pu): vllm server for embedding model inference * fix(pu): fix update_per_collect in ddp train * fix(pu): fix train_entry_time to the same in ddp train * polish(pu): polish jericho uz config * feature(pu): add remove_stuck_actions option * polish(pu): set projection layer train after language embed model * polish(pu): polish jericho_env and add jericho_ppo_config * sync code * fix(pu): fix gradient accumulation_steps option * fix(pu): fix action_mask all-zero bug * fix(pu): fix lr target_model_update bug when accumulation_steps>1 * polish(pu): polish jericho env and its unizero/ppo config * polish(pu): polish jericho unizero config * polish(pu): polish jericho unizero config * polish(pu): polish comments * polish(pu): polish jericho configs * polish(pu): polish jericho configs --------- Co-authored-by: ‘whl’ <‘[email protected]’> Co-authored-by: dyyoungg <[email protected]> Co-authored-by: puyuan <[email protected]> Co-authored-by: PaParaZz1 <[email protected]>
1 parent a24a54d commit 370506f

18 files changed

+1199
-87
lines changed
File renamed without changes.

Diff for: lzero/entry/train_unizero.py

+90-54
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from lzero.worker import MuZeroEvaluator as Evaluator
2222
from lzero.worker import MuZeroCollector as Collector
2323
from .utils import random_collect, calculate_update_per_collect
24+
import torch.distributed as dist
25+
from ding.utils import set_pkg_seed, get_rank, get_world_size
2426

2527

2628
def train_unizero(
@@ -33,167 +35,201 @@ def train_unizero(
3335
) -> 'Policy':
3436
"""
3537
Overview:
36-
The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models.
38+
This function serves as the training entry point for UniZero, as proposed in our paper "UniZero: Generalized and Efficient Planning with Scalable Latent World Models".
3739
UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms,
38-
particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
40+
particularly in environments that require capturing long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
41+
3942
Arguments:
40-
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
41-
``Tuple[dict, dict]`` type means [user_config, create_cfg].
42-
- seed (:obj:`int`): Random seed.
43-
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
44-
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
45-
point to the ckpt file of the pretrained model, and an absolute path is recommended.
46-
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
47-
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
48-
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
43+
- input_cfg (:obj:`Tuple[dict, dict]`): Configuration in dictionary format.
44+
``Tuple[dict, dict]`` indicates [user_config, create_cfg].
45+
- seed (:obj:`int`): Random seed for reproducibility.
46+
- model (:obj:`Optional[torch.nn.Module]`): Instance of a PyTorch model.
47+
- model_path (:obj:`Optional[str]`): Path to the pretrained model, which should
48+
point to the checkpoint file of the pretrained model. An absolute path is recommended.
49+
In LightZero, the path typically resembles ``exp_name/ckpt/ckpt_best.pth.tar``.
50+
- max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations during training.
51+
- max_env_step (:obj:`Optional[int]`): Maximum number of environment interaction steps to collect.
52+
4953
Returns:
50-
- policy (:obj:`Policy`): Converged policy.
54+
- policy (:obj:`Policy`): The converged policy after training.
5155
"""
5256

5357
cfg, create_cfg = input_cfg
5458

5559
# Ensure the specified policy type is supported
56-
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"
60+
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero only supports the following algorithms: 'unizero', 'sampled_unizero'"
61+
logging.info(f"Using policy type: {create_cfg.policy.type}")
5762

58-
# Import the correct GameBuffer class based on the policy type
63+
# Import the appropriate GameBuffer class based on the policy type
5964
game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}
60-
6165
GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
6266
game_buffer_classes[create_cfg.policy.type])
6367

64-
# Set device based on CUDA availability
68+
# Check for GPU availability and set the device accordingly
6569
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
66-
logging.info(f'cfg.policy.device: {cfg.policy.device}')
70+
logging.info(f"Device set to: {cfg.policy.device}")
6771

68-
# Compile the configuration
72+
# Compile the configuration file
6973
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
7074

71-
# Create main components: env, policy
75+
# Create environment manager
7276
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
7377
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
7478
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
7579

80+
# Initialize environment and random seed
7681
collector_env.seed(cfg.seed)
7782
evaluator_env.seed(cfg.seed, dynamic_seed=False)
7883
set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available())
7984

85+
# Initialize wandb if specified
8086
if cfg.policy.use_wandb:
81-
# Initialize wandb
87+
logging.info("Initializing wandb...")
8288
wandb.init(
8389
project="LightZero",
8490
config=cfg,
8591
sync_tensorboard=False,
8692
monitor_gym=False,
8793
save_code=True,
8894
)
95+
logging.info("wandb initialization completed!")
8996

97+
# Create policy
98+
logging.info("Creating policy...")
9099
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
100+
logging.info("Policy created successfully!")
91101

92102
# Load pretrained model if specified
93103
if model_path is not None:
94-
logging.info(f'Loading model from {model_path} begin...')
104+
logging.info(f"Loading pretrained model from {model_path}...")
95105
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
96-
logging.info(f'Loading model from {model_path} end!')
106+
logging.info("Pretrained model loaded successfully!")
97107

98-
# Create worker components: learner, collector, evaluator, replay buffer, commander
108+
# Create core components for training
99109
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
100110
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
101-
102-
# MCTS+RL algorithms related core code
103-
policy_config = cfg.policy
104-
replay_buffer = GameBuffer(policy_config)
111+
replay_buffer = GameBuffer(cfg.policy)
105112
collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name,
106-
policy_config=policy_config)
113+
policy_config=cfg.policy)
107114
evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode,
108115
stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode,
109-
tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config)
116+
tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=cfg.policy)
110117

111-
# Learner's before_run hook
118+
# Execute the learner's before_run hook
112119
learner.call_hook('before_run')
113-
if policy_config.use_wandb:
120+
121+
if cfg.policy.use_wandb:
114122
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)
115123

116-
# Collect random data before training
124+
# Randomly collect data if specified
117125
if cfg.policy.random_collect_episode_num > 0:
126+
logging.info("Collecting random data...")
118127
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
128+
logging.info("Random data collection completed!")
119129

120130
batch_size = policy._cfg.batch_size
121131

132+
if cfg.policy.multi_gpu:
133+
# Get current world size and rank
134+
world_size = get_world_size()
135+
rank = get_rank()
136+
else:
137+
world_size = 1
138+
rank = 0
139+
122140
while True:
123-
# Log buffer memory usage
141+
# Log memory usage of the replay buffer
124142
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
125143

126-
# Set temperature for visit count distributions
144+
# Set temperature parameter for data collection
127145
collect_kwargs = {
128146
'temperature': visit_count_temperature(
129-
policy_config.manual_temperature_decay,
130-
policy_config.fixed_temperature_value,
131-
policy_config.threshold_training_steps_for_final_temperature,
147+
cfg.policy.manual_temperature_decay,
148+
cfg.policy.fixed_temperature_value,
149+
cfg.policy.threshold_training_steps_for_final_temperature,
132150
trained_steps=learner.train_iter
133151
),
134152
'epsilon': 0.0 # Default epsilon value
135153
}
136154

137-
# Configure epsilon for epsilon-greedy exploration
138-
if policy_config.eps.eps_greedy_exploration_in_collect:
155+
# Configure epsilon-greedy exploration
156+
if cfg.policy.eps.eps_greedy_exploration_in_collect:
139157
epsilon_greedy_fn = get_epsilon_greedy_fn(
140-
start=policy_config.eps.start,
141-
end=policy_config.eps.end,
142-
decay=policy_config.eps.decay,
143-
type_=policy_config.eps.type
158+
start=cfg.policy.eps.start,
159+
end=cfg.policy.eps.end,
160+
decay=cfg.policy.eps.decay,
161+
type_=cfg.policy.eps.type
144162
)
145163
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
146164

147165
# Evaluate policy performance
148166
if evaluator.should_eval(learner.train_iter):
167+
logging.info(f"Training iteration {learner.train_iter}: Starting evaluation...")
149168
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
169+
logging.info(f"Training iteration {learner.train_iter}: Evaluation completed, stop condition: {stop}, current reward: {reward}")
150170
if stop:
171+
logging.info("Stopping condition met, training ends!")
151172
break
152173

153174
# Collect new data
154175
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
176+
logging.info(f"Rank {rank}, Training iteration {learner.train_iter}: New data collection completed!")
155177

156178
# Determine updates per collection
157-
update_per_collect = calculate_update_per_collect(cfg, new_data)
179+
update_per_collect = cfg.policy.update_per_collect
180+
if update_per_collect is None:
181+
update_per_collect = calculate_update_per_collect(cfg, new_data, world_size)
158182

159183
# Update replay buffer
160184
replay_buffer.push_game_segments(new_data)
161185
replay_buffer.remove_oldest_data_to_fit()
162186

163-
# Train the policy if sufficient data is available
187+
if world_size > 1:
188+
# Synchronize all ranks before training
189+
try:
190+
dist.barrier()
191+
except Exception as e:
192+
logging.error(f'Rank {rank}: Synchronization barrier failed, error: {e}')
193+
break
194+
195+
# Check if there is sufficient data for training
164196
if collector.envstep > cfg.policy.train_start_after_envsteps:
165197
if cfg.policy.sample_type == 'episode':
166198
data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size
167199
else:
168200
data_sufficient = replay_buffer.get_num_of_transitions() > batch_size
201+
169202
if not data_sufficient:
170203
logging.warning(
171-
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
204+
f'Rank {rank}: The data in replay_buffer is not sufficient to sample a mini-batch: '
172205
f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....'
173206
)
174207
continue
175208

209+
# Execute multiple training rounds
176210
for i in range(update_per_collect):
177211
train_data = replay_buffer.sample(batch_size, policy)
178212
if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0:
179-
# Clear caches and precompute positional embedding matrices
180-
policy.recompute_pos_emb_diff_and_clear_cache() # TODO
181-
182-
if policy_config.use_wandb:
213+
policy.recompute_pos_emb_diff_and_clear_cache()
214+
215+
if cfg.policy.use_wandb:
183216
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)
184217

185-
train_data.append({'train_which_component': 'transformer'})
186-
log_vars = learner.train(train_data, collector.envstep)
218+
train_data.append(learner.train_iter)
187219

220+
log_vars = learner.train(train_data, collector.envstep)
188221
if cfg.policy.use_priority:
189222
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
190223

191224
policy.recompute_pos_emb_diff_and_clear_cache()
192225

193226
# Check stopping criteria
194227
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
228+
logging.info("Stopping condition met, training ends!")
195229
break
196230

197231
learner.call_hook('after_run')
198-
wandb.finish()
199-
return policy
232+
if cfg.policy.use_wandb:
233+
wandb.finish()
234+
logging.info("===== Training Completed =====")
235+
return policy

Diff for: lzero/entry/train_unizero_segment.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from typing import Tuple, Optional
55

66
import torch
7+
import wandb
78
from ding.config import compile_config
89
from ding.envs import create_env_manager
910
from ding.envs import get_vec_env_setting
1011
from ding.policy import create_policy
1112
from ding.rl_utils import get_epsilon_greedy_fn
1213
from ding.utils import EasyTimer
13-
from ding.utils import set_pkg_seed, get_rank
14+
from ding.utils import set_pkg_seed, get_rank, get_world_size
1415
from ding.worker import BaseLearner
1516
from tensorboardX import SummaryWriter
1617
from torch.utils.tensorboard import SummaryWriter
@@ -103,6 +104,9 @@ def train_unizero_segment(
103104
# Learner's before_run hook
104105
learner.call_hook('before_run')
105106

107+
if cfg.policy.use_wandb:
108+
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)
109+
106110
# Collect random data before training
107111
if cfg.policy.random_collect_episode_num > 0:
108112
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
@@ -116,6 +120,14 @@ def train_unizero_segment(
116120
train_epoch = 0
117121
reanalyze_batch_size = cfg.policy.reanalyze_batch_size
118122

123+
if cfg.policy.multi_gpu:
124+
# Get current world size and rank
125+
world_size = get_world_size()
126+
rank = get_rank()
127+
else:
128+
world_size = 1
129+
rank = 0
130+
119131
while True:
120132
# Log buffer memory usage
121133
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
@@ -151,7 +163,7 @@ def train_unizero_segment(
151163
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
152164

153165
# Determine updates per collection
154-
update_per_collect = calculate_update_per_collect(cfg, new_data)
166+
update_per_collect = calculate_update_per_collect(cfg, new_data, world_size)
155167

156168
# Update replay buffer
157169
replay_buffer.push_game_segments(new_data)
@@ -196,8 +208,10 @@ def train_unizero_segment(
196208
logging.info(f'Buffer reanalyze time: {timer.value}')
197209

198210
train_data = replay_buffer.sample(batch_size, policy)
211+
if cfg.policy.use_wandb:
212+
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)
199213

200-
train_data.append({'train_which_component': 'transformer'})
214+
train_data.append(learner.train_iter)
201215
log_vars = learner.train(train_data, collector.envstep)
202216

203217
if cfg.policy.use_priority:
@@ -211,4 +225,6 @@ def train_unizero_segment(
211225
break
212226

213227
learner.call_hook('after_run')
228+
if cfg.policy.use_wandb:
229+
wandb.finish()
214230
return policy

Diff for: lzero/entry/utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def ddp_synchronize():
2626
if is_ddp_enabled():
2727
dist.barrier()
2828

29-
def ddp_all_reduce_sum(tensor):
29+
def ddp_all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
3030
"""
3131
Perform an all-reduce operation (sum) on the given tensor across
3232
all processes in DDP mode. Returns the reduced tensor.
@@ -41,15 +41,16 @@ def ddp_all_reduce_sum(tensor):
4141
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
4242
return tensor
4343

44-
def calculate_update_per_collect(cfg, new_data):
44+
def calculate_update_per_collect(cfg: 'EasyDict', new_data: List[List[torch.Tensor]], world_size: int = 1) -> int:
4545
"""
4646
Calculate the number of updates to perform per data collection in a
4747
Distributed Data Parallel (DDP) setting. This ensures that all GPUs
4848
compute the same `update_per_collect` value, synchronized across processes.
4949
5050
Arguments:
5151
- cfg: Configuration object containing policy settings.
52-
- new_data (list): The newly collected data segments.
52+
- new_data (List[List[torch.Tensor]]): The newly collected data segments.
53+
- world_size (int): The total number of processes.
5354
5455
Returns:
5556
- int: The number of updates to perform per collection.
@@ -68,7 +69,7 @@ def calculate_update_per_collect(cfg, new_data):
6869
for game_segment in new_data[0]
6970
)
7071

71-
if torch.cuda.is_available():
72+
if torch.cuda.is_available() and world_size > 1:
7273
# Convert the collected transitions count to a GPU tensor for DDP operations.
7374
collected_transitions_tensor = torch.tensor(
7475
collected_transitions_num, dtype=torch.int64, device='cuda'

Diff for: lzero/mcts/buffer/game_buffer_muzero.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def _compute_target_policy_non_reanalyzed(
711711
]
712712
else:
713713
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
714-
714+
715715
with torch.no_grad():
716716
policy_index = 0
717717
# 0 -> Invalid target policy for padding outside of game segments,
@@ -730,7 +730,7 @@ def _compute_target_policy_non_reanalyzed(
730730
# for atari/classic_control/box2d environments that only have one player.
731731
target_policies.append(distributions)
732732
else:
733-
# for board games that have two players.
733+
# for board games that have two players or envs that have varied action space.
734734
policy_tmp = [0 for _ in range(policy_shape)]
735735
for index, legal_action in enumerate(legal_actions[policy_index]):
736736
# only the action in ``legal_action`` the policy logits is nonzero

0 commit comments

Comments
 (0)