Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(wrh): add harmony dream in unizero #255

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions lzero/model/unizero_world_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,26 @@ class LossWithIntermediateLosses:
Returns:
- None
"""
def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwargs):
def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, harmony_s_dict=None, **kwargs):
# Ensure that kwargs is not empty
if not kwargs:
raise ValueError("At least one loss must be provided")

# Get a reference device from one of the provided losses
device = next(iter(kwargs.values())).device


if harmony_s_dict is not None:
for k, v in harmony_s_dict.items():
print(f"{k} has s {v}")

loss_obs_harmony_s = harmony_s_dict.get("loss_obs_s", None)
loss_rewards_harmony_s = harmony_s_dict.get("loss_rewards_s", None)
loss_value_harmony_s = harmony_s_dict.get("loss_value_s", None)
loss_policy_harmony_s = harmony_s_dict.get("loss_policy_s", None)
loss_ends_harmony_s = harmony_s_dict.get("loss_ends_s", None)
latent_recon_loss_harmony_s = harmony_s_dict.get("latent_recon_loss_s", None)
perceptual_loss_harmony_s = harmony_s_dict.get("perceptual_loss_s", None)

# Define the weights for each loss type
self.obs_loss_weight = 10
self.reward_loss_weight = 1.
Expand All @@ -164,19 +176,60 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg
self.loss_total = torch.tensor(0., device=device)
for k, v in kwargs.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Define a dictionary for loss weights and harmony_s variables
loss_weights = {
    'loss_obs': (self.obs_loss_weight, 'loss_obs_harmony_s'),
    'loss_rewards': (self.reward_loss_weight, 'loss_rewards_harmony_s'),
    'loss_policy': (self.policy_loss_weight, 'loss_policy_harmony_s'),
    'loss_value': (self.value_loss_weight, 'loss_value_harmony_s'),
    'loss_ends': (self.ends_loss_weight, 'loss_ends_harmony_s'),
    'latent_recon_loss': (self.latent_recon_loss_weight, 'latent_recon_loss_harmony_s'),
    'perceptual_loss': (self.perceptual_loss_weight, 'perceptual_loss_harmony_s')
}

# Iterate through kwargs to process the losses
for k, v in kwargs.items():
    if k in loss_weights:
        weight, harmony_var_name = loss_weights[k]
        harmony_s = globals().get(harmony_var_name)  # Get the harmony_s variable by name

        if harmony_s_dict is None:
            self.loss_total += weight * v
        elif harmony_s is not None:
            self.loss_total += (v / torch.exp(harmony_s)) + torch.log(torch.exp(harmony_s) + 1)
        else:
            self.loss_total += weight * v

if k == 'loss_obs':
self.loss_total += self.obs_loss_weight * v
if harmony_s_dict is None:
self.loss_total += self.obs_loss_weight * v
elif loss_obs_harmony_s is not None:
self.loss_total += ((v / torch.exp(loss_obs_harmony_s)) + torch.log(torch.exp(loss_obs_harmony_s) + 1))
else:
self.loss_total += self.obs_loss_weight * v

elif k == 'loss_rewards':
self.loss_total += self.reward_loss_weight * v
if harmony_s_dict is None:
self.loss_total += self.reward_loss_weight * v
elif loss_rewards_harmony_s is not None:
self.loss_total += ((v / torch.exp(loss_rewards_harmony_s)) + torch.log(torch.exp(loss_rewards_harmony_s) + 1))
else:
self.loss_total += self.reward_loss_weight * v

elif k == 'loss_policy':
self.loss_total += self.policy_loss_weight * v
if harmony_s_dict is None:
self.loss_total += self.policy_loss_weight * v
elif loss_policy_harmony_s is not None:
self.loss_total += ((v / torch.exp(loss_policy_harmony_s)) + torch.log(torch.exp(loss_policy_harmony_s) + 1))
else:
self.loss_total += self.policy_loss_weight * v

elif k == 'loss_value':
self.loss_total += self.value_loss_weight * v
if harmony_s_dict is None:
self.loss_total += self.value_loss_weight * v
elif loss_value_harmony_s is not None:
self.loss_total += ((v / torch.exp(loss_value_harmony_s)) + torch.log(torch.exp(loss_value_harmony_s) + 1))
else:
self.loss_total += self.value_loss_weight * v

elif k == 'loss_ends':
self.loss_total += self.ends_loss_weight * v
if harmony_s_dict is None:
self.loss_total += self.ends_loss_weight * v
elif loss_ends_harmony_s is not None:
self.loss_total += ((v / torch.exp(loss_ends_harmony_s)) + torch.log(torch.exp(loss_ends_harmony_s) + 1))
else:
self.loss_total += self.ends_loss_weight * v

elif k == 'latent_recon_loss':
self.loss_total += self.latent_recon_loss_weight * v
if harmony_s_dict is None:
self.loss_total += self.latent_recon_loss_weight * v
elif latent_recon_loss_harmony_s is not None:
self.loss_total += ((v / torch.exp(latent_recon_loss_harmony_s)) + torch.log(torch.exp(latent_recon_loss_harmony_s) + 1))
else:
self.loss_total += self.latent_recon_loss_weight * v

elif k == 'perceptual_loss':
self.loss_total += self.perceptual_loss_weight * v
if harmony_s_dict is None:
self.loss_total += self.perceptual_loss_weight * v
elif perceptual_loss_harmony_s is not None:
self.loss_total += ((v / torch.exp(perceptual_loss_harmony_s)) + torch.log(torch.exp(perceptual_loss_harmony_s) + 1))
else:
self.loss_total += self.perceptual_loss_weight * v

self.intermediate_losses = {
k: v if isinstance(v, dict) else (v if isinstance(v, float) else v.item())
Expand Down
4 changes: 3 additions & 1 deletion lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,8 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int,
return self.keys_values_wm_size_list

def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, **kwargs: Any) -> LossWithIntermediateLosses:
# Encode observations into latent state representations

harmony_s_dict = kwargs.get("harmony_s_dict", None)
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'])

# ========= for visual analysis =========
Expand Down Expand Up @@ -1069,6 +1070,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
return LossWithIntermediateLosses(
latent_recon_loss_weight=self.latent_recon_loss_weight,
perceptual_loss_weight=self.perceptual_loss_weight,
harmony_s_dict=harmony_s_dict,
loss_obs=discounted_loss_obs,
loss_rewards=discounted_loss_rewards,
loss_value=discounted_loss_value,
Expand Down
47 changes: 40 additions & 7 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class UniZeroPolicy(MuZeroPolicy):
ignore_done=False,
# (int) How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
# collect data -> update policy -> collect data -> ...
# For different env, we have different episode_length,
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor.
# If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically.
Expand Down Expand Up @@ -268,6 +268,14 @@ class UniZeroPolicy(MuZeroPolicy):
# (int) The decay steps from start to end eps.
decay=int(1e5),
),

# ****** Harmony Dream for balancing ******
# (bool) Whether harmony dream is used to balance different weights of loss functions.
harmony_balance=False,
# (List) Loss list involving dynamic balance adjustment when harmony_balance is True.
# If not appeared, the weight of loss will remain fixed as defined.
harmony_loss_names=['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy']
# ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy']
)

def default_model(self) -> Tuple[str, List[str]]:
Expand All @@ -290,12 +298,24 @@ def _init_learn(self) -> None:
Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils.
"""
# NOTE: nanoGPT optimizer
self.harmony_balance = self._cfg.harmony_balance
if self.harmony_balance:
self.harmony_loss_names = self._cfg.harmony_loss_names
assert self.harmony_loss_names != None
harmony_s_names = [f"{name}_s" for name in self.harmony_loss_names]
for name in harmony_s_names:
param = torch.nn.Parameter(-torch.log(torch.tensor(1.0)))
setattr(self, name, param)

self.harmony_s_dict = {name: getattr(self, name) for name in harmony_s_names} if self.harmony_balance else None

self._optimizer_world_model = configure_optimizers_nanogpt(
model=self._model.world_model,
learning_rate=self._cfg.learning_rate,
weight_decay=self._cfg.weight_decay,
device_type=self._cfg.device,
betas=(0.9, 0.95),
additional_params=self.harmony_s_dict
)

# use model_wrapper for specialized demands of different modes
Expand Down Expand Up @@ -409,10 +429,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in

# Update world model
losses = self._learn_model.world_model.compute_loss(
batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle
batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, harmony_s_dict=self.harmony_s_dict
)


# else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除

weighted_total_loss = losses.loss_total

for loss_name, loss_value in losses.intermediate_losses.items():
self.intermediate_losses[f"{loss_name}"] = loss_value

Expand All @@ -430,7 +452,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder']
dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model']
latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms']



assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"

Expand Down Expand Up @@ -516,9 +539,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
'analysis/grad_norm_before': self.grad_norm_before,
'analysis/grad_norm_after': self.grad_norm_after,
}

if self.harmony_balance:
harmony_s_dict_monitor = {k: v.item() for k, v in self.harmony_s_dict.items()}
harmony_s_exp_recip_dict_monitor = {f"{k}_exp_recip": (torch.reciprocal(torch.exp(v))).item() for k, v in self.harmony_s_dict.items()}
return_loss_dict.update(harmony_s_dict_monitor)
return_loss_dict.update(harmony_s_exp_recip_dict_monitor)
return return_loss_dict

def monitor_weights_and_grads(self, model):
for name, param in model.named_parameters():
if param.requires_grad:
Expand Down Expand Up @@ -862,7 +889,7 @@ def _monitor_vars_learn(self) -> List[str]:
Register the variables to be monitored in learn mode. The registered variables will be logged in
tensorboard according to the return value ``_forward_learn``.
"""
return [
return_list = [
'analysis/dormant_ratio_encoder',
'analysis/dormant_ratio_world_model',
'analysis/latent_state_l2_norms',
Expand Down Expand Up @@ -912,6 +939,12 @@ def _monitor_vars_learn(self) -> List[str]:
'reconstruction_loss',
'perceptual_loss',
]
if self.harmony_balance:
harmony_s_monitor_list = self.harmony_s_dict.keys()
harmony_s_exp_recip_monitor_list = [f"{k}_exp_recip" for k in self.harmony_s_dict.keys()]
return_list.extend(harmony_s_monitor_list)
return_list.extend(harmony_s_exp_recip_monitor_list)
return return_list

def _state_dict_learn(self) -> Dict[str, Any]:
"""
Expand Down
12 changes: 11 additions & 1 deletion lzero/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,23 +200,33 @@ def forward(self, input):


# modified from https://github.com/karpathy/nanoGPT/blob/master/model.py#L263
def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type):
def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type, additional_params=None):
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in model.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}

# Add additional parameters to the param_dict if provided
if additional_params is not None:
for name, param in additional_params.items():
if param.requires_grad:
param_dict[name] = param

# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]

optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]

num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")

# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
if torch.cuda.is_available():
Expand Down
6 changes: 3 additions & 3 deletions zoo/atari/config/atari_unizero_config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from easydict import EasyDict
from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map

env_id = 'PongNoFrameskip-v4' # You can specify any Atari game here
env_id = 'QbertNoFrameskip-v4' # You can specify any Atari game here
# 'QbertNoFrameskip-v4'
action_space_size = atari_env_action_space_map[env_id]

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
update_per_collect = None
update_per_collect = 1000
replay_ratio = 0.25
collector_env_num = 8
n_episode = 8
Expand Down