Skip to content

Commit d6d6bf4

Browse files
authored
Merge pull request #646 from patrick-nicodemus/ssqr_diff_fix
Changed def of ssqr_diff' to not modify inputs. Added two tests.
2 parents 16d9bd5 + c4fb941 commit d6d6bf4

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

src/owl/core/owl_ndarray_maths_stub.c

+8-8
Original file line numberDiff line numberDiff line change
@@ -4992,34 +4992,34 @@
49924992
// ssqr_diff
49934993

49944994
#define FUN11 float32_ssqr_diff
4995-
#define INIT float r = 0.
4995+
#define INIT float r = 0. ; float diff
49964996
#define NUMBER float
49974997
#define NUMBER1 float
4998-
#define ACCFN(A,X,Y) X -= Y; X *= X; A += X
4998+
#define ACCFN(A,X,Y) diff=X-Y; diff*=diff; A+=diff
49994999
#define COPYNUM(A) (caml_copy_double(A))
50005000
#include OWL_NDARRAY_MATHS_FOLD
50015001

50025002
#define FUN11 float64_ssqr_diff
5003-
#define INIT double r = 0.
5003+
#define INIT double r = 0. ; double diff
50045004
#define NUMBER double
50055005
#define NUMBER1 double
5006-
#define ACCFN(A,X,Y) X -= Y; X *= X; A += X
5006+
#define ACCFN(A,X,Y) diff=X-Y; diff*=diff; A+=diff
50075007
#define COPYNUM(A) (caml_copy_double(A))
50085008
#include OWL_NDARRAY_MATHS_FOLD
50095009

50105010
#define FUN11 complex32_ssqr_diff
5011-
#define INIT complex_float r = { 0.0, 0.0 }
5011+
#define INIT complex_float r = { 0.0, 0.0 }; complex_float diff
50125012
#define NUMBER complex_float
50135013
#define NUMBER1 complex_float
5014-
#define ACCFN(A,X,Y) X.r -= Y.r; X.i -= Y.i; A.r += (X.r - X.i) * (X.r + X.i); A.i += 2 * A.r * A.i
5014+
#define ACCFN(A,X,Y) diff.r = X.r - Y.r; diff.i = X.i - Y.i; A.r += (diff.r - diff.i) * (diff.r + diff.i); A.i += 2 * A.r * A.i
50155015
#define COPYNUM(A) (cp_two_doubles(A.r, A.i))
50165016
#include OWL_NDARRAY_MATHS_FOLD
50175017

50185018
#define FUN11 complex64_ssqr_diff
5019-
#define INIT complex_double r = { 0.0, 0.0 }
5019+
#define INIT complex_double r = { 0.0, 0.0 }; complex_double diff
50205020
#define NUMBER complex_double
50215021
#define NUMBER1 complex_double
5022-
#define ACCFN(A,X,Y) X.r -= Y.r; X.i -= Y.i; A.r += (X.r - X.i) * (X.r + X.i); A.i += 2 * A.r * A.i
5022+
#define ACCFN(A,X,Y) diff.r = X.r - Y.r; diff.i = X.i - Y.i; A.r += (diff.r - diff.i) * (diff.r + diff.i); A.i += 2 * A.r * A.i
50235023
#define COPYNUM(A) (cp_two_doubles(A.r, A.i))
50245024
#include OWL_NDARRAY_MATHS_FOLD
50255025

test/unit_dense_ndarray.ml

+23-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,24 @@ module To_test = struct
132132
let sum_reduce () =
133133
M.sum_reduce ~axis:[| 0; 2 |] x4 = M.of_array Float64 [| 8.; 8.; 8. |] [| 1; 3; 1 |]
134134

135-
135+
let ssqr_diff32 () =
136+
let a = M.of_array Float32 [| 3.; 4.; 5.; |] [| 1; 3 |] in
137+
let a' = M.copy a in
138+
let b = M.of_array Float32 [| 1.; 2.; 3.; |] [| 1; 3 |] in
139+
let b' = M.copy b in
140+
let ssqrdiff = M.ssqr_diff' a b in
141+
ssqrdiff = 12. && a = a' && b = b'
142+
143+
let ssqr_diff64 () =
144+
let a = M.of_array Float64 [| 3.; 4.; 5.; |] [| 1; 3 |] in
145+
let a' = M.copy a in
146+
let b = M.of_array Float64 [| 1.; 2.; 3.; |] [| 1; 3 |] in
147+
let b' = M.copy b in
148+
let ssqrdiff = M.ssqr_diff' a b in
149+
ssqrdiff = 12. && a = a' && b = b'
150+
151+
152+
136153
let min' () = M.min' x0 = 0.
137154

138155
let max' () = M.max' x0 = 3.
@@ -530,6 +547,10 @@ let sort1 () = Alcotest.(check bool) "sort1" true (To_test.sort1 ())
530547

531548
let sum_reduce () = Alcotest.(check bool) "sum_reduce" true (To_test.sum_reduce ())
532549

550+
let ssqr_diff32 () = Alcotest.(check bool) "ssqr_diff32" true (To_test.ssqr_diff32 ())
551+
552+
let ssqr_diff64 () = Alcotest.(check bool) "ssqr_diff64" true (To_test.ssqr_diff64 ())
553+
533554
let min' () = Alcotest.(check bool) "min'" true (To_test.min' ())
534555

535556
let max' () = Alcotest.(check bool) "max'" true (To_test.max' ())
@@ -674,6 +695,7 @@ let test_set =
674695
; "mul", `Slow, mul; "add_scalar", `Slow, add_scalar; "mul_scalar", `Slow, mul_scalar
675696
; "abs", `Slow, abs; "neg", `Slow, neg; "sum'", `Slow, sum'; "median'", `Slow, median'
676697
; "median", `Slow, median; "sort1", `Slow, sort1; "sum_reduce", `Slow, sum_reduce
698+
; "ssqr_diff32", `Slow, ssqr_diff32 ; "ssqr_diff64", `Slow, ssqr_diff64
677699
; "min'", `Slow, min'; "max'", `Slow, max'; "minmax_i", `Slow, minmax_i
678700
; "init_nd", `Slow, init_nd; "is_zero", `Slow, is_zero
679701
; "is_positive", `Slow, is_positive; "is_negative", `Slow, is_negative

0 commit comments

Comments
 (0)