diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py
index ec6a2d5b..e7ad9d30 100644
--- a/tritonbench/operators/ragged_attention/hstu.py
+++ b/tritonbench/operators/ragged_attention/hstu.py
@@ -43,6 +43,7 @@ def __init__(
         num_heads,
         max_seq_len,
         num_buckets,
+        requires_grad,
         persistent_kernel: bool = False,
     ) -> None:
         super().__init__()
@@ -54,13 +55,13 @@ def __init__(
             torch.randn(
                 (self.num_buckets + 1,),
                 dtype=torch.bfloat16,
-            ).cuda()
+            ).requires_grad_(requires_grad).cuda()
         )
         self.all_pos_weights = torch.nn.Parameter(
             torch.randn(
                 (2 * self.max_seq_len - 1,),
                 dtype=torch.bfloat16,
-            ).cuda()
+            ).requires_grad_(requires_grad).cuda()
         )
         self.persistent_kernel = persistent_kernel
 
@@ -179,7 +180,6 @@ def get_test_inputs(
             86400,
             size=(batch_size, max_seq_len + 1),
         )
-        .requires_grad_(requires_grad)
         .cuda()
     )
     timestamps = timestamp_deltas.cumsum(dim=1)
@@ -189,7 +189,6 @@ def get_test_inputs(
             max_seq_len + 1,
             size=(batch_size,),
         )
-        .requires_grad_(requires_grad)
         .cuda()
     )
     seq_offsets = (
@@ -197,7 +196,6 @@ def get_test_inputs(
             (batch_size + 1,),
             dtype=torch.int64,
         )
-        .requires_grad_(requires_grad)
         .cuda()
     )
     seq_offsets[1:] = torch.cumsum(
diff --git a/tritonbench/operators/ragged_attention/operator.py b/tritonbench/operators/ragged_attention/operator.py
index 4620fd41..9a6656ab 100644
--- a/tritonbench/operators/ragged_attention/operator.py
+++ b/tritonbench/operators/ragged_attention/operator.py
@@ -32,6 +32,7 @@ def __init__(
         self.num_buckets = args.num_buckets
         # set a default number of inputs
         self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
+        self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)
 
     @register_benchmark()
     def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
@@ -40,6 +41,7 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
             self.num_heads,
             self.max_seq_len,
             self.num_buckets,
+            self.requires_grad,
             persistent_kernel=False,
         )
         return lambda: attn(qkv, seq_offsets, timestamps)
@@ -52,6 +54,7 @@ def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps):
             self.num_heads,
             self.max_seq_len,
             self.num_buckets,
+            self.requires_grad,
             persistent_kernel=True,
         )
         return lambda: attn(qkv, seq_offsets, timestamps)
@@ -60,9 +63,8 @@ def get_x_val(self, example_inputs):
         return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets)
 
     def get_input_iter(self):
-        requires_grad = not (self.mode == Mode.FWD_NO_GRAD)
         for _input_id in range(self._num_inputs):
-            inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len, requires_grad)
+            inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len, self.requires_grad)
             yield inputs
 
     def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]: