Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch sdpa implementation in ASR mha #9590

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ class MultiHeadAttention(nn.Module):
dropout_rate (float): dropout rate
"""

def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0):
def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0, use_sdpa=False):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadAttention, self).__init__()
self.use_sdpa = use_sdpa
self.cache_drop_size = None
self.dropout_rate = dropout_rate
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
Expand Down Expand Up @@ -139,8 +141,21 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
# temporary until we solve this more gracefully
with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
out = self.forward_attention(v, scores, mask)

if self.use_sdpa:
scale = 1 / self.s_d_k
n_batch = value.size(0)

if mask is not None:
mask = mask.unsqueeze(1)

out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout_rate, scale=scale)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
out = self.forward_attention(v, scores, mask)

if cache is None:
return out
else:
Expand All @@ -163,9 +178,9 @@ class RelPositionMultiHeadAttention(MultiHeadAttention):
dropout_rate (float): dropout rate
"""

def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0):
def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0, use_sdpa=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=max_cache_len)
super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=max_cache_len, use_sdpa=use_sdpa)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable biases are used in matrix c and matrix d
Expand Down Expand Up @@ -219,6 +234,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
q = q.transpose(1, 2) # (batch, time1, head, d_k)

n_batch_pos = pos_emb.size(0)
n_batch = value.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)

Expand All @@ -231,18 +247,29 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]

scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)
if self.use_sdpa:
scale_factor = 1 / math.sqrt(q_with_bias_u.size(-1))
matrix_bd = matrix_bd[:, :, :, : k.size(-2)] * scale_factor

if mask is not None:
mask = mask.unsqueeze(1)
matrix_bd.masked_fill_(mask.logical_not(), float("-inf"))

out = self.forward_attention(v, scores, mask)
out = torch.nn.functional.scaled_dot_product_attention(q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=self.dropout_rate, scale=scale_factor)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)
out = self.forward_attention(v, scores, mask)

if cache is None:
return out
Expand Down
Loading