Conversation
Xreki
left a comment
There was a problem hiding this comment.
PR描述补充下吧,包括:
(1)本PR的工作
(2)非advance分支动态库的大小
(3)advance分支动态库的大小、打包的whl包名字、打包了哪些内容、打包的命令
PR title也完善下,准确描述该PR的工作内容
| "SHELL:-gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90" | ||
| >) | ||
|
|
||
| target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUDA>: |
| "SHELL:-gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90" | ||
| >) | ||
|
|
||
| INSTALL(TARGETS flashattn |
There was a problem hiding this comment.
最终生成的动态库名称,在关闭、开启advance功能时,最好有所区分,这样Paddle框架中在加载动态库时容易区分些。
There was a problem hiding this comment.
关闭是就是libflashattn.so 开启的时候是libflashattn_advanced.so
| #else | ||
| BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { | ||
| BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { | ||
| BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { |
There was a problem hiding this comment.
这个分支是不是应该保留is_equal_qk模板?我理解非advance分支,需要是causal最优的性能版本
| os.environ["PY_VERSION"] = python_version | ||
|
|
||
| paddle_include_path = paddle.sysconfig.get_include() | ||
| paddle_lib_path = paddle.sysconfig.get_lib() |
There was a problem hiding this comment.
为了将安装的libflash_attn_advanced.so拷贝到paddle路径下
There was a problem hiding this comment.
这种方式我不太确定,请@sneaxiy 也看下。我理解:
- FA动态图即使是安装在自己的目录下,应该也是能找到的
- FA打包成.whl后,对FA的依赖应该需要写到Paddle的requirements.txt里面,那安装FA的时候很可能是还没有安装Paddle的
There was a problem hiding this comment.
FA应该作为Paddle侧的一个外部算子?类似于xformers跟PyTorch的关系,而不是新的so直接去替换Paddle里的FA的so?
| Check whether CUDA is available. | ||
| """ | ||
| try: | ||
| assert len(paddle.static.cuda_places()) > 0 |
| cxx11_abi = "" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() | ||
|
|
||
| # Determine wheel URL based on CUDA version, paddle version, python version and OS | ||
| wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl' |
There was a problem hiding this comment.
whl包里面不需要加paddle版本吧?本身flashattn对paddle版本并没有依赖,是paddle对flashattn版本存在依赖
There was a problem hiding this comment.
好的这个后续会去掉,现在的是默认版本:paddle_flash_attn-2.0.8-cp37-none-any.whl
csrc/setup.py
Outdated
| # Determine the version numbers that will be used to determine the correct wheel | ||
| # We're using the CUDA version used to build paddle, not the one currently installed | ||
| # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) | ||
| paddle_cuda_version = "234" # parse(paddle.version.cuda) |
There was a problem hiding this comment.
这个之前获取paddle cuda版本的时候报错了, 临时加的,后续会删除
4a63f7c to
3aca223
Compare
csrc/CMakeLists.txt
Outdated
| ) | ||
| endif() | ||
|
|
||
| if (WITH_ADVANCED) |
| os.environ["PY_VERSION"] = python_version | ||
|
|
||
| paddle_include_path = paddle.sysconfig.get_include() | ||
| paddle_lib_path = paddle.sysconfig.get_lib() |
There was a problem hiding this comment.
这种方式我不太确定,请@sneaxiy 也看下。我理解:
- FA动态图即使是安装在自己的目录下,应该也是能找到的
- FA打包成.whl后,对FA的依赖应该需要写到Paddle的requirements.txt里面,那安装FA的时候很可能是还没有安装Paddle的
不建议提供这么复杂的配置方式。建议基础版本和扩展版本提供的causal计算能力和性能保持一致;若Paddle中使用到了扩展版本中新增的功能,则Paddle应自动去调用扩展库,用户唯一的感知就是可能需要手动执行下 |
|
FA动态图即使是安装在自己的目录下,paddle是找不到的,paddle只会在自己的libs下找, 可以看Paddle的修改https://github.com/PaddlePaddle/Paddle/pull/59802/files |
|
建议基础版本和扩展版本提供的causal计算能力和性能保持一致-》 已经是一致的了,描述存在问题,已修改 |
| }); | ||
| }); | ||
| }); | ||
| #endif |
There was a problem hiding this comment.
这里的代码是否可以简化下?以免出现2个分支?比如对BOOL_SWITCH做个改进,WITH_ADVANCED前后走不同的定义,以免以后维护更加困难?类似于:
#ifdef PADDLE_WITH_ADVANCED
#define BOOL_SWITCH(...) ...
#else
#define BOOL_SWITCH(...) ...
#endif
| }); | ||
| }); | ||
| #else | ||
| BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { |
There was a problem hiding this comment.
同上。最好对BOOL_SWITCH进行改进。在WITH_ADVANCED开启前后走不同的定义,避免重复代码。
| version = version_detail[0] + version_detail[1] / 10 | ||
| env_version = os.getenv("PY_VERSION") | ||
|
|
||
| if version < 3.7: |
There was a problem hiding this comment.
这个判断比较粗糙。version是浮点数而不是整数,建议改成使用version_detail整数的判断。
| os.environ["PY_VERSION"] = python_version | ||
|
|
||
| paddle_include_path = paddle.sysconfig.get_include() | ||
| paddle_lib_path = paddle.sysconfig.get_lib() |
There was a problem hiding this comment.
FA应该作为Paddle侧的一个外部算子?类似于xformers跟PyTorch的关系,而不是新的so直接去替换Paddle里的FA的so?
|
|
||
| add_dependencies(flashattn flashattn_with_bias_mask) | ||
|
|
||
| set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures") |
There was a problem hiding this comment.
这个写法在外部设置了-DNVCC_ARCH_BIN=...的情况下,取值会是多少,是80还是外部设置的值?
|
|
PR描述:
将csrc/build编译产生的结果编译成whl 或者 .so
使用方法:
1. attn_mask功能支持,
2. fa反向确定算法支持,用于逐位对齐精度调试)
1. flash-attention/csrc/build 下会产生 libflashattn_advanced.so
2. flash-attention/csrc/build/dist 下会生成 paddle_flash_attn-2.0.8-cp37-none-any.whl 【注意这个打包只打包了libflashattn_advanced.so】
3. 2.0.8表示当前paddle flash_attention的版本, cp37表示python版本为3.7, any 表示任意平台均可
1. flash-attention/csrc/build 下会产生 libflashattn.so 不会对.so进行打包【用于paddle内部源码编译】
paddle_flash_attn-2.0.8-cp37-none-any.whl 使用方法:
直接pip install 即可
说明:安装后后在/usr/local/lib/python3.7/dist-packages/paddle/libs/下新增 libflashattn_advanced.so。


触发条件: 当使用additional_mask功能 | 设置确定算法环境变量 | 设置 FLAGS_flash_attention_with_advanced 环境变量时 会调用 libflashattn_advanced.so 中的实现,如果未安装libflashattn_advanced.so 则会直接报错。
本PR修改后.so大小:
