From 78967ed5ae232d7bec7ab4a96e08c278173f466a Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Wed, 31 Jul 2024 09:00:36 +0000 Subject: [PATCH 1/6] feature(wrh): add harmony dream in unizero --- lzero/model/unizero_world_models/utils.py | 71 ++++++++++++++++--- .../model/unizero_world_models/world_model.py | 16 ++++- lzero/policy/unizero.py | 57 +++++++++++++-- lzero/policy/utils.py | 12 +++- zoo/atari/config/atari_unizero_config.py | 6 +- 5 files changed, 141 insertions(+), 21 deletions(-) diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index d6f529971..e8cea47ad 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -142,14 +142,26 @@ class LossWithIntermediateLosses: Returns: - None """ - def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwargs): + def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, harmony_s_dict=None, **kwargs): # Ensure that kwargs is not empty if not kwargs: raise ValueError("At least one loss must be provided") # Get a reference device from one of the provided losses device = next(iter(kwargs.values())).device - + + if harmony_s_dict is not None: + for k, v in harmony_s_dict.items(): + print(f"{k} has s {v}") + + loss_obs_harmony_s = harmony_s_dict.get("loss_obs_s", None) + loss_rewards_harmony_s = harmony_s_dict.get("loss_rewards_s", None) + loss_value_harmony_s = harmony_s_dict.get("loss_value_s", None) + loss_policy_harmony_s = harmony_s_dict.get("loss_policy_s", None) + loss_ends_harmony_s = harmony_s_dict.get("loss_ends_s", None) + latent_recon_loss_harmony_s = harmony_s_dict.get("latent_recon_loss_s", None) + perceptual_loss_harmony_s = harmony_s_dict.get("perceptual_loss_s", None) + # Define the weights for each loss type self.obs_loss_weight = 10 self.reward_loss_weight = 1. @@ -164,19 +176,60 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg self.loss_total = torch.tensor(0., device=device) for k, v in kwargs.items(): if k == 'loss_obs': - self.loss_total += self.obs_loss_weight * v + if harmony_s_dict is None: + self.loss_total += self.obs_loss_weight * v + elif loss_obs_harmony_s is not None: + self.loss_total += ((v / torch.exp(loss_obs_harmony_s)) + torch.log(torch.exp(loss_obs_harmony_s) + 1)) + else: + self.loss_total += self.obs_loss_weight * v + elif k == 'loss_rewards': - self.loss_total += self.reward_loss_weight * v + if harmony_s_dict is None: + self.loss_total += self.reward_loss_weight * v + elif loss_rewards_harmony_s is not None: + self.loss_total += ((v / torch.exp(loss_rewards_harmony_s)) + torch.log(torch.exp(loss_rewards_harmony_s) + 1)) + else: + self.loss_total += self.reward_loss_weight * v + elif k == 'loss_policy': - self.loss_total += self.policy_loss_weight * v + if harmony_s_dict is None: + self.loss_total += self.policy_loss_weight * v + elif loss_policy_harmony_s is not None: + self.loss_total += ((v / torch.exp(loss_policy_harmony_s)) + torch.log(torch.exp(loss_policy_harmony_s) + 1)) + else: + self.loss_total += self.policy_loss_weight * v + elif k == 'loss_value': - self.loss_total += self.value_loss_weight * v + if harmony_s_dict is None: + self.loss_total += self.value_loss_weight * v + elif loss_value_harmony_s is not None: + self.loss_total += ((v / torch.exp(loss_value_harmony_s)) + torch.log(torch.exp(loss_value_harmony_s) + 1)) + else: + self.loss_total += self.value_loss_weight * v + elif k == 'loss_ends': - self.loss_total += self.ends_loss_weight * v + if harmony_s_dict is None: + self.loss_total += self.ends_loss_weight * v + elif loss_ends_harmony_s is not None: + self.loss_total += ((v / torch.exp(loss_ends_harmony_s)) + torch.log(torch.exp(loss_ends_harmony_s) + 1)) + else: + self.loss_total += self.ends_loss_weight * v + elif k == 'latent_recon_loss': - self.loss_total += self.latent_recon_loss_weight * v + if harmony_s_dict is None: + self.loss_total += self.latent_recon_loss_weight * v + elif latent_recon_loss_harmony_s is not None: + self.loss_total += ((v / torch.exp(latent_recon_loss_harmony_s)) + torch.log(torch.exp(latent_recon_loss_harmony_s) + 1)) + else: + self.loss_total += self.latent_recon_loss_weight * v + elif k == 'perceptual_loss': - self.loss_total += self.perceptual_loss_weight * v + if harmony_s_dict is None: + self.loss_total += self.perceptual_loss_weight * v + elif perceptual_loss_harmony_s is not None: + self.loss_total += ((v / torch.exp(perceptual_loss_harmony_s)) + torch.log(torch.exp(perceptual_loss_harmony_s) + 1)) + else: + self.loss_total += self.perceptual_loss_weight * v self.intermediate_losses = { k: v if isinstance(v, dict) else (v if isinstance(v, float) else v.item()) diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index ef31d951c..f7677cf1b 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -850,7 +850,8 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, return self.keys_values_wm_size_list def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, **kwargs: Any) -> LossWithIntermediateLosses: - # Encode observations into latent state representations + + harmony_s_dict = kwargs.get("harmony_s_dict", None) obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) # ========= for visual analysis ========= @@ -1053,7 +1054,17 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Last step loss last_step_mask = mask_padding[:, -1] last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() - + # if harmony_s_dict is not None: + # print(harmony_s_dict.keys()) + # for loss_name, loss_tmp in zip( + # ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], + # [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] + # ): + # if f"{loss_name}_s" in harmony_s_dict: + # harmony_tmp_val = harmony_s_dict.get(f"{loss_name}_s") + # loss_tmp = loss_tmp / torch.exp(harmony_tmp_val) + # loss_tmp = loss_tmp + torch.log(torch.exp(harmony_tmp_val) + 1) + # print(f"{loss_name}_s in dict: {f'{loss_name}_s' in harmony_s_dict}, harmony_s: {harmony_tmp_val}") # Discount reconstruction loss and perceptual loss discounted_latent_recon_loss = latent_recon_loss discounted_perceptual_loss = perceptual_loss @@ -1069,6 +1080,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar return LossWithIntermediateLosses( latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, + harmony_s_dict=harmony_s_dict, loss_obs=discounted_loss_obs, loss_rewards=discounted_loss_rewards, loss_value=discounted_loss_value, diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 73af201c3..96b311028 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -268,6 +268,11 @@ class UniZeroPolicy(MuZeroPolicy): # (int) The decay steps from start to end eps. decay=int(1e5), ), + # ****** Harmony Dream for balancing ****** + harmony_balance=False, + harmony_loss_names=['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy'] + + #^ ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'] ) def default_model(self) -> Tuple[str, List[str]]: @@ -290,12 +295,25 @@ def _init_learn(self) -> None: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ # NOTE: nanoGPT optimizer + self.harmony_balance = self._cfg.harmony_balance + if self.harmony_balance: + self.harmony_loss_names = self._cfg.harmony_loss_names + assert self.harmony_loss_names != None + harmony_s_names = [f"{name}_s" for name in self.harmony_loss_names] + for name in harmony_s_names: + param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) + setattr(self, name, param) + + self.harmony_s_dict = {name: getattr(self, name) for name in harmony_s_names} if self.harmony_balance else None + + print(self.harmony_s_dict) self._optimizer_world_model = configure_optimizers_nanogpt( model=self._model.world_model, learning_rate=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay, device_type=self._cfg.device, betas=(0.9, 0.95), + additional_params=self.harmony_s_dict ) # use model_wrapper for specialized demands of different modes @@ -409,10 +427,26 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Update world model losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, harmony_s_dict=self.harmony_s_dict ) - + # if self.harmony_balance: + # weighted_total_loss = torch.tensor(0., device=self._cfg.device) + # print(f"harmony_loss_names is {self.harmony_loss_names}") + # for loss_name in self.harmony_loss_names: + # if f"{loss_name}_s" in self.harmony_s_dict: + # harmony_s = self.harmony_s_dict.get(f"{loss_name}_s") + # loss_tmp = losses.intermediate_losses.get(loss_name) + # loss_tmp = loss_tmp / torch.exp(harmony_s) + # loss_tmp = loss_tmp + torch.log(torch.exp(harmony_s) + 1) + # print(f"{f'{loss_name}_s' in self.harmony_s_dict}", end=", ") + # print(f"{loss_name}_s is {harmony_s}") + # weighted_total_loss += loss_tmp + + # weighted_total_loss += self.intermediate_losses['first_step_losses'] + + # else: weighted_total_loss = losses.loss_total + for loss_name, loss_value in losses.intermediate_losses.items(): self.intermediate_losses[f"{loss_name}"] = loss_value @@ -430,7 +464,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] - + + assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" @@ -516,9 +551,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'analysis/grad_norm_before': self.grad_norm_before, 'analysis/grad_norm_after': self.grad_norm_after, } - + if self.harmony_balance: + harmony_s_dict_monitor = {k: v.item() for k, v in self.harmony_s_dict.items()} + harmony_s_exp_recip_dict_monitor = {f"{k}_exp_recip": (torch.reciprocal(torch.exp(v))).item() for k, v in self.harmony_s_dict.items()} + return_loss_dict.update(harmony_s_dict_monitor) + return_loss_dict.update(harmony_s_exp_recip_dict_monitor) return return_loss_dict - + def monitor_weights_and_grads(self, model): for name, param in model.named_parameters(): if param.requires_grad: @@ -862,7 +901,7 @@ def _monitor_vars_learn(self) -> List[str]: Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value ``_forward_learn``. """ - return [ + return_list = [ 'analysis/dormant_ratio_encoder', 'analysis/dormant_ratio_world_model', 'analysis/latent_state_l2_norms', @@ -912,6 +951,12 @@ def _monitor_vars_learn(self) -> List[str]: 'reconstruction_loss', 'perceptual_loss', ] + if self.harmony_balance: + harmony_s_monitor_list = self.harmony_s_dict.keys() + harmony_s_exp_recip_monitor_list = [f"{k}_exp_recip" for k in self.harmony_s_dict.keys()] + return_list.extend(harmony_s_monitor_list) + return_list.extend(harmony_s_exp_recip_monitor_list) + return return_list def _state_dict_learn(self) -> Dict[str, Any]: """ diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index aab2743f8..6c3cfd6fb 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -200,23 +200,33 @@ def forward(self, input): # modified from https://github.com/karpathy/nanoGPT/blob/master/model.py#L263 -def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type): +def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type, additional_params=None): # start with all of the candidate parameters param_dict = {pn: p for pn, p in model.named_parameters()} # filter out those that do not require grad param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + + # Add additional parameters to the param_dict if provided + if additional_params is not None: + for name, param in additional_params.items(): + if param.requires_grad: + param_dict[name] = param + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] + num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters if torch.cuda.is_available(): diff --git a/zoo/atari/config/atari_unizero_config.py b/zoo/atari/config/atari_unizero_config.py index 1c549010f..8bd0cb6eb 100644 --- a/zoo/atari/config/atari_unizero_config.py +++ b/zoo/atari/config/atari_unizero_config.py @@ -1,13 +1,13 @@ from easydict import EasyDict from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map - -env_id = 'PongNoFrameskip-v4' # You can specify any Atari game here +env_id = 'QbertNoFrameskip-v4' # You can specify any Atari game here +# 'QbertNoFrameskip-v4' action_space_size = atari_env_action_space_map[env_id] # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -update_per_collect = None +update_per_collect = 1000 replay_ratio = 0.25 collector_env_num = 8 n_episode = 8 From 98cd24a42b6675e919d46ec3ab500cbaa10fd337 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Mon, 5 Aug 2024 12:39:08 +0000 Subject: [PATCH 2/6] feature(wrh): add harmony dream in unizero --- lzero/policy/unizero.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 96b311028..03d22b84a 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -173,7 +173,7 @@ class UniZeroPolicy(MuZeroPolicy): ignore_done=False, # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... + # collect data -> update policy -> collect data -> ... # For different env, we have different episode_length, # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. @@ -268,11 +268,14 @@ class UniZeroPolicy(MuZeroPolicy): # (int) The decay steps from start to end eps. decay=int(1e5), ), + # ****** Harmony Dream for balancing ****** + # (bool) Whether harmony dream is used to balance different weights of loss functions. harmony_balance=False, + # (List) Loss list involving dynamic balance adjustment when harmony_balance is True. + # If not appeared, the weight of loss will remain fixed as defined. harmony_loss_names=['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy'] - - #^ ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'] + # ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'] ) def default_model(self) -> Tuple[str, List[str]]: From f46efd1865fa7c6321d7650d591b0ba088230417 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Mon, 5 Aug 2024 12:41:37 +0000 Subject: [PATCH 3/6] feature(wrh): add harmony dream in unizero --- lzero/model/unizero_world_models/world_model.py | 12 +----------- lzero/policy/unizero.py | 16 +--------------- 2 files changed, 2 insertions(+), 26 deletions(-) diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index f7677cf1b..758926dda 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -1054,17 +1054,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Last step loss last_step_mask = mask_padding[:, -1] last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() - # if harmony_s_dict is not None: - # print(harmony_s_dict.keys()) - # for loss_name, loss_tmp in zip( - # ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], - # [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] - # ): - # if f"{loss_name}_s" in harmony_s_dict: - # harmony_tmp_val = harmony_s_dict.get(f"{loss_name}_s") - # loss_tmp = loss_tmp / torch.exp(harmony_tmp_val) - # loss_tmp = loss_tmp + torch.log(torch.exp(harmony_tmp_val) + 1) - # print(f"{loss_name}_s in dict: {f'{loss_name}_s' in harmony_s_dict}, harmony_s: {harmony_tmp_val}") + # Discount reconstruction loss and perceptual loss discounted_latent_recon_loss = latent_recon_loss discounted_perceptual_loss = perceptual_loss diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 03d22b84a..4de2d6b26 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -432,21 +432,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in losses = self._learn_model.world_model.compute_loss( batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, harmony_s_dict=self.harmony_s_dict ) - # if self.harmony_balance: - # weighted_total_loss = torch.tensor(0., device=self._cfg.device) - # print(f"harmony_loss_names is {self.harmony_loss_names}") - # for loss_name in self.harmony_loss_names: - # if f"{loss_name}_s" in self.harmony_s_dict: - # harmony_s = self.harmony_s_dict.get(f"{loss_name}_s") - # loss_tmp = losses.intermediate_losses.get(loss_name) - # loss_tmp = loss_tmp / torch.exp(harmony_s) - # loss_tmp = loss_tmp + torch.log(torch.exp(harmony_s) + 1) - # print(f"{f'{loss_name}_s' in self.harmony_s_dict}", end=", ") - # print(f"{loss_name}_s is {harmony_s}") - # weighted_total_loss += loss_tmp - - # weighted_total_loss += self.intermediate_losses['first_step_losses'] - + # else: weighted_total_loss = losses.loss_total From ad1d2153204543621d9fce71f1261d117b5c6216 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Mon, 5 Aug 2024 13:08:57 +0000 Subject: [PATCH 4/6] feature(wrh): add harmony dream in unizero --- lzero/policy/unizero.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 4de2d6b26..c91054501 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -309,7 +309,6 @@ def _init_learn(self) -> None: self.harmony_s_dict = {name: getattr(self, name) for name in harmony_s_names} if self.harmony_balance else None - print(self.harmony_s_dict) self._optimizer_world_model = configure_optimizers_nanogpt( model=self._model.world_model, learning_rate=self._cfg.learning_rate, From 6e27ddda6e5c2dadf6c2223e6937afe33007512c Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Mon, 5 Aug 2024 13:11:54 +0000 Subject: [PATCH 5/6] feature(wrh): add harmony dream in unizero --- zoo/atari/config/atari_unizero_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/zoo/atari/config/atari_unizero_config.py b/zoo/atari/config/atari_unizero_config.py index 8bd0cb6eb..7233eae96 100644 --- a/zoo/atari/config/atari_unizero_config.py +++ b/zoo/atari/config/atari_unizero_config.py @@ -1,7 +1,6 @@ from easydict import EasyDict from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map env_id = 'QbertNoFrameskip-v4' # You can specify any Atari game here -# 'QbertNoFrameskip-v4' action_space_size = atari_env_action_space_map[env_id] # ============================================================== From 2f751abc07132af4f72f9ac5f59a231579969408 Mon Sep 17 00:00:00 2001 From: wrh12345 Date: Mon, 5 Aug 2024 13:49:38 +0000 Subject: [PATCH 6/6] feature(wrh): add harmony dream in unizero --- lzero/model/unizero_world_models/utils.py | 76 ++++++----------------- lzero/policy/unizero.py | 1 - 2 files changed, 19 insertions(+), 58 deletions(-) diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index e8cea47ad..0cdd8072f 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -150,10 +150,7 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, harmony # Get a reference device from one of the provided losses device = next(iter(kwargs.values())).device - if harmony_s_dict is not None: - for k, v in harmony_s_dict.items(): - print(f"{k} has s {v}") - + if harmony_s_dict is not None: loss_obs_harmony_s = harmony_s_dict.get("loss_obs_s", None) loss_rewards_harmony_s = harmony_s_dict.get("loss_rewards_s", None) loss_value_harmony_s = harmony_s_dict.get("loss_value_s", None) @@ -174,62 +171,27 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, harmony # Initialize the total loss tensor on the correct device self.loss_total = torch.tensor(0., device=device) + # Define a dictionary for loss weights and harmony_s variables + loss_weights = { + 'loss_obs': (self.obs_loss_weight, loss_obs_harmony_s), + 'loss_rewards': (self.reward_loss_weight, loss_rewards_harmony_s), + 'loss_policy': (self.policy_loss_weight, loss_policy_harmony_s), + 'loss_value': (self.value_loss_weight, loss_value_harmony_s), + 'loss_ends': (self.ends_loss_weight, loss_ends_harmony_s), + 'latent_recon_loss': (self.latent_recon_loss_weight, latent_recon_loss_harmony_s), + 'perceptual_loss': (self.perceptual_loss_weight, perceptual_loss_harmony_s) + } + + # Iterate through kwargs to process the losses for k, v in kwargs.items(): - if k == 'loss_obs': - if harmony_s_dict is None: - self.loss_total += self.obs_loss_weight * v - elif loss_obs_harmony_s is not None: - self.loss_total += ((v / torch.exp(loss_obs_harmony_s)) + torch.log(torch.exp(loss_obs_harmony_s) + 1)) - else: - self.loss_total += self.obs_loss_weight * v - - elif k == 'loss_rewards': - if harmony_s_dict is None: - self.loss_total += self.reward_loss_weight * v - elif loss_rewards_harmony_s is not None: - self.loss_total += ((v / torch.exp(loss_rewards_harmony_s)) + torch.log(torch.exp(loss_rewards_harmony_s) + 1)) - else: - self.loss_total += self.reward_loss_weight * v - - elif k == 'loss_policy': - if harmony_s_dict is None: - self.loss_total += self.policy_loss_weight * v - elif loss_policy_harmony_s is not None: - self.loss_total += ((v / torch.exp(loss_policy_harmony_s)) + torch.log(torch.exp(loss_policy_harmony_s) + 1)) - else: - self.loss_total += self.policy_loss_weight * v - - elif k == 'loss_value': - if harmony_s_dict is None: - self.loss_total += self.value_loss_weight * v - elif loss_value_harmony_s is not None: - self.loss_total += ((v / torch.exp(loss_value_harmony_s)) + torch.log(torch.exp(loss_value_harmony_s) + 1)) - else: - self.loss_total += self.value_loss_weight * v - - elif k == 'loss_ends': - if harmony_s_dict is None: - self.loss_total += self.ends_loss_weight * v - elif loss_ends_harmony_s is not None: - self.loss_total += ((v / torch.exp(loss_ends_harmony_s)) + torch.log(torch.exp(loss_ends_harmony_s) + 1)) - else: - self.loss_total += self.ends_loss_weight * v - - elif k == 'latent_recon_loss': - if harmony_s_dict is None: - self.loss_total += self.latent_recon_loss_weight * v - elif latent_recon_loss_harmony_s is not None: - self.loss_total += ((v / torch.exp(latent_recon_loss_harmony_s)) + torch.log(torch.exp(latent_recon_loss_harmony_s) + 1)) - else: - self.loss_total += self.latent_recon_loss_weight * v - - elif k == 'perceptual_loss': + if k in loss_weights: + weight, harmony_weight = loss_weights[k] if harmony_s_dict is None: - self.loss_total += self.perceptual_loss_weight * v - elif perceptual_loss_harmony_s is not None: - self.loss_total += ((v / torch.exp(perceptual_loss_harmony_s)) + torch.log(torch.exp(perceptual_loss_harmony_s) + 1)) + self.loss_total += weight * v + elif harmony_weight is not None: + self.loss_total += (v / torch.exp(harmony_weight)) + torch.log(torch.exp(harmony_weight) + 1) else: - self.loss_total += self.perceptual_loss_weight * v + self.loss_total += weight * v self.intermediate_losses = { k: v if isinstance(v, dict) else (v if isinstance(v, float) else v.item()) diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index c91054501..557ea9da7 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -432,7 +432,6 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, harmony_s_dict=self.harmony_s_dict ) - # else: weighted_total_loss = losses.loss_total for loss_name, loss_value in losses.intermediate_losses.items():