diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index f6f491cb..ec6a2d5b 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -172,14 +172,14 @@ def forward( def get_test_inputs( - batch_size, num_heads, max_seq_len + batch_size, num_heads, max_seq_len, requires_grad ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: timestamp_deltas: torch.Tensor = ( torch.randint( 86400, size=(batch_size, max_seq_len + 1), ) - .requires_grad_(False) + .requires_grad_(requires_grad) .cuda() ) timestamps = timestamp_deltas.cumsum(dim=1) @@ -189,7 +189,7 @@ def get_test_inputs( max_seq_len + 1, size=(batch_size,), ) - .requires_grad_(False) + .requires_grad_(requires_grad) .cuda() ) seq_offsets = ( @@ -197,7 +197,7 @@ def get_test_inputs( (batch_size + 1,), dtype=torch.int64, ) - .requires_grad_(False) + .requires_grad_(requires_grad) .cuda() ) seq_offsets[1:] = torch.cumsum( @@ -211,7 +211,7 @@ def get_test_inputs( (L, num_heads, 512), dtype=torch.bfloat16, ) - .requires_grad_(False) + .requires_grad_(requires_grad) .cuda() ) return qkv, seq_offsets, timestamps diff --git a/tritonbench/operators/ragged_attention/operator.py b/tritonbench/operators/ragged_attention/operator.py index e941ab29..4620fd41 100644 --- a/tritonbench/operators/ragged_attention/operator.py +++ b/tritonbench/operators/ragged_attention/operator.py @@ -60,8 +60,9 @@ def get_x_val(self, example_inputs): return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets) def get_input_iter(self): + requires_grad = not (self.mode == Mode.FWD_NO_GRAD) for _input_id in range(self._num_inputs): - inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len) + inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len, requires_grad) yield inputs def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]: