Skip to content

Commit

Permalink
Fix backends in flash_attention and gemm (#58)
Browse files Browse the repository at this point in the history
Summary:
To run PT2 cutlass backend, we have to add a cutlass submodule that has the same version as pytorch: https://github.com/pytorch/pytorch/tree/main/third_party

The version points to
https://github.com/NVIDIA/cutlass/tree/bbe579a9e3beb6ea6626d9227ec32d0dae119a49 which is 9 months old.
The FBGEMM cutlass is much newer.

Pull Request resolved: #58

Test Plan:
```
$ python run.py --op gemm --mode fwd --only pt2_cutlass_matmul --num-inputs 1
      (M, N, K)    pt2_cutlass_matmul-speedup    pt2_cutlass_matmul-tflops    pt2_cutlass_matmul-gbps
---------------  ----------------------------  ---------------------------  -------------------------
(256, 256, 256)                                                    3.51871                    41.2349
```

Fixes #17

Reviewed By: FindHao

Differential Revision: D66211890

Pulled By: xuzhao9

fbshipit-source-id: 995b0280c138adfb6c6c959c1bdc3c92cad05369
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 20, 2024
1 parent 23f5346 commit 17b38a4
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 49 deletions.
5 changes: 4 additions & 1 deletion .ci/tritonbench/test-gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ fi

. "${SETUP_SCRIPT}"

# FIXME: patch hstu
# FIXME: patch and install hstu
sudo apt-get install -y patch
python install.py --hstu

# FIXME: install colfax
python install.py --colfax

python -m unittest test.test_gpu.main -v
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@
[submodule "submodules/xformers"]
path = submodules/xformers
url = https://github.com/facebookresearch/xformers.git
[submodule "submodules/cutlass"]
path = submodules/cutlass
url = https://github.com/NVIDIA/cutlass.git
24 changes: 8 additions & 16 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ def test_fbgemm():
print("OK")


def install_cutlass():
from tools.cutlass_kernels.install import install_colfax_cutlass

install_colfax_cutlass()


def install_fa2(compile=False):
if compile:
# compile from source (slow)
Expand All @@ -83,12 +77,6 @@ def install_liger():
subprocess.check_call(cmd)


def install_tk():
from tools.tk.install import install_tk

install_tk()


def install_xformers():
os_env = os.environ.copy()
os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a"
Expand All @@ -101,7 +89,7 @@ def install_xformers():
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU")
parser.add_argument(
"--cutlass", action="store_true", help="Install optional CUTLASS kernels"
"--colfax", action="store_true", help="Install optional Colfax CUTLASS kernels"
)
parser.add_argument(
"--fa2", action="store_true", help="Install optional flash_attention 2 kernels"
Expand Down Expand Up @@ -139,14 +127,18 @@ def install_xformers():
if args.fa3 or args.all:
logger.info("[tritonbench] installing fa3...")
install_fa3()
if args.cutlass or args.all:
logger.info("[tritonbench] installing cutlass-kernels...")
install_cutlass()
if args.colfax or args.all:
logger.info("[tritonbench] installing colfax cutlass-kernels...")
from tools.cutlass_kernels.install import install_colfax_cutlass

install_colfax_cutlass()
if args.jax or args.all:
logger.info("[tritonbench] installing jax...")
install_jax()
if args.tk or args.all:
logger.info("[tritonbench] installing thunderkittens...")
from tools.tk.install import install_tk

install_tk()
if args.liger or args.all:
logger.info("[tritonbench] installing liger-kernels...")
Expand Down
1 change: 1 addition & 0 deletions submodules/cutlass
Submodule cutlass added at bbe579
12 changes: 7 additions & 5 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Ops that we can skip the unit tests
SKIP_OPS = {
"test_op",
}
# Ops that we run forward only
FWD_ONLY_OPS = skip_tests.get("fwd_only_ops", [])

TEST_OPERATORS = set(list_operators_by_collection(op_collection="default")) - SKIP_OPS
TEST_OPERATORS = set(list_operators_by_collection(op_collection="default"))


def check_ci_output(op):
Expand Down Expand Up @@ -67,6 +65,8 @@ def _run_one_operator(args: List[str]):
op.run()
check_ci_output(op)
# Test backward (if applicable)
if tb_args.op in FWD_ONLY_OPS:
return
if op.has_bwd():
del op
tb_args.mode = "bwd"
Expand All @@ -89,6 +89,8 @@ def _run_operator_in_task(op: str, args: List[str]):
task.run()
task.check_output()
# Test backward (if applicable)
if op in FWD_ONLY_OPS:
return
if task.get_attribute("has_bwd", method=True):
task.del_op_instance()
args.extend(["--bwd"])
Expand Down
9 changes: 3 additions & 6 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ bf16xint16_gemm:
# LLVM ERROR: mma16816 data type not supported
- bf16xint16
flash_attention:
# FIXME: enable colfax_cutlass and tk
- xformers
- xformers_splitk
- colfax_cutlass
# thunderkittens cannot handle the default input shapes
- tk
# triton_tutorial_* kernels require triton-main
- triton_tutorial_flash_v2
Expand All @@ -37,8 +34,6 @@ gemm:
- triton_tma_persistent_cached_matmul
- hstu_triton_matmul
- colfax_cutlass_matmul
# FIXME: PT2 CUTLASS backend failed
- pt2_cutlass_matmul
# jagged tests are slow, so disable them in OSS
jagged_layer_norm:
jagged_mean:
Expand All @@ -47,3 +42,5 @@ jagged_sum:
ragged_attention:
- hstu_triton_ragged_attention_persistent
test_op:
fwd_only_ops:
- flash_attention
9 changes: 4 additions & 5 deletions tools/cutlass_kernels/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
"/usr/local/cuda" if "CUDA_HOME" not in os.environ else os.environ["CUDA_HOME"]
)
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu")
FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("external", "cutlass")
TORCH_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass")
COLFAX_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass-kernels")
COLFAX_CUTLASS_TRITONBENCH_PATH = REPO_PATH.joinpath("tools", "cutlass_kernels")

Expand Down Expand Up @@ -41,9 +40,9 @@
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('lib').resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}",
f"-I{str(TORCH_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(TORCH_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}",
f"-I{str(TORCH_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}",
f"-I{CUDA_HOME}/include",
f"-I{str(TORCH_BASE_PATH.joinpath('include').resolve())}",
f"-I{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('include').resolve())}",
Expand Down
24 changes: 9 additions & 15 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,17 @@
)
except (ImportError, IOError, AttributeError):
HAS_FLASH_V3 = False
pass

# [Optional] xformers backend
try:
import xformers # @manual=//fair/xformers:xformers
import xformers.ops.fmha as xformers_fmha # @manual=//fair/xformers:xformers

from .test_fmha_utils import permute_qkv

HAS_XFORMERS = True
except (ImportError, IOError, AttributeError):
pass
HAS_XFORMERS = False

# [Optional] colfax cutlass backend
try:
Expand Down Expand Up @@ -122,7 +123,6 @@
tk_fwd = torch.ops.tk
except (ImportError, IOError, AttributeError):
tk_fwd = None
tk_fwd_causal = None

from typing import Any, Generator, List

Expand All @@ -145,9 +145,6 @@ def parse_op_args(args: List[str]):
parser.add_argument("--n-heads", type=int, default=48, help="Number of heads")
parser.add_argument("--d-head", type=int, default=64, help="specify head dimension")
parser.add_argument("--causal", action="store_true", help="enable causal")
parser.add_argument(
"--xformers-splitk", action="store_true", help="benchmark xformers-split impl"
)
return parser.parse_args(args)


Expand All @@ -168,7 +165,6 @@ def __init__(
self.N_CTX = None
self.causal = args.causal
self.sm_scale = 1.3
self.xformers_splitk = args.xformers_splitk

@register_benchmark()
def aten(
Expand Down Expand Up @@ -335,7 +331,7 @@ def xformers_preprocess(
)
return fhma_input

@register_benchmark(enabled=False)
@register_benchmark(enabled=HAS_XFORMERS)
def xformers(
self,
q: torch.Tensor,
Expand All @@ -346,7 +342,7 @@ def xformers(
xformers_cutlass_fhma = xformers.ops.fmha.cutlass.FwOp
return lambda: xformers_cutlass_fhma().apply(fhma_input, needs_gradient=False)

@register_benchmark(enabled=False)
@register_benchmark(enabled=HAS_XFORMERS)
def xformers_splitk(
self,
q: torch.Tensor,
Expand All @@ -364,7 +360,7 @@ def colfax_cutlass_preprocess(self, q, k, v):
torch.transpose(v, 1, 2),
)

@register_benchmark(enabled=False)
@register_benchmark(enabled=bool(colfax_cutlass_fmha is not None))
def colfax_cutlass(self, q, k, v):
default_scale = 1.0 / math.sqrt(float(self.D_HEAD))
colfax_q, colfax_k, colfax_v = self.colfax_cutlass_preprocess(q, k, v)
Expand All @@ -378,15 +374,13 @@ def colfax_cutlass(self, q, k, v):
default_scale,
)

@register_benchmark(enabled=False)
@register_benchmark(enabled=bool(tk_fwd is not None))
def tk(self, q, k, v):
o = torch.zeros_like(v)
l_tensor = torch.zeros_like(o).to(torch.float32)

def tk_dispatcher():
if self.causal:
tk_fwd_causal.attention_forward_causal(q, k, v, o)
else:
tk_fwd.attention_forward(q, k, v, o)
tk_fwd.attention_forward(q, k, v, o, l_tensor, causal=self.causal)
return o

return tk_dispatcher
Expand Down
14 changes: 14 additions & 0 deletions tritonbench/utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from typing import Optional

from tritonbench.utils.path_utils import REPO_PATH

log = logging.getLogger(__name__)

MAIN_RANDOM_SEED = 1337
Expand All @@ -22,6 +24,18 @@
]


def set_env():
# set cutlass dir
# by default we use the cutlass version built with pytorch
import torch

current_cutlass_dir = torch._inductor.config.cuda.cutlass_dir
if not os.path.exists(current_cutlass_dir):
tb_cutlass_dir = REPO_PATH.joinpath("submodules", "cutlass")
if tb_cutlass_dir.is_dir():
torch._inductor.config.cuda.cutlass_dir = str(tb_cutlass_dir)


def set_random_seed():
"""Make torch manual seed deterministic. Helps with accuracy testing."""
import random
Expand Down
4 changes: 3 additions & 1 deletion tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tritonbench.utils.env_utils import (
apply_precision,
fresh_triton_cache,
set_env,
set_random_seed,
)
from tritonbench.utils.input import input_cast
Expand Down Expand Up @@ -561,6 +562,7 @@ class BenchmarkOperator(metaclass=PostInitProcessor):
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
set_env()
set_random_seed()
self.name = _find_op_name_from_module_path(self.__class__.__module__)
self._raw_extra_args = copy.deepcopy(extra_args)
Expand Down Expand Up @@ -619,7 +621,7 @@ def __post__init__(self):

def _get_bm_func(self, bm_func_name: str):
fwd_fn_lambda = getattr(self, bm_func_name, None)
assert fwd_fn_lambda, (
assert callable(fwd_fn_lambda), (
f"Could not find benchmark {bm_func_name} registered in {self.name}. "
f"Available benchmarks: {REGISTERED_BENCHMARKS[self.name].keys()}. "
)
Expand Down

0 comments on commit 17b38a4

Please sign in to comment.