Skip to content

Commit 7b1c02e

Browse files
authored
[SYCLomatic] Make sure thrust::sort_by_key is migrated correctly for USM (#769)
Signed-off-by: Lu, John <[email protected]>
1 parent 6a623c2 commit 7b1c02e

File tree

6 files changed

+389
-161
lines changed

6 files changed

+389
-161
lines changed

clang/lib/DPCT/APINamesThrust.inc

Lines changed: 10 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -242,125 +242,11 @@ CALL_FACTORY_ENTRY("thrust::swap", CALL("std::swap", ARG(0), ARG(1)))
242242
CALL_FACTORY_ENTRY("thrust::make_pair", CALL("std::make_pair", ARG(0), ARG(1)))
243243

244244
// thrust::stable_sort_by_key
245-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::DplExtrasAlgorithm_stable_sort,
246-
CONDITIONAL_FACTORY_ENTRY(
247-
CheckArgCount(5),
248-
CONDITIONAL_FACTORY_ENTRY(
249-
makeCheckAnd(CheckIsPtr(1), makeCheckNot(checkIsUSM())),
250-
// Handling case: thrust::stable_sort_by_key(policy, ptr, ptr, ptr, comp)
251-
IFELSE_FACTORY_ENTRY(
252-
"thrust::stable_sort_by_key",
253-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::Memory_is_device_ptr,
254-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key", CALL(MapNames::getDpctNamespace() + "is_device_ptr", ARG(1)))),
255-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::DplExtrasMemory_device_pointer_forward_decl,
256-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
257-
CALL(MapNames::getDpctNamespace() + "stable_sort",
258-
CALL("oneapi::dpl::execution::make_device_policy", QUEUESTR),
259-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(1)), ARG(1)),
260-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(2)), ARG(2)),
261-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(3)), ARG(3)),
262-
ARG(4)))),
263-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
264-
CALL(MapNames::getDpctNamespace() + "stable_sort",
265-
ARG("oneapi::dpl::execution::seq"),
266-
ARG(1), ARG(2), ARG(3), ARG(4)))),
267-
// Handling case: thrust::stable_sort_by_key(thrust::device,keys_first, keys_last, values_first, comp)
268-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
269-
CALL(MapNames::getDpctNamespace() + "stable_sort",
270-
makeMappedThrustPolicyEnum(0),
271-
ARG(1), ARG(2), ARG(3), ARG(4)))
272-
),
273-
CONDITIONAL_FACTORY_ENTRY(
274-
CheckArgCount(3),
275-
CONDITIONAL_FACTORY_ENTRY(
276-
makeCheckAnd(CheckIsPtr(1), makeCheckNot(checkIsUSM())),
277-
// Handling case: thrust::stable_sort_by_key(ptr, ptr, ptr)
278-
IFELSE_FACTORY_ENTRY(
279-
"thrust::stable_sort_by_key",
280-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::Memory_is_device_ptr,
281-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key", CALL(MapNames::getDpctNamespace() + "is_device_ptr", ARG(1)))),
282-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::DplExtrasMemory_device_pointer_forward_decl,
283-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
284-
CALL(MapNames::getDpctNamespace() + "stable_sort",
285-
CALL("oneapi::dpl::execution::make_device_policy", QUEUESTR),
286-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(0)), ARG(0)),
287-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(1)), ARG(1)),
288-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(2)), ARG(2))))),
289-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
290-
CALL(MapNames::getDpctNamespace() + "stable_sort",
291-
ARG("oneapi::dpl::execution::seq"),
292-
ARG(0), ARG(1), ARG(2)))),
293-
CONDITIONAL_FACTORY_ENTRY(
294-
CheckThrustArgType(1, "thrust::device_ptr"),
295-
// Handling case: thrust::stable_sort_by_key(h_keys.begin(), h_keys.end(), h_values.begin());
296-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
297-
CALL(MapNames::getDpctNamespace() + "stable_sort",
298-
CALL("oneapi::dpl::execution::make_device_policy", QUEUESTR),
299-
ARG(0), ARG(1), ARG(2))),
300-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
301-
CALL(MapNames::getDpctNamespace() + "stable_sort",
302-
ARG("oneapi::dpl::execution::seq"),
303-
ARG(0), ARG(1), ARG(2))))
304-
),
305-
CONDITIONAL_FACTORY_ENTRY(
306-
makeCheckAnd(CheckIsPtr(1), makeCheckNot(checkIsUSM())),
307-
CONDITIONAL_FACTORY_ENTRY(
308-
IsPolicyArgType(0),
309-
// Handling case: thrust::stable_sort_by_key(policy, ptr, ptr, ptr);
310-
IFELSE_FACTORY_ENTRY(
311-
"thrust::stable_sort_by_key",
312-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::Memory_is_device_ptr,
313-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key", CALL(MapNames::getDpctNamespace() + "is_device_ptr", ARG(1)))),
314-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::DplExtrasMemory_device_pointer_forward_decl,
315-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
316-
CALL(MapNames::getDpctNamespace() + "stable_sort",
317-
CALL("oneapi::dpl::execution::make_device_policy", QUEUESTR),
318-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(1)), ARG(1)),
319-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(2)), ARG(2)),
320-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(3)), ARG(3))))),
321-
CALL_FACTORY_ENTRY("thrust::unique_by_key_copy",
322-
CALL(MapNames::getDpctNamespace() + "stable_sort",
323-
ARG("oneapi::dpl::execution::seq"),
324-
ARG(1), ARG(2), ARG(3)))),
325-
//Handling case: thrust::stable_sort_by_key(ptr, ptr, ptr,pred);
326-
IFELSE_FACTORY_ENTRY(
327-
"thrust::stable_sort_by_key",
328-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::Memory_is_device_ptr,
329-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key", CALL(MapNames::getDpctNamespace() + "is_device_ptr", ARG(1)))),
330-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::DplExtrasMemory_device_pointer_forward_decl,
331-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
332-
CALL(MapNames::getDpctNamespace() + "stable_sort",
333-
CALL("oneapi::dpl::execution::make_device_policy", QUEUESTR),
334-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(0)), ARG(0)),
335-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(1)), ARG(1)),
336-
CALL(TEMPLATED_CALLEE_WITH_ARGS(MapNames::getDpctNamespace() + "device_pointer", getDerefedType(2)), ARG(2)),
337-
ARG(3)))),
338-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
339-
CALL(MapNames::getDpctNamespace() + "stable_sort",
340-
ARG("oneapi::dpl::execution::seq"),
341-
ARG(0), ARG(1), ARG(2), ARG(3))))),
342-
CONDITIONAL_FACTORY_ENTRY(
343-
IsPolicyArgType(0),
344-
// Handling case: thrust::stable_sort_by_key(thrust::device, d_keys.begin(), d_keys.end(), d_values.begin());
345-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
346-
CALL(MapNames::getDpctNamespace() + "stable_sort",
347-
makeMappedThrustPolicyEnum(0),
348-
ARG(1), ARG(2), ARG(3))),
349-
CONDITIONAL_FACTORY_ENTRY(
350-
CheckThrustArgType(1, "thrust::device_ptr"),
351-
// Handling case: thrust::stable_sort_by_key(d_keys.begin(), d_keys.end(), d_values.begin(),binary_pred);
352-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
353-
CALL(MapNames::getDpctNamespace() + "stable_sort",
354-
CALL("oneapi::dpl::execution::make_device_policy", QUEUESTR),
355-
ARG(0), ARG(1), ARG(2), ARG(3))),
356-
CALL_FACTORY_ENTRY("thrust::stable_sort_by_key",
357-
CALL(MapNames::getDpctNamespace() + "stable_sort",
358-
ARG("oneapi::dpl::execution::seq"),
359-
ARG(0), ARG(1), ARG(2), ARG(3)))))
360-
)
361-
)
362-
)
363-
)
245+
thrustFactory("thrust::stable_sort_by_key",
246+
{{5,PolicyState::HasPolicy,3,MapNames::getDpctNamespace() + "stable_sort", HelperFeatureEnum::DplExtrasAlgorithm_stable_sort},
247+
{4,PolicyState::HasPolicy,3,MapNames::getDpctNamespace() + "stable_sort", HelperFeatureEnum::DplExtrasAlgorithm_stable_sort},
248+
{4,PolicyState::NoPolicy ,3,MapNames::getDpctNamespace() + "stable_sort", HelperFeatureEnum::DplExtrasAlgorithm_stable_sort},
249+
{3,PolicyState::NoPolicy ,3,MapNames::getDpctNamespace() + "stable_sort", HelperFeatureEnum::DplExtrasAlgorithm_stable_sort}}),
364250

365251
// thrust::find
366252
CONDITIONAL_FACTORY_ENTRY(
@@ -384,47 +270,11 @@ CONDITIONAL_FACTORY_ENTRY(
384270
)
385271

386272
// thrust::sort_by_key
387-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::DplExtrasAlgorithm_sort,
388-
CONDITIONAL_FACTORY_ENTRY(
389-
CheckArgCount(5),
390-
//Handling case: thrust::sort_by_key(policy, device.begin(), device.end(), device.begin(), comp);
391-
CALL_FACTORY_ENTRY("thrust::sort_by_key",
392-
CALL(MapNames::getDpctNamespace() + "sort",
393-
makeMappedThrustPolicyEnum(0),
394-
ARG(1), ARG(2), ARG(3), ARG(4))),
395-
CONDITIONAL_FACTORY_ENTRY(
396-
CheckArgCount(3),
397-
//Handling case: thrust::sort_by_key(host.begin(), host.end(), host.begin());
398-
CONDITIONAL_FACTORY_ENTRY(
399-
CheckThrustArgType(1, "thrust::device_ptr"),
400-
CALL_FACTORY_ENTRY("thrust::sort_by_key",
401-
CALL(MapNames::getDpctNamespace() + "sort",
402-
CALL("oneapi::dpl::execution::make_device_policy", QUEUESTR),
403-
ARG(0), ARG(1), ARG(2))),
404-
CALL_FACTORY_ENTRY("thrust::sort_by_key",
405-
CALL(MapNames::getDpctNamespace() + "sort",
406-
ARG("oneapi::dpl::execution::seq"),
407-
ARG(0), ARG(1), ARG(2)))),
408-
CONDITIONAL_FACTORY_ENTRY(
409-
IsPolicyArgType(0),
410-
//Handling case: thrust::sort_by_key(policy, device.begin(), device.end(), device.begin());
411-
CALL_FACTORY_ENTRY("thrust::sort_by_key",
412-
CALL(MapNames::getDpctNamespace() + "sort",
413-
makeMappedThrustPolicyEnum(0),
414-
ARG(1), ARG(2), ARG(3))),
415-
//Handling case: thrust::sort_by_key(device.begin(), device.end(), device.begin(), comp);
416-
CONDITIONAL_FACTORY_ENTRY(
417-
CheckThrustArgType(1, "thrust::device_ptr"),
418-
CALL_FACTORY_ENTRY("thrust::sort_by_key",
419-
CALL(MapNames::getDpctNamespace() + "sort",
420-
CALL("oneapi::dpl::execution::make_device_policy", QUEUESTR),
421-
ARG(0), ARG(1), ARG(2), ARG(3))),
422-
CALL_FACTORY_ENTRY("thrust::sort_by_key",
423-
CALL(MapNames::getDpctNamespace() + "sort",
424-
ARG("oneapi::dpl::execution::seq"),
425-
ARG(0), ARG(1), ARG(2), ARG(3))))))
426-
))
427-
273+
thrustFactory("thrust::sort_by_key",
274+
{{5,PolicyState::HasPolicy,3,MapNames::getDpctNamespace() + "sort", HelperFeatureEnum::DplExtrasAlgorithm_sort},
275+
{4,PolicyState::HasPolicy,3,MapNames::getDpctNamespace() + "sort", HelperFeatureEnum::DplExtrasAlgorithm_sort},
276+
{4,PolicyState::NoPolicy ,3,MapNames::getDpctNamespace() + "sort", HelperFeatureEnum::DplExtrasAlgorithm_sort},
277+
{3,PolicyState::NoPolicy ,3,MapNames::getDpctNamespace() + "sort", HelperFeatureEnum::DplExtrasAlgorithm_sort}}),
428278

429279
// thrust::inner_product
430280
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::DplExtrasNumeric_inner_product,

clang/test/dpct/test_api_level/DplExtrasAlgorithm/api_test10.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// RUN: FileCheck --input-file %T/DplExtrasAlgorithm/api_test10_out/count.txt --match-full-lines %s
66
// RUN: rm -rf %T/DplExtrasAlgorithm/api_test10_out
77

8-
// CHECK: 5
8+
// CHECK: 38
99
// TEST_FEATURE: DplExtrasAlgorithm_sort
1010

1111
#include <thrust/sort.h>

0 commit comments

Comments
 (0)