From 55f8284e27f970aaddcec6e2b82420c24adab7dc Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 16 May 2024 15:39:31 +0000 Subject: [PATCH] Added correction arg in jnp.var and jnp.std Description: - Added correction arg in jnp.var and jnp.std - Addresses https://github.com/google/jax/issues/21088 - Updated signatures in init.pyi - Updated tests --- jax/_src/numpy/reductions.py | 24 ++++++++++++------- .../array_api/_statistical_functions.py | 7 +++--- jax/numpy/__init__.pyi | 4 ++-- tests/lax_numpy_reducers_test.py | 23 +++++++++++++----- tests/lax_numpy_test.py | 4 ++-- 5 files changed, 40 insertions(+), 22 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index a6f4e1f7b8e8..84fae48ac2cf 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -433,13 +433,17 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, @implements(np.var, skip_params=['out']) def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: ArrayLike | None = None) -> Array: - return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: + if correction is None: + correction = ddof + elif not isinstance(ddof, int) or ddof != 0: + raise ValueError("ddof and correction can't be provided simultaneously.") + return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, - out: None = None, ddof: int = 0, keepdims: bool = False, *, + out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("var", a) dtypes.check_user_dtype_supported(dtype, "var") @@ -465,7 +469,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=computation_dtype, keepdims=keepdims) - normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype)) + normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype)) result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where) return lax.div(result, normalizer).astype(dtype) @@ -494,13 +498,17 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy @implements(np.std, skip_params=['out']) def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: ArrayLike | None = None) -> Array: - return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: + if correction is None: + correction = ddof + elif not isinstance(ddof, int) or ddof != 0: + raise ValueError("ddof and correction can't be provided simultaneously.") + return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, - out: None = None, ddof: int = 0, keepdims: bool = False, *, + out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("std", a) dtypes.check_user_dtype_supported(dtype, "std") @@ -508,7 +516,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}") if out is not None: raise NotImplementedError("The 'out' argument to jnp.std is not supported.") - return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where)) + return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where)) @implements(np.ptp, skip_params=['out']) diff --git a/jax/experimental/array_api/_statistical_functions.py b/jax/experimental/array_api/_statistical_functions.py index c34fb1fc3af4..8ee6a39198ee 100644 --- a/jax/experimental/array_api/_statistical_functions.py +++ b/jax/experimental/array_api/_statistical_functions.py @@ -14,13 +14,12 @@ import jax -# TODO(micky774): Remove after deprecating ddof-->correction in jnp.std and -# jnp.var + def std(x, /, *, axis=None, correction=0.0, keepdims=False): """Calculates the standard deviation of the input array x.""" - return jax.numpy.std(x, axis=axis, ddof=correction, keepdims=keepdims) + return jax.numpy.std(x, axis=axis, correction=correction, keepdims=keepdims) def var(x, /, *, axis=None, correction=0.0, keepdims=False): """Calculates the variance of the input array x.""" - return jax.numpy.var(x, axis=axis, ddof=correction, keepdims=keepdims) + return jax.numpy.var(x, axis=axis, correction=correction, keepdims=keepdims) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 36a6347d35ba..23a21021e5ef 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -783,7 +783,7 @@ def stack( ) -> Array: ... def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ...) -> Array: ... + where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ... def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ... def sum( a: ArrayLike, @@ -894,7 +894,7 @@ def vander( ) -> Array: ... def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ...) -> Array: ... + where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ... def vdot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 861b9014c589..975393767af9 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -540,31 +540,42 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims): rtol=tol, atol=tol) @jtu.sample_product( + test_fns=[(np.var, jnp.var), (np.std, jnp.std)], shape=[(5,), (10, 5)], dtype=all_dtypes, out_dtype=inexact_dtypes, axis=[None, 0, -1], - ddof=[0, 1, 2], + ddof_correction=[(0, None), (1, None), (1, 0), (0, 0), (0, 1), (0, 2)], keepdims=[False, True], ) - def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): + def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, ddof_correction, keepdims): + np_fn, jnp_fn = test_fns + ddof, correction = ddof_correction rng = jtu.rand_default(self.rng()) args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") @jtu.ignore_warning(category=NumpyComplexWarning) def np_fun(x): + # setup ddof and correction kwargs excluding case when correction is not specified + ddof_correction_kwargs = {"ddof": ddof} + if correction is not None: + key = "correction" if numpy_version >= (2, 0) else "ddof" + ddof_correction_kwargs[key] = correction # Numpy fails with bfloat16 inputs - out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), + out = np_fn(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype, - axis=axis, ddof=ddof, keepdims=keepdims) + axis=axis, keepdims=keepdims, **ddof_correction_kwargs) return out.astype(out_dtype) - jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) + jnp_fun = partial(jnp_fn, dtype=out_dtype, axis=axis, ddof=ddof, correction=correction, + keepdims=keepdims) tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-3, np.complex128: 1e-6}) if (jnp.issubdtype(dtype, jnp.complexfloating) and not jnp.issubdtype(out_dtype, jnp.complexfloating)): - self.assertRaises(ValueError, lambda: jnp_fun(*args_maker())) + self.assertRaises(ValueError, jnp_fun, *args_maker()) + elif (correction is not None and ddof != 0): + self.assertRaises(ValueError, jnp_fun, *args_maker()) else: self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 0a441d2e513a..f46933c169c9 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5960,9 +5960,9 @@ def testWrappedSignaturesMatch(self): 'reshape': ['shape', 'copy'], 'row_stack': ['casting'], 'stack': ['casting'], - 'std': ['correction', 'mean'], + 'std': ['mean'], 'tri': ['like'], - 'var': ['correction', 'mean'], + 'var': ['mean'], 'vstack': ['casting'], 'zeros_like': ['subok', 'order'] }