Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import torch.distributed as dist
from pympler.asizeof import asizeof
from tensorboardX import SummaryWriter


import torch
import torch.distributed as dist

Expand Down Expand Up @@ -139,7 +137,7 @@ def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], b
else:
raise TypeError(f"observation_shape must be int, list, or tuple, but got {type(observation_shape).__name__}")

return torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device)
return torch.full(shape, fill_value=pad_token_id, dtype=torch.float32, device=device) if pad_token_id == 0 else torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device)

def random_collect(
policy_cfg: 'EasyDict', # noqa
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
from ding.utils import MODEL_REGISTRY, SequenceType
from easydict import EasyDict
from transformers import T5ForConditionalGeneration, T5Tokenizer
# from transformers import T5ForConditionalGeneration, T5Tokenizer

from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \
Expand Down
15 changes: 11 additions & 4 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,17 @@ class UniZeroPolicy(MuZeroPolicy):
perceptual_loss_weight=0.,
# (float) The weight of the policy entropy loss.
policy_entropy_weight=0,
# (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse'].
predict_latent_loss_type='group_kl',
# (str) The normalization type for the final layer in both the head and the encoder.
# This option must be the same for both 'final_norm_option_in_head' and 'final_norm_option_in_encoder'.
# Valid options are 'LayerNorm' and 'SimNorm'.
# When set to 'LayerNorm', the 'predict_latent_loss_type' should be 'mse'.
# When set to 'SimNorm', the 'predict_latent_loss_type' should be 'group_kl'.
final_norm_option_in_head="LayerNorm",
final_norm_option_in_encoder="LayerNorm",
# (str) The type of loss function for predicting latent variables.
# Options are 'mse' (Mean Squared Error) or 'group_kl' (Group Kullback-Leibler divergence).
# This choice is dependent on the normalization method selected above.
predict_latent_loss_type='mse',
# (str) The type of observation. Options are ['image', 'vector'].
obs_type='image',
# (float) The discount factor for future rewards.
Expand Down Expand Up @@ -345,8 +354,6 @@ def _init_learn(self) -> None:
)
self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device)
self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device)
assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced...
assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model
self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution)
self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution)

Expand Down
Loading