1111#if FLA_ENABLE_AMD_OPT
1212int fla_dpotrf_small_avx2 (char * uplo , integer * n , doublereal * a , integer * lda , integer * info );
1313#endif
14+
15+ /* Threshold values used for tuning thread binding and workload partitioning.*/
16+ #define PROC_BIND_CLOSE_THREADS 96
17+ #define PROC_BIND_CLOSE_SIZE 9760
18+
1419void xerbla_ (const char * srname , const integer * info , ftnlen srname_len );
20+ extern int fla_thread_get_num_threads ();
21+
22+ void dpotrf_auto_tune_params (integer n , int * num_threads , int * block_size )
23+ {
24+ * block_size = FLA_POTRF_BLOCK_SIZE ; // Default block size
1525
16- #define BLOCK_SIZE FLA_POTRF_BLOCK_SIZE
26+ // Get maximum available threads
27+ int max_threads = fla_thread_get_num_threads ();
28+
29+ if (n <= 300 )
30+ {
31+ * num_threads = 8 ;
32+ * block_size = 64 ;
33+ }
34+ else if (n <= 1024 )
35+ {
36+ * num_threads = 8 ;
37+ * block_size = 128 ;
38+ }
39+ else if (n <= 1280 )
40+ {
41+ * num_threads = 32 ;
42+ * block_size = 128 ;
43+ }
44+ else if (n <= 2048 )
45+ {
46+ * num_threads = 32 ;
47+ * block_size = 128 ;
48+ }
49+ else if (n >= 9760 )
50+ {
51+ // Large problems
52+ * num_threads = 128 ;
53+ * block_size = 256 ;
54+ }
55+ else
56+ {
57+ // Medium-large problems
58+ * num_threads = 64 ;
59+ * block_size = 224 ;
60+ }
61+ // Ensure we don't exceed available threads
62+ * num_threads = fla_min (* num_threads , max_threads );
63+ }
1764
1865/* Table of constant values */
1966static integer c__1 = 1 ;
@@ -272,7 +319,8 @@ static size_t A21;
272319static size_t A12 ;
273320static size_t A22 ;
274321
275- #define A (m , n , gm , gn , mb , nb , uplo ) (double *)get_tile_addr_triangle(A, gm, m, n, gm, gn, mb, nb, uplo)
322+ #define A (m , n , gm , gn , mb , nb , uplo ) \
323+ (double *)get_tile_addr_triangle(A, gm, m, n, gm, gn, mb, nb, uplo)
276324
277325static inline integer get_tile_lda (integer mb , integer gm , integer k )
278326{
@@ -321,7 +369,6 @@ static inline void *get_tile_addr_triangle(double *A, integer N, integer m, inte
321369 return (void * )((char * )A + (offset * eltsize ));
322370}
323371
324-
325372static inline integer get_tile_rows (integer mt , integer mb , integer m , integer k )
326373{
327374 if (k < mt - 1 )
@@ -334,23 +381,22 @@ static inline integer get_tile_rows(integer mt, integer mb, integer m, integer k
334381
335382void dpotrf_tile (char * uplo , integer n , double * A , integer lda , integer * iinfo
336383#if AOCL_FLA_PROGRESS_H
337- ,
338- aocl_fla_progress_callback aocl_fla_progress_ptr ,
339- integer progress_step_count ,
340- integer progress_total_threads
384+ ,
385+ aocl_fla_progress_callback aocl_fla_progress_ptr , integer progress_step_count ,
386+ integer progress_total_threads
341387#endif
342388)
343389{
344- #pragma omp task depend(inout : A [0: lda * n])
390+ #pragma omp task depend(inout : A[0 : lda * n])
345391 {
346392 __builtin_prefetch (A , 1 , 3 );
347393#if AOCL_FLA_PROGRESS_H
348394 integer thread_id = omp_get_thread_num ();
349- if (aocl_fla_progress_ptr )
350- {
351- AOCL_FLA_PROGRESS_FUNC_PTR ("DPOTRF" , 6 , & progress_step_count , & thread_id ,
352- & progress_total_threads );
353- }
395+ if (aocl_fla_progress_ptr )
396+ {
397+ AOCL_FLA_PROGRESS_FUNC_PTR ("DPOTRF" , 6 , & progress_step_count , & thread_id ,
398+ & progress_total_threads );
399+ }
354400#endif
355401 lapack_dpotrf (uplo , & n , A , & lda , iinfo );
356402 }
@@ -364,9 +410,7 @@ void dtrsm_tile(char *side, char *uplo, char *transa, char *diag, integer m, int
364410 ak = m ;
365411 else
366412 ak = n ;
367- #pragma omp task depend(in \
368- : A [0:(lda)*ak]) depend(inout \
369- : B [0:(ldb) * (n)])
413+ #pragma omp task depend(in : A[0 : (lda) * ak]) depend(inout : B[0 : (ldb) * (n)])
370414 {
371415 __builtin_prefetch (B , 1 , 3 );
372416 dtrsm_ (side , uplo , transa , diag , & m , & n , alpha , A , & lda , B , & ldb );
@@ -381,9 +425,7 @@ void dsyrk_tile(char *uplo, char *trans, integer n, integer k, double *alpha, do
381425 ak = k ;
382426 else
383427 ak = n ;
384- #pragma omp task depend(in \
385- : A [0:(lda)*ak]) depend(inout \
386- : C [0:(ldc) * (n)])
428+ #pragma omp task depend(in : A[0 : (lda) * ak]) depend(inout : C[0 : (ldc) * (n)])
387429 {
388430 __builtin_prefetch (A , 1 , 3 );
389431 dsyrk_ (uplo , trans , & n , & k , alpha , A , & lda , beta , C , & ldc );
@@ -405,10 +447,8 @@ void dgemm_tile(char *transa, char *transb, integer m, integer n, integer k, dou
405447 bk = n ;
406448 else
407449 bk = k ;
408- #pragma omp task depend(in \
409- : A [0:(lda)*ak]) depend(in \
410- : B [0:(ldb)*bk]) depend(inout \
411- : C [0:(ldc) * (n)])
450+ #pragma omp task depend(in : A[0 : (lda) * ak]) depend(in : B[0 : (ldb) * bk]) \
451+ depend(inout : C[0 : (ldc) * (n)])
412452 {
413453 __builtin_prefetch (C , 1 , 3 );
414454 dgemm_ (transa , transb , & m , & n , & k , alpha , A , & lda , B , & ldb , beta , C , & ldc );
@@ -572,7 +612,7 @@ void omp_dpotrf(char *uplo, double *A, integer *n, integer *lda, integer mt, int
572612
573613void dlacpy_tile (integer m , integer n , double * A , integer lda , double * B , integer ldb )
574614{
575- #pragma omp task depend(in : A [0: (lda) * (n)]) depend(out : B [0: (ldb) * (n)])
615+ #pragma omp task depend(in : A[0 : (lda) * (n)]) depend(out : B[0 : (ldb) * (n)])
576616 {
577617 dlacpy_ ("Full" , & m , & n , A , & lda , B , & ldb );
578618 }
@@ -639,7 +679,7 @@ void matrix_untile(double *pA, integer lda, double *A, integer nb, integer mb, c
639679 * - Frees temporary memory
640680 *
641681 * Memory allocation size calculation:
642- * - Block sizes: nb (column blocks) = 256 , mb (row blocks) = 256
682+ * - Block sizes: nb (column blocks), mb (row blocks) are auto-tuned based on problem size
643683 * - Number of tiles: mt, nt calculated based on matrix dimensions
644684 * - Total memory includes full tiles plus remainder elements
645685 *
@@ -650,8 +690,8 @@ void matrix_untile(double *pA, integer lda, double *A, integer nb, integer mb, c
650690 * @param info Pointer to error information output
651691 *
652692 * @note Memory allocation failure results in fallback to reference algorithm(lapack_dpotrf).
653- *
654- * @note Reference: "A class of parallel tiled linear algebra algorithms for
693+ *
694+ * @note Reference: "A class of parallel tiled linear algebra algorithms for
655695 * multicore architectures. Parallel Computing, 35(1), 38-53"
656696 * by "Buttari, A., Langou, J., Kurzak, J., & Dongarra, J"
657697 */
@@ -677,12 +717,17 @@ void lapack_dpotrf_var1(char *uplo, integer *n, doublereal *A, integer *lda, int
677717 return ;
678718 }
679719 // Quick return if possible
680- if ( * n == 0 )
720+ if (* n == 0 )
681721 {
682722 return ;
683723 }
684- integer nb = BLOCK_SIZE ;
685- integer mb = BLOCK_SIZE ;
724+
725+ // Auto-tune parameters based on problem size
726+ integer auto_num_threads , auto_block_size ;
727+ dpotrf_auto_tune_params (* n , & auto_num_threads , & auto_block_size );
728+
729+ integer nb = auto_block_size ;
730+ integer mb = auto_block_size ;
686731 integer mt = (* n == 0 ) ? 0 : (* n - 1 ) / nb + 1 ;
687732 integer nt = (* n == 0 ) ? 0 : (* n - 1 ) / nb + 1 ;
688733 integer lm1 = * n / mb ;
@@ -706,17 +751,36 @@ void lapack_dpotrf_var1(char *uplo, integer *n, doublereal *A, integer *lda, int
706751 {
707752 uplo_local = 'U' ;
708753 }
709- #pragma omp parallel
754+
755+ if (* n <= PROC_BIND_CLOSE_SIZE || auto_num_threads < PROC_BIND_CLOSE_THREADS )
756+ {
757+ #pragma omp parallel num_threads(auto_num_threads) proc_bind(close)
710758#pragma omp single
759+ {
760+ // Translate to tile layout.
761+ matrix_tile (A , * lda , temp_A , nb , mb , & uplo_local , mt , nt , * n , * n , * n );
762+
763+ // Call to tiled potrf path
764+ omp_dpotrf (& uplo_local , temp_A , n , lda , mt , mb , nb , * n , * n , info );
765+
766+ // Get time for matrix_untile
767+ matrix_untile (A , * lda , temp_A , nb , mb , & uplo_local , mt , nt , * n , * n , * n );
768+ }
769+ }
770+ else
711771 {
712- // Translate to tile layout.
713- matrix_tile (A , * lda , temp_A , nb , mb , & uplo_local , mt , nt , * n , * n , * n );
772+ #pragma omp parallel num_threads(auto_num_threads) proc_bind(spread)
773+ #pragma omp single
774+ {
775+ // Translate to tile layout.
776+ matrix_tile (A , * lda , temp_A , nb , mb , & uplo_local , mt , nt , * n , * n , * n );
714777
715- // Call to tiled potrf path
716- omp_dpotrf (& uplo_local , temp_A , n , lda , mt , mb , nb , * n , * n , info );
778+ // Call to tiled potrf path
779+ omp_dpotrf (& uplo_local , temp_A , n , lda , mt , mb , nb , * n , * n , info );
717780
718- // Get time for matrix_untile
719- matrix_untile (A , * lda , temp_A , nb , mb , & uplo_local , mt , nt , * n , * n , * n );
781+ // Get time for matrix_untile
782+ matrix_untile (A , * lda , temp_A , nb , mb , & uplo_local , mt , nt , * n , * n , * n );
783+ }
720784 }
721785 free (temp_A );
722786}
0 commit comments