Skip to content

Commit

Permalink
Update HSTU and use the OSS wrapper for non-persisent kernels (#53)
Browse files Browse the repository at this point in the history
Summary:
Add the following features to ragged_hstu:

1. Add tflops metric
2. Use _RaggedAttentionRelativeBiasFunction to wrap the Triton kernel
3. Add backward

Pull Request resolved: #53

Test Plan:
```
$ python run.py --op ragged_attention --metrics latency,tflops --mode bwd

            x_val    hstu_triton_ragged_attention-tflops    hstu_triton_ragged_attention-latency
-----------------  -------------------------------------  --------------------------------------
(8, 4, 512, 2048)                               0.306747                                2.81939
(8, 4, 512, 2048)                               1.65614                                 0.867936
(8, 4, 512, 2048)                               2.00125                                 0.84768
(8, 4, 512, 2048)                               2.13756                                 0.991968
(8, 4, 512, 2048)                               1.96315                                 0.902976
(8, 4, 512, 2048)                               1.50214                                 0.836192
(8, 4, 512, 2048)                               1.34825                                 0.859936
(8, 4, 512, 2048)                               1.90546                                 0.97408
(8, 4, 512, 2048)                               1.72114                                 0.902368
(8, 4, 512, 2048)                               2.30999                                 1.01107
```

Reviewed By: manman-ren

Differential Revision: D66021701

Pulled By: xuzhao9

fbshipit-source-id: b0d9f32d49e02c113e4aafa597be68c17d952283
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 15, 2024
1 parent c08a2a8 commit c2ef66a
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 51 deletions.
6 changes: 6 additions & 0 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def install_xformers():
parser.add_argument(
"--fa3", action="store_true", help="Install optional flash_attention 3 kernels"
)
parser.add_argument("--hstu", action="store_true", help="Install HSTU.")
parser.add_argument("--jax", action="store_true", help="Install jax nightly")
parser.add_argument("--tk", action="store_true", help="Install ThunderKittens")
parser.add_argument("--liger", action="store_true", help="Install Liger-kernel")
Expand Down Expand Up @@ -153,6 +154,11 @@ def install_xformers():
if args.xformers or args.all:
logger.info("[tritonbench] installing xformers...")
install_xformers()
if args.hstu or args.all:
logger.info("[tritonbench] installing hstu...")
from tools.hstu.install import install_hstu

install_hstu()
logger.info("[tritonbench] installation complete!")
# run tests to check installation
if args.test:
Expand Down
13 changes: 13 additions & 0 deletions tools/hstu/hstu.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
diff --git a/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py b/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
index b4e318b..d6bc894 100644
--- a/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
+++ b/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
@@ -36,7 +36,7 @@ try:
VersionedSpec,
)
except ImportError:
- from hammer.oss.generative_recommenders.ops.triton.utils import (
+ from generative_recommenders.ops.triton.utils import (
_switch_to_contiguous_if_needed,
autotune_max_seq_len,
NamedSpecType,
34 changes: 34 additions & 0 deletions tools/hstu/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import subprocess
import sys
from pathlib import Path

PATCH_DIR = str(
Path(__file__)
.parent.parent.parent.joinpath("submodules", "generative-recommenders")
.absolute()
)
PATCH_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "hstu.patch")


def install_hstu():
try:
subprocess.check_output(
[
"patch",
"-p1",
"--forward",
"-i",
PATCH_FILE,
"-r",
"/tmp/rej",
],
cwd=PATCH_DIR,
)
except subprocess.SubprocessError as e:
output_str = str(e.output)
if "previously applied" in output_str:
return
else:
print(str(output_str))
sys.exit(1)
96 changes: 50 additions & 46 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
)
except ModuleNotFoundError:
# OSS Import
import importlib
with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))):
from generative_recommenders.ops.triton import triton_ragged_hstu_attention

with add_path(str(SUBMODULE_PATH)):
triton_ragged_hstu_attention = importlib.import_module(
"generative-recommenders.ops.triton.triton_ragged_hstu_attention"
)
_ragged_hstu_attn_fwd = triton_ragged_hstu_attention._ragged_hstu_attn_fwd
_ragged_hstu_attn_fwd_persistent = (
triton_ragged_hstu_attention._ragged_hstu_attn_fwd_persistent
)
_RaggedAttentionRelativeBiasFunction = (
triton_ragged_hstu_attention._RaggedAttentionRelativeBiasFunction
)

@torch.fx.wrap
def prev_power_of_2(x: int) -> int:
Expand All @@ -47,6 +46,7 @@ def __init__(
num_heads,
max_seq_len,
num_buckets,
requires_grad,
persistent_kernel: bool = False,
) -> None:
super().__init__()
Expand All @@ -58,13 +58,17 @@ def __init__(
torch.randn(
(self.num_buckets + 1,),
dtype=torch.bfloat16,
).cuda()
)
.requires_grad_(requires_grad)
.cuda()
)
self.all_pos_weights = torch.nn.Parameter(
torch.randn(
(2 * self.max_seq_len - 1,),
dtype=torch.bfloat16,
).cuda()
)
.requires_grad_(requires_grad)
.cuda()
)
self.persistent_kernel = persistent_kernel

Expand Down Expand Up @@ -141,57 +145,57 @@ def forward(
"HAS_SORT_BY_LENGTH_INDICES": False,
"sort_by_length_indices": None,
}
if not IS_FBCODE:
del kwargs["MAX_ATTN_LEN"]
del kwargs["HAS_CONTEXTUAL_SEQ_LEN"]
del kwargs["contextual_seq_len"]
del kwargs["HAS_SORT_BY_LENGTH_INDICES"]
del kwargs["sort_by_length_indices"]
kwargs["HAS_MAX_ATTN_LEN"] = False
kwargs["max_attn_len"] = 0

if self.persistent_kernel:
grid = (1216,)
_ragged_hstu_attn_fwd_persistent[grid](**kwargs)
else:
grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
out = _RaggedAttentionRelativeBiasFunction.apply(
self.max_seq_len, # N
kwargs["alpha"],
q,
k,
v,
kwargs["seq_offsets"],
kwargs["INVALID_MASK_TYPE"],
timestamps,
self.all_ts_weights, # ts_weights
self.all_pos_weights, # pos_weights
kwargs["CAUSAL"], # causal,
kwargs["num_buckets"], # num_buckets
"sqrt", # time_bucket_fn
kwargs["time_bucket_incr"], # time_bucket_incr
kwargs["time_bucket_div"], # time_bucket_div
kwargs["time_delta"], # time_delta
kwargs["max_pos_ind"], # max_pos_ind
kwargs["num_targets"],
None, # attn_scale
kwargs["ATTN_BIAS_TYPE"], # relative_bias_type
kwargs["MAX_ATTN_LEN"], # max_attn_len
kwargs["contextual_seq_len"], # contextual_seq_len
kwargs["sort_by_length_indices"], # sort_by_length
)
_ragged_hstu_attn_fwd[grid](**kwargs)

return out


def get_test_inputs(
batch_size, num_heads, max_seq_len
batch_size, num_heads, max_seq_len, requires_grad
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
timestamp_deltas: torch.Tensor = (
torch.randint(
86400,
size=(batch_size, max_seq_len + 1),
)
.requires_grad_(False)
.cuda()
)
timestamp_deltas: torch.Tensor = torch.randint(
86400,
size=(batch_size, max_seq_len + 1),
).cuda()
timestamps = timestamp_deltas.cumsum(dim=1)

lengths = (
torch.randint(
max_seq_len + 1,
size=(batch_size,),
)
.requires_grad_(False)
.cuda()
)
seq_offsets = (
torch.zeros(
(batch_size + 1,),
dtype=torch.int64,
)
.requires_grad_(False)
.cuda()
)
lengths = torch.randint(
max_seq_len + 1,
size=(batch_size,),
).cuda()
seq_offsets = torch.zeros(
(batch_size + 1,),
dtype=torch.int64,
).cuda()
seq_offsets[1:] = torch.cumsum(
lengths,
dim=0,
Expand All @@ -203,7 +207,7 @@ def get_test_inputs(
(L, num_heads, 512),
dtype=torch.bfloat16,
)
.requires_grad_(False)
.requires_grad_(requires_grad)
.cuda()
)
return qkv, seq_offsets, timestamps
66 changes: 62 additions & 4 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import argparse

from typing import List, Optional
from typing import Any, Callable, List, Optional

from tritonbench.utils.triton_op import BenchmarkOperator, register_benchmark
import torch
from tritonbench.utils.input import input_filter

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Mode,
register_benchmark,
register_metric,
)

from .hstu import get_test_inputs, RaggedHSTUAttn

Expand Down Expand Up @@ -30,6 +39,7 @@ def __init__(
self.num_buckets = args.num_buckets
# set a default number of inputs
self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)

@register_benchmark()
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
Expand All @@ -38,17 +48,20 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.requires_grad,
persistent_kernel=False,
)
return lambda: attn(qkv, seq_offsets, timestamps)

@register_benchmark()
# TODO: enable persistent kernels when the OSS backward is ready
@register_benchmark(enabled=False)
def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.requires_grad,
persistent_kernel=True,
)
return lambda: attn(qkv, seq_offsets, timestamps)
Expand All @@ -58,5 +71,50 @@ def get_x_val(self, example_inputs):

def get_input_iter(self):
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len)
inputs = get_test_inputs(
self.batch_size, self.num_heads, self.max_seq_len, self.requires_grad
)
yield inputs

def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
o = fwd_fn()
o_tensor = input_filter(
lambda x: isinstance(x, torch.Tensor),
o,
)
do = torch.rand_like(o_tensor)
fn = lambda: o_tensor.backward(do, retain_graph=True)
return fn

@register_metric()
def tflops(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> float:
ratio = 2.0 # triangular masking
f1 = 0.0
f2 = 0.0
jagged = True
qkv, seq_offsets, timestamps = example_inputs
q = qkv[:, :, :128]
v = qkv[:, :, 256:384]
_, nheads, attn_dim = q.shape
_, _, hidden_dim = v.shape
max_seqlen = timestamps.size(1) - 1

for i in range(self.batch_size):
seq_len = (
int((seq_offsets[i + 1] - seq_offsets[i]).item())
if jagged
else max_seqlen
)
# (QK^T), dQ = d(QK^T)K, dK^T = Q^Td(QK^T)
f1 += 2 * self.num_heads * attn_dim * seq_len**2 // ratio
# (QK^T)V, d(QK^T) = dOV^T, dV = (QK^T)^TdO,
f2 += 2 * self.num_heads * hidden_dim * seq_len**2 // ratio
if self.mode == Mode.FWD:
tflops = f1 + f2 # computes (QK^T) and (QK^T)V
elif self.mode == Mode.BWD:
tflops = 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T)
elif self.mode == Mode.FWD_BWD:
tflops = 4 * f1 + 3 * f2
return tflops / metrics.latency * 1e-9

0 comments on commit c2ef66a

Please sign in to comment.