[pallas] Improve some error messages and add API tests. #22173
Merged
+338
−81
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
We make the following improvements:
mismatch using
tree_util.KeyPath
.in_specs
is nota sequence, instead of the current PyTreeDef mismatch error.
returns an unexpected number of values.
check that the block shapes are static.
the overall array. Without this we used to get a
safe_zip
error. We also carry the pytree paths to localize the error.
we used to get
body_fun output and input must have same type structure
in the interpreter,
assert len(jaxpr.outvars) == 0
on GPU,and
INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0
on TPU.
To simplify the generation of the error messages we added a helper
function
tree_util.equality_errors_pytreedef
, which is just liketree_util.equality_errors
but takesPyTreeDef
inputs rather thanPyTrees. We then used this new helper function in
pjit.py
andstages.py
.