Skip to content

Commit

Permalink
Set requires_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 14, 2024
1 parent 056b832 commit dceef3e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
10 changes: 5 additions & 5 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,14 @@ def forward(


def get_test_inputs(
batch_size, num_heads, max_seq_len
batch_size, num_heads, max_seq_len, requires_grad
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
timestamp_deltas: torch.Tensor = (
torch.randint(
86400,
size=(batch_size, max_seq_len + 1),
)
.requires_grad_(False)
.requires_grad_(requires_grad)
.cuda()
)
timestamps = timestamp_deltas.cumsum(dim=1)
Expand All @@ -189,15 +189,15 @@ def get_test_inputs(
max_seq_len + 1,
size=(batch_size,),
)
.requires_grad_(False)
.requires_grad_(requires_grad)
.cuda()
)
seq_offsets = (
torch.zeros(
(batch_size + 1,),
dtype=torch.int64,
)
.requires_grad_(False)
.requires_grad_(requires_grad)
.cuda()
)
seq_offsets[1:] = torch.cumsum(
Expand All @@ -211,7 +211,7 @@ def get_test_inputs(
(L, num_heads, 512),
dtype=torch.bfloat16,
)
.requires_grad_(False)
.requires_grad_(requires_grad)
.cuda()
)
return qkv, seq_offsets, timestamps
3 changes: 2 additions & 1 deletion tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def get_x_val(self, example_inputs):
return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets)

def get_input_iter(self):
requires_grad = not (self.mode == Mode.FWD_NO_GRAD)
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len)
inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len, requires_grad)
yield inputs

def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
Expand Down

0 comments on commit dceef3e

Please sign in to comment.