Skip to content

Commit

Permalink
Merge pull request #217 from frasercrmck/fix-riscv-local-size-for-sub…
Browse files Browse the repository at this point in the history
…group-count

[mux][riscv] Fix getLocalSizeForSubGroupCount
  • Loading branch information
frasercrmck authored Nov 16, 2023
2 parents 5f7fb86 + 5245e49 commit ef8133e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 10 deletions.
43 changes: 37 additions & 6 deletions modules/mux/targets/riscv/source/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,50 @@ mux_result_t kernel_s::getLocalSizeForSubGroupCount(size_t sub_group_count,
size_t *out_local_size_x,
size_t *out_local_size_y,
size_t *out_local_size_z) {
// FIXME: For a single sub-group, we know we can satisfy that with a
// work-group of 1,1,1. For any other sub-group count, we should ensure that
// the work-group size we report comes back through getKernelVariantForWGSize
// when it comes to run it. See CA-4784.
if (sub_group_count == 1) {
// Grab the maximum sub-group size we've compiled for.
uint32_t max_sub_group_size = 1;
for (auto &v : variant_data) {
max_sub_group_size = std::max(max_sub_group_size, v.sub_group_size);
}

// For simplicity, if we're being asked for just the one sub-group, or the
// kernel's sub-group size is 1, we know we can satisfy the query with a
// work-group of 1,1,1.
if (sub_group_count == 1 || max_sub_group_size == 1) {
*out_local_size_x = 1;
*out_local_size_y = 1;
*out_local_size_z = 1;
} else {
return mux_success;
}

// For any other sub-group count, we should ensure that the work-group size
// we report comes back through getKernelVariantForWGSize when it comes to
// run it.
*out_local_size_x = sub_group_count * max_sub_group_size;
*out_local_size_y = 1;
*out_local_size_z = 1;

// If the required local work-group size would be an invalid work-group size,
// return 0,0,0 as per the specification.
if (*out_local_size_x > device->info->max_work_group_size_x) {
*out_local_size_x = 0;
*out_local_size_y = 0;
*out_local_size_z = 0;
return mux_success;
}

#ifndef NDEBUG
// Double-check that if we were to be asked for the kernel variant for this
// work-group size we've reported, we'd receive a kernel variant with the
// same sub-group size as we've assumed for the calculations.
mux::hal::kernel_variant_s variant;
mux_result_t res = getKernelVariantForWGSize(
*out_local_size_x, *out_local_size_y, *out_local_size_z, &variant);
if (res != mux_success || variant.sub_group_size != max_sub_group_size) {
return mux_error_internal;
}
#endif

return mux_success;
}

Expand Down
2 changes: 1 addition & 1 deletion modules/mux/test/muxQueryLocalSizeForSubGroupCount.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ TEST_P(muxQueryLocalSizeForSubGroupCountTest, ValidateLocalSize) {
// be 1.
const auto one_dimensional_counts =
std::count(std::begin(local_sizes), std::end(local_sizes), 1);
ASSERT_EQ(one_dimensional_counts, 2);
ASSERT_GE(one_dimensional_counts, 2);

// The local size must be evenly divisible by the sub-group size with no
// remainders.
Expand Down
3 changes: 0 additions & 3 deletions source/cl/scripts/cts-3.0-online-ignore-linux-riscv.csv
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,3 @@ SPIR (cl_khr_spir),spir/test_spir conversions_double no-unzip
Mipmaps (Kernel),images/kernel_read_write/test_image_streams test_mipmaps CL_FILTER_NEAREST
Mipmaps (clCopyImage),images/clCopyImage/test_cl_copy_images
Mipmaps (clReadWriteImage),images/clReadWriteImage/test_cl_read_write_images test_mipmaps

# TODO(CA-4540): Fix cl12/api/test_api/sub_group_dispatch.
API,api/test_api sub_group_dispatch

0 comments on commit ef8133e

Please sign in to comment.