Skip to content

Commit

Permalink
Added trace alias to jnp.linalg
Browse files Browse the repository at this point in the history
Related to #21088
  • Loading branch information
vfdev-5 committed May 29, 2024
1 parent 24509f3 commit 0030d22
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 9 deletions.
58 changes: 51 additions & 7 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]:
@partial(jit, static_argnames=('method',))
def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult:
"""
Computes the sign and (natural) logarithm of the determinant of an array.
Compute the sign and (natural) logarithm of the determinant of an array.
JAX implementation of :func:`numpy.linalg.slotdet`.
Expand Down Expand Up @@ -662,7 +662,7 @@ def _det_3x3(a: Array) -> Array:
@jit
def det(a: ArrayLike) -> Array:
"""
Computes the determinant of an array.
Compute the determinant of an array.
JAX implementation of :func:`numpy.linalg.det`.
Expand Down Expand Up @@ -706,7 +706,7 @@ def _det_jvp(primals, tangents):

def eig(a: ArrayLike) -> tuple[Array, Array]:
"""
Computes the eigenvalues and eigenvectors of a square array.
Compute the eigenvalues and eigenvectors of a square array.
JAX implementation of :func:`numpy.linalg.eig`.
Expand Down Expand Up @@ -750,7 +750,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
@jit
def eigvals(a: ArrayLike) -> Array:
"""
Computes the eigenvalues of a general matrix.
Compute the eigenvalues of a general matrix.
JAX implementation of :func:`numpy.linalg.eigvals`.
Expand Down Expand Up @@ -788,7 +788,7 @@ def eigvals(a: ArrayLike) -> Array:
def eigh(a: ArrayLike, UPLO: str | None = None,
symmetrize_input: bool = True) -> EighResult:
"""
Computes the eigenvalues and eigenvectors of a Hermitian matrix.
Compute the eigenvalues and eigenvectors of a Hermitian matrix.
JAX implementation of :func:`numpy.linalg.eigh`.
Expand Down Expand Up @@ -842,7 +842,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
@partial(jit, static_argnames=('UPLO',))
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
"""
Computes the eigenvalues of a Hermitian matrix.
Compute the eigenvalues of a Hermitian matrix.
JAX implementation of :func:`numpy.linalg.eigvalsh`.
Expand Down Expand Up @@ -1599,7 +1599,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:

def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
ord: int | str = 2) -> Array:
"""Computes the vector norm of a vector or batch of vectors.
"""Compute the vector norm of a vector or batch of vectors.
JAX implementation of :func:`numpy.linalg.vector_norm`.
Expand Down Expand Up @@ -2136,3 +2136,47 @@ def cond(x: ArrayLike, p=None):
r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1))
# Convert NaNs to infs where original array has no NaNs.
return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r)


def trace(x: ArrayLike, /, *,
offset: int = 0, dtype: DTypeLike | None = None) -> Array:
"""Compute the trace of a matrix.
JAX implementation of :func:`numpy.linalg.trace`.
Args:
x: array of shape ``(..., M, N)`` and whose innermost two
dimensions form MxN matrices for which to take the trace.
offset: positive or negative offset from the main diagonal
(default: 0).
dtype: data type of the returned array (default: ``None``). If ``None``,
then output dtype will match the dtype of ``x``, promoted to default
precision in the case of integer types.
Returns:
array of batched traces with shape ``x.shape[:-2]``
See also:
- :func:`jax.numpy.trace`: similar API in the ``jax.numpy`` namespace.
Examples:
Trace of a single matrix:
>>> x = jnp.array([[1, 2, 3, 4],
... [5, 6, 7, 8],
... [9, 10, 11, 12]])
>>> jnp.linalg.trace(x)
Array(18, dtype=int32)
>>> jnp.linalg.trace(x, offset=1)
Array(21, dtype=int32)
>>> jnp.linalg.trace(x, offset=-1, dtype="float32")
Array(15., dtype=float32)
Batched traces:
>>> x = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.linalg.trace(x)
Array([15, 51], dtype=int32)
"""
check_arraylike('jnp.linalg.trace', x)
return jnp.trace(x, offset=offset, axis1=-2, axis2=-1, dtype=dtype)
3 changes: 1 addition & 2 deletions jax/experimental/array_api/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
vector_norm as vector_norm,
)

# TODO(micky774): Add trace to jax.numpy.linalg
from jax.numpy import trace as trace
from jax.numpy.linalg import trace as trace

from jax.experimental.array_api._linear_algebra_functions import (
matrix_rank as matrix_rank,
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
tensordot as tensordot,
tensorinv as tensorinv,
tensorsolve as tensorsolve,
trace as trace,
vector_norm as vector_norm,
vecdot as vecdot,
)
12 changes: 12 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,18 @@ def testDiagonal(self, shape, dtype, offset):
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)

def testTrace(self):
shape, dtype, offset, out_dtype = (3, 4), "float32", 0, None
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
lax_fun = partial(jnp.linalg.trace, offset=offset, dtype=out_dtype)
if jtu.numpy_version() >= (2, 0, 0):
np_fun = partial(np.linalg.trace, offset=offset)
else:
np_fun = partial(np.trace, offset=offset, axis1=-2, axis2=-1, dtype=out_dtype)
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)


class ScipyLinalgTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 0030d22

Please sign in to comment.