Skip to content

Commit e782d28

Browse files
committed
[CI] Change torch #include to make it work with torch 2.1 Philox
1 parent 073afd5 commit e782d28

File tree

4 files changed

+6
-3
lines changed

4 files changed

+6
-3
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
#include <torch/nn/functional.h>
88
#include <c10/cuda/CUDAGuard.h>
99
#include <c10/cuda/CUDAStream.h>
10-
#include <ATen/cuda/CUDAGeneratorImpl.h>
10+
#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
11+
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
1112

1213
#include <cutlass/numeric_types.h>
1314

csrc/flash_attn/src/flash.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <cuda.h>
88
#include <vector>
99

10-
#include <ATen/cuda/PhiloxUtils.cuh> // For at::cuda::philox::unpack
10+
#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
1111

1212
constexpr int TOTAL_DIM = 0;
1313
constexpr int H_DIM = 1;

csrc/flash_attn/src/flash_fwd_kernel.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#pragma once
66

7+
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
8+
79
#include <cute/tensor.hpp>
810

911
#include <cutlass/cutlass.h>

flash_attn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.7.1.post2"
1+
__version__ = "2.7.1.post3"
22

33
from flash_attn.flash_attn_interface import (
44
flash_attn_func,

0 commit comments

Comments
 (0)