Skip to content

Commit 5c412bb

Browse files
xiongjyupuyuan
andauthored
feature(xjy): add encoder_decoder_type option for jericho's world model (#391)
* Qwen is tested as a policy in the jericho environment * fixed the bug that bad reflection cannot be collected * supports options for selecting encoder/decoder * fixed a few bugs and standardized the format * standardize the format again --------- Co-authored-by: puyuan <[email protected]>
1 parent c2eb518 commit 5c412bb

File tree

11 files changed

+347
-118
lines changed

11 files changed

+347
-118
lines changed

lzero/entry/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,36 @@ def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]],
111111

112112
return torch.zeros(shape).to(device)
113113

114+
def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str, pad_token_id: int = 0) -> torch.Tensor:
115+
"""
116+
Overview:
117+
Initialize a tensor filled with `pad_token_id` for batch observations.
118+
This function is designed to be flexible and can handle both textual
119+
and non-textual observations:
120+
121+
- For textual observations: it initializes `input_ids` with padding tokens,
122+
ensuring consistent sequence lengths within a batch.
123+
- For non-textual observations: it provides a convenient way to fill
124+
observation tensors with a default of 0,
125+
ensuring shape compatibility and preventing uninitialized values.
126+
Arguments:
127+
- observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor.
128+
- batch_size (:obj:`int`): The batch size.
129+
- device (:obj:`str`): The device to store the tensor.
130+
- pad_token_id (:obj:`int`): The token ID (or placeholder value) used for padding.
131+
Returns:
132+
- padded_tensor (:obj:`torch.Tensor`): A tensor of the given shape,
133+
filled with `pad_token_id`.
134+
"""
135+
if isinstance(observation_shape, (list, tuple)):
136+
shape = [batch_size, *observation_shape]
137+
elif isinstance(observation_shape, int):
138+
shape = [batch_size, observation_shape]
139+
else:
140+
raise TypeError(f"observation_shape must be int, list, or tuple, but got {type(observation_shape).__name__}")
141+
142+
return torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device)
143+
114144
def random_collect(
115145
policy_cfg: 'EasyDict', # noqa
116146
policy: 'Policy', # noqa

lzero/mcts/buffer/game_buffer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,19 +156,22 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
156156
# For some environments (e.g., Jericho), the action space size may be different.
157157
# To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length),
158158
# we avoid sampling from the last `num_unroll_steps` steps of the game segment.
159-
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps:
160-
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item()
159+
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps:
160+
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item()
161+
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
162+
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()
161163
else:
162164
# For environments with a fixed action space (e.g., Atari),
163165
# we can safely sample from the entire game segment range.
164166
if pos_in_game_segment >= self._cfg.game_segment_length:
165167
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
168+
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
169+
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()
166170

167171
pos_in_game_segment_list.append(pos_in_game_segment)
168172

169173

170174
make_time = [time.time() for _ in range(len(batch_index_list))]
171-
172175
orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time)
173176
return orig_data
174177

lzero/mcts/tree_search/mcts_ctree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def search(
185185
current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch,
186186
min_max_stats_lst, results, virtual_to_play_batch
187187
)
188-
188+
189189
return first_action_latent_map
190190

191191

lzero/model/common.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
import torch.nn as nn
1616
import torch.nn.functional as F
1717
import torch.nn.init as init
18+
from transformers import AutoModelForCausalLM, AutoTokenizer
1819
from ding.torch_utils import MLP, ResBlock
1920
from ding.torch_utils.network.normalization import build_normalization
2021
from ding.utils import SequenceType
2122
from ditk import logging
2223
from ding.utils import set_pkg_seed, get_rank, get_world_size
23-
import torch
24+
25+
2426

2527
def MLP_V2(
2628
in_channels: int,
@@ -361,6 +363,116 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
361363

362364
return output
363365

366+
class QwenNetwork(nn.Module):
367+
def __init__(self,
368+
model_path: str = 'Qwen/Qwen3-1.7B',
369+
embedding_size: int = 768,
370+
final_norm_option_in_encoder: str = "layernorm",
371+
group_size: int = 8,
372+
tokenizer=None):
373+
super().__init__()
374+
375+
logging.info(f"Loading Qwen model from: {model_path}")
376+
377+
local_rank = get_rank()
378+
if local_rank == 0:
379+
self.pretrained_model = AutoModelForCausalLM.from_pretrained(
380+
model_path,
381+
torch_dtype="auto",
382+
device_map={"": local_rank},
383+
attn_implementation="flash_attention_2"
384+
)
385+
if get_world_size() > 1:
386+
torch.distributed.barrier()
387+
if local_rank != 0:
388+
self.pretrained_model = AutoModelForCausalLM.from_pretrained(
389+
model_path,
390+
torch_dtype="auto",
391+
device_map={"": local_rank},
392+
attn_implementation="flash_attention_2"
393+
)
394+
395+
for p in self.pretrained_model.parameters():
396+
p.requires_grad = False
397+
398+
if tokenizer is None:
399+
if local_rank == 0:
400+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
401+
if get_world_size() > 1:
402+
torch.distributed.barrier()
403+
if local_rank != 0:
404+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
405+
else:
406+
self.tokenizer = tokenizer
407+
408+
qwen_hidden_size = self.pretrained_model.config.hidden_size
409+
410+
self.embedding_head = nn.Sequential(
411+
nn.Linear(qwen_hidden_size, embedding_size),
412+
self._create_norm_layer(final_norm_option_in_encoder, embedding_size, group_size)
413+
)
414+
415+
def _create_norm_layer(self, norm_option, embedding_size, group_size):
416+
if norm_option.lower() == "simnorm":
417+
return SimNorm(simnorm_dim=group_size)
418+
elif norm_option.lower() == "layernorm":
419+
return nn.LayerNorm(embedding_size)
420+
else:
421+
raise NotImplementedError(f"Normalization type '{norm_option}' is not implemented.")
422+
423+
def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
424+
"""
425+
Overview:
426+
Encode the input token sequence `x` into a latent representation
427+
using a pretrained language model backbone followed by a projection head.
428+
Arguments:
429+
- x (:obj:`torch.Tensor`): Input token ids of shape (B, L)
430+
- no_grad (:obj:`bool`, optional, default=True): If True, encoding is performed under `torch.no_grad()` to save memory and computation (no gradient tracking).
431+
Returns:
432+
- latent (:obj:`torch.Tensor`): Encoded latent state of shape (B, D).
433+
"""
434+
pad_id = self.tokenizer.pad_token_id
435+
attention_mask = (x != pad_id).long().to(x.device)
436+
context = {'input_ids': x.long(), 'attention_mask': attention_mask}
437+
if no_grad:
438+
with torch.no_grad():
439+
outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True)
440+
else:
441+
outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True)
442+
last_hidden = outputs.hidden_states[-1]
443+
444+
B, L, H = last_hidden.size()
445+
lengths = attention_mask.sum(dim=1) # [B]
446+
positions = torch.clamp(lengths - 1, min=0) # [B]
447+
batch_idx = torch.arange(B, device=last_hidden.device)
448+
449+
selected = last_hidden[batch_idx, positions] # [B, H]
450+
451+
latent = self.embedding_head(selected.to(self.embedding_head[0].weight.dtype))
452+
return latent
453+
454+
def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str:
455+
"""
456+
Decodes embeddings into text via the decoder network.
457+
"""
458+
embeddings_detached = embeddings.detach()
459+
self.pretrained_model.eval()
460+
461+
# Directly generate using provided embeddings
462+
with torch.no_grad():
463+
param = next(self.pretrained_model.parameters())
464+
embeddings = embeddings_detached.to(device=param.device, dtype=param.dtype)
465+
gen_ids = self.pretrained_model.generate(
466+
inputs_embeds=embeddings,
467+
max_length=max_length
468+
)
469+
texts = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
470+
self.pretrained_model.train()
471+
return texts[0] if len(texts) == 1 else texts
472+
473+
def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
474+
return self.encode(x, no_grad=no_grad)
475+
364476

365477
class HFLanguageRepresentationNetwork(nn.Module):
366478
def __init__(self,
@@ -542,7 +654,6 @@ def __init__(
542654
else:
543655
raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}")
544656

545-
546657
def forward(self, x: torch.Tensor) -> torch.Tensor:
547658
"""
548659
Shapes:

lzero/model/unizero_model.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \
1010
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \
11-
HFLanguageRepresentationNetwork
11+
HFLanguageRepresentationNetwork, QwenNetwork
1212
from .unizero_world_models.tokenizer import Tokenizer
1313
from .unizero_world_models.world_model import WorldModel
1414
from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size
@@ -96,21 +96,37 @@ def __init__(
9696
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
9797
print('==' * 20)
9898
elif world_model_cfg.obs_type == 'text':
99-
self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder)
100-
# print(self.representation_network.model.encoder.layer[0].attention.output.LayerNorm.weight)
101-
102-
if self.rank == 0:
103-
self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small")
104-
self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small")
105-
if self.world_size > 1:
106-
# Wait until rank 0 finishes loading the tokenizer
107-
torch.distributed.barrier()
108-
if self.rank != 0:
109-
self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small")
110-
self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small")
111-
112-
projection = [self.representation_network.pretrained_model.config.hidden_size, self.decoder_network.config.d_model]
113-
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, with_lpips=False, projection=projection)
99+
if kwargs['encoder_option'] == 'legacy':
100+
self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder)
101+
if world_model_cfg.decode_loss_mode is None or world_model_cfg.decode_loss_mode.lower() == 'none':
102+
self.decoder_network = None
103+
self.decoder_network_tokenizer = None
104+
projection = None
105+
else:
106+
if self.rank == 0:
107+
self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small")
108+
self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small")
109+
if self.world_size > 1:
110+
# Wait until rank 0 finishes loading the tokenizer
111+
torch.distributed.barrier()
112+
if self.rank != 0:
113+
self.decoder_network = T5ForConditionalGeneration.from_pretrained("t5-small")
114+
self.decoder_network_tokenizer = T5Tokenizer.from_pretrained("t5-small")
115+
projection = [world_model_cfg.embed_dim, self.decoder_network.config.d_model]
116+
elif kwargs['encoder_option'] == 'qwen':
117+
self.representation_network = QwenNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder)
118+
if world_model_cfg.decode_loss_mode is None or world_model_cfg.decode_loss_mode.lower() == 'none':
119+
self.decoder_network = None
120+
self.decoder_network_tokenizer = None
121+
projection = None
122+
else:
123+
projection = [world_model_cfg.embed_dim, self.representation_network.pretrained_model.config.hidden_size]
124+
self.decoder_network = self.representation_network
125+
self.decoder_network_tokenizer = None
126+
else:
127+
raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}")
128+
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer,
129+
with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option'])
114130
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
115131
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
116132
print('==' * 20)

0 commit comments

Comments
 (0)