diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index d6f529971..0cdd8072f 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -142,14 +142,23 @@ class LossWithIntermediateLosses: Returns: - None """ - def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwargs): + def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, harmony_s_dict=None, **kwargs): # Ensure that kwargs is not empty if not kwargs: raise ValueError("At least one loss must be provided") # Get a reference device from one of the provided losses device = next(iter(kwargs.values())).device - + + if harmony_s_dict is not None: + loss_obs_harmony_s = harmony_s_dict.get("loss_obs_s", None) + loss_rewards_harmony_s = harmony_s_dict.get("loss_rewards_s", None) + loss_value_harmony_s = harmony_s_dict.get("loss_value_s", None) + loss_policy_harmony_s = harmony_s_dict.get("loss_policy_s", None) + loss_ends_harmony_s = harmony_s_dict.get("loss_ends_s", None) + latent_recon_loss_harmony_s = harmony_s_dict.get("latent_recon_loss_s", None) + perceptual_loss_harmony_s = harmony_s_dict.get("perceptual_loss_s", None) + # Define the weights for each loss type self.obs_loss_weight = 10 self.reward_loss_weight = 1. @@ -162,21 +171,27 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg # Initialize the total loss tensor on the correct device self.loss_total = torch.tensor(0., device=device) + # Define a dictionary for loss weights and harmony_s variables + loss_weights = { + 'loss_obs': (self.obs_loss_weight, loss_obs_harmony_s), + 'loss_rewards': (self.reward_loss_weight, loss_rewards_harmony_s), + 'loss_policy': (self.policy_loss_weight, loss_policy_harmony_s), + 'loss_value': (self.value_loss_weight, loss_value_harmony_s), + 'loss_ends': (self.ends_loss_weight, loss_ends_harmony_s), + 'latent_recon_loss': (self.latent_recon_loss_weight, latent_recon_loss_harmony_s), + 'perceptual_loss': (self.perceptual_loss_weight, perceptual_loss_harmony_s) + } + + # Iterate through kwargs to process the losses for k, v in kwargs.items(): - if k == 'loss_obs': - self.loss_total += self.obs_loss_weight * v - elif k == 'loss_rewards': - self.loss_total += self.reward_loss_weight * v - elif k == 'loss_policy': - self.loss_total += self.policy_loss_weight * v - elif k == 'loss_value': - self.loss_total += self.value_loss_weight * v - elif k == 'loss_ends': - self.loss_total += self.ends_loss_weight * v - elif k == 'latent_recon_loss': - self.loss_total += self.latent_recon_loss_weight * v - elif k == 'perceptual_loss': - self.loss_total += self.perceptual_loss_weight * v + if k in loss_weights: + weight, harmony_weight = loss_weights[k] + if harmony_s_dict is None: + self.loss_total += weight * v + elif harmony_weight is not None: + self.loss_total += (v / torch.exp(harmony_weight)) + torch.log(torch.exp(harmony_weight) + 1) + else: + self.loss_total += weight * v self.intermediate_losses = { k: v if isinstance(v, dict) else (v if isinstance(v, float) else v.item()) diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index ef31d951c..758926dda 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -850,7 +850,8 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, return self.keys_values_wm_size_list def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, **kwargs: Any) -> LossWithIntermediateLosses: - # Encode observations into latent state representations + + harmony_s_dict = kwargs.get("harmony_s_dict", None) obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) # ========= for visual analysis ========= @@ -1069,6 +1070,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar return LossWithIntermediateLosses( latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, + harmony_s_dict=harmony_s_dict, loss_obs=discounted_loss_obs, loss_rewards=discounted_loss_rewards, loss_value=discounted_loss_value, diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 73af201c3..557ea9da7 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -173,7 +173,7 @@ class UniZeroPolicy(MuZeroPolicy): ignore_done=False, # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... + # collect data -> update policy -> collect data -> ... # For different env, we have different episode_length, # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. @@ -268,6 +268,14 @@ class UniZeroPolicy(MuZeroPolicy): # (int) The decay steps from start to end eps. decay=int(1e5), ), + + # ****** Harmony Dream for balancing ****** + # (bool) Whether harmony dream is used to balance different weights of loss functions. + harmony_balance=False, + # (List) Loss list involving dynamic balance adjustment when harmony_balance is True. + # If not appeared, the weight of loss will remain fixed as defined. + harmony_loss_names=['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy'] + # ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'] ) def default_model(self) -> Tuple[str, List[str]]: @@ -290,12 +298,24 @@ def _init_learn(self) -> None: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ # NOTE: nanoGPT optimizer + self.harmony_balance = self._cfg.harmony_balance + if self.harmony_balance: + self.harmony_loss_names = self._cfg.harmony_loss_names + assert self.harmony_loss_names != None + harmony_s_names = [f"{name}_s" for name in self.harmony_loss_names] + for name in harmony_s_names: + param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) + setattr(self, name, param) + + self.harmony_s_dict = {name: getattr(self, name) for name in harmony_s_names} if self.harmony_balance else None + self._optimizer_world_model = configure_optimizers_nanogpt( model=self._model.world_model, learning_rate=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay, device_type=self._cfg.device, betas=(0.9, 0.95), + additional_params=self.harmony_s_dict ) # use model_wrapper for specialized demands of different modes @@ -409,10 +429,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Update world model losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, harmony_s_dict=self.harmony_s_dict ) - + weighted_total_loss = losses.loss_total + for loss_name, loss_value in losses.intermediate_losses.items(): self.intermediate_losses[f"{loss_name}"] = loss_value @@ -430,7 +451,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] - + + assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" @@ -516,9 +538,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'analysis/grad_norm_before': self.grad_norm_before, 'analysis/grad_norm_after': self.grad_norm_after, } - + if self.harmony_balance: + harmony_s_dict_monitor = {k: v.item() for k, v in self.harmony_s_dict.items()} + harmony_s_exp_recip_dict_monitor = {f"{k}_exp_recip": (torch.reciprocal(torch.exp(v))).item() for k, v in self.harmony_s_dict.items()} + return_loss_dict.update(harmony_s_dict_monitor) + return_loss_dict.update(harmony_s_exp_recip_dict_monitor) return return_loss_dict - + def monitor_weights_and_grads(self, model): for name, param in model.named_parameters(): if param.requires_grad: @@ -862,7 +888,7 @@ def _monitor_vars_learn(self) -> List[str]: Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value ``_forward_learn``. """ - return [ + return_list = [ 'analysis/dormant_ratio_encoder', 'analysis/dormant_ratio_world_model', 'analysis/latent_state_l2_norms', @@ -912,6 +938,12 @@ def _monitor_vars_learn(self) -> List[str]: 'reconstruction_loss', 'perceptual_loss', ] + if self.harmony_balance: + harmony_s_monitor_list = self.harmony_s_dict.keys() + harmony_s_exp_recip_monitor_list = [f"{k}_exp_recip" for k in self.harmony_s_dict.keys()] + return_list.extend(harmony_s_monitor_list) + return_list.extend(harmony_s_exp_recip_monitor_list) + return return_list def _state_dict_learn(self) -> Dict[str, Any]: """ diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index aab2743f8..6c3cfd6fb 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -200,23 +200,33 @@ def forward(self, input): # modified from https://github.com/karpathy/nanoGPT/blob/master/model.py#L263 -def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type): +def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type, additional_params=None): # start with all of the candidate parameters param_dict = {pn: p for pn, p in model.named_parameters()} # filter out those that do not require grad param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + + # Add additional parameters to the param_dict if provided + if additional_params is not None: + for name, param in additional_params.items(): + if param.requires_grad: + param_dict[name] = param + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] + num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters if torch.cuda.is_available(): diff --git a/zoo/atari/config/atari_unizero_config.py b/zoo/atari/config/atari_unizero_config.py index 1c549010f..7233eae96 100644 --- a/zoo/atari/config/atari_unizero_config.py +++ b/zoo/atari/config/atari_unizero_config.py @@ -1,13 +1,12 @@ from easydict import EasyDict from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map - -env_id = 'PongNoFrameskip-v4' # You can specify any Atari game here +env_id = 'QbertNoFrameskip-v4' # You can specify any Atari game here action_space_size = atari_env_action_space_map[env_id] # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -update_per_collect = None +update_per_collect = 1000 replay_ratio = 0.25 collector_env_num = 8 n_episode = 8