Skip to content

Commit feb6a01

Browse files
author
puyuan
committed
polish(pu): polish weight decay and add latent_norm_loss
1 parent 961f3be commit feb6a01

File tree

4 files changed

+86
-24
lines changed

4 files changed

+86
-24
lines changed

lzero/model/unizero_world_models/utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,24 +257,25 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu
257257
if not kwargs:
258258
raise ValueError("At least one loss must be provided")
259259

260+
260261
# Get a reference device from one of the provided losses
261262
device = next(iter(kwargs.values())).device
262263

263264
# NOTE: Define the weights for each loss type
264265
if not continuous_action_space:
265266
# orig, for atari and memory
266-
# self.obs_loss_weight = 10
267-
# self.value_loss_weight = 0.5
268-
# self.reward_loss_weight = 1.
269-
# self.policy_loss_weight = 1.
270-
# self.ends_loss_weight = 0.
267+
self.obs_loss_weight = 10
268+
self.value_loss_weight = 0.5
269+
self.reward_loss_weight = 1.
270+
self.policy_loss_weight = 1.
271+
self.ends_loss_weight = 0.
271272

272273
# muzero loss weight
273-
self.obs_loss_weight = 2
274-
self.value_loss_weight = 0.25
275-
self.reward_loss_weight = 1
276-
self.policy_loss_weight = 1
277-
self.ends_loss_weight = 0.
274+
# self.obs_loss_weight = 2
275+
# self.value_loss_weight = 0.25
276+
# self.reward_loss_weight = 1
277+
# self.policy_loss_weight = 1
278+
# self.ends_loss_weight = 0.
278279

279280
# EZV2, for atari and memory
280281
# self.obs_loss_weight = 5
@@ -297,6 +298,11 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu
297298
# self.reward_loss_weight = 0.1
298299
# self.ends_loss_weight = 0.
299300

301+
# TODO(pu)
302+
# self.latent_norm_loss_weight = 0.1
303+
self.latent_norm_loss_weight = 0.01
304+
305+
300306
self.latent_recon_loss_weight = latent_recon_loss_weight
301307
self.perceptual_loss_weight = perceptual_loss_weight
302308

@@ -317,6 +323,8 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu
317323
self.loss_total += self.latent_recon_loss_weight * v
318324
elif k == 'perceptual_loss':
319325
self.loss_total += self.perceptual_loss_weight * v
326+
elif k == 'latent_norm_loss':
327+
self.loss_total += self.latent_norm_loss_weight * v
320328

321329
self.intermediate_losses = {
322330
k: v if isinstance(v, dict) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item())

lzero/model/unizero_world_models/world_model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,39 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
15961596
# Calculate the L2 norm of the latent action
15971597
latent_action_l2_norms = torch.norm(self.act_embedding_table(act_tokens), p=2, dim=2).mean()
15981598

1599+
if self.config.latent_norm_loss:
1600+
# ==================== L2惩罚损失计算(最终修复版 v2) ====================
1601+
# 1. 计算每个 latent_state 向量的L2范数的平方。
1602+
# 根据调试信息,obs_embeddings shape: (B*L, 1, E)
1603+
# 所以 latent_norm_sq shape: (B*L, 1)
1604+
latent_norm_sq = torch.norm(obs_embeddings, p=2, dim=-1).pow(2)
1605+
# 2. 获取源掩码。
1606+
# 根据调试信息,mask_source shape: (B, L)
1607+
mask_source = batch['mask_padding']
1608+
# 3. 将源掩码从 (B, L) reshape 为 (B*L, 1),以匹配 latent_norm_sq 的形状。
1609+
# 这是解决维度不匹配错误的关键。
1610+
# 我们使用 view(-1, 1) 来实现这个变形。
1611+
correct_mask = mask_source.contiguous().view(-1, 1)
1612+
# 4. 检查变形后的形状是否匹配。
1613+
# 这是一个防御性编程,确保两个张量的第一个维度是相同的。
1614+
if latent_norm_sq.shape[0] != correct_mask.shape[0]:
1615+
# 如果形状不匹配,打印错误信息并抛出异常,这能帮助我们更快地定位未来可能出现的新问题。
1616+
raise RuntimeError(
1617+
f"Shape mismatch for L2 norm loss calculation! "
1618+
f"latent_norm_sq shape: {latent_norm_sq.shape}, "
1619+
f"but correct_mask shape after reshape is: {correct_mask.shape}. "
1620+
f"Original mask_source shape was: {mask_source.shape}"
1621+
)
1622+
# 5. 直接进行逐元素乘法。因为现在它们的形状都是 (B*L, 1),所以可以安全相乘。
1623+
masked_latent_norm_sq = latent_norm_sq * correct_mask
1624+
# 6. 计算平均损失。分母是掩码中所有“1”的总和,代表有效的元素数量。
1625+
# 增加一个极小值 epsilon (1e-8) 防止分母为零。
1626+
latent_norm_loss = masked_latent_norm_sq.sum() / (correct_mask.sum() + 1e-8)
1627+
# =================================================================
1628+
else:
1629+
latent_norm_loss = torch.tensor(0.)
1630+
1631+
15991632
# Forward pass to obtain predictions for observations, rewards, and policies
16001633
outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos)
16011634

@@ -1849,6 +1882,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
18491882
policy_mu=mu,
18501883
policy_sigma=sigma,
18511884
target_sampled_actions=target_sampled_actions,
1885+
latent_norm_loss=latent_norm_loss, # 新增
18521886
)
18531887
else:
18541888
return LossWithIntermediateLosses(
@@ -1870,6 +1904,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
18701904
dormant_ratio_world_model=dormant_ratio_world_model,
18711905
latent_state_l2_norms=latent_state_l2_norms,
18721906
latent_action_l2_norms=latent_action_l2_norms,
1907+
latent_norm_loss=latent_norm_loss, # 新增
18731908

18741909
)
18751910

zoo/atari/config/atari_muzero_segment_config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def main(env_id, seed):
3030

3131
# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
3232
# buffer_reanalyze_freq = 1
33+
buffer_reanalyze_freq = 1/2
34+
# buffer_reanalyze_freq = 1/10
3335
# buffer_reanalyze_freq = 1/50
34-
buffer_reanalyze_freq = 1/10000000000
36+
# buffer_reanalyze_freq = 1/10000000000
3537
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
3638
reanalyze_batch_size = 160
3739
# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
@@ -165,16 +167,16 @@ def main(env_id, seed):
165167
parser.add_argument('--seed', type=int, help='The seed to use', default=0)
166168
args = parser.parse_args()
167169

168-
args.env = 'MsPacmanNoFrameskip-v4'
169-
# args.env = 'QbertNoFrameskip-v4'
170+
# args.env = 'MsPacmanNoFrameskip-v4'
171+
args.env = 'QbertNoFrameskip-v4'
170172
# args.env = 'SeaquestNoFrameskip-v4'
171173
# args.env = 'BreakoutNoFrameskip-v4'
172174

173175
args.seed = 0
174176
main(args.env, args.seed)
175177

176178
"""
177-
export CUDA_VISIBLE_DEVICES=4
179+
export CUDA_VISIBLE_DEVICES=3
178180
cd /fs-computility/niuyazhe/puyuan/code/LightZero
179181
python /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/atari/config/atari_muzero_segment_config.py
180182
"""

zoo/atari/config/atari_unizero_segment_config.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ def main(env_id, seed):
99
# ==============================================================
1010
# begin of the most frequently changed config specified by the user
1111
# ==============================================================
12-
collector_env_num = 8
13-
num_segments = 8
14-
evaluator_env_num = 3
12+
# collector_env_num = 8
13+
# num_segments = 8
14+
# evaluator_env_num = 3
1515

16-
# collector_env_num = 1
17-
# num_segments = 1
18-
# evaluator_env_num = 1
16+
collector_env_num = 1
17+
num_segments = 1
18+
evaluator_env_num = 1
1919

2020
num_simulations = 50
2121
collect_num_simulations = 25
@@ -25,6 +25,8 @@ def main(env_id, seed):
2525
max_env_step = int(50e6)
2626
batch_size = 256
2727
# batch_size = 64 # debug
28+
# batch_size = 4 # debug
29+
2830
num_layers = 2
2931
# replay_ratio = 0.25
3032
replay_ratio = 0.1
@@ -33,6 +35,10 @@ def main(env_id, seed):
3335
num_unroll_steps = 10
3436
infer_context_length = 4
3537

38+
# game_segment_length = 40
39+
# num_unroll_steps = 20
40+
# infer_context_length = 8
41+
3642
# game_segment_length = 200
3743
# num_unroll_steps = 16
3844
# infer_context_length = 8
@@ -93,6 +99,8 @@ def main(env_id, seed):
9399
norm_type=norm_type,
94100
num_res_blocks=2,
95101
num_channels=128,
102+
# num_res_blocks=1, # TODO
103+
# num_channels=64,
96104
support_size=601,
97105
policy_entropy_weight=5e-3,
98106
# policy_entropy_weight=5e-2, # TODO(pu)
@@ -125,6 +133,13 @@ def main(env_id, seed):
125133
# final_norm_option_in_encoder="SimNorm",
126134
# final_norm_option_in_obs_head="SimNorm",
127135
# predict_latent_loss_type='group_kl',
136+
137+
# weight_decay=1e-2,
138+
# latent_norm_loss=True,
139+
140+
latent_norm_loss=False,
141+
weight_decay=1e-4, # TODO
142+
128143
),
129144
),
130145
# gradient_scale=True, #TODO
@@ -160,8 +175,8 @@ def main(env_id, seed):
160175
grad_clip_value=5,
161176
replay_buffer_size=int(1e6),
162177
# eval_freq=int(5e3),
163-
# eval_freq=int(1e4),
164-
eval_freq=int(2e4),
178+
eval_freq=int(1e4), # TODO
179+
# eval_freq=int(2e4),
165180
collector_env_num=collector_env_num,
166181
evaluator_env_num=evaluator_env_num,
167182
# ============= The key different params for reanalyze =============
@@ -193,8 +208,10 @@ def main(env_id, seed):
193208
# ============ use muzero_segment_collector instead of muzero_collector =============
194209
from lzero.entry import train_unizero_segment
195210

196-
main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_fix-init-recur_clear20_mulossweight_spsi20_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
211+
# main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_wd1e-2_fix-init-recur_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
212+
main_config.exp_name = f'data_unizero_longrun_20250827/{env_id[:-14]}/{env_id[:-14]}_uz_lnlw001_fix-init-recur_clear{game_segment_length}_originlossweight_spsi{game_segment_length}_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
197213

214+
# main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_fix-init-recur_clear{game_segment_length}_mulossweight_spsi{game_segment_length}_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
198215

199216
# main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_fix-init-recur_clear20_origlossweight_spsi20_envnum{collector_env_num}_encoder-head-l2norm_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
200217

@@ -256,7 +273,7 @@ def main(env_id, seed):
256273
main(args.env, args.seed)
257274

258275
"""
259-
export CUDA_VISIBLE_DEVICES=0
276+
export CUDA_VISIBLE_DEVICES=6
260277
cd /fs-computility/niuyazhe/puyuan/code/LightZero
261278
python /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/atari/config/atari_unizero_segment_config.py
262279
"""

0 commit comments

Comments
 (0)