Skip to content

Commit

Permalink
Added device kwarg to jnp.fft.fftfreq and jnp.fft.rfftfreq
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Jun 28, 2024
1 parent db2e347 commit 4c2ac74
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 36 deletions.
71 changes: 39 additions & 32 deletions jax/_src/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

from jax import dtypes
from jax import lax
from jax._src.lib import xla_client
from jax.sharding import Sharding
from jax._src.lib import xla_client as xc
from jax._src.util import safe_zip
from jax._src.numpy.util import check_arraylike, implements, promote_dtypes_inexact
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs, reductions
from jax._src.typing import Array, ArrayLike
from jax._src.typing import Array, ArrayLike, DTypeLike

Shape = Sequence[int]

Expand All @@ -44,7 +45,7 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array:
'"ortho" or "forward".')


def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
def _fft_core(func_name: str, fft_type: xc.FftType, a: ArrayLike,
s: Shape | None, axes: Sequence[int] | None,
norm: str | None) -> Array:
full_name = f"jax.numpy.fft.{func_name}"
Expand Down Expand Up @@ -86,14 +87,14 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
in_s = list(arr.shape)
for axis, x in safe_zip(axes, s):
in_s[axis] = x
if fft_type == xla_client.FftType.IRFFT:
if fft_type == xc.FftType.IRFFT:
in_s[-1] = (in_s[-1] // 2 + 1)
# Cropping
arr = arr[tuple(map(slice, in_s))]
# Padding
arr = jnp.pad(arr, [(0, x-y) for x, y in zip(in_s, arr.shape)])
else:
if fft_type == xla_client.FftType.IRFFT:
if fft_type == xc.FftType.IRFFT:
s = [arr.shape[axis] for axis in axes[:-1]]
if axes:
s += [max(0, 2 * (arr.shape[axes[-1]] - 1))]
Expand All @@ -113,28 +114,28 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
def fftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm)
return _fft_core('fftn', xc.FftType.FFT, a, s, axes, norm)


@implements(np.fft.ifftn)
def ifftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm)
return _fft_core('ifftn', xc.FftType.IFFT, a, s, axes, norm)


@implements(np.fft.rfftn)
def rfftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm)
return _fft_core('rfftn', xc.FftType.RFFT, a, s, axes, norm)


@implements(np.fft.irfftn)
def irfftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm)
return _fft_core('irfftn', xc.FftType.IRFFT, a, s, axes, norm)


def _axis_check_1d(func_name: str, axis: int | None):
Expand All @@ -145,7 +146,7 @@ def _axis_check_1d(func_name: str, axis: int | None):
"Got axis = %r." % (full_name, full_name, axis)
)

def _fft_core_1d(func_name: str, fft_type: xla_client.FftType,
def _fft_core_1d(func_name: str, fft_type: xc.FftType,
a: ArrayLike, n: int | None, axis: int | None,
norm: str | None) -> Array:
_axis_check_1d(func_name, axis)
Expand All @@ -157,25 +158,25 @@ def _fft_core_1d(func_name: str, fft_type: xla_client.FftType,
@implements(np.fft.fft)
def fft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('fft', xla_client.FftType.FFT, a, n=n, axis=axis,
return _fft_core_1d('fft', xc.FftType.FFT, a, n=n, axis=axis,
norm=norm)

@implements(np.fft.ifft)
def ifft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, n=n, axis=axis,
return _fft_core_1d('ifft', xc.FftType.IFFT, a, n=n, axis=axis,
norm=norm)

@implements(np.fft.rfft)
def rfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, n=n, axis=axis,
return _fft_core_1d('rfft', xc.FftType.RFFT, a, n=n, axis=axis,
norm=norm)

@implements(np.fft.irfft)
def irfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, n=n, axis=axis,
return _fft_core_1d('irfft', xc.FftType.IRFFT, a, n=n, axis=axis,
norm=norm)

@implements(np.fft.hfft)
Expand All @@ -184,7 +185,7 @@ def hfft(a: ArrayLike, n: int | None = None,
conj_a = ufuncs.conj(a)
_axis_check_1d('hfft', axis)
nn = (conj_a.shape[axis] - 1) * 2 if n is None else n
return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, n=n, axis=axis,
return _fft_core_1d('hfft', xc.FftType.IRFFT, conj_a, n=n, axis=axis,
norm=norm) * nn

@implements(np.fft.ihfft)
Expand All @@ -193,12 +194,12 @@ def ihfft(a: ArrayLike, n: int | None = None,
_axis_check_1d('ihfft', axis)
arr = jnp.asarray(a)
nn = arr.shape[axis] if n is None else n
output = _fft_core_1d('ihfft', xla_client.FftType.RFFT, arr, n=n, axis=axis,
output = _fft_core_1d('ihfft', xc.FftType.RFFT, arr, n=n, axis=axis,
norm=norm)
return ufuncs.conj(output) * (1 / nn)


def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
def _fft_core_2d(func_name: str, fft_type: xc.FftType, a: ArrayLike,
s: Shape | None, axes: Sequence[int],
norm: str | None) -> Array:
full_name = f"jax.numpy.fft.{func_name}"
Expand All @@ -213,34 +214,37 @@ def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
@implements(np.fft.fft2)
def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes,
return _fft_core_2d('fft2', xc.FftType.FFT, a, s=s, axes=axes,
norm=norm)

@implements(np.fft.ifft2)
def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes,
return _fft_core_2d('ifft2', xc.FftType.IFFT, a, s=s, axes=axes,
norm=norm)

@implements(np.fft.rfft2)
def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes,
return _fft_core_2d('rfft2', xc.FftType.RFFT, a, s=s, axes=axes,
norm=norm)

@implements(np.fft.irfft2)
def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes,
return _fft_core_2d('irfft2', xc.FftType.IRFFT, a, s=s, axes=axes,
norm=norm)


@implements(np.fft.fftfreq, extra_params="""
dtype : Optional
dtype : dtype, optional
The dtype of the returned frequencies. If not specified, JAX's default
floating point dtype will be used.
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
device: xc.Device | Sharding | None = None) -> Array:
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
if isinstance(n, (list, tuple)):
raise ValueError(
Expand All @@ -252,13 +256,13 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
"Got d = %s." % list(d))

k = jnp.zeros(n, dtype=dtype)
k = jnp.zeros(n, dtype=dtype, device=device)
if n % 2 == 0:
# k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
k = k.at[0: n // 2].set( jnp.arange(0, n // 2, dtype=dtype))
k = k.at[0: n // 2].set(jnp.arange(0, n // 2, dtype=dtype))

# k[n // 2:] = jnp.arange(-n // 2, -1)
k = k.at[n // 2:].set( jnp.arange(-n // 2, 0, dtype=dtype))
k = k.at[n // 2:].set(jnp.arange(-n // 2, 0, dtype=dtype))

else:
# k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
Expand All @@ -267,15 +271,18 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype))

return k / jnp.array(d * n, dtype=dtype)
return k / jnp.array(d * n, dtype=dtype, device=device)


@implements(np.fft.rfftfreq, extra_params="""
dtype : Optional
dtype : dtype, optional
The dtype of the returned frequencies. If not specified, JAX's default
floating point dtype will be used.
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
device: xc.Device | Sharding | None = None) -> Array:
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
if isinstance(n, (list, tuple)):
raise ValueError(
Expand All @@ -288,12 +295,12 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
"Got d = %s." % list(d))

if n % 2 == 0:
k = jnp.arange(0, n // 2 + 1, dtype=dtype)
k = jnp.arange(0, n // 2 + 1, dtype=dtype, device=device)

else:
k = jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)
k = jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype, device=device)

return k / jnp.array(d * n, dtype=dtype)
return k / jnp.array(d * n, dtype=dtype, device=device)


@implements(np.fft.fftshift)
Expand Down
22 changes: 18 additions & 4 deletions tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,16 @@ def testFft2Errors(self, inverse, real):
dtype=all_dtypes,
size=[9, 10, 101, 102],
d=[0.1, 2.],
device=[None, -1],
)
def testFftfreq(self, size, d, dtype):
def testFftfreq(self, size, d, dtype, device):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng([size], dtype),)
jnp_op = jnp.fft.fftfreq
np_op = np.fft.fftfreq
jnp_fn = lambda a: jnp_op(size, d=d)
if device is not None:
device = jax.devices()[device]
jnp_fn = lambda a: jnp_op(size, d=d, device=device)
np_fn = lambda a: np_op(size, d=d)
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
Expand All @@ -362,6 +365,10 @@ def testFftfreq(self, size, d, dtype):
if dtype in inexact_dtypes:
tol = 0.15 # TODO(skye): can we be more precise?
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
# Test device
if device is not None:
out = jnp_fn(args_maker())
self.assertEqual(out.devices(), {device})

@jtu.sample_product(n=[[0, 1, 2]])
def testFftfreqErrors(self, n):
Expand All @@ -384,13 +391,16 @@ def testFftfreqErrors(self, n):
dtype=all_dtypes,
size=[9, 10, 101, 102],
d=[0.1, 2.],
device=[None, -1],
)
def testRfftfreq(self, size, d, dtype):
def testRfftfreq(self, size, d, dtype, device):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng([size], dtype),)
jnp_op = jnp.fft.rfftfreq
np_op = np.fft.rfftfreq
jnp_fn = lambda a: jnp_op(size, d=d)
if device is not None:
device = jax.devices()[device]
jnp_fn = lambda a: jnp_op(size, d=d, device=device)
np_fn = lambda a: np_op(size, d=d)
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
Expand All @@ -400,6 +410,10 @@ def testRfftfreq(self, size, d, dtype):
if dtype in inexact_dtypes:
tol = 0.15 # TODO(skye): can we be more precise?
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
# Test device
if device is not None:
out = jnp_fn(args_maker())
self.assertEqual(out.devices(), {device})

@jtu.sample_product(n=[[0, 1, 2]])
def testRfftfreqErrors(self, n):
Expand Down

0 comments on commit 4c2ac74

Please sign in to comment.