Skip to content

Commit bcb9a83

Browse files
committed
fix(pu): fix pad dtype bug
1 parent 5c412bb commit bcb9a83

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

lzero/entry/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import torch.distributed as dist
77
from pympler.asizeof import asizeof
88
from tensorboardX import SummaryWriter
9-
10-
119
import torch
1210
import torch.distributed as dist
1311

@@ -139,7 +137,7 @@ def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], b
139137
else:
140138
raise TypeError(f"observation_shape must be int, list, or tuple, but got {type(observation_shape).__name__}")
141139

142-
return torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device)
140+
return torch.full(shape, fill_value=pad_token_id, dtype=torch.float32, device=device) if pad_token_id == 0 else torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device)
143141

144142
def random_collect(
145143
policy_cfg: 'EasyDict', # noqa

lzero/model/unizero_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55
from ding.utils import MODEL_REGISTRY, SequenceType
66
from easydict import EasyDict
7-
from transformers import T5ForConditionalGeneration, T5Tokenizer
7+
# from transformers import T5ForConditionalGeneration, T5Tokenizer
88

99
from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \
1010
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \

lzero/policy/unizero.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,17 @@ class UniZeroPolicy(MuZeroPolicy):
113113
perceptual_loss_weight=0.,
114114
# (float) The weight of the policy entropy loss.
115115
policy_entropy_weight=0,
116-
# (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse'].
117-
predict_latent_loss_type='group_kl',
116+
# (str) The normalization type for the final layer in both the head and the encoder.
117+
# This option must be the same for both 'final_norm_option_in_head' and 'final_norm_option_in_encoder'.
118+
# Valid options are 'LayerNorm' and 'SimNorm'.
119+
# When set to 'LayerNorm', the 'predict_latent_loss_type' should be 'mse'.
120+
# When set to 'SimNorm', the 'predict_latent_loss_type' should be 'group_kl'.
121+
final_norm_option_in_head="LayerNorm",
122+
final_norm_option_in_encoder="LayerNorm",
123+
# (str) The type of loss function for predicting latent variables.
124+
# Options are 'mse' (Mean Squared Error) or 'group_kl' (Group Kullback-Leibler divergence).
125+
# This choice is dependent on the normalization method selected above.
126+
predict_latent_loss_type='mse',
118127
# (str) The type of observation. Options are ['image', 'vector'].
119128
obs_type='image',
120129
# (float) The discount factor for future rewards.
@@ -345,8 +354,6 @@ def _init_learn(self) -> None:
345354
)
346355
self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device)
347356
self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device)
348-
assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced...
349-
assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model
350357
self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution)
351358
self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution)
352359

0 commit comments

Comments
 (0)