@@ -894,14 +894,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
894894}
895895
896896/* *
897- * @brief Get or expand a cached float32 tensor filled with a scalar value.
897+ * @brief Get or expand a cached tensor filled with a scalar value.
898898 *
899- * This function manages cached device memory for float32 tensors. If the current
899+ * This function manages cached device memory for tensors. If the current
900900 * cache size is insufficient for the requested tensor shape, the old memory will
901- * be released and new memory will be allocated. The allocated buffer is then
902- * initialized either with zeros (when @p value == 0.0f) or with the given scalar
903- * value using CANN operations. Finally, an aclTensor object is created from the
904- * cached memory and returned.
901+ * be released and new memory will be allocated. The allocated buffer is
902+ * initialized with the given scalar value using CANN operations.
903+ * Finally, an aclTensor object is created from the cached memory and returned.
905904 *
906905 * @param ctx The CANN backend context that manages device memory.
907906 * @param buffer A pointer to the cached device buffer (will be allocated
@@ -910,25 +909,27 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
910909 * updated when the cache is expanded.
911910 * @param ne The tensor shape array (number of elements in each dimension).
912911 * @param nb The stride size for each dimension.
912+ * @param dtype Data type of cached tensor.
913913 * @param dims The number of tensor dimensions.
914914 * @param value The scalar value used to fill the tensor (supports zero
915915 * initialization via memset or arbitrary values via fill_scalar).
916916 * @return An aclTensor pointer created from the cached buffer.
917917 */
918- static aclTensor* get_f32_cache_acl_tensor (
918+ static aclTensor* get_cache_acl_tensor (
919919 ggml_backend_cann_context& ctx,
920920 void ** buffer,
921921 int64_t &cache_element,
922922 int64_t * ne,
923923 size_t * nb,
924+ ggml_type dtype,
924925 int64_t dims,
925926 float value) {
926927 // Calculate total number of elements
927928 int64_t n_element = 1 ;
928929 for (int i = 0 ; i < dims; i++) {
929930 n_element *= ne[i];
930931 }
931- size_t size = n_element * sizeof ( float );
932+ size_t size = n_element * ggml_type_size (dtype );
932933
933934 // Allocate or expand cache if needed
934935 if (cache_element < n_element) {
@@ -941,19 +942,17 @@ static aclTensor* get_f32_cache_acl_tensor(
941942 cache_element = n_element;
942943
943944 // Initialize cache
944- if (value == 0 .0f ) {
945- ACL_CHECK (aclrtMemsetAsync (*buffer, size, 0 , size, ctx.stream ()));
946- } else {
947- int64_t pool_ne[1 ] = { n_element };
948- size_t pool_nb[1 ] = { sizeof (float ) };
949- aclTensor* acl_value = ggml_cann_create_tensor (
950- *buffer, ACL_FLOAT, sizeof (float ), pool_ne, pool_nb, 1 );
951- aclnn_fill_scalar (ctx, 1 , acl_value);
952- ggml_cann_release_resources (ctx, acl_value);
953- }
945+ int64_t pool_ne[1 ] = { n_element };
946+ size_t pool_nb[1 ] = { ggml_type_size (dtype) };
947+ aclTensor* acl_value = ggml_cann_create_tensor (
948+ *buffer, ggml_cann_type_mapping (dtype), ggml_type_size (dtype),
949+ pool_ne, pool_nb, 1 );
950+ aclnn_fill_scalar (ctx, value, acl_value);
951+ ggml_cann_release_resources (ctx, acl_value);
954952 }
955953
956- return ggml_cann_create_tensor (*buffer, ACL_FLOAT, sizeof (float ), ne, nb, dims);
954+ return ggml_cann_create_tensor (*buffer, ggml_cann_type_mapping (dtype),
955+ ggml_type_size (dtype), ne, nb, dims);
957956}
958957
959958void ggml_cann_rms_norm (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -965,35 +964,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
965964 float eps;
966965 memcpy (&eps, dst->op_params , sizeof (float ));
967966
968- // build gamma, one.. .
967+ // build gamma.
969968 size_t acl_gamma_nb[GGML_MAX_DIMS];
970- acl_gamma_nb[0 ] = sizeof (float );
969+ // gamma's type is the same with dst.
970+ acl_gamma_nb[0 ] = ggml_type_size (dst->type );
971971 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
972972 acl_gamma_nb[i] = acl_gamma_nb[i - 1 ] * src->ne [i - 1 ];
973973 }
974- aclTensor* acl_gamma = get_f32_cache_acl_tensor (
974+ aclTensor* acl_gamma = get_cache_acl_tensor (
975975 ctx,
976976 &ctx.rms_norm_one_tensor_cache .cache ,
977977 ctx.rms_norm_one_tensor_cache .size ,
978978 src->ne ,
979979 acl_gamma_nb,
980+ dst->type ,
980981 1 , // dims
981982 1 .0f // value
982983 );
983984
984- // build rstd, zero.. .
985+ // build rstd.
985986 int64_t acl_rstd_ne[] = {src->ne [1 ], src->ne [2 ], src->ne [3 ]};
986987 size_t acl_rstd_nb[GGML_MAX_DIMS - 1 ];
988+ // rstd will always be F32.
987989 acl_rstd_nb[0 ] = sizeof (float );
988990 for (int i = 1 ; i < GGML_MAX_DIMS - 1 ; i++) {
989991 acl_rstd_nb[i] = acl_rstd_nb[i - 1 ] * acl_rstd_ne[i - 1 ];
990992 }
991- aclTensor* acl_rstd = get_f32_cache_acl_tensor (
993+ aclTensor* acl_rstd = get_cache_acl_tensor (
992994 ctx,
993995 &ctx.rms_norm_zero_tensor_cache .cache ,
994996 ctx.rms_norm_zero_tensor_cache .size ,
995997 acl_rstd_ne,
996998 acl_rstd_nb,
999+ GGML_TYPE_F32,
9971000 GGML_MAX_DIMS - 1 ,
9981001 0 .0f // value
9991002 );
@@ -1765,41 +1768,42 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
17651768 ggml_tensor* src0 = dst->src [0 ]; // src
17661769 ggml_tensor* src1 = dst->src [1 ]; // index
17671770
1771+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1772+
17681773 switch (src0->type ) {
1769- case GGML_TYPE_F32: {
1770- aclnn_index_select_4d (ctx, src0->data , src0->ne , src0->nb ,
1771- dst->data , dst->ne , dst->nb ,
1772- src1, dst->type );
1773- break ;
1774- }
1775- case GGML_TYPE_F16: {
1776- aclTensor* acl_src0 = ggml_cann_create_tensor (src0);
1777- ggml_cann_pool_alloc src_buffer_allocator (
1778- ctx.pool (), ggml_nelements (src0) * sizeof (float ));
1779- void * src_trans_buffer = src_buffer_allocator.get ();
1780- size_t src_trans_nb[GGML_MAX_DIMS];
1781- src_trans_nb[0 ] = sizeof (float );
1782- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
1783- src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
1774+ case GGML_TYPE_F16:
1775+ case GGML_TYPE_F32:
1776+ if (src0->type == dst->type ) {
1777+ aclnn_index_select_4d (ctx, src0->data , src0->ne , src0->nb ,
1778+ dst->data , dst->ne , dst->nb ,
1779+ src1, dst->type );
1780+ } else {
1781+ aclTensor* acl_src0 = ggml_cann_create_tensor (src0);
1782+ ggml_cann_pool_alloc src_buffer_allocator (
1783+ ctx.pool (), ggml_nelements (src0) * ggml_element_size (dst));
1784+ void * src_trans_buffer = src_buffer_allocator.get ();
1785+ size_t src_trans_nb[GGML_MAX_DIMS];
1786+ src_trans_nb[0 ] = dst->nb [0 ];
1787+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
1788+ src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
1789+ }
1790+ aclTensor* src_trans_tensor = ggml_cann_create_tensor (
1791+ src_trans_buffer, ggml_cann_type_mapping (dst->type ), ggml_type_size (dst->type ),
1792+ src0->ne , src_trans_nb, GGML_MAX_DIMS);
1793+ aclnn_cast (ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
1794+ aclnn_index_select_4d (ctx, src_trans_buffer, src0->ne , src_trans_nb,
1795+ dst->data , dst->ne , dst->nb ,
1796+ src1, dst->type );
1797+ ggml_cann_release_resources (ctx, acl_src0, src_trans_tensor);
17841798 }
1785- aclTensor* src_trans_tensor = ggml_cann_create_tensor (
1786- src_trans_buffer, ACL_FLOAT, ggml_type_size (dst->type ),
1787- src0->ne , src_trans_nb, GGML_MAX_DIMS);
1788- aclnn_cast (ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
1789- aclnn_index_select_4d (ctx, src_trans_buffer, src0->ne , src_trans_nb,
1790- dst->data , dst->ne , dst->nb ,
1791- src1, dst->type );
1792- ggml_cann_release_resources (ctx, acl_src0, src_trans_tensor);
17931799 break ;
1794- }
17951800 case GGML_TYPE_Q8_0: {
17961801 // add 1 dim for bcast mul.
17971802 size_t weight_nb[GGML_MAX_DIMS + 1 ], scale_nb[GGML_MAX_DIMS + 1 ],
17981803 dequant_nb[GGML_MAX_DIMS + 1 ];
17991804 int64_t weight_ne[GGML_MAX_DIMS + 1 ], scale_ne[GGML_MAX_DIMS + 1 ],
18001805 *dequant_ne;
18011806 int64_t scale_offset = 0 ;
1802-
18031807 // [3,4,5,64] -> [3,4,5,2,32]
18041808 weight_ne[0 ] = QK8_0;
18051809 weight_ne[1 ] = src0->ne [0 ] / QK8_0;
@@ -1809,7 +1813,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
18091813 weight_ne[i] = src0->ne [i - 1 ];
18101814 weight_nb[i] = weight_nb[i - 1 ] * weight_ne[i - 1 ];
18111815 }
1812-
18131816 // [3,4,5,64] -> [3,4,5,2,1]
18141817 scale_ne[0 ] = 1 ;
18151818 scale_ne[1 ] = src0->ne [0 ] / QK8_0;
@@ -1819,35 +1822,30 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
18191822 scale_ne[i] = src0->ne [i - 1 ];
18201823 scale_nb[i] = scale_nb[i - 1 ] * scale_ne[i - 1 ];
18211824 }
1822-
18231825 // [3,4,5,64] -> [3,4,5,2,32]
18241826 dequant_ne = weight_ne;
1825- dequant_nb[0 ] = sizeof ( float );
1827+ dequant_nb[0 ] = ggml_type_size (dst-> type );
18261828 for (int i = 1 ; i < GGML_MAX_DIMS + 1 ; i++) {
18271829 dequant_nb[i] = dequant_nb[i - 1 ] * dequant_ne[i - 1 ];
18281830 }
1829-
18301831 scale_offset = ggml_nelements (src0) * sizeof (int8_t );
18311832 ggml_cann_pool_alloc dequant_buffer_allocator (
1832- ctx.pool (), ggml_nelements (src0) * sizeof (float ));
1833-
1833+ ctx.pool (), ggml_nelements (src0) * ggml_type_size (dst->type ));
18341834 aclTensor* acl_weight_tensor = ggml_cann_create_tensor (
18351835 src0->data , ACL_INT8, sizeof (int8_t ), weight_ne, weight_nb,
18361836 GGML_MAX_DIMS + 1 );
18371837 aclTensor* acl_scale_tensor = ggml_cann_create_tensor (
18381838 src0->data , ACL_FLOAT16, sizeof (uint16_t ), scale_ne, scale_nb,
18391839 GGML_MAX_DIMS + 1 , ACL_FORMAT_ND, scale_offset);
18401840 aclTensor* dequant_tensor = ggml_cann_create_tensor (
1841- dequant_buffer_allocator.get (), ACL_FLOAT, sizeof ( float ),
1841+ dequant_buffer_allocator.get (), ggml_cann_type_mapping (dst-> type ), ggml_type_size (dst-> type ),
18421842 dequant_ne, dequant_nb, GGML_MAX_DIMS + 1 );
1843-
18441843 aclnn_mul (ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
1845- dequant_nb[0 ] = sizeof ( float );
1844+ dequant_nb[0 ] = ggml_type_size (dst-> type );
18461845 dequant_ne = src0->ne ;
18471846 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
18481847 dequant_nb[i] = dequant_nb[i - 1 ] * src0->ne [i - 1 ];
18491848 }
1850-
18511849 aclnn_index_select_4d (ctx, dequant_buffer_allocator.get (),
18521850 dequant_ne, dequant_nb,
18531851 dst->data , dst->ne , dst->nb ,
@@ -1965,16 +1963,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
19651963 // Only check env once.
19661964 static bool weight_to_nz = parse_bool (get_env (" GGML_CANN_WEIGHT_NZ" ).value_or (" on" ));
19671965 if (weight_to_nz && is_matmul_weight (weight)) {
1968- int64_t acl_stride[2 ] = {1 , transpose_ne[1 ]};
1969-
1970- // Reverse ne.
1971- std::reverse (transpose_ne, transpose_ne + n_dims);
1972-
1973- std::vector<int64_t > storageDims = {transpose_ne[0 ], transpose_ne[1 ]};
1974-
1975- acl_weight_tensor = aclCreateTensor (
1976- transpose_ne, n_dims, ggml_cann_type_mapping (weight->type ), acl_stride,
1977- 0 , ACL_FORMAT_FRACTAL_NZ, storageDims.data (), 2 , weight->data );
1966+ acl_weight_tensor =
1967+ ggml_cann_create_tensor (weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
19781968 } else {
19791969 acl_weight_tensor =
19801970 ggml_cann_create_tensor (weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@@ -3178,7 +3168,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
31783168 aclTensor* acl_src0_f16_tensor = nullptr ;
31793169 aclTensor* acl_src1_f16_tensor = nullptr ;
31803170 aclTensor* acl_src2_f16_tensor = nullptr ;
3181- aclTensor* acl_dst_f16_tensor = nullptr ;
31823171
31833172 // Step 1: cast the src0 (Query) to fp16 if needed
31843173 ggml_cann_pool_alloc src0_f16_allocator (ctx.pool ());
@@ -3216,22 +3205,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32163205 acl_src2_f16_tensor = ggml_cann_create_tensor (src2, src2_bsnd_ne,
32173206 src2_bsnd_nb, GGML_MAX_DIMS);
32183207
3219- ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
3220- void * out_f16_buffer = out_f16_allocator.alloc (
3221- ggml_nelements (dst) * faElemSize);
3222-
3223- int64_t * out_f16_ne = src0_bsnd_ne;
3224- size_t out_f16_nb[GGML_MAX_DIMS];
3225- out_f16_nb[0 ] = faElemSize;
3226- for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
3227- out_f16_nb[i] = out_f16_nb[i - 1 ] * out_f16_ne[i - 1 ];
3228- }
3229-
3230- acl_dst_f16_tensor = ggml_cann_create_tensor (
3231- out_f16_buffer, faDataType, faElemSize,
3232- out_f16_ne, out_f16_nb, GGML_MAX_DIMS
3233- );
3234-
32353208 // Step 3: create the PSEShift tensor if needed
32363209 // this tensor is considered as mask (f16) in the llama.cpp
32373210 aclTensor* bcast_pse_tensor = nullptr ;
@@ -3334,8 +3307,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33343307 int64_t keyAntiquantMode = 0 ;
33353308 int64_t valueAntiquantMode = 0 ;
33363309
3337- // Step 5: launch the FusedInferAttentionScoreV2 kernel.
3338- // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
3310+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
3311+ aclTensor * fa_dst_tensor = nullptr ;
3312+ aclTensor * acl_dst_tensor = nullptr ;
3313+ ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
3314+ if (dst->type == GGML_TYPE_F32) {
3315+ void * out_f16_buffer = out_f16_allocator.alloc (
3316+ ggml_nelements (dst) * faElemSize);
3317+
3318+ int64_t * out_f16_ne = src0_bsnd_ne;
3319+ size_t out_f16_nb[GGML_MAX_DIMS];
3320+ out_f16_nb[0 ] = faElemSize;
3321+ for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
3322+ out_f16_nb[i] = out_f16_nb[i - 1 ] * out_f16_ne[i - 1 ];
3323+ }
3324+
3325+ fa_dst_tensor = ggml_cann_create_tensor (
3326+ out_f16_buffer, faDataType, faElemSize,
3327+ out_f16_ne, out_f16_nb, GGML_MAX_DIMS
3328+ );
3329+ }
3330+ else {
3331+ fa_dst_tensor = ggml_cann_create_tensor (dst);
3332+ }
33393333
33403334 GGML_CANN_CALL_ACLNN_OP (ctx, FusedInferAttentionScoreV2,
33413335 acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
@@ -3357,23 +3351,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33573351 blockSize, antiquantMode, // blockSize, antiquantMode
33583352 softmaxLseFlag, // softmaxLseFlag
33593353 keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
3360- acl_dst_f16_tensor , // attentionOut
3354+ fa_dst_tensor , // attentionOut
33613355 nullptr // softmaxLse
33623356 );
33633357
3364- // Step 6: post-processing, permute and cast to f32
3365- aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3366- // TODO: when dst is fp16, don't need cast
3367- aclnn_cast (ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
3368- ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
3369- acl_src1_f16_tensor,
3370- acl_src2_f16_tensor,
3371- acl_dst_f16_tensor,
3372- acl_dst_tensor);
3373- if (src3 != nullptr ){
3374- ggml_cann_release_resources (ctx, bcast_pse_tensor);
3358+ if (dst->type == GGML_TYPE_F32) {
3359+ // Step 6: post-processing, permute and cast to f32
3360+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3361+ aclnn_cast (ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
33753362 }
3376- }else {
3363+
3364+ ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
3365+ acl_src1_f16_tensor,
3366+ acl_src2_f16_tensor,
3367+ fa_dst_tensor,
3368+ acl_dst_tensor,
3369+ bcast_pse_tensor);
3370+
3371+ } else {
33773372 GGML_ABORT (" Function is not implemented." );
33783373 }
33793374}
0 commit comments