Skip to content

Fix backends in flash_attention and gemm #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
5 changes: 4 additions & 1 deletion .ci/tritonbench/test-gpu.sh
Original file line number Diff line number Diff line change
@@ -8,8 +8,11 @@ fi

. "${SETUP_SCRIPT}"

# FIXME: patch hstu
# FIXME: patch and install hstu
sudo apt-get install -y patch
python install.py --hstu

# FIXME: install colfax
python install.py --colfax

python -m unittest test.test_gpu.main -v
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -19,3 +19,6 @@
[submodule "submodules/xformers"]
path = submodules/xformers
url = https://github.com/facebookresearch/xformers.git
[submodule "submodules/cutlass"]
path = submodules/cutlass
url = https://github.com/NVIDIA/cutlass.git
24 changes: 8 additions & 16 deletions install.py
Original file line number Diff line number Diff line change
@@ -52,12 +52,6 @@ def test_fbgemm():
print("OK")


def install_cutlass():
from tools.cutlass_kernels.install import install_colfax_cutlass

install_colfax_cutlass()


def install_fa2(compile=False):
if compile:
# compile from source (slow)
@@ -83,12 +77,6 @@ def install_liger():
subprocess.check_call(cmd)


def install_tk():
from tools.tk.install import install_tk

install_tk()


def install_xformers():
os_env = os.environ.copy()
os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a"
@@ -101,7 +89,7 @@ def install_xformers():
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU")
parser.add_argument(
"--cutlass", action="store_true", help="Install optional CUTLASS kernels"
"--colfax", action="store_true", help="Install optional Colfax CUTLASS kernels"
)
parser.add_argument(
"--fa2", action="store_true", help="Install optional flash_attention 2 kernels"
@@ -139,14 +127,18 @@ def install_xformers():
if args.fa3 or args.all:
logger.info("[tritonbench] installing fa3...")
install_fa3()
if args.cutlass or args.all:
logger.info("[tritonbench] installing cutlass-kernels...")
install_cutlass()
if args.colfax or args.all:
logger.info("[tritonbench] installing colfax cutlass-kernels...")
from tools.cutlass_kernels.install import install_colfax_cutlass

install_colfax_cutlass()
if args.jax or args.all:
logger.info("[tritonbench] installing jax...")
install_jax()
if args.tk or args.all:
logger.info("[tritonbench] installing thunderkittens...")
from tools.tk.install import install_tk

install_tk()
if args.liger or args.all:
logger.info("[tritonbench] installing liger-kernels...")
1 change: 1 addition & 0 deletions submodules/cutlass
Submodule cutlass added at bbe579
12 changes: 7 additions & 5 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
@@ -30,12 +30,10 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Ops that we can skip the unit tests
SKIP_OPS = {
"test_op",
}
# Ops that we run forward only
FWD_ONLY_OPS = skip_tests.get("fwd_only_ops", [])

TEST_OPERATORS = set(list_operators_by_collection(op_collection="default")) - SKIP_OPS
TEST_OPERATORS = set(list_operators_by_collection(op_collection="default"))


def check_ci_output(op):
@@ -67,6 +65,8 @@ def _run_one_operator(args: List[str]):
op.run()
check_ci_output(op)
# Test backward (if applicable)
if tb_args.op in FWD_ONLY_OPS:
return
if op.has_bwd():
del op
tb_args.mode = "bwd"
@@ -89,6 +89,8 @@ def _run_operator_in_task(op: str, args: List[str]):
task.run()
task.check_output()
# Test backward (if applicable)
if op in FWD_ONLY_OPS:
return
if task.get_attribute("has_bwd", method=True):
task.del_op_instance()
args.extend(["--bwd"])
9 changes: 3 additions & 6 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
@@ -8,10 +8,7 @@ bf16xint16_gemm:
# LLVM ERROR: mma16816 data type not supported
- bf16xint16
flash_attention:
# FIXME: enable colfax_cutlass and tk
- xformers
- xformers_splitk
- colfax_cutlass
# thunderkittens cannot handle the default input shapes
- tk
# triton_tutorial_* kernels require triton-main
- triton_tutorial_flash_v2
@@ -37,8 +34,6 @@ gemm:
- triton_tma_persistent_cached_matmul
- hstu_triton_matmul
- colfax_cutlass_matmul
# FIXME: PT2 CUTLASS backend failed
- pt2_cutlass_matmul
# jagged tests are slow, so disable them in OSS
jagged_layer_norm:
jagged_mean:
@@ -47,3 +42,5 @@ jagged_sum:
ragged_attention:
- hstu_triton_ragged_attention_persistent
test_op:
fwd_only_ops:
- flash_attention
9 changes: 4 additions & 5 deletions tools/cutlass_kernels/install.py
Original file line number Diff line number Diff line change
@@ -8,8 +8,7 @@
"/usr/local/cuda" if "CUDA_HOME" not in os.environ else os.environ["CUDA_HOME"]
)
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu")
FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("external", "cutlass")
TORCH_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass")
COLFAX_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass-kernels")
COLFAX_CUTLASS_TRITONBENCH_PATH = REPO_PATH.joinpath("tools", "cutlass_kernels")

@@ -41,9 +40,9 @@
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('lib').resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}",
f"-I{str(TORCH_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(TORCH_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}",
f"-I{str(TORCH_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}",
f"-I{CUDA_HOME}/include",
f"-I{str(TORCH_BASE_PATH.joinpath('include').resolve())}",
f"-I{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('include').resolve())}",
24 changes: 9 additions & 15 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
@@ -85,16 +85,17 @@
)
except (ImportError, IOError, AttributeError):
HAS_FLASH_V3 = False
pass

# [Optional] xformers backend
try:
import xformers # @manual=//fair/xformers:xformers
import xformers.ops.fmha as xformers_fmha # @manual=//fair/xformers:xformers

from .test_fmha_utils import permute_qkv

HAS_XFORMERS = True
except (ImportError, IOError, AttributeError):
pass
HAS_XFORMERS = False

# [Optional] colfax cutlass backend
try:
@@ -122,7 +123,6 @@
tk_fwd = torch.ops.tk
except (ImportError, IOError, AttributeError):
tk_fwd = None
tk_fwd_causal = None

from typing import Any, Generator, List

@@ -145,9 +145,6 @@ def parse_op_args(args: List[str]):
parser.add_argument("--n-heads", type=int, default=48, help="Number of heads")
parser.add_argument("--d-head", type=int, default=64, help="specify head dimension")
parser.add_argument("--causal", action="store_true", help="enable causal")
parser.add_argument(
"--xformers-splitk", action="store_true", help="benchmark xformers-split impl"
)
return parser.parse_args(args)


@@ -168,7 +165,6 @@ def __init__(
self.N_CTX = None
self.causal = args.causal
self.sm_scale = 1.3
self.xformers_splitk = args.xformers_splitk

@register_benchmark()
def aten(
@@ -335,7 +331,7 @@ def xformers_preprocess(
)
return fhma_input

@register_benchmark(enabled=False)
@register_benchmark(enabled=HAS_XFORMERS)
def xformers(
self,
q: torch.Tensor,
@@ -346,7 +342,7 @@ def xformers(
xformers_cutlass_fhma = xformers.ops.fmha.cutlass.FwOp
return lambda: xformers_cutlass_fhma().apply(fhma_input, needs_gradient=False)

@register_benchmark(enabled=False)
@register_benchmark(enabled=HAS_XFORMERS)
def xformers_splitk(
self,
q: torch.Tensor,
@@ -364,7 +360,7 @@ def colfax_cutlass_preprocess(self, q, k, v):
torch.transpose(v, 1, 2),
)

@register_benchmark(enabled=False)
@register_benchmark(enabled=bool(colfax_cutlass_fmha is not None))
def colfax_cutlass(self, q, k, v):
default_scale = 1.0 / math.sqrt(float(self.D_HEAD))
colfax_q, colfax_k, colfax_v = self.colfax_cutlass_preprocess(q, k, v)
@@ -378,15 +374,13 @@ def colfax_cutlass(self, q, k, v):
default_scale,
)

@register_benchmark(enabled=False)
@register_benchmark(enabled=bool(tk_fwd is not None))
def tk(self, q, k, v):
o = torch.zeros_like(v)
l_tensor = torch.zeros_like(o).to(torch.float32)

def tk_dispatcher():
if self.causal:
tk_fwd_causal.attention_forward_causal(q, k, v, o)
else:
tk_fwd.attention_forward(q, k, v, o)
tk_fwd.attention_forward(q, k, v, o, l_tensor, causal=self.causal)
return o

return tk_dispatcher
14 changes: 14 additions & 0 deletions tritonbench/utils/env_utils.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,8 @@

from typing import Optional

from tritonbench.utils.path_utils import REPO_PATH

log = logging.getLogger(__name__)

MAIN_RANDOM_SEED = 1337
@@ -22,6 +24,18 @@
]


def set_env():
# set cutlass dir
# by default we use the cutlass version built with pytorch
import torch

current_cutlass_dir = torch._inductor.config.cuda.cutlass_dir
if not os.path.exists(current_cutlass_dir):
tb_cutlass_dir = REPO_PATH.joinpath("submodules", "cutlass")
if tb_cutlass_dir.is_dir():
torch._inductor.config.cuda.cutlass_dir = str(tb_cutlass_dir)


def set_random_seed():
"""Make torch manual seed deterministic. Helps with accuracy testing."""
import random
4 changes: 3 additions & 1 deletion tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
from tritonbench.utils.env_utils import (
apply_precision,
fresh_triton_cache,
set_env,
set_random_seed,
)
from tritonbench.utils.input import input_cast
@@ -561,6 +562,7 @@ class BenchmarkOperator(metaclass=PostInitProcessor):
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
set_env()
set_random_seed()
self.name = _find_op_name_from_module_path(self.__class__.__module__)
self._raw_extra_args = copy.deepcopy(extra_args)
@@ -619,7 +621,7 @@ def __post__init__(self):

def _get_bm_func(self, bm_func_name: str):
fwd_fn_lambda = getattr(self, bm_func_name, None)
assert fwd_fn_lambda, (
assert callable(fwd_fn_lambda), (
f"Could not find benchmark {bm_func_name} registered in {self.name}. "
f"Available benchmarks: {REGISTERED_BENCHMARKS[self.name].keys()}. "
)