-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(wrh): add harmony dream in unizero #255
Open
ruiheng123
wants to merge
6
commits into
opendilab:main
Choose a base branch
from
ruiheng123:dev-unizero-harmonydream
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
78967ed
feature(wrh): add harmony dream in unizero
ruiheng123 98cd24a
feature(wrh): add harmony dream in unizero
ruiheng123 f46efd1
feature(wrh): add harmony dream in unizero
ruiheng123 ad1d215
feature(wrh): add harmony dream in unizero
ruiheng123 6e27ddd
feature(wrh): add harmony dream in unizero
ruiheng123 2f751ab
feature(wrh): add harmony dream in unizero
ruiheng123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -173,7 +173,7 @@ class UniZeroPolicy(MuZeroPolicy): | |
ignore_done=False, | ||
# (int) How many updates(iterations) to train after collector's one collection. | ||
# Bigger "update_per_collect" means bigger off-policy. | ||
# collect data -> update policy-> collect data -> ... | ||
# collect data -> update policy -> collect data -> ... | ||
# For different env, we have different episode_length, | ||
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. | ||
# If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. | ||
|
@@ -268,6 +268,14 @@ class UniZeroPolicy(MuZeroPolicy): | |
# (int) The decay steps from start to end eps. | ||
decay=int(1e5), | ||
), | ||
|
||
# ****** Harmony Dream for balancing ****** | ||
# (bool) Whether harmony dream is used to balance different weights of loss functions. | ||
harmony_balance=False, | ||
# (List) Loss list involving dynamic balance adjustment when harmony_balance is True. | ||
# If not appeared, the weight of loss will remain fixed as defined. | ||
harmony_loss_names=['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy'] | ||
# ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'] | ||
) | ||
|
||
def default_model(self) -> Tuple[str, List[str]]: | ||
|
@@ -290,12 +298,24 @@ def _init_learn(self) -> None: | |
Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. | ||
""" | ||
# NOTE: nanoGPT optimizer | ||
self.harmony_balance = self._cfg.harmony_balance | ||
if self.harmony_balance: | ||
self.harmony_loss_names = self._cfg.harmony_loss_names | ||
assert self.harmony_loss_names != None | ||
harmony_s_names = [f"{name}_s" for name in self.harmony_loss_names] | ||
for name in harmony_s_names: | ||
param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) | ||
setattr(self, name, param) | ||
|
||
self.harmony_s_dict = {name: getattr(self, name) for name in harmony_s_names} if self.harmony_balance else None | ||
|
||
self._optimizer_world_model = configure_optimizers_nanogpt( | ||
model=self._model.world_model, | ||
learning_rate=self._cfg.learning_rate, | ||
weight_decay=self._cfg.weight_decay, | ||
device_type=self._cfg.device, | ||
betas=(0.9, 0.95), | ||
additional_params=self.harmony_s_dict | ||
) | ||
|
||
# use model_wrapper for specialized demands of different modes | ||
|
@@ -409,10 +429,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in | |
|
||
# Update world model | ||
losses = self._learn_model.world_model.compute_loss( | ||
batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle | ||
batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, harmony_s_dict=self.harmony_s_dict | ||
) | ||
|
||
|
||
# else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删除 |
||
weighted_total_loss = losses.loss_total | ||
|
||
for loss_name, loss_value in losses.intermediate_losses.items(): | ||
self.intermediate_losses[f"{loss_name}"] = loss_value | ||
|
||
|
@@ -430,7 +452,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in | |
dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] | ||
dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] | ||
latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] | ||
|
||
|
||
|
||
assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" | ||
assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" | ||
|
||
|
@@ -516,9 +539,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in | |
'analysis/grad_norm_before': self.grad_norm_before, | ||
'analysis/grad_norm_after': self.grad_norm_after, | ||
} | ||
|
||
if self.harmony_balance: | ||
harmony_s_dict_monitor = {k: v.item() for k, v in self.harmony_s_dict.items()} | ||
harmony_s_exp_recip_dict_monitor = {f"{k}_exp_recip": (torch.reciprocal(torch.exp(v))).item() for k, v in self.harmony_s_dict.items()} | ||
return_loss_dict.update(harmony_s_dict_monitor) | ||
return_loss_dict.update(harmony_s_exp_recip_dict_monitor) | ||
return return_loss_dict | ||
|
||
def monitor_weights_and_grads(self, model): | ||
for name, param in model.named_parameters(): | ||
if param.requires_grad: | ||
|
@@ -862,7 +889,7 @@ def _monitor_vars_learn(self) -> List[str]: | |
Register the variables to be monitored in learn mode. The registered variables will be logged in | ||
tensorboard according to the return value ``_forward_learn``. | ||
""" | ||
return [ | ||
return_list = [ | ||
'analysis/dormant_ratio_encoder', | ||
'analysis/dormant_ratio_world_model', | ||
'analysis/latent_state_l2_norms', | ||
|
@@ -912,6 +939,12 @@ def _monitor_vars_learn(self) -> List[str]: | |
'reconstruction_loss', | ||
'perceptual_loss', | ||
] | ||
if self.harmony_balance: | ||
harmony_s_monitor_list = self.harmony_s_dict.keys() | ||
harmony_s_exp_recip_monitor_list = [f"{k}_exp_recip" for k in self.harmony_s_dict.keys()] | ||
return_list.extend(harmony_s_monitor_list) | ||
return_list.extend(harmony_s_exp_recip_monitor_list) | ||
return return_list | ||
|
||
def _state_dict_learn(self) -> Dict[str, Any]: | ||
""" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.