Skip to content

Commit

Permalink
Merge pull request #21725 from rajasekharporeddy:testbranch3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642332950
  • Loading branch information
jax authors committed Jun 11, 2024
2 parents 8199267 + 7989c70 commit 5cf52b8
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,32 @@ def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 1
See Also:
:func:`jax.scipy.linalg.expm_frechet`
Example:
``expm`` is the matrix exponential, and has similar properties to the more
familiar scalar exponential. For scalars ``a`` and ``b``, :math:`e^{a + b}
= e^a e^b`. However, for matrices, this property only holds when ``A`` and
``B`` commute (``AB = BA``). In this case, ``expm(A+B) = expm(A) @ expm(B)``
>>> A = jnp.array([[2, 0],
... [0, 1]])
>>> B = jnp.array([[3, 0],
... [0, 4]])
>>> jnp.allclose(jax.scipy.linalg.expm(A+B),
... jax.scipy.linalg.expm(A) @ jax.scipy.linalg.expm(B),
... rtol=0.0001)
Array(True, dtype=bool)
If a matrix ``X`` is invertible, then
``expm(X @ A @ inv(X)) = X @ expm(A) @ inv(X)``
>>> X = jnp.array([[3, 1],
... [2, 5]])
>>> X_inv = jax.scipy.linalg.inv(X)
>>> jnp.allclose(jax.scipy.linalg.expm(X @ A @ X_inv),
... X @ jax.scipy.linalg.expm(A) @ X_inv)
Array(True, dtype=bool)
"""
A, = promote_dtypes_inexact(A)

Expand Down Expand Up @@ -1642,6 +1668,36 @@ def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: float
``posdef`` is either :math:`n \times n` or :math:`m \times m` depending on
whether ``side`` is ``"right"`` or ``"left"``, respectively.
Example:
Polar decomposition of a 3x3 matrix:
>>> a = jnp.array([[1., 2., 3.],
... [5., 4., 2.],
... [3., 2., 1.]])
>>> U, P = jax.scipy.linalg.polar(a)
U is a Unitary Matrix:
>>> jnp.round(U.T @ U)
Array([[ 1., -0., -0.],
[-0., 1., 0.],
[-0., 0., 1.]], dtype=float32)
P is positive-semidefinite Matrix:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(P)
[[4.79 3.25 1.23]
[3.25 3.06 2.01]
[1.23 2.01 2.91]]
The original matrix can be reconstructed by multiplying the U and P:
>>> a_reconstructed = U @ P
>>> jnp.allclose(a, a_reconstructed)
Array(True, dtype=bool)
.. _QDWH: https://epubs.siam.org/doi/abs/10.1137/090774999
"""
arr = jnp.asarray(a)
Expand Down

0 comments on commit 5cf52b8

Please sign in to comment.