Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference_cost_matmul: Confusion or Bug regarding the MAC-count of Scaled Dot-Product Attention #60

Open
4 tasks done
iksnagreb opened this issue Jun 28, 2023 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@iksnagreb
Copy link
Contributor

Prerequisites

Please make sure to check off these prerequisites before submitting a bug report.

  • Test that the bug appears on the current version of the main branch. Make sure to include the commit hash of the commit you checked out. 6ca8f8e
  • Check that the issue hasn't already been reported, by checking the currently open issues.
  • If there are steps to reproduce the problem, make sure to write them down below.
  • If relevant, please include the ONNX files, which were created directly before and/or after the bug.

Quick summary

While working on our characterization of the transformer data-flow we encountered some discrepancies when validating against the QONNX inference_cost estimations of the MatMul operator within the attention mechanism. We are not entirely sure whether this is indeed a bug on the QONNX side or still some confusion/error on our side. Thus we would like to start a discussion to understand this issue.

Details

Multi-Head Scaled Dot-Product Attention involves two consecutive MatMul operations where both inputs dynamically depend on the model inputs. The heads are independent of each other and typically treated in a way similar to a batch dimension. Our cost model assumes HxTxTxd MAC operations for each of the two MatMuls, i.e. H heads each producing a TxT attention matrix (T is the sequence length) where each element is the result of a d-dimensional dot-product. However, the QONNX analysis function inference_cost_matmul seems to be off by an additional factor of H (i.e. HxHxTxTxd), indicating the heads are not treated like a batch dimension.

My suspicion is further raised by the following lines from the QONNX inference_cost_matmul function:

# exclude common dim (last axis) from one side to avoid duplication
n_macs = np.prod(i_shape[:-1]) * np.prod(w_shape)

Is this actually always the case? At least for the model graph I have attached it seems like the last axis is not the common dimension.

In the following, I provide a minimal working example of a scaled dot-product attention in isolation in PyTorch exporting to an ONNX graph. I have also attached the already preprocessed graph which in particular already includes the InferShapes transform. Note that running the qonnx.util.inference_cost script on the PyTorch ONNX export breaks at the FoldConstants transform due to IndexError which is probably unrelated and should be investigated separately (I have "fixed" it by removing that transformation step for now).

Steps to Reproduce

The following code produces a minimal example of scaled dot-product attention and exports to ONNX.

import torch


# Minimal working example of the Scaled Dot-Product Attention mechanism
class ScaleDotProductAttention(torch.nn.Module):
    # Initializes the module parameters
    def __init__(self, num_heads):
        # Initialize the PyTorch base Module
        super().__init__()
        # Set the number of attention heads
        self.num_heads = num_heads

    # Forward pass computing scaled dot-product attention between q, k and v
    def forward(self, q, k, v):
        # Assume the most simple case of q, k and v all having the same
        # dimensions
        assert q.shape == k.shape == v.shape, \
            "Q, K and V must have the same shape"
        # Embedding dimension must be divisible by number of heads
        assert q.shape[-1] % self.num_heads == 0, \
            f"Dimensions must be divisible by heads ({self.num_heads})"

        # Assume sequence first layout and get the sizes per axis
        s, b, d = q.shape
        # Number of heads and dimension per head
        n_head, d_head = self.num_heads, d // self.num_heads

        # Reshape tensors to treat the heads like batch dimensions
        q = q.reshape(s, b, n_head, d_head).reshape(s, b * n_head, d_head)
        k = k.reshape(s, b, n_head, d_head).reshape(s, b * n_head, d_head)
        v = v.reshape(s, b, n_head, d_head).reshape(s, b * n_head, d_head)
        # Compute the not-yet-normalized attentions matrices for each head.
        #   Note: permute brings batch x heads to front and transposes k
        a = torch.matmul(q.permute(1, 0, 2), k.permute(1, 2, 0))
        # Scale and normalize the attention matrix
        a = torch.softmax(a * (d_head ** -0.5), dim=-1)
        # Apply the attention matrices to the value projection
        #   Note: Second permute brings sequence dimension back to front
        o = torch.matmul(a, v.permute(1, 0, 2)).permute(1, 0, 2)
        # Reshape heads into feature dimension
        o = o.reshape(s, b, n_head, d_head).reshape(s, b, n_head * d_head)

        # Return the scaled dot-product attention output
        return o


# Script entrypoint
if __name__ == '__main__':
    # Instantiate a scale dot-product attention with 4 attention heads
    sdp = ScaleDotProductAttention(num_heads=4)
    # Generate random query, key and value tensors
    #   Note: Sequence of length 64, single instance batch, 128 dim embeddings
    q, k, v = torch.randn(3, 64, 1, 128)
    # Export the attention module to ONNX
    torch.onnx.export(sdp, args=(q, k, v), f='sdp.onnx')

Get MAC operation counts by running

python -m qonnx.util.inference_cost sdp.onnx

Outputs something like

{'op_mac_FLOAT32_FLOAT32': 4194304.0, 'mem_w_FLOAT32': 0.0, 'mem_o_FLOAT32': 24576.0, 'unsupported': "{'Softmax', 'Pow', 'Constant'}", 'discount_sparsity': True, 'total_bops': 4294967296.0, 'total_mem_w_bits': 0.0, 'total_mem_o_bits': 786432.0}

Expected behavior

According to our cost model, the MAC count should be 2x HxTxTxd, which for the given example model is 2x 4x64x64x32 = 1048576.

Actual behavior

The MAC count is reported as 4194304, which is 4x (Hx) our expectation, indicating a cost function of 2x HxHxTxTxd.

Attached ONNX Graph

sdp.costs.onnx.zip
sdp onnx

@iksnagreb iksnagreb added the bug Something isn't working label Jun 28, 2023
@iksnagreb iksnagreb changed the title inference_cost_matmul: Confusion or Bug regarding the of Scaled Dot-Product Attention inference_cost_matmul: Confusion or Bug regarding the MAC-count of Scaled Dot-Product Attention Jun 28, 2023
@maltanar maltanar self-assigned this Jul 7, 2023
@Harsh9650
Copy link
Collaborator

Harsh9650 commented Nov 21, 2023

Hello Christoph, Thanks for highlighting the issue. We've addressed this issue at #90 . Please try it and let us know if you encounter any further issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants