Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas] Improve some error messages and add API tests. #22173

Merged
merged 1 commit into from
Jul 4, 2024

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jun 28, 2024

We make the following improvements:

  • pytree structural disequality messages attempt to localize the
    mismatch using tree_util.KeyPath.
  • we generate a simpler error message for when in_specs is not
    a sequence, instead of the current PyTreeDef mismatch error.
  • we generate an error message for when the index map function
    returns an unexpected number of values.
  • added error localization to the existing shape polymorphism
    check that the block shapes are static.
  • we check that the rank of the block_shape matches the rank of
    the overall array. Without this we used to get a safe_zip
    error. We also carry the pytree paths to localize the error.
  • We check that the kernel function returns None. Without this
    we used to get body_fun output and input must have same type structure
    in the interpreter, assert len(jaxpr.outvars) == 0 on GPU,
    and INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0
    on TPU.

To simplify the generation of the error messages we added a helper
function tree_util.equality_errors_pytreedef, which is just like
tree_util.equality_errors but takes PyTreeDef inputs rather than
PyTrees. We then used this new helper function in pjit.py and stages.py.

@gnecula gnecula self-assigned this Jun 28, 2024
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jun 28, 2024
@gnecula gnecula force-pushed the pallas_errors branch 9 times, most recently from e821c8b to 0d09d05 Compare July 2, 2024 09:46
@gnecula gnecula force-pushed the pallas_errors branch 10 times, most recently from 5417817 to c5c67b6 Compare July 3, 2024 09:36
jax/_src/pallas/core.py Show resolved Hide resolved
@gnecula gnecula force-pushed the pallas_errors branch 2 times, most recently from 7936254 to 3dcbc6f Compare July 4, 2024 06:38
We make the following improvements:

  * pytree structural disequality messages now attempt to localize the
    mismatch using tree_util.KeyPath.
  * we generate a simpler error message for when `in_specs` is not
    a sequence, instead of the current PyTreeDef mismatch error.
  * we generate an error message for when the index map function
    in a BlockSpec returns an unexpected number of results.
  * added error localization to the existing shape polymorphism
    check that the block shapes are static.
  * We check that the kernel function returns None. Without this
    we used to get `body_fun output and input must have same type structure`
    in the interpreter, `assert len(jaxpr.outvars) == 0` on GPU,
    and `INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0`
    on TPU.
  * we check that the rank of the block_shape matches the rank of
    the overall array. Without this we used to get a `safe_zip`
    error. We also carry the pytree paths to localize the error.

To simplify the generation of the error messages we added a helper
function `tree_util.equality_errors_pytreedef`, which is just like
`tree_util.equality_errors` but takes `PyTreeDef` inputs rather than
PyTrees. We then used this new helper function in `pjit.py` and `stages.py`.
@copybara-service copybara-service bot merged commit 6c00cd1 into google:main Jul 4, 2024
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants