Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Will attention_mask be extended to 3D? (concatenate short samples for efficient training) #432

Closed
sz128 opened this issue Aug 7, 2023 · 19 comments

Comments

@sz128
Copy link
Contributor

sz128 commented Aug 7, 2023

It seems unpad_input only support 2-D attention_mask matrix, while it is also meaningful to support a 3-D attention_mask matrix (batch_size x seq_len x seq_len).

There are several scenarios that we would like to use a 3-D attention_mask:

  • graph-attention;
  • Concatenating short samples to reduce padding rate which can make training more efficient. [supported now]
@tridao
Copy link
Contributor

tridao commented Aug 7, 2023

Cool, thanks for the context. Attention mask will eventually be there (hopefully) but we have other priorities right now.
Btw we already support concatenating short samples to avoid padding (e.g. the "varlen" functions, see our BERT implementation).

@sz128
Copy link
Contributor Author

sz128 commented Aug 8, 2023

Glad to hear that. For concatenating short samples, I would like to make real sequences in a batch to have closer lengths, since the training efficiency is dominated by the max length of the batch. I draw a figure to show the procedure.
image
In my understanding, it is different with the "varlen" functions. Am I right?

@tridao
Copy link
Contributor

tridao commented Aug 8, 2023

I see, that's different from varlen yeah.

@lumosity4tpj
Copy link

In triton's code, it seems that we can use bias to represent attention mask if the attention mask we pass in is already a concatenated mask.

@lumosity4tpj
Copy link

In triton's code, it seems that we can use bias to represent attention mask if the attention mask we pass in is already a concatenated mask.

That is, when the flash_attn_func function is passed, the attention mask is similar to bias (with the same shape), and there are two possible cases for the spliced attention mask:

  1. It is possible that the attention mask of this block is all the values that should be mask, so qk is all -inf when qk = qk * softmax_scale + bias, resulting in nan in the following calculation;
  2. Due to the padding sequence, there may be cases where lse_i and m_i are -inf, resulting in the case of nan in the following calculation;

Solution:
The solution I tried was to convert -inf to the min value of float32 (like torch.finfo(torch.float32).min=-3.40282e+38) when in use, so that there is no such thing as -inf - (-inf)=nan, but -inf - (min value) = -inf:

  1. For m_ij: p = tl.exp(qk-m_ij [:, None]), acc_o_scale = tl.exp(m_i-m_ij), l_i_new = tl.exp(lse_i-m_ij) + l_ij;
  2. For lse_i: o_scale = tl.exp(m_i-lse_i)

In the case of bwd_kernel, the lse_i of fwd is saved, so there is no need to change it.
This is difficult to verify, I can only see that the loss trend is consistent during training, but there is a small amount of deviation. Is my solution above correct?
@tridao @sz128

@gavin1332
Copy link

gavin1332 commented Aug 29, 2023

@sz128 Could we concatenate the batched examples to only one to work together with the BlockDiagonalCausalMask? With extra computation on padding tokens at the cost.

@sz128
Copy link
Contributor Author

sz128 commented Aug 29, 2023

@tridao @lumosity4tpj @gavin1332 Hi, guys. I found a method (#499, 4bdc7e1) to perform mask for concatenating short samples, which is very simple and highly compatible with the current flash-attn implementation.

@sz128 sz128 changed the title Will attention_mask be extended to 3D? Will attention_mask be extended to 3D? (concatenate short samples for efficient training) Aug 29, 2023
@sz128
Copy link
Contributor Author

sz128 commented Aug 29, 2023

Glad to hear that. For concatenating short samples, I would like to make real sequences in a batch to have closer lengths, since the training efficiency is dominated by the max length of the batch. I draw a figure to show the procedure. image In my understanding, it is different with the "varlen" functions. Am I right?

Now, a solution is provided in #499.

@gavin1332
Copy link

@sz128 It seems your PR missing a paired pad_input_for_concatenated_sequences function to recover the data?

@gavin1332
Copy link

And as the short-samples-concatenated examples is almost compacted, it may be waste of memory to unpad them into a new hidden_states Tensor, expecially in LLM. As a substitute, we let the padding tokens take part in computation, so that we could only reshape the tensor inplace to reuse its memory.

@sz128
Copy link
Contributor Author

sz128 commented Aug 30, 2023

@sz128 It seems your PR missing a paired pad_input_for_concatenated_sequences function to recover the data?

No. Just use pad_input.

@sz128
Copy link
Contributor Author

sz128 commented Aug 30, 2023

And as the short-samples-concatenated examples is almost compacted, it may be waste of memory to unpad them into a new hidden_states Tensor, expecially in LLM. As a substitute, we let the padding tokens take part in computation, so that we could only reshape the tensor inplace to reuse its memory.

Actually, short-samples are concatenated before feeding them into LLM. pad_input_for_concatenated_sequences just play a role that tell the flash-attn kernel which are boundaries of different short-samples.

@sz128
Copy link
Contributor Author

sz128 commented Aug 30, 2023

image

For the usage of pad_input_for_concatenated_sequences , these pseudo codes may help

        qkv = torch.stack(
            [query_states, key_states, value_states], dim=2
        )  # [bsz, nh, 3, q_len, hd]
        qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]
        nheads = qkv.shape[-2]
        x = rearrange(qkv, "b s three h d -> b s (three h d)")
        x_unpad, indices, cu_q_lens, max_s = unpad_input_for_concatenated_sequences(x, attention_mask_in_length)
        x_unpad = rearrange(
            x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
        )
        output_unpad = flash_attn_varlen_qkvpacked_func(
            x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
        )
        output = rearrange(
              pad_input(
                  rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
              ),
              "b s (h d) -> b s h d",
              h=nheads,
          )
        attn_output = rearrange(output, "b s h d -> b s (h d)")

@gavin1332
Copy link

image For the usage of [`pad_input_for_concatenated_sequences`](https://github.com/Dao-AILab/flash-attention/blob/0cb595ad943ac7539c49825f520659c0f61d4f40/flash_attn/bert_padding.py#L125C5-L125C43) , these pseudo codes may help
        qkv = torch.stack(
            [query_states, key_states, value_states], dim=2
        )  # [bsz, nh, 3, q_len, hd]
        qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]
        nheads = qkv.shape[-2]
        x = rearrange(qkv, "b s three h d -> b s (three h d)")
        x_unpad, indices, cu_q_lens, max_s = unpad_input_for_concatenated_sequences(x, attention_mask_in_length)
        x_unpad = rearrange(
            x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
        )
        output_unpad = flash_attn_varlen_qkvpacked_func(
            x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
        )
        output = rearrange(
              pad_input(
                  rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
              ),
              "b s (h d) -> b s h d",
              h=nheads,
          )
        attn_output = rearrange(output, "b s h d -> b s (h d)")

Thanks for your detailed explanation!

@sz128
Copy link
Contributor Author

sz128 commented Aug 31, 2023

This issue is closed now. Concatenating short samples is supported now. If you are interested in casual 3-D attention mask, #57 may help.

@clarence-lee-sheng
Copy link

Hello! Does this solution also work for grouped/multi query attention. I noticed in the code it was mentioned in flash_attn_varlen_qkvpacked_func that for multi-query and grouped-query attention (MQA/GQA), we should look at flash_attn_varlen_kvpacked_func and flash_attn_varlen_func instead.

https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L912C1-L913C64

@sz128
Copy link
Contributor Author

sz128 commented Aug 7, 2024

It works for grouped/multi query attention, since this solution just reshape hidden states before flash-attn function to prevent self-attention between different samples and then rearrange flash-attn outputs.

You can find its detailed implementation for llama3 in huggingface transformers, https://github.com/sz128/LLMs_implementations/blob/main/sample_mask_with_flash-attn-2.ipynb .

@sz128
Copy link
Contributor Author

sz128 commented Aug 19, 2024

I also want to recommend flexattention which supports casual attention mask and achieves 85%~90% of FlashAttention2’s performance.

@clarence-lee-sheng
Copy link

Thank you for the suggestion and also your helpful response to adapting the masking with flash attention, flexattention is great work as well, appreciate you for sharing this info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants