Skip to content

Commit

Permalink
move redundant cod
Browse files Browse the repository at this point in the history
  • Loading branch information
Seventeen17 committed Nov 15, 2024
1 parent eaa2d2e commit 6500544
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 8 deletions.
2 changes: 1 addition & 1 deletion bazel/flash_attn.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ genrule(
"popd",
"cp external/flash_attn/build/*/*.so $(location flash_attn_cuda.so)"]),
visibility = ["//visibility:public"],
)
)
7 changes: 0 additions & 7 deletions torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,6 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream,
torch::from_blob(buffers[8 + buf_offset], {params.b + 1}, opts);
at::Tensor rng_state =
torch::from_blob(buffers[6 + buf_offset], {2}, opts.dtype(torch::kInt64));
// Fill zeros for outputs.
// cudaMemsetAsync(buffers[4 + buf_offset], 0, params.b * params.h *
// params.seqlen_q * sizeof(torch::kFloat), cuda_stream);
// cudaMemsetAsync(buffers[5 + buf_offset], 0, params.b * params.seqlen_q *
// params.h * params.d * sizeof(scalar_type), cuda_stream);
cudaMemsetAsync(rng_state.data_ptr(), 0, 2 * sizeof(int64_t), cuda_stream);
softmax_lse.fill_(0);
o_output.fill_(0);
Expand Down Expand Up @@ -155,8 +150,6 @@ void custom_call_flash_attention_varlen_forward(cudaStream_t stream,
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};

auto dprops = at::cuda::getCurrentDeviceProperties();

Flash_fwd_params launch_params;

// Reset the parameters
Expand Down

0 comments on commit 6500544

Please sign in to comment.