Skip to content

Commit

Permalink
Fixes requires grad
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 14, 2024
1 parent dceef3e commit 4823828
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 3 additions & 5 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
num_heads,
max_seq_len,
num_buckets,
requires_grad,
persistent_kernel: bool = False,
) -> None:
super().__init__()
Expand All @@ -54,13 +55,13 @@ def __init__(
torch.randn(
(self.num_buckets + 1,),
dtype=torch.bfloat16,
).cuda()
).requires_grad_(requires_grad).cuda()
)
self.all_pos_weights = torch.nn.Parameter(
torch.randn(
(2 * self.max_seq_len - 1,),
dtype=torch.bfloat16,
).cuda()
).requires_grad_(requires_grad).cuda()
)
self.persistent_kernel = persistent_kernel

Expand Down Expand Up @@ -179,7 +180,6 @@ def get_test_inputs(
86400,
size=(batch_size, max_seq_len + 1),
)
.requires_grad_(requires_grad)
.cuda()
)
timestamps = timestamp_deltas.cumsum(dim=1)
Expand All @@ -189,15 +189,13 @@ def get_test_inputs(
max_seq_len + 1,
size=(batch_size,),
)
.requires_grad_(requires_grad)
.cuda()
)
seq_offsets = (
torch.zeros(
(batch_size + 1,),
dtype=torch.int64,
)
.requires_grad_(requires_grad)
.cuda()
)
seq_offsets[1:] = torch.cumsum(
Expand Down
6 changes: 4 additions & 2 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self.num_buckets = args.num_buckets
# set a default number of inputs
self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)

@register_benchmark()
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
Expand All @@ -40,6 +41,7 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.requires_grad,
persistent_kernel=False,
)
return lambda: attn(qkv, seq_offsets, timestamps)
Expand All @@ -52,6 +54,7 @@ def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps):
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.requires_grad,
persistent_kernel=True,
)
return lambda: attn(qkv, seq_offsets, timestamps)
Expand All @@ -60,9 +63,8 @@ def get_x_val(self, example_inputs):
return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets)

def get_input_iter(self):
requires_grad = not (self.mode == Mode.FWD_NO_GRAD)
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len, requires_grad)
inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len, self.requires_grad)
yield inputs

def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
Expand Down

0 comments on commit 4823828

Please sign in to comment.