Skip to content

Commit df53834

Browse files
author
Prasada Ch, Venkatesha
authored
AOCL-LAPACK: DPOTRF performance tuning (flame#82)
AOCL-LAPACK: DPOTRF performance tuning -> Included a function to set the number of threads and block size for dpotrf. AMD-Internal: [CPUPL-6637]
1 parent 45afda3 commit df53834

File tree

1 file changed

+101
-37
lines changed

1 file changed

+101
-37
lines changed

src/lapack/dec/chol/front/flamec/lapack_dpotrf.c

Lines changed: 101 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,56 @@
1111
#if FLA_ENABLE_AMD_OPT
1212
int fla_dpotrf_small_avx2(char *uplo, integer *n, doublereal *a, integer *lda, integer *info);
1313
#endif
14+
15+
/* Threshold values used for tuning thread binding and workload partitioning.*/
16+
#define PROC_BIND_CLOSE_THREADS 96
17+
#define PROC_BIND_CLOSE_SIZE 9760
18+
1419
void xerbla_(const char *srname, const integer *info, ftnlen srname_len);
20+
extern int fla_thread_get_num_threads();
21+
22+
void dpotrf_auto_tune_params(integer n, int *num_threads, int *block_size)
23+
{
24+
*block_size = FLA_POTRF_BLOCK_SIZE; // Default block size
1525

16-
#define BLOCK_SIZE FLA_POTRF_BLOCK_SIZE
26+
// Get maximum available threads
27+
int max_threads = fla_thread_get_num_threads();
28+
29+
if(n <= 300)
30+
{
31+
*num_threads = 8;
32+
*block_size = 64;
33+
}
34+
else if(n <= 1024)
35+
{
36+
*num_threads = 8;
37+
*block_size = 128;
38+
}
39+
else if(n <= 1280)
40+
{
41+
*num_threads = 32;
42+
*block_size = 128;
43+
}
44+
else if(n <= 2048)
45+
{
46+
*num_threads = 32;
47+
*block_size = 128;
48+
}
49+
else if(n >= 9760)
50+
{
51+
// Large problems
52+
*num_threads = 128;
53+
*block_size = 256;
54+
}
55+
else
56+
{
57+
// Medium-large problems
58+
*num_threads = 64;
59+
*block_size = 224;
60+
}
61+
// Ensure we don't exceed available threads
62+
*num_threads = fla_min(*num_threads, max_threads);
63+
}
1764

1865
/* Table of constant values */
1966
static integer c__1 = 1;
@@ -272,7 +319,8 @@ static size_t A21;
272319
static size_t A12;
273320
static size_t A22;
274321

275-
#define A(m, n, gm, gn, mb, nb, uplo) (double *)get_tile_addr_triangle(A, gm, m, n, gm, gn, mb, nb, uplo)
322+
#define A(m, n, gm, gn, mb, nb, uplo) \
323+
(double *)get_tile_addr_triangle(A, gm, m, n, gm, gn, mb, nb, uplo)
276324

277325
static inline integer get_tile_lda(integer mb, integer gm, integer k)
278326
{
@@ -321,7 +369,6 @@ static inline void *get_tile_addr_triangle(double *A, integer N, integer m, inte
321369
return (void *)((char *)A + (offset * eltsize));
322370
}
323371

324-
325372
static inline integer get_tile_rows(integer mt, integer mb, integer m, integer k)
326373
{
327374
if(k < mt - 1)
@@ -334,23 +381,22 @@ static inline integer get_tile_rows(integer mt, integer mb, integer m, integer k
334381

335382
void dpotrf_tile(char *uplo, integer n, double *A, integer lda, integer *iinfo
336383
#if AOCL_FLA_PROGRESS_H
337-
,
338-
aocl_fla_progress_callback aocl_fla_progress_ptr,
339-
integer progress_step_count,
340-
integer progress_total_threads
384+
,
385+
aocl_fla_progress_callback aocl_fla_progress_ptr, integer progress_step_count,
386+
integer progress_total_threads
341387
#endif
342388
)
343389
{
344-
#pragma omp task depend(inout : A [0:lda * n])
390+
#pragma omp task depend(inout : A[0 : lda * n])
345391
{
346392
__builtin_prefetch(A, 1, 3);
347393
#if AOCL_FLA_PROGRESS_H
348394
integer thread_id = omp_get_thread_num();
349-
if(aocl_fla_progress_ptr)
350-
{
351-
AOCL_FLA_PROGRESS_FUNC_PTR("DPOTRF", 6, &progress_step_count, &thread_id,
352-
&progress_total_threads);
353-
}
395+
if(aocl_fla_progress_ptr)
396+
{
397+
AOCL_FLA_PROGRESS_FUNC_PTR("DPOTRF", 6, &progress_step_count, &thread_id,
398+
&progress_total_threads);
399+
}
354400
#endif
355401
lapack_dpotrf(uplo, &n, A, &lda, iinfo);
356402
}
@@ -364,9 +410,7 @@ void dtrsm_tile(char *side, char *uplo, char *transa, char *diag, integer m, int
364410
ak = m;
365411
else
366412
ak = n;
367-
#pragma omp task depend(in \
368-
: A [0:(lda)*ak]) depend(inout \
369-
: B [0:(ldb) * (n)])
413+
#pragma omp task depend(in : A[0 : (lda) * ak]) depend(inout : B[0 : (ldb) * (n)])
370414
{
371415
__builtin_prefetch(B, 1, 3);
372416
dtrsm_(side, uplo, transa, diag, &m, &n, alpha, A, &lda, B, &ldb);
@@ -381,9 +425,7 @@ void dsyrk_tile(char *uplo, char *trans, integer n, integer k, double *alpha, do
381425
ak = k;
382426
else
383427
ak = n;
384-
#pragma omp task depend(in \
385-
: A [0:(lda)*ak]) depend(inout \
386-
: C [0:(ldc) * (n)])
428+
#pragma omp task depend(in : A[0 : (lda) * ak]) depend(inout : C[0 : (ldc) * (n)])
387429
{
388430
__builtin_prefetch(A, 1, 3);
389431
dsyrk_(uplo, trans, &n, &k, alpha, A, &lda, beta, C, &ldc);
@@ -405,10 +447,8 @@ void dgemm_tile(char *transa, char *transb, integer m, integer n, integer k, dou
405447
bk = n;
406448
else
407449
bk = k;
408-
#pragma omp task depend(in \
409-
: A [0:(lda)*ak]) depend(in \
410-
: B [0:(ldb)*bk]) depend(inout \
411-
: C [0:(ldc) * (n)])
450+
#pragma omp task depend(in : A[0 : (lda) * ak]) depend(in : B[0 : (ldb) * bk]) \
451+
depend(inout : C[0 : (ldc) * (n)])
412452
{
413453
__builtin_prefetch(C, 1, 3);
414454
dgemm_(transa, transb, &m, &n, &k, alpha, A, &lda, B, &ldb, beta, C, &ldc);
@@ -572,7 +612,7 @@ void omp_dpotrf(char *uplo, double *A, integer *n, integer *lda, integer mt, int
572612

573613
void dlacpy_tile(integer m, integer n, double *A, integer lda, double *B, integer ldb)
574614
{
575-
#pragma omp task depend(in : A [0:(lda) * (n)]) depend(out : B [0:(ldb) * (n)])
615+
#pragma omp task depend(in : A[0 : (lda) * (n)]) depend(out : B[0 : (ldb) * (n)])
576616
{
577617
dlacpy_("Full", &m, &n, A, &lda, B, &ldb);
578618
}
@@ -639,7 +679,7 @@ void matrix_untile(double *pA, integer lda, double *A, integer nb, integer mb, c
639679
* - Frees temporary memory
640680
*
641681
* Memory allocation size calculation:
642-
* - Block sizes: nb (column blocks) = 256, mb (row blocks) = 256
682+
* - Block sizes: nb (column blocks), mb (row blocks) are auto-tuned based on problem size
643683
* - Number of tiles: mt, nt calculated based on matrix dimensions
644684
* - Total memory includes full tiles plus remainder elements
645685
*
@@ -650,8 +690,8 @@ void matrix_untile(double *pA, integer lda, double *A, integer nb, integer mb, c
650690
* @param info Pointer to error information output
651691
*
652692
* @note Memory allocation failure results in fallback to reference algorithm(lapack_dpotrf).
653-
*
654-
* @note Reference: "A class of parallel tiled linear algebra algorithms for
693+
*
694+
* @note Reference: "A class of parallel tiled linear algebra algorithms for
655695
* multicore architectures. Parallel Computing, 35(1), 38-53"
656696
* by "Buttari, A., Langou, J., Kurzak, J., & Dongarra, J"
657697
*/
@@ -677,12 +717,17 @@ void lapack_dpotrf_var1(char *uplo, integer *n, doublereal *A, integer *lda, int
677717
return;
678718
}
679719
// Quick return if possible
680-
if( *n == 0)
720+
if(*n == 0)
681721
{
682722
return;
683723
}
684-
integer nb = BLOCK_SIZE;
685-
integer mb = BLOCK_SIZE;
724+
725+
// Auto-tune parameters based on problem size
726+
integer auto_num_threads, auto_block_size;
727+
dpotrf_auto_tune_params(*n, &auto_num_threads, &auto_block_size);
728+
729+
integer nb = auto_block_size;
730+
integer mb = auto_block_size;
686731
integer mt = (*n == 0) ? 0 : (*n - 1) / nb + 1;
687732
integer nt = (*n == 0) ? 0 : (*n - 1) / nb + 1;
688733
integer lm1 = *n / mb;
@@ -706,17 +751,36 @@ void lapack_dpotrf_var1(char *uplo, integer *n, doublereal *A, integer *lda, int
706751
{
707752
uplo_local = 'U';
708753
}
709-
#pragma omp parallel
754+
755+
if(*n <= PROC_BIND_CLOSE_SIZE || auto_num_threads < PROC_BIND_CLOSE_THREADS)
756+
{
757+
#pragma omp parallel num_threads(auto_num_threads) proc_bind(close)
710758
#pragma omp single
759+
{
760+
// Translate to tile layout.
761+
matrix_tile(A, *lda, temp_A, nb, mb, &uplo_local, mt, nt, *n, *n, *n);
762+
763+
// Call to tiled potrf path
764+
omp_dpotrf(&uplo_local, temp_A, n, lda, mt, mb, nb, *n, *n, info);
765+
766+
// Get time for matrix_untile
767+
matrix_untile(A, *lda, temp_A, nb, mb, &uplo_local, mt, nt, *n, *n, *n);
768+
}
769+
}
770+
else
711771
{
712-
// Translate to tile layout.
713-
matrix_tile(A, *lda, temp_A, nb, mb, &uplo_local, mt, nt, *n, *n, *n);
772+
#pragma omp parallel num_threads(auto_num_threads) proc_bind(spread)
773+
#pragma omp single
774+
{
775+
// Translate to tile layout.
776+
matrix_tile(A, *lda, temp_A, nb, mb, &uplo_local, mt, nt, *n, *n, *n);
714777

715-
// Call to tiled potrf path
716-
omp_dpotrf(&uplo_local, temp_A, n, lda, mt, mb, nb, *n, *n, info);
778+
// Call to tiled potrf path
779+
omp_dpotrf(&uplo_local, temp_A, n, lda, mt, mb, nb, *n, *n, info);
717780

718-
// Get time for matrix_untile
719-
matrix_untile(A, *lda, temp_A, nb, mb, &uplo_local, mt, nt, *n, *n, *n);
781+
// Get time for matrix_untile
782+
matrix_untile(A, *lda, temp_A, nb, mb, &uplo_local, mt, nt, *n, *n, *n);
783+
}
720784
}
721785
free(temp_A);
722786
}

0 commit comments

Comments
 (0)