21
21
from lzero .worker import MuZeroEvaluator as Evaluator
22
22
from lzero .worker import MuZeroCollector as Collector
23
23
from .utils import random_collect , calculate_update_per_collect
24
+ import torch .distributed as dist
25
+ from ding .utils import set_pkg_seed , get_rank , get_world_size
24
26
25
27
26
28
def train_unizero (
@@ -33,167 +35,201 @@ def train_unizero(
33
35
) -> 'Policy' :
34
36
"""
35
37
Overview:
36
- The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models.
38
+ This function serves as the training entry point for UniZero, as proposed in our paper " UniZero: Generalized and Efficient Planning with Scalable Latent World Models" .
37
39
UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms,
38
- particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
40
+ particularly in environments that require capturing long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
41
+
39
42
Arguments:
40
- - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
41
- ``Tuple[dict, dict]`` type means [user_config, create_cfg].
42
- - seed (:obj:`int`): Random seed.
43
- - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
44
- - model_path (:obj:`Optional[str]`): The pretrained model path, which should
45
- point to the ckpt file of the pretrained model, and an absolute path is recommended.
46
- In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
47
- - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
48
- - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
43
+ - input_cfg (:obj:`Tuple[dict, dict]`): Configuration in dictionary format.
44
+ ``Tuple[dict, dict]`` indicates [user_config, create_cfg].
45
+ - seed (:obj:`int`): Random seed for reproducibility.
46
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of a PyTorch model.
47
+ - model_path (:obj:`Optional[str]`): Path to the pretrained model, which should
48
+ point to the checkpoint file of the pretrained model. An absolute path is recommended.
49
+ In LightZero, the path typically resembles ``exp_name/ckpt/ckpt_best.pth.tar``.
50
+ - max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations during training.
51
+ - max_env_step (:obj:`Optional[int]`): Maximum number of environment interaction steps to collect.
52
+
49
53
Returns:
50
- - policy (:obj:`Policy`): Converged policy.
54
+ - policy (:obj:`Policy`): The converged policy after training .
51
55
"""
52
56
53
57
cfg , create_cfg = input_cfg
54
58
55
59
# Ensure the specified policy type is supported
56
- assert create_cfg .policy .type in ['unizero' , 'sampled_unizero' ], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"
60
+ assert create_cfg .policy .type in ['unizero' , 'sampled_unizero' ], "train_unizero only supports the following algorithms: 'unizero', 'sampled_unizero'"
61
+ logging .info (f"Using policy type: { create_cfg .policy .type } " )
57
62
58
- # Import the correct GameBuffer class based on the policy type
63
+ # Import the appropriate GameBuffer class based on the policy type
59
64
game_buffer_classes = {'unizero' : 'UniZeroGameBuffer' , 'sampled_unizero' : 'SampledUniZeroGameBuffer' }
60
-
61
65
GameBuffer = getattr (__import__ ('lzero.mcts' , fromlist = [game_buffer_classes [create_cfg .policy .type ]]),
62
66
game_buffer_classes [create_cfg .policy .type ])
63
67
64
- # Set device based on CUDA availability
68
+ # Check for GPU availability and set the device accordingly
65
69
cfg .policy .device = cfg .policy .model .world_model_cfg .device if torch .cuda .is_available () else 'cpu'
66
- logging .info (f'cfg.policy.device : { cfg .policy .device } ' )
70
+ logging .info (f"Device set to : { cfg .policy .device } " )
67
71
68
- # Compile the configuration
72
+ # Compile the configuration file
69
73
cfg = compile_config (cfg , seed = seed , env = None , auto = True , create_cfg = create_cfg , save_cfg = True )
70
74
71
- # Create main components: env, policy
75
+ # Create environment manager
72
76
env_fn , collector_env_cfg , evaluator_env_cfg = get_vec_env_setting (cfg .env )
73
77
collector_env = create_env_manager (cfg .env .manager , [partial (env_fn , cfg = c ) for c in collector_env_cfg ])
74
78
evaluator_env = create_env_manager (cfg .env .manager , [partial (env_fn , cfg = c ) for c in evaluator_env_cfg ])
75
79
80
+ # Initialize environment and random seed
76
81
collector_env .seed (cfg .seed )
77
82
evaluator_env .seed (cfg .seed , dynamic_seed = False )
78
83
set_pkg_seed (cfg .seed , use_cuda = torch .cuda .is_available ())
79
84
85
+ # Initialize wandb if specified
80
86
if cfg .policy .use_wandb :
81
- # Initialize wandb
87
+ logging . info ( "Initializing wandb..." )
82
88
wandb .init (
83
89
project = "LightZero" ,
84
90
config = cfg ,
85
91
sync_tensorboard = False ,
86
92
monitor_gym = False ,
87
93
save_code = True ,
88
94
)
95
+ logging .info ("wandb initialization completed!" )
89
96
97
+ # Create policy
98
+ logging .info ("Creating policy..." )
90
99
policy = create_policy (cfg .policy , model = model , enable_field = ['learn' , 'collect' , 'eval' ])
100
+ logging .info ("Policy created successfully!" )
91
101
92
102
# Load pretrained model if specified
93
103
if model_path is not None :
94
- logging .info (f' Loading model from { model_path } begin ...' )
104
+ logging .info (f" Loading pretrained model from { model_path } ..." )
95
105
policy .learn_mode .load_state_dict (torch .load (model_path , map_location = cfg .policy .device ))
96
- logging .info (f'Loading model from { model_path } end!' )
106
+ logging .info ("Pretrained model loaded successfully!" )
97
107
98
- # Create worker components: learner, collector, evaluator, replay buffer, commander
108
+ # Create core components for training
99
109
tb_logger = SummaryWriter (os .path .join ('./{}/log/' .format (cfg .exp_name ), 'serial' )) if get_rank () == 0 else None
100
110
learner = BaseLearner (cfg .policy .learn .learner , policy .learn_mode , tb_logger , exp_name = cfg .exp_name )
101
-
102
- # MCTS+RL algorithms related core code
103
- policy_config = cfg .policy
104
- replay_buffer = GameBuffer (policy_config )
111
+ replay_buffer = GameBuffer (cfg .policy )
105
112
collector = Collector (env = collector_env , policy = policy .collect_mode , tb_logger = tb_logger , exp_name = cfg .exp_name ,
106
- policy_config = policy_config )
113
+ policy_config = cfg . policy )
107
114
evaluator = Evaluator (eval_freq = cfg .policy .eval_freq , n_evaluator_episode = cfg .env .n_evaluator_episode ,
108
115
stop_value = cfg .env .stop_value , env = evaluator_env , policy = policy .eval_mode ,
109
- tb_logger = tb_logger , exp_name = cfg .exp_name , policy_config = policy_config )
116
+ tb_logger = tb_logger , exp_name = cfg .exp_name , policy_config = cfg . policy )
110
117
111
- # Learner 's before_run hook
118
+ # Execute the learner 's before_run hook
112
119
learner .call_hook ('before_run' )
113
- if policy_config .use_wandb :
120
+
121
+ if cfg .policy .use_wandb :
114
122
policy .set_train_iter_env_step (learner .train_iter , collector .envstep )
115
123
116
- # Collect random data before training
124
+ # Randomly collect data if specified
117
125
if cfg .policy .random_collect_episode_num > 0 :
126
+ logging .info ("Collecting random data..." )
118
127
random_collect (cfg .policy , policy , LightZeroRandomPolicy , collector , collector_env , replay_buffer )
128
+ logging .info ("Random data collection completed!" )
119
129
120
130
batch_size = policy ._cfg .batch_size
121
131
132
+ if cfg .policy .multi_gpu :
133
+ # Get current world size and rank
134
+ world_size = get_world_size ()
135
+ rank = get_rank ()
136
+ else :
137
+ world_size = 1
138
+ rank = 0
139
+
122
140
while True :
123
- # Log buffer memory usage
141
+ # Log memory usage of the replay buffer
124
142
log_buffer_memory_usage (learner .train_iter , replay_buffer , tb_logger )
125
143
126
- # Set temperature for visit count distributions
144
+ # Set temperature parameter for data collection
127
145
collect_kwargs = {
128
146
'temperature' : visit_count_temperature (
129
- policy_config .manual_temperature_decay ,
130
- policy_config .fixed_temperature_value ,
131
- policy_config .threshold_training_steps_for_final_temperature ,
147
+ cfg . policy .manual_temperature_decay ,
148
+ cfg . policy .fixed_temperature_value ,
149
+ cfg . policy .threshold_training_steps_for_final_temperature ,
132
150
trained_steps = learner .train_iter
133
151
),
134
152
'epsilon' : 0.0 # Default epsilon value
135
153
}
136
154
137
- # Configure epsilon for epsilon -greedy exploration
138
- if policy_config .eps .eps_greedy_exploration_in_collect :
155
+ # Configure epsilon-greedy exploration
156
+ if cfg . policy .eps .eps_greedy_exploration_in_collect :
139
157
epsilon_greedy_fn = get_epsilon_greedy_fn (
140
- start = policy_config .eps .start ,
141
- end = policy_config .eps .end ,
142
- decay = policy_config .eps .decay ,
143
- type_ = policy_config .eps .type
158
+ start = cfg . policy .eps .start ,
159
+ end = cfg . policy .eps .end ,
160
+ decay = cfg . policy .eps .decay ,
161
+ type_ = cfg . policy .eps .type
144
162
)
145
163
collect_kwargs ['epsilon' ] = epsilon_greedy_fn (collector .envstep )
146
164
147
165
# Evaluate policy performance
148
166
if evaluator .should_eval (learner .train_iter ):
167
+ logging .info (f"Training iteration { learner .train_iter } : Starting evaluation..." )
149
168
stop , reward = evaluator .eval (learner .save_checkpoint , learner .train_iter , collector .envstep )
169
+ logging .info (f"Training iteration { learner .train_iter } : Evaluation completed, stop condition: { stop } , current reward: { reward } " )
150
170
if stop :
171
+ logging .info ("Stopping condition met, training ends!" )
151
172
break
152
173
153
174
# Collect new data
154
175
new_data = collector .collect (train_iter = learner .train_iter , policy_kwargs = collect_kwargs )
176
+ logging .info (f"Rank { rank } , Training iteration { learner .train_iter } : New data collection completed!" )
155
177
156
178
# Determine updates per collection
157
- update_per_collect = calculate_update_per_collect (cfg , new_data )
179
+ update_per_collect = cfg .policy .update_per_collect
180
+ if update_per_collect is None :
181
+ update_per_collect = calculate_update_per_collect (cfg , new_data , world_size )
158
182
159
183
# Update replay buffer
160
184
replay_buffer .push_game_segments (new_data )
161
185
replay_buffer .remove_oldest_data_to_fit ()
162
186
163
- # Train the policy if sufficient data is available
187
+ if world_size > 1 :
188
+ # Synchronize all ranks before training
189
+ try :
190
+ dist .barrier ()
191
+ except Exception as e :
192
+ logging .error (f'Rank { rank } : Synchronization barrier failed, error: { e } ' )
193
+ break
194
+
195
+ # Check if there is sufficient data for training
164
196
if collector .envstep > cfg .policy .train_start_after_envsteps :
165
197
if cfg .policy .sample_type == 'episode' :
166
198
data_sufficient = replay_buffer .get_num_of_game_segments () > batch_size
167
199
else :
168
200
data_sufficient = replay_buffer .get_num_of_transitions () > batch_size
201
+
169
202
if not data_sufficient :
170
203
logging .warning (
171
- f'The data in replay_buffer is not sufficient to sample a mini-batch: '
204
+ f'Rank { rank } : The data in replay_buffer is not sufficient to sample a mini-batch: '
172
205
f'batch_size: { batch_size } , replay_buffer: { replay_buffer } . Continue to collect now ....'
173
206
)
174
207
continue
175
208
209
+ # Execute multiple training rounds
176
210
for i in range (update_per_collect ):
177
211
train_data = replay_buffer .sample (batch_size , policy )
178
212
if cfg .policy .reanalyze_ratio > 0 and i % 20 == 0 :
179
- # Clear caches and precompute positional embedding matrices
180
- policy .recompute_pos_emb_diff_and_clear_cache () # TODO
181
-
182
- if policy_config .use_wandb :
213
+ policy .recompute_pos_emb_diff_and_clear_cache ()
214
+
215
+ if cfg .policy .use_wandb :
183
216
policy .set_train_iter_env_step (learner .train_iter , collector .envstep )
184
217
185
- train_data .append ({'train_which_component' : 'transformer' })
186
- log_vars = learner .train (train_data , collector .envstep )
218
+ train_data .append (learner .train_iter )
187
219
220
+ log_vars = learner .train (train_data , collector .envstep )
188
221
if cfg .policy .use_priority :
189
222
replay_buffer .update_priority (train_data , log_vars [0 ]['value_priority_orig' ])
190
223
191
224
policy .recompute_pos_emb_diff_and_clear_cache ()
192
225
193
226
# Check stopping criteria
194
227
if collector .envstep >= max_env_step or learner .train_iter >= max_train_iter :
228
+ logging .info ("Stopping condition met, training ends!" )
195
229
break
196
230
197
231
learner .call_hook ('after_run' )
198
- wandb .finish ()
199
- return policy
232
+ if cfg .policy .use_wandb :
233
+ wandb .finish ()
234
+ logging .info ("===== Training Completed =====" )
235
+ return policy
0 commit comments