From 621cbff42e99d9b2ba0f4ad85272714fc280de2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?T=C3=B6r=C3=B6k=20Edwin?= Date: Mon, 25 Mar 2024 16:32:39 +0000 Subject: [PATCH] Avoid calling `log(0)` when generating gaussian random variables (#662) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * owl_stats_ziggurat: avoid log(0) sfmt_f64_1 is documented to include 0, which would result in `log(0) = neg_infinity`. use sfmt_f64_3 instead which is documented to return `(0, 1)` instead. This should also match the paper, which says "UNI floats it to (0,1)": > Marsaglia, George, and Wai Wan Tsang. > "The ziggurat method for generating random variables." > Journal of statistical software 5 (2000): 1-7. See https://github.com/owlbarn/owl/issues/661 Signed-off-by: Edwin Török * Owl_base_stats_dist_uniform: add a function that returns a random float in (0,1) Signed-off-by: Edwin Török * Owl_base_stats_dist_gaussian.{std_gaussian_rvs,gaussian_rvs}: avoid infinity on Random.float returning 0 https://github.com/owlbarn/owl/issues/661 Signed-off-by: Edwin Török --------- Signed-off-by: Edwin Török --- src/base/stats/owl_base_stats_dist_gaussian.ml | 10 +++++----- src/base/stats/owl_base_stats_dist_uniform.ml | 5 +++++ src/owl/stats/owl_stats_ziggurat.c | 10 +++++----- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/base/stats/owl_base_stats_dist_gaussian.ml b/src/base/stats/owl_base_stats_dist_gaussian.ml index 75874d338..f447d4dfd 100644 --- a/src/base/stats/owl_base_stats_dist_gaussian.ml +++ b/src/base/stats/owl_base_stats_dist_gaussian.ml @@ -2,7 +2,7 @@ * OWL - OCaml Scientific Computing * Copyright (c) 2016-2022 Liang Wang *) - +open Owl_base_stats_dist_uniform let _u1 = ref 0. let _u2 = ref 0. @@ -20,8 +20,8 @@ let std_gaussian_rvs () = !_z1) else ( _case := true; - _u1 := Random.float 1.; - _u2 := Random.float 1.; + _u1 := rand01_exclusive (); + _u2 := rand01_exclusive (); _z0 := sqrt (~-.2. *. log !_u1) *. cos (2. *. Owl_const.pi *. !_u2); _z1 := sqrt (~-.2. *. log !_u1) *. sin (2. *. Owl_const.pi *. !_u2); !_z0) @@ -35,8 +35,8 @@ let gaussian_rvs ~mu ~sigma = mu +. (sigma *. !_z1)) else ( _case := true; - _u1 := Random.float 1.; - _u2 := Random.float 1.; + _u1 := rand01_exclusive (); + _u2 := rand01_exclusive (); _z0 := sqrt (~-.2. *. log !_u1) *. cos (2. *. Owl_const.pi *. !_u2); _z1 := sqrt (~-.2. *. log !_u1) *. sin (2. *. Owl_const.pi *. !_u2); mu +. (sigma *. !_z0)) diff --git a/src/base/stats/owl_base_stats_dist_uniform.ml b/src/base/stats/owl_base_stats_dist_uniform.ml index 2c55e15a6..5137534f3 100644 --- a/src/base/stats/owl_base_stats_dist_uniform.ml +++ b/src/base/stats/owl_base_stats_dist_uniform.ml @@ -9,3 +9,8 @@ let uniform_int_rvs n = Random.int n let std_uniform_rvs () = Random.float 1. let uniform_rvs ~a ~b = a +. ((b -. a) *. Random.float 1.) + +(* The constants below are Printf.printf "%h,%h" (Float.succ 0.) (Float.pred 1.) + Also [Float.succ 0. +. Float.pred 1. < 1.] + *) +let rand01_exclusive () = 0x0.0000000000001p-1022 +. Random.float 0x1.fffffffffffffp-1 diff --git a/src/owl/stats/owl_stats_ziggurat.c b/src/owl/stats/owl_stats_ziggurat.c index 259698979..23cd38d9e 100644 --- a/src/owl/stats/owl_stats_ziggurat.c +++ b/src/owl/stats/owl_stats_ziggurat.c @@ -22,13 +22,13 @@ inline double std_exponential_rvs ( ) { else { for ( ; ; ) { if ( iz == 0 ) { - value = 7.69711 - log ( sfmt_f64_1 ); + value = 7.69711 - log ( sfmt_f64_3 ); break; } x = jz * we[iz]; - if ( fe[iz] + sfmt_f64_1 * ( fe[iz-1] - fe[iz] ) < exp ( - x ) ) { + if ( fe[iz] + sfmt_f64_3 * ( fe[iz-1] - fe[iz] ) < exp ( - x ) ) { value = x; break; } @@ -92,8 +92,8 @@ inline double std_gaussian_rvs ( ) { for ( ; ; ) { if ( iz == 0 ) { for ( ; ; ) { - x = - 0.2904764 * log ( sfmt_f64_1 ); - y = - log ( sfmt_f64_1 ); + x = - 0.2904764 * log ( sfmt_f64_3 ); + y = - log ( sfmt_f64_3 ); if ( x * x <= y + y ) break; } @@ -103,7 +103,7 @@ inline double std_gaussian_rvs ( ) { x = hz * wn[iz]; - if ( fn[iz] + ( sfmt_f64_1 ) * ( fn[iz-1] - fn[iz] ) < exp ( - 0.5 * x * x ) ) { + if ( fn[iz] + ( sfmt_f64_3 ) * ( fn[iz-1] - fn[iz] ) < exp ( - 0.5 * x * x ) ) { value = x; break; }