|
15 | 15 | import torch.nn as nn
|
16 | 16 | import torch.nn.functional as F
|
17 | 17 | import torch.nn.init as init
|
| 18 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
18 | 19 | from ding.torch_utils import MLP, ResBlock
|
19 | 20 | from ding.torch_utils.network.normalization import build_normalization
|
20 | 21 | from ding.utils import SequenceType
|
21 | 22 | from ditk import logging
|
22 | 23 | from ding.utils import set_pkg_seed, get_rank, get_world_size
|
23 |
| -import torch |
| 24 | + |
| 25 | + |
24 | 26 |
|
25 | 27 | def MLP_V2(
|
26 | 28 | in_channels: int,
|
@@ -361,6 +363,116 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
361 | 363 |
|
362 | 364 | return output
|
363 | 365 |
|
| 366 | +class QwenNetwork(nn.Module): |
| 367 | + def __init__(self, |
| 368 | + model_path: str = 'Qwen/Qwen3-1.7B', |
| 369 | + embedding_size: int = 768, |
| 370 | + final_norm_option_in_encoder: str = "layernorm", |
| 371 | + group_size: int = 8, |
| 372 | + tokenizer=None): |
| 373 | + super().__init__() |
| 374 | + |
| 375 | + logging.info(f"Loading Qwen model from: {model_path}") |
| 376 | + |
| 377 | + local_rank = get_rank() |
| 378 | + if local_rank == 0: |
| 379 | + self.pretrained_model = AutoModelForCausalLM.from_pretrained( |
| 380 | + model_path, |
| 381 | + torch_dtype="auto", |
| 382 | + device_map={"": local_rank}, |
| 383 | + attn_implementation="flash_attention_2" |
| 384 | + ) |
| 385 | + if get_world_size() > 1: |
| 386 | + torch.distributed.barrier() |
| 387 | + if local_rank != 0: |
| 388 | + self.pretrained_model = AutoModelForCausalLM.from_pretrained( |
| 389 | + model_path, |
| 390 | + torch_dtype="auto", |
| 391 | + device_map={"": local_rank}, |
| 392 | + attn_implementation="flash_attention_2" |
| 393 | + ) |
| 394 | + |
| 395 | + for p in self.pretrained_model.parameters(): |
| 396 | + p.requires_grad = False |
| 397 | + |
| 398 | + if tokenizer is None: |
| 399 | + if local_rank == 0: |
| 400 | + self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| 401 | + if get_world_size() > 1: |
| 402 | + torch.distributed.barrier() |
| 403 | + if local_rank != 0: |
| 404 | + self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| 405 | + else: |
| 406 | + self.tokenizer = tokenizer |
| 407 | + |
| 408 | + qwen_hidden_size = self.pretrained_model.config.hidden_size |
| 409 | + |
| 410 | + self.embedding_head = nn.Sequential( |
| 411 | + nn.Linear(qwen_hidden_size, embedding_size), |
| 412 | + self._create_norm_layer(final_norm_option_in_encoder, embedding_size, group_size) |
| 413 | + ) |
| 414 | + |
| 415 | + def _create_norm_layer(self, norm_option, embedding_size, group_size): |
| 416 | + if norm_option.lower() == "simnorm": |
| 417 | + return SimNorm(simnorm_dim=group_size) |
| 418 | + elif norm_option.lower() == "layernorm": |
| 419 | + return nn.LayerNorm(embedding_size) |
| 420 | + else: |
| 421 | + raise NotImplementedError(f"Normalization type '{norm_option}' is not implemented.") |
| 422 | + |
| 423 | + def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: |
| 424 | + """ |
| 425 | + Overview: |
| 426 | + Encode the input token sequence `x` into a latent representation |
| 427 | + using a pretrained language model backbone followed by a projection head. |
| 428 | + Arguments: |
| 429 | + - x (:obj:`torch.Tensor`): Input token ids of shape (B, L) |
| 430 | + - no_grad (:obj:`bool`, optional, default=True): If True, encoding is performed under `torch.no_grad()` to save memory and computation (no gradient tracking). |
| 431 | + Returns: |
| 432 | + - latent (:obj:`torch.Tensor`): Encoded latent state of shape (B, D). |
| 433 | + """ |
| 434 | + pad_id = self.tokenizer.pad_token_id |
| 435 | + attention_mask = (x != pad_id).long().to(x.device) |
| 436 | + context = {'input_ids': x.long(), 'attention_mask': attention_mask} |
| 437 | + if no_grad: |
| 438 | + with torch.no_grad(): |
| 439 | + outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) |
| 440 | + else: |
| 441 | + outputs = self.pretrained_model(**context, output_hidden_states=True, return_dict=True) |
| 442 | + last_hidden = outputs.hidden_states[-1] |
| 443 | + |
| 444 | + B, L, H = last_hidden.size() |
| 445 | + lengths = attention_mask.sum(dim=1) # [B] |
| 446 | + positions = torch.clamp(lengths - 1, min=0) # [B] |
| 447 | + batch_idx = torch.arange(B, device=last_hidden.device) |
| 448 | + |
| 449 | + selected = last_hidden[batch_idx, positions] # [B, H] |
| 450 | + |
| 451 | + latent = self.embedding_head(selected.to(self.embedding_head[0].weight.dtype)) |
| 452 | + return latent |
| 453 | + |
| 454 | + def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str: |
| 455 | + """ |
| 456 | + Decodes embeddings into text via the decoder network. |
| 457 | + """ |
| 458 | + embeddings_detached = embeddings.detach() |
| 459 | + self.pretrained_model.eval() |
| 460 | + |
| 461 | + # Directly generate using provided embeddings |
| 462 | + with torch.no_grad(): |
| 463 | + param = next(self.pretrained_model.parameters()) |
| 464 | + embeddings = embeddings_detached.to(device=param.device, dtype=param.dtype) |
| 465 | + gen_ids = self.pretrained_model.generate( |
| 466 | + inputs_embeds=embeddings, |
| 467 | + max_length=max_length |
| 468 | + ) |
| 469 | + texts = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True) |
| 470 | + self.pretrained_model.train() |
| 471 | + return texts[0] if len(texts) == 1 else texts |
| 472 | + |
| 473 | + def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: |
| 474 | + return self.encode(x, no_grad=no_grad) |
| 475 | + |
364 | 476 |
|
365 | 477 | class HFLanguageRepresentationNetwork(nn.Module):
|
366 | 478 | def __init__(self,
|
@@ -542,7 +654,6 @@ def __init__(
|
542 | 654 | else:
|
543 | 655 | raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}")
|
544 | 656 |
|
545 |
| - |
546 | 657 | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
547 | 658 | """
|
548 | 659 | Shapes:
|
|
0 commit comments