diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index e15788b0cc12..d28843e19b67 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -496,7 +496,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]: @partial(jit, static_argnames=('method',)) def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: """ - Computes the sign and (natural) logarithm of the determinant of an array. + Compute the sign and (natural) logarithm of the determinant of an array. JAX implementation of :func:`numpy.linalg.slotdet`. @@ -662,7 +662,7 @@ def _det_3x3(a: Array) -> Array: @jit def det(a: ArrayLike) -> Array: """ - Computes the determinant of an array. + Compute the determinant of an array. JAX implementation of :func:`numpy.linalg.det`. @@ -706,7 +706,7 @@ def _det_jvp(primals, tangents): def eig(a: ArrayLike) -> tuple[Array, Array]: """ - Computes the eigenvalues and eigenvectors of a square array. + Compute the eigenvalues and eigenvectors of a square array. JAX implementation of :func:`numpy.linalg.eig`. @@ -750,7 +750,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: @jit def eigvals(a: ArrayLike) -> Array: """ - Computes the eigenvalues of a general matrix. + Compute the eigenvalues of a general matrix. JAX implementation of :func:`numpy.linalg.eigvals`. @@ -788,7 +788,7 @@ def eigvals(a: ArrayLike) -> Array: def eigh(a: ArrayLike, UPLO: str | None = None, symmetrize_input: bool = True) -> EighResult: """ - Computes the eigenvalues and eigenvectors of a Hermitian matrix. + Compute the eigenvalues and eigenvectors of a Hermitian matrix. JAX implementation of :func:`numpy.linalg.eigh`. @@ -842,7 +842,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None, @partial(jit, static_argnames=('UPLO',)) def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: """ - Computes the eigenvalues of a Hermitian matrix. + Compute the eigenvalues of a Hermitian matrix. JAX implementation of :func:`numpy.linalg.eigvalsh`. @@ -1599,7 +1599,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False, ord: int | str = 2) -> Array: - """Computes the vector norm of a vector or batch of vectors. + """Compute the vector norm of a vector or batch of vectors. JAX implementation of :func:`numpy.linalg.vector_norm`. @@ -2136,3 +2136,47 @@ def cond(x: ArrayLike, p=None): r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1)) # Convert NaNs to infs where original array has no NaNs. return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r) + + +def trace(x: ArrayLike, /, *, + offset: int = 0, dtype: DTypeLike | None = None) -> Array: + """Compute the trace of a matrix. + + JAX implementation of :func:`numpy.linalg.trace`. + + Args: + x: array of shape ``(..., M, N)`` and whose innermost two + dimensions form MxN matrices for which to take the trace. + offset: positive or negative offset from the main diagonal + (default: 0). + dtype: data type of the returned array (default: ``None``). If ``None``, + then output dtype will match the dtype of ``x``, promoted to default + precision in the case of integer types. + + Returns: + array of batched traces with shape ``x.shape[:-2]`` + + See also: + - :func:`jax.numpy.trace`: similar API in the ``jax.numpy`` namespace. + + Examples: + Trace of a single matrix: + + >>> x = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8], + ... [9, 10, 11, 12]]) + >>> jnp.linalg.trace(x) + Array(18, dtype=int32) + >>> jnp.linalg.trace(x, offset=1) + Array(21, dtype=int32) + >>> jnp.linalg.trace(x, offset=-1, dtype="float32") + Array(15., dtype=float32) + + Batched traces: + + >>> x = jnp.arange(24).reshape(2, 3, 4) + >>> jnp.linalg.trace(x) + Array([15, 51], dtype=int32) + """ + check_arraylike('jnp.linalg.trace', x) + return jnp.trace(x, offset=offset, axis1=-2, axis2=-1, dtype=dtype) diff --git a/jax/experimental/array_api/linalg.py b/jax/experimental/array_api/linalg.py index f19955409d5f..6494884135fe 100644 --- a/jax/experimental/array_api/linalg.py +++ b/jax/experimental/array_api/linalg.py @@ -35,8 +35,7 @@ vector_norm as vector_norm, ) -# TODO(micky774): Add trace to jax.numpy.linalg -from jax.numpy import trace as trace +from jax.numpy.linalg import trace as trace from jax.experimental.array_api._linear_algebra_functions import ( matrix_rank as matrix_rank, diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index 98b9ca3e0694..c342fde0ae6e 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -44,6 +44,7 @@ tensordot as tensordot, tensorinv as tensorinv, tensorsolve as tensorsolve, + trace as trace, vector_norm as vector_norm, vecdot as vecdot, ) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 52d359ec2ffe..f5471a881764 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1305,6 +1305,18 @@ def testDiagonal(self, shape, dtype, offset): self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) + def testTrace(self): + shape, dtype, offset, out_dtype = (3, 4), "float32", 0, None + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + lax_fun = partial(jnp.linalg.trace, offset=offset, dtype=out_dtype) + if jtu.numpy_version() >= (2, 0, 0): + np_fun = partial(np.linalg.trace, offset=offset) + else: + np_fun = partial(np.trace, offset=offset, axis1=-2, axis2=-1, dtype=out_dtype) + self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker) + class ScipyLinalgTest(jtu.JaxTestCase):