Skip to content

Commit

Permalink
Merge pull request #646 from patrick-nicodemus/ssqr_diff_fix
Browse files Browse the repository at this point in the history
Changed def of ssqr_diff' to not modify inputs. Added two tests.
  • Loading branch information
jzstark committed Mar 7, 2024
2 parents 16d9bd5 + c4fb941 commit d6d6bf4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
16 changes: 8 additions & 8 deletions src/owl/core/owl_ndarray_maths_stub.c
Original file line number Diff line number Diff line change
Expand Up @@ -4992,34 +4992,34 @@
// ssqr_diff

#define FUN11 float32_ssqr_diff
#define INIT float r = 0.
#define INIT float r = 0. ; float diff
#define NUMBER float
#define NUMBER1 float
#define ACCFN(A,X,Y) X -= Y; X *= X; A += X
#define ACCFN(A,X,Y) diff=X-Y; diff*=diff; A+=diff
#define COPYNUM(A) (caml_copy_double(A))
#include OWL_NDARRAY_MATHS_FOLD

#define FUN11 float64_ssqr_diff
#define INIT double r = 0.
#define INIT double r = 0. ; double diff
#define NUMBER double
#define NUMBER1 double
#define ACCFN(A,X,Y) X -= Y; X *= X; A += X
#define ACCFN(A,X,Y) diff=X-Y; diff*=diff; A+=diff
#define COPYNUM(A) (caml_copy_double(A))
#include OWL_NDARRAY_MATHS_FOLD

#define FUN11 complex32_ssqr_diff
#define INIT complex_float r = { 0.0, 0.0 }
#define INIT complex_float r = { 0.0, 0.0 }; complex_float diff
#define NUMBER complex_float
#define NUMBER1 complex_float
#define ACCFN(A,X,Y) X.r -= Y.r; X.i -= Y.i; A.r += (X.r - X.i) * (X.r + X.i); A.i += 2 * A.r * A.i
#define ACCFN(A,X,Y) diff.r = X.r - Y.r; diff.i = X.i - Y.i; A.r += (diff.r - diff.i) * (diff.r + diff.i); A.i += 2 * A.r * A.i
#define COPYNUM(A) (cp_two_doubles(A.r, A.i))
#include OWL_NDARRAY_MATHS_FOLD

#define FUN11 complex64_ssqr_diff
#define INIT complex_double r = { 0.0, 0.0 }
#define INIT complex_double r = { 0.0, 0.0 }; complex_double diff
#define NUMBER complex_double
#define NUMBER1 complex_double
#define ACCFN(A,X,Y) X.r -= Y.r; X.i -= Y.i; A.r += (X.r - X.i) * (X.r + X.i); A.i += 2 * A.r * A.i
#define ACCFN(A,X,Y) diff.r = X.r - Y.r; diff.i = X.i - Y.i; A.r += (diff.r - diff.i) * (diff.r + diff.i); A.i += 2 * A.r * A.i
#define COPYNUM(A) (cp_two_doubles(A.r, A.i))
#include OWL_NDARRAY_MATHS_FOLD

Expand Down
24 changes: 23 additions & 1 deletion test/unit_dense_ndarray.ml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,24 @@ module To_test = struct
let sum_reduce () =
M.sum_reduce ~axis:[| 0; 2 |] x4 = M.of_array Float64 [| 8.; 8.; 8. |] [| 1; 3; 1 |]


let ssqr_diff32 () =
let a = M.of_array Float32 [| 3.; 4.; 5.; |] [| 1; 3 |] in
let a' = M.copy a in
let b = M.of_array Float32 [| 1.; 2.; 3.; |] [| 1; 3 |] in
let b' = M.copy b in
let ssqrdiff = M.ssqr_diff' a b in
ssqrdiff = 12. && a = a' && b = b'

let ssqr_diff64 () =
let a = M.of_array Float64 [| 3.; 4.; 5.; |] [| 1; 3 |] in
let a' = M.copy a in
let b = M.of_array Float64 [| 1.; 2.; 3.; |] [| 1; 3 |] in
let b' = M.copy b in
let ssqrdiff = M.ssqr_diff' a b in
ssqrdiff = 12. && a = a' && b = b'



let min' () = M.min' x0 = 0.

let max' () = M.max' x0 = 3.
Expand Down Expand Up @@ -530,6 +547,10 @@ let sort1 () = Alcotest.(check bool) "sort1" true (To_test.sort1 ())

let sum_reduce () = Alcotest.(check bool) "sum_reduce" true (To_test.sum_reduce ())

let ssqr_diff32 () = Alcotest.(check bool) "ssqr_diff32" true (To_test.ssqr_diff32 ())

let ssqr_diff64 () = Alcotest.(check bool) "ssqr_diff64" true (To_test.ssqr_diff64 ())

let min' () = Alcotest.(check bool) "min'" true (To_test.min' ())

let max' () = Alcotest.(check bool) "max'" true (To_test.max' ())
Expand Down Expand Up @@ -674,6 +695,7 @@ let test_set =
; "mul", `Slow, mul; "add_scalar", `Slow, add_scalar; "mul_scalar", `Slow, mul_scalar
; "abs", `Slow, abs; "neg", `Slow, neg; "sum'", `Slow, sum'; "median'", `Slow, median'
; "median", `Slow, median; "sort1", `Slow, sort1; "sum_reduce", `Slow, sum_reduce
; "ssqr_diff32", `Slow, ssqr_diff32 ; "ssqr_diff64", `Slow, ssqr_diff64
; "min'", `Slow, min'; "max'", `Slow, max'; "minmax_i", `Slow, minmax_i
; "init_nd", `Slow, init_nd; "is_zero", `Slow, is_zero
; "is_positive", `Slow, is_positive; "is_negative", `Slow, is_negative
Expand Down

0 comments on commit d6d6bf4

Please sign in to comment.