Skip to content

Commit 0d3292e

Browse files
authored
Merge pull request #546 from abergeron/test_elem_scalar
Test for the fix in #545
2 parents c228dc9 + c08f33c commit 0d3292e

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

tests/check_elemwise.c

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,26 @@
66
#include "gpuarray/error.h"
77
#include "gpuarray/types.h"
88

9+
#if CHECK_MINOR_VERSION < 11
10+
11+
#ifndef CK_FLOATING_DIG
12+
# define CK_FLOATING_DIG 6
13+
#endif /* CK_FLOATING_DIG */
14+
15+
#define _ck_assert_floating(X, OP, Y, TP, TM) do { \
16+
TP _ck_x = (X); \
17+
TP _ck_y = (Y); \
18+
ck_assert_msg(_ck_x OP _ck_y, \
19+
"Assertion '%s' failed: %s == %.*"TM"g, %s == %.*"TM"g", \
20+
#X" "#OP" "#Y, \
21+
#X, (int)CK_FLOATING_DIG, _ck_x, \
22+
#Y, (int)CK_FLOATING_DIG, _ck_y); \
23+
} while (0)
24+
25+
#define ck_assert_float_eq(X, Y) _ck_assert_floating(X, ==, Y, float, "")
26+
#endif
27+
28+
929
extern void *ctx;
1030

1131
void setup(void);
@@ -434,6 +454,60 @@ START_TEST(test_basic_scalar) {
434454
}
435455
END_TEST
436456

457+
START_TEST(test_basic_scalar_dtype) {
458+
GpuArray x;
459+
GpuArray y;
460+
float a = 1.1f;
461+
462+
GpuElemwise *ge;
463+
464+
static const int32_t data1[4] = {0, 1, 2, 3};
465+
static const float data2[4] = {2.0, 2.0, 2.0, 2.0};
466+
float data3[4];
467+
468+
size_t dims[2] = {2, 2};
469+
470+
gpuelemwise_arg args[3] = {{0}};
471+
void *rargs[3];
472+
473+
ga_assert_ok(GpuArray_empty(&x, ctx, GA_INT, 2, dims, GA_C_ORDER));
474+
ga_assert_ok(GpuArray_write(&x, data1, sizeof(data1)));
475+
476+
ga_assert_ok(GpuArray_empty(&y, ctx, GA_FLOAT, 2, dims, GA_F_ORDER));
477+
ga_assert_ok(GpuArray_write(&y, data2, sizeof(data2)));
478+
479+
args[0].name = "a";
480+
args[0].typecode = GA_FLOAT;
481+
args[0].flags = GE_SCALAR;
482+
483+
args[1].name = "x";
484+
args[1].typecode = GA_INT;
485+
args[1].flags = GE_READ;
486+
487+
args[2].name = "y";
488+
args[2].typecode = GA_FLOAT;
489+
args[2].flags = GE_READ|GE_WRITE;
490+
491+
ge = GpuElemwise_new(ctx, "", "y = a * x + y", 3, args, 2, 0);
492+
493+
ck_assert_ptr_ne(ge, NULL);
494+
495+
rargs[0] = &a;
496+
rargs[1] = &x;
497+
rargs[2] = &y;
498+
499+
ga_assert_ok(GpuElemwise_call(ge, rargs, 0));
500+
501+
ga_assert_ok(GpuArray_read(data3, sizeof(data3), &y));
502+
503+
ck_assert_float_eq(data3[0], 2.0f);
504+
ck_assert_float_eq(data3[1], 4.2f);
505+
506+
ck_assert_float_eq(data3[2], 3.1f);
507+
ck_assert_float_eq(data3[3], 5.3f);
508+
}
509+
END_TEST
510+
437511
START_TEST(test_basic_remove1) {
438512
GpuArray a;
439513
GpuArray b;
@@ -820,6 +894,7 @@ Suite *get_suite(void) {
820894
tcase_add_test(tc, test_basic_simple);
821895
tcase_add_test(tc, test_basic_f16);
822896
tcase_add_test(tc, test_basic_scalar);
897+
tcase_add_test(tc, test_basic_scalar_dtype);
823898
tcase_add_test(tc, test_basic_offset);
824899
tcase_add_test(tc, test_basic_remove1);
825900
tcase_add_test(tc, test_basic_broadcast);

0 commit comments

Comments
 (0)