diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 03e468fa99a9..3da6c19745ff 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -20,12 +20,13 @@ from jax import dtypes from jax import lax -from jax._src.lib import xla_client +from jax.sharding import Sharding +from jax._src.lib import xla_client as xc from jax._src.util import safe_zip from jax._src.numpy.util import check_arraylike, implements, promote_dtypes_inexact from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import ufuncs, reductions -from jax._src.typing import Array, ArrayLike +from jax._src.typing import Array, ArrayLike, DTypeLike Shape = Sequence[int] @@ -44,7 +45,7 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array: '"ortho" or "forward".') -def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, +def _fft_core(func_name: str, fft_type: xc.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int] | None, norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" @@ -86,14 +87,14 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, in_s = list(arr.shape) for axis, x in safe_zip(axes, s): in_s[axis] = x - if fft_type == xla_client.FftType.IRFFT: + if fft_type == xc.FftType.IRFFT: in_s[-1] = (in_s[-1] // 2 + 1) # Cropping arr = arr[tuple(map(slice, in_s))] # Padding arr = jnp.pad(arr, [(0, x-y) for x, y in zip(in_s, arr.shape)]) else: - if fft_type == xla_client.FftType.IRFFT: + if fft_type == xc.FftType.IRFFT: s = [arr.shape[axis] for axis in axes[:-1]] if axes: s += [max(0, 2 * (arr.shape[axes[-1]] - 1))] @@ -113,28 +114,28 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, def fftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: - return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm) + return _fft_core('fftn', xc.FftType.FFT, a, s, axes, norm) @implements(np.fft.ifftn) def ifftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: - return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm) + return _fft_core('ifftn', xc.FftType.IFFT, a, s, axes, norm) @implements(np.fft.rfftn) def rfftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: - return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm) + return _fft_core('rfftn', xc.FftType.RFFT, a, s, axes, norm) @implements(np.fft.irfftn) def irfftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: - return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm) + return _fft_core('irfftn', xc.FftType.IRFFT, a, s, axes, norm) def _axis_check_1d(func_name: str, axis: int | None): @@ -145,7 +146,7 @@ def _axis_check_1d(func_name: str, axis: int | None): "Got axis = %r." % (full_name, full_name, axis) ) -def _fft_core_1d(func_name: str, fft_type: xla_client.FftType, +def _fft_core_1d(func_name: str, fft_type: xc.FftType, a: ArrayLike, n: int | None, axis: int | None, norm: str | None) -> Array: _axis_check_1d(func_name, axis) @@ -157,25 +158,25 @@ def _fft_core_1d(func_name: str, fft_type: xla_client.FftType, @implements(np.fft.fft) def fft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: - return _fft_core_1d('fft', xla_client.FftType.FFT, a, n=n, axis=axis, + return _fft_core_1d('fft', xc.FftType.FFT, a, n=n, axis=axis, norm=norm) @implements(np.fft.ifft) def ifft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: - return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, n=n, axis=axis, + return _fft_core_1d('ifft', xc.FftType.IFFT, a, n=n, axis=axis, norm=norm) @implements(np.fft.rfft) def rfft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: - return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, n=n, axis=axis, + return _fft_core_1d('rfft', xc.FftType.RFFT, a, n=n, axis=axis, norm=norm) @implements(np.fft.irfft) def irfft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: - return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, n=n, axis=axis, + return _fft_core_1d('irfft', xc.FftType.IRFFT, a, n=n, axis=axis, norm=norm) @implements(np.fft.hfft) @@ -184,7 +185,7 @@ def hfft(a: ArrayLike, n: int | None = None, conj_a = ufuncs.conj(a) _axis_check_1d('hfft', axis) nn = (conj_a.shape[axis] - 1) * 2 if n is None else n - return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, n=n, axis=axis, + return _fft_core_1d('hfft', xc.FftType.IRFFT, conj_a, n=n, axis=axis, norm=norm) * nn @implements(np.fft.ihfft) @@ -193,12 +194,12 @@ def ihfft(a: ArrayLike, n: int | None = None, _axis_check_1d('ihfft', axis) arr = jnp.asarray(a) nn = arr.shape[axis] if n is None else n - output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, arr, n=n, axis=axis, + output = _fft_core_1d('ihfft', xc.FftType.RFFT, arr, n=n, axis=axis, norm=norm) return ufuncs.conj(output) * (1 / nn) -def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, +def _fft_core_2d(func_name: str, fft_type: xc.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int], norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" @@ -213,34 +214,37 @@ def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, @implements(np.fft.fft2) def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: - return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes, + return _fft_core_2d('fft2', xc.FftType.FFT, a, s=s, axes=axes, norm=norm) @implements(np.fft.ifft2) def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: - return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes, + return _fft_core_2d('ifft2', xc.FftType.IFFT, a, s=s, axes=axes, norm=norm) @implements(np.fft.rfft2) def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: - return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes, + return _fft_core_2d('rfft2', xc.FftType.RFFT, a, s=s, axes=axes, norm=norm) @implements(np.fft.irfft2) def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: - return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes, + return _fft_core_2d('irfft2', xc.FftType.IRFFT, a, s=s, axes=axes, norm=norm) @implements(np.fft.fftfreq, extra_params=""" -dtype : Optional +dtype : dtype, optional The dtype of the returned frequencies. If not specified, JAX's default floating point dtype will be used. +device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. """) -def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array: +def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, + device: xc.Device | Sharding | None = None) -> Array: dtype = dtype or dtypes.canonicalize_dtype(jnp.float_) if isinstance(n, (list, tuple)): raise ValueError( @@ -252,13 +256,13 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array: "The d argument of jax.numpy.fft.fftfreq only takes a single value. " "Got d = %s." % list(d)) - k = jnp.zeros(n, dtype=dtype) + k = jnp.zeros(n, dtype=dtype, device=device) if n % 2 == 0: # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) - k = k.at[0: n // 2].set( jnp.arange(0, n // 2, dtype=dtype)) + k = k.at[0: n // 2].set(jnp.arange(0, n // 2, dtype=dtype)) # k[n // 2:] = jnp.arange(-n // 2, -1) - k = k.at[n // 2:].set( jnp.arange(-n // 2, 0, dtype=dtype)) + k = k.at[n // 2:].set(jnp.arange(-n // 2, 0, dtype=dtype)) else: # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) @@ -267,15 +271,18 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array: # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype)) - return k / jnp.array(d * n, dtype=dtype) + return k / jnp.array(d * n, dtype=dtype, device=device) @implements(np.fft.rfftfreq, extra_params=""" -dtype : Optional +dtype : dtype, optional The dtype of the returned frequencies. If not specified, JAX's default floating point dtype will be used. +device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. """) -def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array: +def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, + device: xc.Device | Sharding | None = None) -> Array: dtype = dtype or dtypes.canonicalize_dtype(jnp.float_) if isinstance(n, (list, tuple)): raise ValueError( @@ -288,12 +295,12 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array: "Got d = %s." % list(d)) if n % 2 == 0: - k = jnp.arange(0, n // 2 + 1, dtype=dtype) + k = jnp.arange(0, n // 2 + 1, dtype=dtype, device=device) else: - k = jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype) + k = jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype, device=device) - return k / jnp.array(d * n, dtype=dtype) + return k / jnp.array(d * n, dtype=dtype, device=device) @implements(np.fft.fftshift) diff --git a/tests/fft_test.py b/tests/fft_test.py index ce7455fdb4f8..05fa96a93fae 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -346,13 +346,16 @@ def testFft2Errors(self, inverse, real): dtype=all_dtypes, size=[9, 10, 101, 102], d=[0.1, 2.], + device=[None, -1], ) - def testFftfreq(self, size, d, dtype): + def testFftfreq(self, size, d, dtype, device): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng([size], dtype),) jnp_op = jnp.fft.fftfreq np_op = np.fft.fftfreq - jnp_fn = lambda a: jnp_op(size, d=d) + if device is not None: + device = jax.devices()[device] + jnp_fn = lambda a: jnp_op(size, d=d, device=device) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, @@ -362,6 +365,10 @@ def testFftfreq(self, size, d, dtype): if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) + # Test device + if device is not None: + out = jnp_fn(args_maker()) + self.assertEqual(out.devices(), {device}) @jtu.sample_product(n=[[0, 1, 2]]) def testFftfreqErrors(self, n): @@ -384,13 +391,16 @@ def testFftfreqErrors(self, n): dtype=all_dtypes, size=[9, 10, 101, 102], d=[0.1, 2.], + device=[None, -1], ) - def testRfftfreq(self, size, d, dtype): + def testRfftfreq(self, size, d, dtype, device): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng([size], dtype),) jnp_op = jnp.fft.rfftfreq np_op = np.fft.rfftfreq - jnp_fn = lambda a: jnp_op(size, d=d) + if device is not None: + device = jax.devices()[device] + jnp_fn = lambda a: jnp_op(size, d=d, device=device) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, @@ -400,6 +410,10 @@ def testRfftfreq(self, size, d, dtype): if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) + # Test device + if device is not None: + out = jnp_fn(args_maker()) + self.assertEqual(out.devices(), {device}) @jtu.sample_product(n=[[0, 1, 2]]) def testRfftfreqErrors(self, n):