Skip to content

Commit

Permalink
Begin deprecation of implicit input conversion in FFT module
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Apr 18, 2024
1 parent 7cb0e60 commit f2de13b
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 22 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ Remember to align the itemized text with the first line of an item within a list
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
* The {mod}`jax.numpy.fft` module now issues a deprecation warning when
passing inputs that would require implicit conversion (e.g. `jnp.float32`
with `fft`, `jnp.int32` with `rftt`). Either manually onvert the inputs
to preserve old behavior, or use a more appropriate `fft` function.

## jaxlib 0.4.27

Expand Down
54 changes: 52 additions & 2 deletions jax/_src/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,65 @@
from collections.abc import Sequence
import operator
import numpy as np
import warnings

from jax import dtypes
from jax import lax
from jax._src import dtypes
from jax._src.lib import xla_client
from jax._src.util import safe_zip
from jax._src.numpy.util import check_arraylike, implements, promote_dtypes_inexact
from jax._src.numpy.util import (
check_arraylike, implements,
promote_dtypes_inexact, promote_dtypes_complex)
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs, reductions
from jax._src.typing import Array, ArrayLike

Shape = Sequence[int]

NEEDS_COMPLEX_IN = {'fft', 'fftn', 'hfft', 'ifft', 'ifftn', 'irfft', 'irfftn'}
NEEDS_REAL_IN = {
# These are already handled in lax.fft
# 'rfft', 'rfftn', 'ihfft',
'fftshift', 'ifftshift'
}

# TODO(micky774): Promote warnings to ValueErrors when deprecation is completed
# and uncomment the portion of NEEDS_REAL_IN which currently defers type
# checking to lax.fft. Deprecation began 4-18-24.
def _check_input_fft(func_name: str, x: Array):
kind = x.dtype.kind
suggest_alternative_msg = (
" or consider using a more appropriate fft function if applicable."
)
if func_name in NEEDS_COMPLEX_IN and kind != "c":
warnings.warn(
f"Passing non-complex valued inputs to {func_name} is deprecated and "
"will soon raise a ValueError. Please explicitly convert the input to a "
f"complex dtype before passing to {func_name} in order to suppress this "
"warning," + suggest_alternative_msg,
DeprecationWarning, stacklevel=2
)
return promote_dtypes_complex(x)[0]
if func_name in NEEDS_REAL_IN:
if kind == "c":
warnings.warn(
f"Passing complex-valued inputs to {func_name} is deprecated and "
"will soon raise a ValueError. To suppress this warning, please convert "
"to real values first, such as by using jnp.real or jnp.imag to take "
"the real or imaginary components respectively," + suggest_alternative_msg,
DeprecationWarning, stacklevel=2
)
elif kind != "f":
warnings.warn(
f"Passing integral inputs to {func_name} is deprecated and "
"will soon raise a ValueError. Please convert to a real-valued "
"floating-point input first.",
DeprecationWarning, stacklevel=2
)
return promote_dtypes_inexact(x)
return x


def _fft_norm(s: Array, func_name: str, norm: str) -> Array:
if norm == "backward":
return jnp.array(1)
Expand All @@ -50,6 +97,7 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
full_name = f"jax.numpy.fft.{func_name}"
check_arraylike(full_name, a)
arr = jnp.asarray(a)
arr = _check_input_fft(func_name, arr)

if s is not None:
s = tuple(map(operator.index, s))
Expand Down Expand Up @@ -300,6 +348,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
check_arraylike("fftshift", x)
x = jnp.asarray(x)
arr = _check_input_fft("fftshift", x)
shift: int | Sequence[int]
if axes is None:
axes = tuple(range(x.ndim))
Expand All @@ -316,6 +365,7 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
check_arraylike("ifftshift", x)
x = jnp.asarray(x)
arr = _check_input_fft("ifftshift", x)
shift: int | Sequence[int]
if axes is None:
axes = tuple(range(x.ndim))
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/scipy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def dct(x: Array, type: int = 2, n: int | None = None,
for a in range(x.ndim)])

N = x.shape[axis]
v = _dct_interleave(x, axis)
v, = promote_dtypes_complex(_dct_interleave(x, axis))
V = jnp.fft.fft(v, axis=axis)
k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis])
out = V * _W4(N, k)
Expand All @@ -68,7 +68,7 @@ def dct(x: Array, type: int = 2, n: int | None = None,
def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes)
N1, N2 = x.shape[axis1], x.shape[axis2]
v = _dct_interleave(_dct_interleave(x, axis1), axis2)
v, = promote_dtypes_complex(_dct_interleave(_dct_interleave(x, axis1), axis2))
V = jnp.fft.fftn(v, axes=axes)
k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype),
[a for a in range(x.ndim) if a != axis1])
Expand Down
47 changes: 46 additions & 1 deletion jax/experimental/array_api/_fft_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,89 @@
# limitations under the License.

import jax.numpy as jnp

from jax._src.numpy.fft import NEEDS_COMPLEX_IN, NEEDS_REAL_IN as _NEEDS_REAL_IN
from jax._src.numpy.util import check_arraylike

NEEDS_REAL_IN = _NEEDS_REAL_IN.union({'rfft', 'rfftn', 'ihfft'})

# TODO(micky774): Remove when jax.numpy.fft deprecation completes. Deprecation
# began 4-18-24.
def _check_input_fft(func_name: str, x):
check_arraylike('jax.experimental.array_api.' + func_name, x)
arr = jnp.asarray(x)
kind = arr.dtype.kind
suggest_alternative_msg = (
" or consider using a more appropriate fft function if applicable."
)
if func_name in NEEDS_COMPLEX_IN and kind != "c":
raise ValueError(
f"{func_name} requires complex-valued input, but received input with type "
f"{arr.dtype} instead. Please explicitly convert to a complex-valued input "
"first," + suggest_alternative_msg,
)
if func_name in NEEDS_REAL_IN:
needs_real_msg = (
f"{func_name} requires real-valued floating-point input, but received "
f"input with type {arr.dtype} instead. Please convert to a real-valued "
"floating-point input first"
)
if kind == "c":
raise ValueError(
needs_real_msg + ", such as by using jnp.real or jnp.imag to take the "
"real or imaginary components respectively," + suggest_alternative_msg,
)
elif kind != "f":
raise ValueError(needs_real_msg + '.')
return arr

def fft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional discrete Fourier transform."""
_check_input_fft('fft', x)
return jnp.fft.fft(x, n=n, axis=axis, norm=norm)

def ifft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional inverse discrete Fourier transform."""
_check_input_fft('ifft', x)
return jnp.fft.ifft(x, n=n, axis=axis, norm=norm)

def fftn(x, /, *, s=None, axes=None, norm='backward'):
"""Computes the n-dimensional discrete Fourier transform."""
_check_input_fft('fftn', x)
return jnp.fft.fftn(x, s=s, axes=axes, norm=norm)

def ifftn(x, /, *, s=None, axes=None, norm='backward'):
"""Computes the n-dimensional inverse discrete Fourier transform."""
_check_input_fft('ifftn', x)
return jnp.fft.ifftn(x, s=s, axes=axes, norm=norm)

def rfft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional discrete Fourier transform for real-valued input."""
_check_input_fft('rfft', x)
return jnp.fft.rfft(x, n=n, axis=axis, norm=norm)

def irfft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional inverse of rfft for complex-valued input."""
_check_input_fft('irfft', x)
return jnp.fft.irfft(x, n=n, axis=axis, norm=norm)

def rfftn(x, /, *, s=None, axes=None, norm='backward'):
"""Computes the n-dimensional discrete Fourier transform for real-valued input."""
_check_input_fft('rfftn', x)
return jnp.fft.rfftn(x, s=s, axes=axes, norm=norm)

def irfftn(x, /, *, s=None, axes=None, norm='backward'):
"""Computes the n-dimensional inverse of rfftn for complex-valued input."""
_check_input_fft('irfftn', x)
return jnp.fft.irfftn(x, s=s, axes=axes, norm=norm)

def hfft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional discrete Fourier transform of a signal with Hermitian symmetry."""
_check_input_fft('hfft', x)
return jnp.fft.hfft(x, n=n, axis=axis, norm=norm)

def ihfft(x, /, *, n=None, axis=-1, norm='backward'):
"""Computes the one-dimensional inverse discrete Fourier transform of a signal with Hermitian symmetry."""
_check_input_fft('ihfft', x)
return jnp.fft.ihfft(x, n=n, axis=axis, norm=norm)

def fftfreq(n, /, *, d=1.0, device=None):
Expand All @@ -65,8 +108,10 @@ def rfftfreq(n, /, *, d=1.0, device=None):

def fftshift(x, /, *, axes=None):
"""Shift the zero-frequency component to the center of the spectrum."""
_check_input_fft('fftshift', x)
return jnp.fft.fftshift(x, axes=axes)

def ifftshift(x, /, *, axes=None):
"""Inverse of fftshift."""
_check_input_fft('ifftshift', x)
return jnp.fft.ifftshift(x, axes=axes)
30 changes: 28 additions & 2 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from jax._src import config, test_util as jtu
from jax._src.dtypes import _default_types, canonicalize_dtype
from jax.experimental import array_api

from jax.experimental.array_api._fft_functions import (
NEEDS_COMPLEX_IN, NEEDS_REAL_IN)
config.parse_flags_with_absl()

MAIN_NAMESPACE = {
Expand Down Expand Up @@ -326,7 +327,7 @@ def test_dtypes_info(self, kind):
target_dict = control[kind]
assert info_dict == target_dict

class ArrayAPIErrors(absltest.TestCase):
class ArrayAPIErrors(jtu.JaxTestCase):
"""Test that our array API implementations raise errors where required"""

# TODO(micky774): Remove when jnp.clip deprecation is completed
Expand All @@ -347,6 +348,31 @@ def test_clip_complex(self):
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, max=-1+5j)

@jtu.sample_product(
[dict(dtype=dtype,func_name=func_name)
for real in [True, False]
for dtype in (jtu.dtypes.complex if real else jtu.dtypes.floating)
+ jtu.dtypes.integer + jtu.dtypes.boolean
for func_name in (NEEDS_REAL_IN if real else NEEDS_COMPLEX_IN)
])
def testFftWarnings(self, dtype, func_name):
shape = (2, 3, 4)
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
func = getattr(array_api.fft, func_name)

if func_name in NEEDS_COMPLEX_IN:
msg = "complex-valued input"
else:
msg = "real-valued"
if x.dtype.kind == 'c':
msg += ".*real or imaginary"
if x.dtype.kind in {'c', 'r'}:
msg += ".*or consider using a more"

with self.assertRaisesRegex(ValueError, expected_regex=msg):
func(x)


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit f2de13b

Please sign in to comment.