Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch sdpa implementation in ASR mha #9590

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

WoodieDudy
Copy link

@WoodieDudy WoodieDudy commented Jul 2, 2024

Hola. I changed the mha implementation for the ASR modules so that it uses torch.nn.functional.scaled_dot_product_attention.
This accelerated forward in the mha by 27% and backward by 17% on the A100.
Pytorch sdpa is continuously being optimized, ensuring that we benefit from the latest performance improvements.
My code uses memory efficient backend in sdpa because flash attention doesn't support custom attention bias. There is ongoing work to contribute custom bias support in the flash-attention repository. PR.

What else do I need to do to merge this pr?

Usage

There is also my benchmark:

import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
from nemo.collections.asr.parts.submodules.multi_head_attention import RelPositionMultiHeadAttention

torch.manual_seed(123)

device = "cuda"
batch_size = 32
seq_len = 1024
d_model = 512
n_head = 8

query = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
key = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False)
mask = torch.triu(mask, diagonal=1).bool()
mask = None
pos_emb = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)

attention_sdpa = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None, use_sdpa=True).to(device)
attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device)
for original_param, sdpa_param in zip(attention_original.parameters(), attention_sdpa.parameters()):
    original_param.data.copy_(sdpa_param.data)

# attention_sdpa = torch.compile(attention_sdpa)
# attention_original = torch.compile(attention_original)


def measure_time(attention, query, key, value, mask, pos_emb):
    timer = benchmark.Timer(
        stmt='attention(query, key, value, mask, pos_emb);torch.cuda.synchronize()',
        setup='torch.cuda.synchronize()',
        globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
    )

    with torch.no_grad():
        torch.cuda.synchronize()
        results = timer.blocked_autorange(min_run_time=10)
        forward_time = results.mean
        output = attention(query, key, value, mask, pos_emb)
    return forward_time, output


def measure_fwd_bwd_time(attention, query, key, value, mask, pos_emb):
    timer = benchmark.Timer(
        stmt='loss=attention(query, key, value, mask, pos_emb).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()',
        globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
    )
    torch.cuda.synchronize()
    results = timer.blocked_autorange(min_run_time=10)
    fwd_bwd_time = results.mean
    return fwd_bwd_time


time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb)

print(f"Original implementation time: {time_fwd_original:.6f} seconds")
print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds")
print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%")

time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask, pos_emb)
time_bwd_original = time_fwd_bwd_original - time_fwd_original
time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa

print(f"Original implementation backward time: {time_bwd_original:.6f} seconds")
print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds")
print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%")

print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}")

# Original implementation time: 0.049075 seconds
# SDPA implementation time: 0.035598 seconds
# SDPA boost 27.46%
# Original implementation backward time: 0.127004 seconds
# SDPA implementation backward time: 0.104986 seconds
# SDPA backward boost 17.34%
# Outputs are the same

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Who can review?

cc @titu1994 @SeanNaren

Additional Information

@github-actions github-actions bot added the ASR label Jul 2, 2024
@WoodieDudy
Copy link
Author

I also attempted to run the tests in the repository but encountered an issue. NaNs appear when a mask with a fully False row is passed to MHA. Because of such mask, filling the matrix_bd with -inf values using matrix_bd.masked_fill_(mask.logical_not(), float("-inf")) results in a row of only -inf, and after the softmax, this entire row becomes NaNs. I am unsure how to resolve this since the softmax and multiplication by value occur within torch.nn.functional.scaled_dot_product_attention, and I cannot intervene. In your implementation, this is handled by manually filling with zeros attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0).

What can be done about this? Should we write a separate test that doesn't use such masks as input?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant