diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7a85edbdcdb84..689bfb5d11a36 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1970,7 +1970,9 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { const bool has_mask = op->src[3] != nullptr; if (ggml_metal_op_flash_attn_ext_use_vec(op)) { - const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; + // note: always reserve the padding space to avoid graph reallocations + //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; + const bool has_kvpad = true; if (has_kvpad) { res += OP_FLASH_ATTN_EXT_VEC_NCPSG*( @@ -1979,7 +1981,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); } } else { - const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0; + //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0; + const bool has_kvpad = true; if (has_kvpad) { res += OP_FLASH_ATTN_EXT_NCPSG*( @@ -2015,9 +2018,10 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op); // this optimization is not useful for the vector kernels - if (is_vec) { - return res; - } + // note: always reserve the blk buffer to avoid graph reallocations + //if (is_vec) { + // return res; + //} const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; @@ -2044,13 +2048,16 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { size_t res = 0; - if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + // note: always reserve the temp buffer to avoid graph reallocations + //if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + if (true) { const int64_t nwg = 32; + const int64_t ne01_max = std::min(ne01, 32); // temp buffer for writing the results from each workgroup // - ne20: the size of the Value head // - + 2: the S and M values for each intermediate result - res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2)); } return res;