Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different sized encoder for TransformerDecoder #182

Open
Adamits opened this issue Apr 24, 2024 · 4 comments
Open

Different sized encoder for TransformerDecoder #182

Adamits opened this issue Apr 24, 2024 · 4 comments
Assignees
Labels
enhancement New feature or request

Comments

@Adamits
Copy link
Collaborator

Adamits commented Apr 24, 2024

It would be convenient to allow the encoder output_size to be different from the TransformerDecoder embedding size. To illustrate the issue with this, the below code snippet

import torch
import math

def generate_square_subsequent_mask(length: int) -> torch.Tensor:
        return torch.triu(torch.full((length, length), -math.inf), diagonal=1)


# INITIALIZE A TRANSFORMER WITH THIS HIDDEN AND EMBEDDING SIZE
hid=128
emb=64
decoder_layer = torch.nn.TransformerDecoderLayer(
    d_model=emb,
    dim_feedforward=hid,
    nhead=2,
    dropout=0.2,
    activation="relu",
    batch_first=True,
)
frank_transformer = torch.nn.TransformerDecoder(
    decoder_layer=decoder_layer,
    num_layers=2,
    norm=torch.nn.LayerNorm(emb),
)

# INITIALIZE TARGETS WITH EMBEDDING SIZE
# AND A FAKE ENCODER OUTPUT WITH HIDDEN SIZE
b = 4
seq_len = 10
target_embedding = torch.randn((b, seq_len, emb))
encoder_hidden = torch.randn(b, seq_len, hid)
target_sequence_length = target_embedding.size(1)
# -> seq_len x seq_len.
causal_mask = generate_square_subsequent_mask(
    seq_len
)
# -> B x seq_len x d_model.
output = frank_transformer(
    target_embedding,
    encoder_hidden,
    tgt_mask=causal_mask,
    # memory_key_padding_mask=source_mask,
    # tgt_key_padding_mask=target_mask,
)

throws:

File "test.py", line 34, in <module>
    output = frank_transformer(
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/transformer.py", line 460, in forward
    output = mod(output, memory, tgt_mask=tgt_mask,
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/transformer.py", line 847, in forward
    x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
  File "torch/nn/modules/transformer.py", line 865, in _mha_block
    x = self.multihead_attn(x, mem, mem,
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "torch/nn/functional.py", line 5300, in multi_head_attention_forward
    q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
  File "torch/nn/functional.py", line 4836, in _in_projection_packed
    kv_proj = linear(k, w_kv, b_kv)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (40x128 and 64x128)

But if I change the code such that encoder_hidden = torch.randn(b, seq_len, hid) --> encoder_hidden = torch.randn(b, seq_len, emb), then this works fine.

Essentially, we need the self-attention and multihead-attention to expect different input sizes (which may also require the layer norms to change too).

I am putting this up, and will try to work out a solution. The easiest thing for allowing this behavior in yoyodyne would be to either project the encoder output size into the decoder embedding size, or visa versa, but I feel that this changes the architecture more than necessary. Instead, I would like to consider if there is an elegant way to update the sa_block and mha_block such that it does not break other things in the transformer (e.g. layer norm).

@kylebgorman kylebgorman added the enhancement New feature or request label Apr 24, 2024
@Adamits
Copy link
Collaborator Author

Adamits commented Apr 24, 2024

I thought about this more. Since the residual layers in transformers are just summing self_attn and mha_attention (with layer norm in between), I don't think we can make this update without fundamentally changing the transformer architecture (e.g. via concatenating them, or projecting one into the size of the other).

I think the best thing to do is either:

  1. Force the encoder output to be the same as the decoder embedding size (raise an error if it is not)
  2. Infer when the encoder output is different from the decoder embedding size and create an additional layer in the yoyodyne model class that projects source output into the decoder input size.

One place that 1) gives us an issue is if we want to use an LSTM encoder with a transformer decoder. Then the encoder outputs hidden_size * num_directions and the transformer expects embedding_size. This limits the shape of a valid architecture quite a bit. Not sure if that is a problem or not though.

@kylebgorman
Copy link
Contributor

I think either would be fine. This is a good example of a second type of presupposition we will want to test for before training begins.

@bonham79
Copy link
Collaborator

bonham79 commented Jun 3, 2024

I want to say there was a variant of transformer a while back that approached this problem (Sumformer maybe). But I think the ideal solution would be do simply add an additional layer perceptron to force alignment. Personally I don't think it's too much variation on transformer architecture since everyone and their grandmother creates an inhouse variant. (You'll note that no one uses PyTorch's base form.)

Regarding layer norms. What you caaan mess with is swapping out with batch norm. Bit too late for me to do the maths for the main issue, but it may give more flexibility with variations in depth.

@kylebgorman
Copy link
Contributor

@bonham79 is on point about how everyone has a slightly different transfomer variant and it's okay.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants