Skip to content

Commit

Permalink
Align default parameters with typical benchmarks (#89)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #89

This changes various default parameters and input generators to align
more closely with usage in generative-recommenders

Reviewed By: xuzhao9, adamomainz

Differential Revision: D66528330

fbshipit-source-id: 5295215ce86780412f56d4929ff040dc41992d09
  • Loading branch information
bertmaher authored and facebook-github-bot committed Dec 2, 2024
1 parent 2474f1e commit 7e1f269
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 45 deletions.
64 changes: 40 additions & 24 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def __init__(
self.all_ts_weights = torch.nn.Parameter(
torch.randn(
(self.num_buckets + 1,),
dtype=torch.bfloat16,
dtype=torch.float32,
)
.requires_grad_(requires_grad)
.cuda()
)
self.all_pos_weights = torch.nn.Parameter(
torch.randn(
(2 * self.max_seq_len - 1,),
dtype=torch.bfloat16,
dtype=torch.float32,
)
.requires_grad_(requires_grad)
.cuda()
Expand All @@ -81,17 +81,16 @@ def __init__(

def forward(
self,
qkv: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_offsets: torch.Tensor,
timestamps: torch.Tensor,
num_targets: torch.Tensor,
) -> torch.Tensor:
NUM_BUCKETS = self.num_buckets
torch._check(timestamps.size(0) + 1 == seq_offsets.size(0))

q = qkv[:, :, :128]
k = qkv[:, :, 128:256]
v = qkv[:, :, 256:384]
out = torch.zeros_like(v)

Z = timestamps.size(0)
Expand Down Expand Up @@ -134,13 +133,13 @@ def forward(
"DeltaSize": None,
"num_buckets": NUM_BUCKETS,
"max_pos_ind": None,
"time_bucket_incr": 60.0,
"time_bucket_incr": 60,
"time_bucket_div": 1.0,
"time_delta": 0.0,
"INVALID_MASK_TYPE": "lower_triangular",
"CAUSAL": True,
"BUCKET_FN": "sqrt",
"ATTN_BIAS_TYPE": "fused",
"ATTN_BIAS_TYPE": "ALL",
"USE_TIME_BIAS": False,
"USE_POS_BIAS": False,
"HAS_MAX_POS_IND": False,
Expand All @@ -150,7 +149,7 @@ def forward(
"ALLOW_TF32": True,
"BLOCK_D_Q": DimQ,
"BLOCK_D_V": DimV,
"MAX_ATTN_LEN": 0,
"MAX_ATTN_LEN": None,
"CONTEXTUAL_SEQ_LEN": 0,
"HAS_SORT_BY_LENGTH_INDICES": False,
"sort_by_length_indices": None,
Expand Down Expand Up @@ -219,27 +218,42 @@ def generate_sparse_seq_len(
)


try:
from hammer.benchmark.module_factory.hstu_utils import (
apply_SL,
generate_hstu_timestamps,
)
except ImportError:

def apply_SL(lengths: torch.Tensor, alpha: float, max_seq_len: int):
return lengths

def generate_hstu_timestamps(batch_size, seq_len):
ts = torch.rand(batch_size, seq_len + 1, device="cuda") ** -0.8
ts = torch.clamp(torch.abs(ts * 86400), max=1e7)
ts, _ = torch.sort(ts, dim=1)
return ts.long()


def get_test_inputs(
batch_size,
num_heads,
attn_dim,
hidden_dim,
max_seq_len,
sparsity,
target_size,
sort_by_length,
requires_grad,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
timestamp_deltas: torch.Tensor = torch.randint(
86400,
size=(batch_size, max_seq_len + 1),
).cuda()
timestamps = timestamp_deltas.cumsum(dim=1)

timestamps = generate_hstu_timestamps(batch_size, max_seq_len)
lengths = generate_sparse_seq_len(
size=batch_size,
max_seq_len=max_seq_len,
sparsity=sparsity,
device=torch.device("cuda"),
)
lengths = apply_SL(lengths, alpha=2.0, max_seq_len=max_seq_len)
# assume has_delta_q is False
num_targets = None
if target_size != 0:
Expand All @@ -254,19 +268,21 @@ def get_test_inputs(
seq_offsets = torch.zeros(
(batch_size + 1,),
dtype=torch.int64,
).cuda()
device="cuda",
)
seq_offsets[1:] = torch.cumsum(
lengths,
dim=0,
)
L = int(seq_offsets[-1].item())

qkv = (
torch.randn(
(L, num_heads, 512),
dtype=torch.bfloat16,
)
.requires_grad_(requires_grad)
.cuda()
qkv = torch.randn(
(L, num_heads, attn_dim * 2 + hidden_dim),
dtype=torch.bfloat16,
device="cuda",
)
return qkv, seq_offsets, timestamps, num_targets
q, k, v = torch.split(qkv, [attn_dim, attn_dim, hidden_dim], dim=-1)
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
return q, k, v, seq_offsets, timestamps, num_targets, max_seq_len
53 changes: 32 additions & 21 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
parser.add_argument("--batch-size", type=int, default=256, help="Batch size")
parser.add_argument("--heads", type=int, default=4, help="Number of heads")
parser.add_argument("--max-seq-len-log2", type=int, default=9)
parser.add_argument("--attn-dim", type=int, default=128)
parser.add_argument("--hidden-dim", type=int, default=128)
parser.add_argument("--max-seq-len-log2", type=int, default=15)
parser.add_argument("--num-buckets", type=int, default=2048)
parser.add_argument("--seq-sparsity", type=float, default=0.8)
parser.add_argument("--seq-sparsity", type=float, default=0.95)
parser.add_argument("--target-size", type=int, default=20)
parser.add_argument("--sort-by-length", type=bool, default=False)
parser.add_argument("--sort-by-length", type=bool, default=True)
return parser.parse_args(args)


Expand All @@ -39,71 +41,82 @@ def __init__(
args = parse_op_args(self.extra_args)
self.batch_size = args.batch_size
self.num_heads = args.heads
self.max_seq_len = 2**args.max_seq_len_log2
self.attn_dim = args.attn_dim
self.hidden_dim = args.hidden_dim
self.max_seq_len_log2 = args.max_seq_len_log2
self.num_buckets = args.num_buckets
self.sparsity = args.seq_sparsity
self.target_size = args.target_size
self.sort_by_length = args.sort_by_length
# set a default number of inputs
self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)

@register_benchmark()
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets):
def hstu_triton_ragged_attention(
self, q, k, v, seq_offsets, timestamps, num_targets, seq_len
):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
self.max_seq_len,
seq_len,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
persistent_kernel=False,
)
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)
return lambda: attn(q, k, v, seq_offsets, timestamps, num_targets)

# TODO: enable persistent kernels when the OSS backward is ready
@register_benchmark(enabled=False)
def hstu_triton_ragged_attention_persistent(
self, qkv, seq_offsets, timestamps, num_targets
self,
q,
k,
v,
seq_offsets,
timestamps,
num_targets,
seq_len,
):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
self.max_seq_len,
seq_len,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
persistent_kernel=True,
)
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)
return lambda: attn(q, k, v, seq_offsets, timestamps, num_targets)

def get_x_val(self, example_inputs):
seq_len = example_inputs[-1]
return (
self.batch_size,
self.num_heads,
self.max_seq_len,
seq_len,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
)

def get_input_iter(self):
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(
for seq_len in [2**i for i in range(8, self.max_seq_len_log2)]:
yield get_test_inputs(
self.batch_size,
self.num_heads,
self.max_seq_len,
self.attn_dim,
self.hidden_dim,
seq_len,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
)
yield inputs

def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
o = fwd_fn()
Expand All @@ -123,9 +136,7 @@ def tflops(
f1 = 0.0
f2 = 0.0
jagged = True
qkv, seq_offsets, timestamps, num_targets = example_inputs
q = qkv[:, :, :128]
v = qkv[:, :, 256:384]
q, k, v, seq_offsets, timestamps, num_targets = example_inputs
_, nheads, attn_dim = q.shape
_, _, hidden_dim = v.shape
max_seqlen = timestamps.size(1) - 1
Expand Down

0 comments on commit 7e1f269

Please sign in to comment.