Skip to content

Commit 039f524

Browse files
committed
Extract common logic from single_task and parallel_for handler-less submit,
add a kernel move unit test for single_task
1 parent 67bbc2a commit 039f524

File tree

3 files changed

+110
-71
lines changed

3 files changed

+110
-71
lines changed

sycl/include/sycl/khr/free_function_commands.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,9 @@ void launch_task(handler &h, const KernelType &k) {
319319
h.single_task(k);
320320
}
321321

322-
template <typename KernelType>
323-
void launch_task(const sycl::queue &q, const KernelType &k,
322+
template <typename KernelType, typename = typename std::enable_if_t<
323+
enable_kernel_function_overload<KernelType>>>
324+
void launch_task(const sycl::queue &q, KernelType &&k,
324325
const sycl::detail::code_location &codeLoc =
325326
sycl::detail::code_location::current()) {
326327
// TODO The handler-less path does not support kernel function properties
@@ -331,7 +332,8 @@ void launch_task(const sycl::queue &q, const KernelType &k,
331332
!(detail::KernelLambdaHasKernelHandlerArgT<KernelType,
332333
void>::value)) {
333334
detail::submit_kernel_direct_single_task(
334-
q, ext::oneapi::experimental::empty_properties_t{}, k, codeLoc);
335+
q, ext::oneapi::experimental::empty_properties_t{},
336+
std::forward<KernelType>(k), codeLoc);
335337
} else {
336338
submit(q, [&](handler &h) { launch_task<KernelType>(h, k); }, codeLoc);
337339
}

sycl/include/sycl/queue.hpp

Lines changed: 53 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -157,70 +157,12 @@ class __SYCL_EXPORT SubmissionInfo {
157157

158158
} // namespace v1
159159

160-
template <typename KernelName = detail::auto_name, bool EventNeeded = false,
160+
template <detail::WrapAs WrapAs, typename LambdaArgType,
161+
typename KernelName = detail::auto_name, bool EventNeeded = false,
161162
typename PropertiesT, typename KernelTypeUniversalRef, int Dims>
162-
auto submit_kernel_direct_parallel_for(
163-
const queue &Queue, PropertiesT Props, const nd_range<Dims> &Range,
164-
KernelTypeUniversalRef &&KernelFunc,
165-
const detail::code_location &CodeLoc = detail::code_location::current()) {
166-
// TODO Properties not supported yet
167-
(void)Props;
168-
static_assert(
169-
std::is_same_v<PropertiesT,
170-
ext::oneapi::experimental::empty_properties_t>,
171-
"Setting properties not supported yet for no-CGH kernel submit.");
172-
detail::tls_code_loc_t TlsCodeLocCapture(CodeLoc);
173-
174-
using KernelType =
175-
std::remove_const_t<std::remove_reference_t<KernelTypeUniversalRef>>;
176-
177-
using NameT =
178-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
179-
using LambdaArgType =
180-
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
181-
static_assert(
182-
std::is_convertible_v<sycl::nd_item<Dims>, LambdaArgType>,
183-
"Kernel argument of a sycl::parallel_for with sycl::nd_range "
184-
"must be either sycl::nd_item or be convertible from sycl::nd_item");
185-
using TransformedArgType = sycl::nd_item<Dims>;
186-
187-
#ifndef __SYCL_DEVICE_ONLY__
188-
detail::checkValueRange<Dims>(Range);
189-
#endif
190-
191-
detail::KernelWrapper<detail::WrapAs::parallel_for, NameT, KernelType,
192-
TransformedArgType, PropertiesT>::wrap(KernelFunc);
193-
194-
HostKernelRef<KernelType, KernelTypeUniversalRef, TransformedArgType, Dims>
195-
HostKernel(std::forward<KernelTypeUniversalRef>(KernelFunc));
196-
197-
// Instantiating the kernel on the host improves debugging.
198-
// Passing this pointer to another translation unit prevents optimization.
199-
#ifndef NDEBUG
200-
// TODO: call library to prevent dropping call due to optimization
201-
(void)
202-
detail::GetInstantiateKernelOnHostPtr<KernelType, LambdaArgType, Dims>();
203-
#endif
204-
205-
detail::DeviceKernelInfo *DeviceKernelInfoPtr =
206-
&detail::getDeviceKernelInfo<NameT>();
207-
208-
if constexpr (EventNeeded) {
209-
return submit_kernel_direct_with_event_impl(
210-
Queue, Range, HostKernel, DeviceKernelInfoPtr,
211-
TlsCodeLocCapture.query(), TlsCodeLocCapture.isToplevel());
212-
} else {
213-
submit_kernel_direct_without_event_impl(
214-
Queue, Range, HostKernel, DeviceKernelInfoPtr,
215-
TlsCodeLocCapture.query(), TlsCodeLocCapture.isToplevel());
216-
}
217-
}
218-
219-
template <typename KernelName = detail::auto_name, bool EventNeeded = false,
220-
typename PropertiesT, typename KernelTypeUniversalRef>
221-
auto submit_kernel_direct_single_task(
163+
auto submit_kernel_direct(
222164
const queue &Queue, [[maybe_unused]] PropertiesT Props,
223-
KernelTypeUniversalRef &&KernelFunc,
165+
const nd_range<Dims> &Range, KernelTypeUniversalRef &&KernelFunc,
224166
const detail::code_location &CodeLoc = detail::code_location::current()) {
225167
// TODO Properties not supported yet
226168
static_assert(
@@ -235,17 +177,18 @@ auto submit_kernel_direct_single_task(
235177
using NameT =
236178
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
237179

238-
detail::KernelWrapper<detail::WrapAs::single_task, NameT, KernelType, void,
180+
detail::KernelWrapper<WrapAs, NameT, KernelType, LambdaArgType,
239181
PropertiesT>::wrap(KernelFunc);
240182

241-
HostKernelRef<KernelType, KernelTypeUniversalRef, void, 1> HostKernel(
242-
std::forward<KernelTypeUniversalRef>(KernelFunc));
183+
HostKernelRef<KernelType, KernelTypeUniversalRef, LambdaArgType, Dims>
184+
HostKernel(std::forward<KernelTypeUniversalRef>(KernelFunc));
243185

244186
// Instantiating the kernel on the host improves debugging.
245187
// Passing this pointer to another translation unit prevents optimization.
246188
#ifndef NDEBUG
247189
// TODO: call library to prevent dropping call due to optimization.
248-
(void)detail::GetInstantiateKernelOnHostPtr<KernelType, void, 1>();
190+
(void)
191+
detail::GetInstantiateKernelOnHostPtr<KernelType, LambdaArgType, Dims>();
249192
#endif
250193

251194
detail::DeviceKernelInfo *DeviceKernelInfoPtr =
@@ -269,15 +212,57 @@ auto submit_kernel_direct_single_task(
269212

270213
if constexpr (EventNeeded) {
271214
return submit_kernel_direct_with_event_impl(
272-
Queue, nd_range<1>{1, 1}, HostKernel, DeviceKernelInfoPtr,
215+
Queue, Range, HostKernel, DeviceKernelInfoPtr,
273216
TlsCodeLocCapture.query(), TlsCodeLocCapture.isToplevel());
274217
} else {
275218
submit_kernel_direct_without_event_impl(
276-
Queue, nd_range<1>{1, 1}, HostKernel, DeviceKernelInfoPtr,
219+
Queue, Range, HostKernel, DeviceKernelInfoPtr,
277220
TlsCodeLocCapture.query(), TlsCodeLocCapture.isToplevel());
278221
}
279222
}
280223

224+
template <typename KernelName = detail::auto_name, bool EventNeeded = false,
225+
typename PropertiesT, typename KernelTypeUniversalRef, int Dims>
226+
auto submit_kernel_direct_parallel_for(
227+
const queue &Queue, PropertiesT Props, const nd_range<Dims> &Range,
228+
KernelTypeUniversalRef &&KernelFunc,
229+
const detail::code_location &CodeLoc = detail::code_location::current()) {
230+
231+
using KernelType =
232+
std::remove_const_t<std::remove_reference_t<KernelTypeUniversalRef>>;
233+
234+
using LambdaArgType =
235+
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
236+
static_assert(
237+
std::is_convertible_v<sycl::nd_item<Dims>, LambdaArgType>,
238+
"Kernel argument of a sycl::parallel_for with sycl::nd_range "
239+
"must be either sycl::nd_item or be convertible from sycl::nd_item");
240+
using TransformedArgType = sycl::nd_item<Dims>;
241+
242+
#ifndef __SYCL_DEVICE_ONLY__
243+
detail::checkValueRange<Dims>(Range);
244+
#endif
245+
246+
return submit_kernel_direct<detail::WrapAs::parallel_for, TransformedArgType,
247+
KernelName, EventNeeded, PropertiesT,
248+
KernelTypeUniversalRef, Dims>(
249+
Queue, Props, Range, std::forward<KernelTypeUniversalRef>(KernelFunc),
250+
CodeLoc);
251+
}
252+
253+
template <typename KernelName = detail::auto_name, bool EventNeeded = false,
254+
typename PropertiesT, typename KernelTypeUniversalRef>
255+
auto submit_kernel_direct_single_task(
256+
const queue &Queue, PropertiesT Props, KernelTypeUniversalRef &&KernelFunc,
257+
const detail::code_location &CodeLoc = detail::code_location::current()) {
258+
259+
return submit_kernel_direct<detail::WrapAs::single_task, void, KernelName,
260+
EventNeeded, PropertiesT, KernelTypeUniversalRef,
261+
1>(
262+
Queue, Props, nd_range<1>{1, 1},
263+
std::forward<KernelTypeUniversalRef>(KernelFunc), CodeLoc);
264+
}
265+
281266
} // namespace detail
282267

283268
namespace ext ::oneapi ::experimental {

sycl/unittests/Extensions/FreeFunctionCommands/FreeFunctionCommandsEvents.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class FreeFunctionCommandsEventsTests : public ::testing::Test {
7878
protected:
7979
void SetUp() override {
8080
counter_urEnqueueKernelLaunch = 0;
81+
counter_urEnqueueKernelLaunchWithEvent = 0;
8182
counter_urUSMEnqueueMemcpy = 0;
8283
counter_urUSMEnqueueFill = 0;
8384
counter_urUSMEnqueuePrefetch = 0;
@@ -281,6 +282,57 @@ TEST_F(FreeFunctionCommandsEventsTests,
281282
ASSERT_EQ(counter_urEnqueueKernelLaunchWithEvent, size_t{1});
282283
}
283284

285+
TEST_F(FreeFunctionCommandsEventsTests, LaunchTaskShortcutMoveKernel) {
286+
mock::getCallbacks().set_replace_callback("urEnqueueKernelLaunch",
287+
&redefined_urEnqueueKernelLaunch);
288+
289+
TestMoveFunctor::MoveCtorCalls = 0;
290+
TestMoveFunctor MoveOnly;
291+
std::mutex CvMutex;
292+
std::condition_variable Cv;
293+
bool ready = false;
294+
295+
// This kernel submission uses scheduler-bypass path, so the HostKernel
296+
// shouldn't be constructed.
297+
298+
sycl::khr::launch_task(Queue, std::move(MoveOnly));
299+
300+
ASSERT_EQ(TestMoveFunctor::MoveCtorCalls, 0);
301+
ASSERT_EQ(counter_urEnqueueKernelLaunch, size_t{1});
302+
303+
// Another kernel submission is queued behind a host task,
304+
// to force the scheduler-based submission. In this case, the HostKernel
305+
// should be constructed.
306+
307+
// Replace the callback with an event based one, since the scheduler
308+
// needs to create an event internally
309+
mock::getCallbacks().set_replace_callback(
310+
"urEnqueueKernelLaunch", &redefined_urEnqueueKernelLaunchWithEvent);
311+
312+
Queue.submit([&](sycl::handler &CGH) {
313+
CGH.host_task([&] {
314+
std::unique_lock<std::mutex> lk(CvMutex);
315+
Cv.wait(lk, [&ready] { return ready; });
316+
});
317+
});
318+
319+
sycl::khr::launch_task(Queue, std::move(MoveOnly));
320+
321+
{
322+
std::unique_lock<std::mutex> lk(CvMutex);
323+
ready = true;
324+
}
325+
Cv.notify_one();
326+
327+
Queue.wait();
328+
329+
// Move ctor for TestMoveFunctor is called during move construction of
330+
// HostKernel. Copy ctor is called by InstantiateKernelOnHost, can't delete
331+
// it.
332+
ASSERT_EQ(TestMoveFunctor::MoveCtorCalls, 1);
333+
ASSERT_EQ(counter_urEnqueueKernelLaunchWithEvent, size_t{1});
334+
}
335+
284336
TEST_F(FreeFunctionCommandsEventsTests, SubmitLaunchGroupedKernelNoEvent) {
285337
mock::getCallbacks().set_replace_callback("urEnqueueKernelLaunch",
286338
&redefined_urEnqueueKernelLaunch);

0 commit comments

Comments
 (0)