File tree Expand file tree Collapse file tree 4 files changed +10
-6
lines changed Expand file tree Collapse file tree 4 files changed +10
-6
lines changed Original file line number Diff line number Diff line change 80
80
cuda-version : ' 11.7.1'
81
81
- torch-version : ' 2.1.0.dev20230731'
82
82
cuda-version : ' 11.8.0'
83
+ # Pytorch >= 2.1 with nvcc 12.1.0 segfaults during compilation, so
84
+ # we only use CUDA 12.2. setup.py as a special case that will
85
+ # download the wheel for CUDA 12.2 instead.
86
+ - torch-version : ' 2.1.0.dev20230731'
87
+ cuda-version : ' 12.1.0'
83
88
84
89
steps :
85
90
- name : Checkout
Original file line number Diff line number Diff line change 1
- __version__ = "2.2.3.post1 "
1
+ __version__ = "2.2.3.post2 "
2
2
3
3
from flash_attn .flash_attn_interface import (
4
4
flash_attn_func ,
Original file line number Diff line number Diff line change @@ -223,6 +223,8 @@ def get_wheel_url():
223
223
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
224
224
torch_cuda_version = parse (torch .version .cuda )
225
225
torch_version_raw = parse (torch .__version__ )
226
+ if torch_version_raw .major == 2 and torch_version_raw .minor == 1 :
227
+ torch_cuda_version = parse ("12.2" )
226
228
python_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
227
229
platform_name = get_platform ()
228
230
flash_version = get_package_version ()
Original file line number Diff line number Diff line change @@ -85,14 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
85
85
RUN pip install git+https://github.com/mlcommons/
[email protected]
86
86
87
87
# Install FlashAttention
88
- RUN pip install flash-attn==2.2.3.post1
88
+ RUN pip install flash-attn==2.2.3.post2
89
89
90
90
# Install CUDA extensions for cross-entropy, fused dense, layer norm
91
91
RUN git clone https://github.com/HazyResearch/flash-attention \
92
- && cd flash-attention && git checkout v2.2.3.post1 \
93
- && cd csrc/fused_softmax && pip install . && cd ../../ \
94
- && cd csrc/rotary && pip install . && cd ../../ \
92
+ && cd flash-attention && git checkout v2.2.3.post2 \
95
93
&& cd csrc/layer_norm && pip install . && cd ../../ \
96
94
&& cd csrc/fused_dense_lib && pip install . && cd ../../ \
97
- && cd csrc/ft_attention && pip install . && cd ../../ \
98
95
&& cd .. && rm -rf flash-attention
You can’t perform that action at this time.
0 commit comments