Skip to content

Commit 9375ac9

Browse files
committed
[CI] Don't include <ATen/cuda/CUDAGraphsUtils.cuh>
1 parent e782d28 commit 9375ac9

File tree

5 files changed

+8
-4
lines changed

5 files changed

+8
-4
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <c10/cuda/CUDAGuard.h>
99
#include <c10/cuda/CUDAStream.h>
1010
#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
11-
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
11+
#include "philox_unpack.cuh" // For at::cuda::philox::unpack
1212

1313
#include <cutlass/numeric_types.h>
1414

csrc/flash_attn/src/flash_fwd_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#pragma once
66

7-
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
7+
#include "philox_unpack.cuh" // For at::cuda::philox::unpack
88

99
#include <cute/tensor.hpp>
1010

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
// This is purely so that it works with torch 2.1. For torch 2.2+ we can include ATen/cuda/PhiloxUtils.cuh
2+
3+
#pragma once
4+
#include <ATen/cuda/detail/UnpackRaw.cuh>

flash_attn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.7.1.post3"
1+
__version__ = "2.7.1.post4"
22

33
from flash_attn.flash_attn_interface import (
44
flash_attn_func,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def check_if_rocm_home_none(global_option: str) -> None:
114114

115115

116116
def append_nvcc_threads(nvcc_extra_args):
117-
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
117+
nvcc_threads = os.getenv("NVCC_THREADS") or "2"
118118
return nvcc_extra_args + ["--threads", nvcc_threads]
119119

120120

0 commit comments

Comments
 (0)