Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect dtype promotion in DynamicLossScale #47

Open
joeryjoery opened this issue Jan 26, 2024 · 2 comments
Open

Incorrect dtype promotion in DynamicLossScale #47

joeryjoery opened this issue Jan 26, 2024 · 2 comments

Comments

@joeryjoery
Copy link

Following the provided example for the DynamicLossScale causes errors if run directly.

import jax
import jax.numpy as jnp
import jmp

dyn = jmp.DynamicLossScale(jnp.float16(2**15))

g = jnp.ones(5, jnp.float16)
finite = jmp.all_finite(g)
dyn.adjust(~finite)
>> TypeError: lax.select requires arguments to have the same dtypes, got float32, float16. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

This is of course fixed by doing jmp.DynamicLossScale(jnp.float32(2**15)), but doesn't this defeat the purpose of this object?

@ekuznetsov139
Copy link

What is needed is to construct DynamicLossScale as jmp.DynamicLossScale(jnp.float32(2**15)) and to change loss_scale.py:132 to
return jax.tree_util.tree_map(lambda x: (x * self.loss_scale).astype(x.dtype), tree)
This way gradients are computed in float16 and loss_scale.loss_scale won't overflow after the first 2000 steps (if it is in float16, jmp will increase it to 2**16, which is outside legal range of float16).

What really puzzles me is that this is the only jax mixed precision package that comes up in searches, and it is evidently not just dead, but it has been broken for months and no one cares. Which raises two possibilities:

  1. Does everyone use jax to train their models strictly in float32 or bf16?
  2. Does no one use jax any more?

@joeryjoery
Copy link
Author

What is needed is to construct DynamicLossScale as jmp.DynamicLossScale(jnp.float32(2**15)) and to change loss_scale.py:132 to return jax.tree_util.tree_map(lambda x: (x * self.loss_scale).astype(x.dtype), tree) This way gradients are computed in float16 and loss_scale.loss_scale won't overflow after the first 2000 steps (if it is in float16, jmp will increase it to 2**16, which is outside legal range of float16).

What really puzzles me is that this is the only jax mixed precision package that comes up in searches, and it is evidently not just dead, but it has been broken for months and no one cares. Which raises two possibilities:

  1. Does everyone use jax to train their models strictly in float32 or bf16?
  2. Does no one use jax any more?

I'll try that with a fork of this repo when I have time, thanks for the suggestion!

I think jax is growing in popularity though haha ;p. Though, these open-source projects might not be a deepmind-priority.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants