Skip to content

Commit

Permalink
add improvements and fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed May 10, 2024
1 parent 9ea809e commit a5658c8
Showing 1 changed file with 27 additions and 32 deletions.
59 changes: 27 additions & 32 deletions samples/99_matrixexperiments/matrix_helpers.cl
Original file line number Diff line number Diff line change
Expand Up @@ -58,49 +58,44 @@ float8 activation(float8 f)
#define __builtin_expect(x)
#endif

inline int get_2d_group_linear_id()
inline uint get_2d_group_linear_id()
{
int lid = (int)get_group_id(1);
lid *= (int)get_num_groups(0);
lid += (int)get_group_id(0);
uint lid = (uint)get_group_id(1);
lid *= (uint)get_num_groups(0);
lid += (uint)get_group_id(0);
return lid;
}

inline int get_swizzled_group_id(int dim)
{
#if defined(SWIZZLE_SNAKE)
const int max_wg_num = 64; // number of work-groups in flight to optimize for
const int wg_num_n = 8; // tuneable
const int wg_num_m = max_wg_num / wg_num_n;
int group_range_n = get_num_groups(0);
int wg_repeat_n = group_range_n / wg_num_n;
int repeat_id = get_2d_group_linear_id() / max_wg_num;
int repeat_id_n = repeat_id % wg_repeat_n;
int repeat_id_m = repeat_id / wg_repeat_n;
if (dim == 0) {
int repeat_start_n_0 = repeat_id_n * wg_num_n;
int repeat_start_n_1 = (wg_repeat_n - repeat_id_n - 1) * wg_num_n;
int repeat_start_n = (repeat_id_m & 1) == 0 ? repeat_start_n_0 : repeat_start_n_1;
int wg_inner_id = get_2d_group_linear_id() % max_wg_num;
int wg_coord_n = wg_inner_id % wg_num_n;
int start_n_id = repeat_start_n + wg_coord_n;
//if (start_n_id >= get_num_groups(0)) {
// printf("ERROR: swizzled group ID(0) is out of range for gid (%d, %d): computed %d, num groups is %d\n",
// (int)get_group_id(0), (int)get_group_id(1), start_n_id, (int)get_num_groups(0));
//}
const uint max_wg_num = 64; // number of work-groups in flight to optimize for
const uint wg_num_n = 8; // tuneable
const uint wg_num_m = max_wg_num / wg_num_n;
uint group_range_n = get_num_groups(0);
uint wg_repeat_n = group_range_n / wg_num_n;
uint repeat_id = get_2d_group_linear_id() / max_wg_num;
uint repeat_id_n = repeat_id % wg_repeat_n;
uint repeat_id_m = repeat_id / wg_repeat_n;
if (dim == 0 & (uint)get_num_groups(0) >= wg_num_n & (uint)get_num_groups(1) >= wg_num_m) {
uint repeat_start_n_0 = repeat_id_n * wg_num_n;
uint repeat_start_n_1 = (wg_repeat_n - repeat_id_n - 1) * wg_num_n;
uint repeat_start_n = (repeat_id_m & 1) == 0 ? repeat_start_n_0 : repeat_start_n_1;
uint wg_inner_id = get_2d_group_linear_id() % max_wg_num;
uint wg_coord_n = wg_inner_id % wg_num_n;
uint start_n_id = repeat_start_n + wg_coord_n;
return start_n_id;
}
if (dim == 1) {
int repeat_start_m = repeat_id_m * wg_num_m;
int wg_inner_id = get_2d_group_linear_id() % max_wg_num;
int wg_coord_m = wg_inner_id / wg_num_n;
int start_m_id = repeat_start_m + wg_coord_m;
//if (start_m_id >= get_num_groups(1)) {
// printf("ERROR: swizzled group ID(1) is out of range for gid (%d, %d): computed %d, num groups is %d\n",
// (int)get_group_id(0), (int)get_group_id(1), start_m_id, (int)get_num_groups(1));
//}
if (dim == 1 & (uint)get_num_groups(0) >= wg_num_n & (uint)get_num_groups(1) >= wg_num_m) {
uint repeat_start_m = repeat_id_m * wg_num_m;
uint wg_inner_id = get_2d_group_linear_id() % max_wg_num;
uint wg_coord_m = wg_inner_id / wg_num_n;
uint start_m_id = repeat_start_m + wg_coord_m;
return start_m_id;
}
//if (get_group_id(0) < 2 && get_group_id(1) < 2) {
// printf("Fallback: get_group_id() = (%d, %d), get_num_groups() = (%d, %d), wg_num_n = %d, wg_num_m = %d\n", (int)get_group_id(0), (int)get_group_id(1), (int)get_num_groups(0), (int)get_num_groups(1), wg_num_n, wg_num_m);
//}
return get_group_id(dim);
#else
return get_group_id(dim);
Expand Down

0 comments on commit a5658c8

Please sign in to comment.