From 48238285c0b40e7361dd81cb271904119ef31bb4 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 14 Nov 2024 15:09:00 -0500 Subject: [PATCH] Fixes requires grad --- tritonbench/operators/ragged_attention/hstu.py | 8 +++----- tritonbench/operators/ragged_attention/operator.py | 6 ++++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index ec6a2d5b..e7ad9d30 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -43,6 +43,7 @@ def __init__( num_heads, max_seq_len, num_buckets, + requires_grad, persistent_kernel: bool = False, ) -> None: super().__init__() @@ -54,13 +55,13 @@ def __init__( torch.randn( (self.num_buckets + 1,), dtype=torch.bfloat16, - ).cuda() + ).requires_grad_(requires_grad).cuda() ) self.all_pos_weights = torch.nn.Parameter( torch.randn( (2 * self.max_seq_len - 1,), dtype=torch.bfloat16, - ).cuda() + ).requires_grad_(requires_grad).cuda() ) self.persistent_kernel = persistent_kernel @@ -179,7 +180,6 @@ def get_test_inputs( 86400, size=(batch_size, max_seq_len + 1), ) - .requires_grad_(requires_grad) .cuda() ) timestamps = timestamp_deltas.cumsum(dim=1) @@ -189,7 +189,6 @@ def get_test_inputs( max_seq_len + 1, size=(batch_size,), ) - .requires_grad_(requires_grad) .cuda() ) seq_offsets = ( @@ -197,7 +196,6 @@ def get_test_inputs( (batch_size + 1,), dtype=torch.int64, ) - .requires_grad_(requires_grad) .cuda() ) seq_offsets[1:] = torch.cumsum( diff --git a/tritonbench/operators/ragged_attention/operator.py b/tritonbench/operators/ragged_attention/operator.py index 4620fd41..9a6656ab 100644 --- a/tritonbench/operators/ragged_attention/operator.py +++ b/tritonbench/operators/ragged_attention/operator.py @@ -32,6 +32,7 @@ def __init__( self.num_buckets = args.num_buckets # set a default number of inputs self._num_inputs = 10 if self._num_inputs is None else self._num_inputs + self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD) @register_benchmark() def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps): @@ -40,6 +41,7 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps): self.num_heads, self.max_seq_len, self.num_buckets, + self.requires_grad, persistent_kernel=False, ) return lambda: attn(qkv, seq_offsets, timestamps) @@ -52,6 +54,7 @@ def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps): self.num_heads, self.max_seq_len, self.num_buckets, + self.requires_grad, persistent_kernel=True, ) return lambda: attn(qkv, seq_offsets, timestamps) @@ -60,9 +63,8 @@ def get_x_val(self, example_inputs): return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets) def get_input_iter(self): - requires_grad = not (self.mode == Mode.FWD_NO_GRAD) for _input_id in range(self._num_inputs): - inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len, requires_grad) + inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len, self.requires_grad) yield inputs def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]: