File tree Expand file tree Collapse file tree 4 files changed +6
-3
lines changed Expand file tree Collapse file tree 4 files changed +6
-3
lines changed Original file line number Diff line number Diff line change 7
7
#include < torch/nn/functional.h>
8
8
#include < c10/cuda/CUDAGuard.h>
9
9
#include < c10/cuda/CUDAStream.h>
10
- #include < ATen/cuda/CUDAGeneratorImpl.h>
10
+ #include < ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
11
+ #include < ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
11
12
12
13
#include < cutlass/numeric_types.h>
13
14
Original file line number Diff line number Diff line change 7
7
#include < cuda.h>
8
8
#include < vector>
9
9
10
- #include < ATen/cuda/PhiloxUtils.cuh > // For at::cuda::philox::unpack
10
+ #include < ATen/cuda/CUDAGeneratorImpl.h > // For at::Generator and at::PhiloxCudaState
11
11
12
12
constexpr int TOTAL_DIM = 0 ;
13
13
constexpr int H_DIM = 1 ;
Original file line number Diff line number Diff line change 4
4
5
5
#pragma once
6
6
7
+ #include < ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
8
+
7
9
#include < cute/tensor.hpp>
8
10
9
11
#include < cutlass/cutlass.h>
Original file line number Diff line number Diff line change 1
- __version__ = "2.7.1.post2 "
1
+ __version__ = "2.7.1.post3 "
2
2
3
3
from flash_attn .flash_attn_interface import (
4
4
flash_attn_func ,
You can’t perform that action at this time.
0 commit comments