@@ -113,8 +113,17 @@ class UniZeroPolicy(MuZeroPolicy):
113
113
perceptual_loss_weight = 0. ,
114
114
# (float) The weight of the policy entropy loss.
115
115
policy_entropy_weight = 0 ,
116
- # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse'].
117
- predict_latent_loss_type = 'group_kl' ,
116
+ # (str) The normalization type for the final layer in both the head and the encoder.
117
+ # This option must be the same for both 'final_norm_option_in_head' and 'final_norm_option_in_encoder'.
118
+ # Valid options are 'LayerNorm' and 'SimNorm'.
119
+ # When set to 'LayerNorm', the 'predict_latent_loss_type' should be 'mse'.
120
+ # When set to 'SimNorm', the 'predict_latent_loss_type' should be 'group_kl'.
121
+ final_norm_option_in_head = "LayerNorm" ,
122
+ final_norm_option_in_encoder = "LayerNorm" ,
123
+ # (str) The type of loss function for predicting latent variables.
124
+ # Options are 'mse' (Mean Squared Error) or 'group_kl' (Group Kullback-Leibler divergence).
125
+ # This choice is dependent on the normalization method selected above.
126
+ predict_latent_loss_type = 'mse' ,
118
127
# (str) The type of observation. Options are ['image', 'vector'].
119
128
obs_type = 'image' ,
120
129
# (float) The discount factor for future rewards.
@@ -345,8 +354,6 @@ def _init_learn(self) -> None:
345
354
)
346
355
self .value_support = DiscreteSupport (* self ._cfg .model .value_support_range , self ._cfg .device )
347
356
self .reward_support = DiscreteSupport (* self ._cfg .model .reward_support_range , self ._cfg .device )
348
- assert self .value_support .size == self ._learn_model .value_support_size # if these assertions fails, somebody introduced...
349
- assert self .reward_support .size == self ._learn_model .reward_support_size # ...incoherence between policy and model
350
357
self .value_inverse_scalar_transform_handle = InverseScalarTransform (self .value_support , self ._cfg .model .categorical_distribution )
351
358
self .reward_inverse_scalar_transform_handle = InverseScalarTransform (self .reward_support , self ._cfg .model .categorical_distribution )
352
359
0 commit comments