Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potentially harmful use of assert statements in solver.py #13

Open
itk22 opened this issue Jul 6, 2023 · 0 comments
Open

Potentially harmful use of assert statements in solver.py #13

itk22 opened this issue Jul 6, 2023 · 0 comments
Assignees

Comments

@itk22
Copy link
Collaborator

itk22 commented Jul 6, 2023

Dear @tianjuxue,
I noticed a problematic use of assert statements in the FEM's solver.py code which could lead to failures. Here is the relevant code:

def jax_solve(problem, A_fn, b, x0, precond):
    pc = get_jacobi_precond(jacobi_preconditioner(problem)) if precond else None
    x, info = jax.scipy.sparse.linalg.bicgstab(A_fn, b, x0=x0, M=pc, tol=1e-10, atol=1e-10, maxiter=10000)

    # Verify convergence
    err = np.linalg.norm(A_fn(x) - b)
    print(f"JAX scipy linear solve res = {err}")

    # HERE IS THE PROBLEMATIC ASSERT:
    assert err < 0.1, f"JAX linear solver failed to converge with err = {err}"

    return x

The assert statement above acts as a control flow statement and requires concrete values to work properly. In one of my use cases, the err variable is actually a:

Traced<ShapedArray(float64[])>with<BatchTrace(level=1/0)> with
 val = Array([0.], dtype=float64)
 batch_dim = 0

and the assert statement breaks. I am not exactly sure why in my case err is not a concrete value but I do feel like having pure asserts here doesn't fit with the functional purity phiolosophy of JAX. I think a feasible alternative for this could be using jax.lax.cond or perhaps assertions from the Chex library.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant