Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiSteps OOM on TPU #320

Closed
agemagician opened this issue Mar 22, 2022 · 6 comments
Closed

MultiSteps OOM on TPU #320

agemagician opened this issue Mar 22, 2022 · 6 comments

Comments

@agemagician
Copy link

agemagician commented Mar 22, 2022

Hello,

I am testing "MultiSteps" for gradient accumulation on Colab TPU V2, but every time I use anything above 1, I get an OOM.

It seems that it increases the memory requirements equivalently to increase the batch size, which should not be the case.

My understanding is that all I need to do is to use it, is to call it after the optimizer and increase the training batch size by multiplying it by the gradient accumulation steps.
My understanding is also that I should be able to increase the batch size by using gradient accumulation.
In my case, if the per_device_train_batch_size is 8, then I can only set gradient_accumulation_steps to 1, and if the per_device_train_batch_size is 2, then I can then increase the gradient_accumulation_steps to 4.

 train_batch_size = (
        int(training_args.per_device_train_batch_size)
        * jax.device_count()
        * training_args.gradient_accumulation_steps
    )
optimizer = optax.adamw(
            learning_rate=learning_schedule_fn,
            b1=training_args.optim_beta1,
            b2=training_args.optim_beta2,
            weight_decay=training_args.optim_weight_decay,
            mask=decay_mask_fn,
        )

if training_args.gradient_accumulation_steps > 1:
    optimizer = optax.MultiSteps(
        optimizer, training_args.gradient_accumulation_steps
    )

The rest of the code should be the same.

Is my understanding is correct or there is something else we should take care of while using the MultiSteps function?

@mkunesch
Copy link
Member

mkunesch commented Mar 24, 2022

Hi!

I think the problem is that you are increasing the training batch size - this isn't necessary. The wrapper counts batches as you pass them in and returns updates=0 whenever the count isn't k and then the accumulated gradients at the kth step. This means that there shouldn't be an extra memory requirement.

Is this something that's confusing in the documentation? If so, is there something we could change to make it clearer?

Thanks a lot!

@agemagician
Copy link
Author

agemagician commented Mar 24, 2022

Thanks @mkunesch for your reply.

Please, correct me if I am mistaken.

From my experience in Pytorch + GPU + gradient accumulation, I can increase the accumulation steps without causing the GPU memory to increase and gives OOM.
For example, if the actual maximum GPU batch size is 16, then I can make the gradient accumulation, any number like 2 or 32, then the gradient will be calculated for every gradient accumulation step, which virtually increases the global batch size.

In the case of MultuSteps, if the actual maximum TPU batch size is 16, then I can only make the MultiStep work with 1 gradient accumulation step, if I want to increase the gradient accumulation step, then I have to decrease the actual TPU batch size. For example, if I want to make the gradient accumulation step = 2, then I must decrease the TPU batch size down to 8. This makes the virtual global batch size is always the same.

My understanding is that I can increase the gradient accumulation steps in MultiSteps without having to decrease the actual TPU batch size.

Please, correct me if my understanding is incorrect, or if my understanding is correct but there is a bug hat causing the TPU to give OOM.

@mkunesch
Copy link
Member

mkunesch commented Mar 24, 2022

Hi! Yes, you should totally be able to increase the accumulation steps in optax.multi_steps without causing the memory usage to increase - but you are also increasing the batch size in your code. This isn't how the optax implementation of multi_steps works: in the optax implementation, the optax wrapper accumulates results from k subsequent batches you pass in. The batches should have the same size independent of k.

So the following line in your code shouldn't be there:

        * training_args.gradient_accumulation_steps

Does that make sense or am I misunderstanding the problem?

@agemagician
Copy link
Author

Yes, it makes a lot of sense now.

So the train_batch_size should always be fixed and independent of the gradient_accumulation_steps.

The correct code should be :

train_batch_size = (
        int(training_args.per_device_train_batch_size)
        * jax.device_count()
    )
optimizer = optax.adamw(
            learning_rate=learning_schedule_fn,
            b1=training_args.optim_beta1,
            b2=training_args.optim_beta2,
            weight_decay=training_args.optim_weight_decay,
            mask=decay_mask_fn,
        )

if training_args.gradient_accumulation_steps > 1:
    optimizer = optax.MultiSteps(
        optimizer, training_args.gradient_accumulation_steps
    )

The MultiSteps and optimizer are expecting to receive the actual physical batch size, and then it will not calculate these batches until it hits the gradient_accumulation_steps.

It is very clear now.

It will be great if you could explain that in the documentation.

Thanks a lot for your explanation.

@agemagician
Copy link
Author

My suggestion is :

An optimiser wrapper to spread gradient computation over multiple steps.

This wrapper will allow multiple mini-steps to accumulate their gradients together before applying them. It wraps another optimiser, and makes sure that this optimiser updates its state only when enough mini-steps have been performed. At any other mini-step, the inner optimiser is not used and the updates returned by the wrapper are all 0.

The number of mini-steps per gradient update is controlled by a function, and it can vary over training. This offers a mean of varying batch size over training.

new ====> Important: The training batch size should not be changed in the training code when using the MultiSteps function. The MultiSteps and optimizer are expecting to receive the physical batch size, then it will apply the gradient every k step.

example:
train_batch_size = (
        per_device_train_batch_size
        * jax.device_count()
    )
optimizer = optax.adamw(
            learning_rate=0.01
        )

if gradient_accumulation_steps > 1:
    optimizer = optax.MultiSteps(
        optimizer, gradient_accumulation_steps
    )

Maybe something like that.

@agemagician
Copy link
Author

My initial assumption is that the MultiSteps will take the virtual batch size and perform a loop over the mini-batches.

Thanks again for your excellent explanation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants