-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Will attention_mask be extended to 3D? (concatenate short samples for efficient training) #432
Comments
Cool, thanks for the context. Attention mask will eventually be there (hopefully) but we have other priorities right now. |
I see, that's different from varlen yeah. |
In triton's code, it seems that we can use bias to represent attention mask if the attention mask we pass in is already a concatenated mask. |
That is, when the flash_attn_func function is passed, the attention mask is similar to bias (with the same shape), and there are two possible cases for the spliced attention mask:
Solution:
In the case of bwd_kernel, the lse_i of fwd is saved, so there is no need to change it. |
@sz128 Could we concatenate the batched examples to only one to work together with the |
@tridao @lumosity4tpj @gavin1332 Hi, guys. I found a method (#499, 4bdc7e1) to perform mask for concatenating short samples, which is very simple and highly compatible with the current |
Now, a solution is provided in #499. |
@sz128 It seems your PR missing a paired |
And as the short-samples-concatenated examples is almost compacted, it may be waste of memory to unpad them into a new |
No. Just use |
Actually, short-samples are concatenated before feeding them into LLM. |
For the usage of
|
This issue is closed now. Concatenating short samples is supported now. If you are interested in casual 3-D attention mask, #57 may help. |
Hello! Does this solution also work for grouped/multi query attention. I noticed in the code it was mentioned in flash_attn_varlen_qkvpacked_func that for multi-query and grouped-query attention (MQA/GQA), we should look at flash_attn_varlen_kvpacked_func and flash_attn_varlen_func instead. |
It works for grouped/multi query attention, since this solution just reshape hidden states before flash-attn function to prevent self-attention between different samples and then rearrange flash-attn outputs. You can find its detailed implementation for llama3 in huggingface transformers, https://github.com/sz128/LLMs_implementations/blob/main/sample_mask_with_flash-attn-2.ipynb . |
I also want to recommend flexattention which supports casual attention mask and achieves 85%~90% of FlashAttention2’s performance. |
Thank you for the suggestion and also your helpful response to adapting the masking with flash attention, flexattention is great work as well, appreciate you for sharing this info. |
It seems
unpad_input
only support 2-D attention_mask matrix, while it is also meaningful to support a 3-D attention_mask matrix (batch_size x seq_len x seq_len).There are several scenarios that we would like to use a 3-D attention_mask:
The text was updated successfully, but these errors were encountered: