You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am writing a pallas kernel, where one of the lines:
frobenius_sq_norm = square_norm(w_tl).sum(keepdims=True)... includes a sum() w/ keepdims. I don't expect sum() to work without keepdims, at that would produce a scalar. However, for keepdims=True, being vec->vec, it should work.
Cannot lower reductions to scalar. Reduce to one element vector instead, using keepdims=True.
This is because the actual implementation is first removing the dimension, and then adding it back in.
In reductions.py in jax.numpy, we can find:
if keepdims:
result = lax.expand_dims(result, pos_dims)
The real fix would probably be to pass along keepdims to all leaf locations where we actually end up invoking the op, and ensure that that op respects it and preserves the vec, instead of repackaging a scalar into a vec.
The text was updated successfully, but these errors were encountered:
Description
I am writing a pallas kernel, where one of the lines:
frobenius_sq_norm = square_norm(w_tl).sum(keepdims=True)...
includes a sum() w/ keepdims. I don't expect sum() to work without keepdims, at that would produce a scalar. However, for keepdims=True, being vec->vec, it should work.Cannot lower reductions to scalar. Reduce to one element vector instead, using keepdims=True.
This is because the actual implementation is first removing the dimension, and then adding it back in.
In reductions.py in jax.numpy, we can find:
The real fix would probably be to pass along keepdims to all leaf locations where we actually end up invoking the op, and ensure that that op respects it and preserves the vec, instead of repackaging a scalar into a vec.
The text was updated successfully, but these errors were encountered: