Skip to content

Helpful tools and examples for working with flex-attention

License

Notifications You must be signed in to change notification settings

drisspg/attention-gym

Repository files navigation

Attention Gym

Attention Gym is a collection of helpful tools and examples for working with flex-attention

Overview

This repository aims to provide a playground for experimenting with various attention mechanisms using the FlexAttention API. It includes implementations of different attention variants, performance comparisons, and utility functions to help researchers and developers explore and optimize attention mechanisms in their models.

Features

  • Implementations of various attention mechanisms using FlexAttention
  • Utility functions for creating and combining attention masks
  • Examples of how to use FlexAttention in real-world scenarios

Getting Started

Prerequisites

  • PyTorch (version 2.5 or higher)

Installation

git clone https://github.com/drisspg/attention-gym.git
cd attention-gym
pip install .

Usage

Here's a quick example of how to use the FlexAttention API with a custom attention mechanism:

from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from attn_gym.masks import causal_mask

# Create a causal mask
Q_LEN, KV_LEN = query.size(-2), key.size(-2)
block_mask: BlockMask = create_block_mask(causal_mask, 1, 1, Q_LEN, KV_LEN)

# Use FlexAttention with a causal mask modification
output = flex_attention(query, key, value, block_mask=causal_mask)

Examples

Check out the examples/ directory for more detailed examples of different attention mechanisms and how to implement them using FlexAttention.

Dev

Install dev requirements

pip install -e ".[dev]"

Install pre-commit hooks

pre-commit install

About

Helpful tools and examples for working with flex-attention

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages