Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize cache update. #151

Merged
merged 57 commits into from
Aug 6, 2024
Merged

Optimize cache update. #151

merged 57 commits into from
Aug 6, 2024

Conversation

wang2yn84
Copy link
Collaborator

@wang2yn84 wang2yn84 commented Jul 19, 2024

We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The left aligned cache also improves the insert efficiency. The overall benchmark performance is boosted by 15%.

) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
"""Ragged multi query attention."""
with jax.named_scope("ragged_mqa"):
batch_size, num_heads, head_dim = q.shape
seq_len = k.shape[1]
batch_size, time, head_dim = q.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to change num_heads to time?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After vmap, the number of heads dimension are gone. So it's indeed the sequence length dimension, which we can also call it "time".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel "Time" is kind of misleading variable name here. Can we use q_seq_len instead of time?

If we are only using ragged attention in decode sate, do we need this query seq len as it always be 1?

seq_len = k.shape[-2]

stacked = False
if k.ndim == 5:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share an example that the n.ndim is 5 (with block quantization)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why the block the quantization matters. If the cache is stacked, it will have layer, batch, number of heads, time, head dim these 5 dimensions no matter if it's quantized or not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarification!

normalize_var: bool,
quantized: bool,
):
"""Pallas kernel for flash attention."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace "flash" with "ragged"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated!

def run():
q = q_ref[...].astype(jnp.float32)
k = k_ref[...].astype(jnp.float32)
v = v_ref[...].astype(jnp.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to convert to fp32? can we use bf16?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the arithmetic operation only supports f32 and it reports error if force to be bf16. Confirmed with XLA team about the constraint: b/340263269 and b/341729764.

return layer_ref[0], b_next, i_next, 0
return b_next, i_next, 0

def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share why i_next are assigned to different position between kv_index_map and kv_scale_index_map?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i_next doesn't get different value. It's in different position because the scale has the shape of batch, 1, kv_length. And the grid[1] applied to the last dimension here. That's why we give the i_next in this dimension.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine with precompute_ragged_block_indices, for a giving decode: start = jnp.asarray([11, 0, 10])
input_pos = jnp.asarray([15, 9, 8]), suppose cache_len = 16
block_size = 4, can you share what are expected kv index map?

@@ -310,15 +586,78 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
return output


def flash_attention(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flash attention use block q, k, v to do tiling compute. Is this function just an vanilla attention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flash attention has the capability of blockwise compute the local softmax, which is exactly what we are doing here. In terms of how to divide the block, it's up to the user. We leveraged this to divide the attention calculation to existing cache and new cache. So this is indeed the flash attention.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any up function to call this local attention? If this function is only for the each for loop q_block, v_block and k_block, should we rename it as block_attention?

In generate Flash attention need to dynamic select the max and scale the softmax. Below code are like a flash attention from the ragged_mqa... funciton:

  m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
   m_next = jnp.maximum(m_prev, m_curr)
   alpha = jnp.exp(m_prev - m_next)
   beta = jnp.exp(m_curr - m_next)
   l_next = alpha * l_prev + beta * l_curr
   l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)

Correct me If i'm wrong.

self.input_pos,
)

def update(self, key, value, layer_id: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great implementation! But in general, I feel the logic is too complex to maintain. Can we have different KVCacheGenerate class to handle ring_buffer, ragged attention and stacked or not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about merging the Int8KVCacheGenerate and KVCacheGenerate cuz there are a lot of shared code. I can combine all 4 additional flags (lazy_cache_update, generate_cache_stacked, new_cache_stacked, flash_attention) into 1 to simplify the logic, cuz these flags only helps for my experimentation. It should not be exposed to user. Wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. My main concerns is current code logic is too complex for read and maintain. The cache manager is very straightforward implementation before, but right now the logic is very complex. Let's only keep the most optimized code in the repo.

required=False,
)
flags.DEFINE_bool(
"generate_cache_stacked",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are benefits of cache_stacked?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It reduces the DMA transfer time. Minimize the number of DMA transfer helps.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the XLA handles cache insertion for all the layers much more efficiently than iterating over layer dimension by user.

"Whether to enable ring buffer",
required=False,
)
flags.DEFINE_bool(
"flash_attention",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to enable flash_attention by itself without ragged attention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, ragged attention has better performance than flash attention. As I indicated in the description, it only takes effect at test mode. Which means user cannot directly enable it in either interactive, offline or server mode.

input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)),
output_specs=(qkv_pspec, (others_pspec, others_pspec)),
sharding_axis=self.shard_axis,
input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct me if I'm wrong, the ragged_attention_new doesn't support generate_cache_stacked

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ragged_attention_new is for new cache in the current step, which has length of 1, so there is nothing to stack.

@FanhaiLu1
Copy link
Collaborator

We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The overall benchmark performance is boosted by 15%.

we can delay the cache insertion to the end of each step

15% improvement is a great achievement! I assume the test side use stacked aligned + ragged attention, do you have any performance number with left aligned (without stacked) + ragged attention?

@wang2yn84
Copy link
Collaborator Author

We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The overall benchmark performance is boosted by 15%.

we can delay the cache insertion to the end of each step

15% improvement is a great achievement! I assume the test side use stacked aligned + ragged attention, do you have any performance number with left aligned (without stacked) + ragged attention?

We used to insert cache inside attention, then use updated cache for calculation. With the help of flash attention/ragged attention, we can delay the cache insertion to the end of each step. By switching to left aligned stacked cache, we can minimize the data transfer to HBM and therefore improve performance. The decode step time reduced from 52ms to 42ms. The overall benchmark performance is boosted by 15%.

we can delay the cache insertion to the end of each step

15% improvement is a great achievement! I assume the test side use stacked aligned + ragged attention, do you have any performance number with left aligned (without stacked) + ragged attention?

When cache is left aligned + unstacked, the data transfer overhead is non neglegible. I tried flash attention, which is 90ms for each step. These overhead has nothing to do with which attention you are using.

@wang2yn84 wang2yn84 closed this Jul 19, 2024
@wang2yn84 wang2yn84 reopened this Jul 19, 2024
self.new_v_scaler,
]
(
self.cache_k._elem,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

._elem seems spurious; as _elem is already a jax array.

So it's either: x._elem = foo(jax_array_inputs) OR x = call_jax(foo, torch_tensor_inputs)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, decided to remove _elem since it's violating lint anyway.

@@ -367,9 +367,25 @@ def apply_rotary_emb(
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)."""

bs, n_kv_heads, slen, head_dim = x.shape
bs, n_kv_heads, slen, head_dim = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bs, n_kv_heads, slen, head_dim, *_ = x.shape

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, should be *_, bs, n_kv_heads, slen, head_dim.= x.shape ?

x.shape[-2],
x.shape[-1],
)
if x.ndim == 5:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stacked = x.ndim == 5

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better! Thanks!

if n_rep == 1:
return x
if stacked:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or just put the ndim == 5 here and remove the stacked var

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably prefer to keep stacked to make the code more clear.

@wang2yn84
Copy link
Collaborator Author

Fixed based on your comments, all the unit tests and lint errors. Please let me know if you have any other comment/suggestions. @qihqi @FanhaiLu1

@qihqi
Copy link
Collaborator

qihqi commented Jul 20, 2024

there is some updates on deps/Jetstream is that intentional?

) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
"""Ragged multi query attention."""
with jax.named_scope("ragged_mqa"):
batch_size, num_heads, head_dim = q.shape
seq_len = k.shape[1]
batch_size, time, head_dim = q.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel "Time" is kind of misleading variable name here. Can we use q_seq_len instead of time?

If we are only using ragged attention in decode sate, do we need this query seq len as it always be 1?

seq_len = k.shape[-2]

stacked = False
if k.ndim == 5:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarification!

seq_len = k.shape[-2]

stacked = False
if k.ndim == 4:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vmap reduce the head dem, so stacked ndim become 4 from 5. Correct me If'm I'm wrong.

I'm wondering, do we need a vmp in ragged attention? The shmap did first reduction which reduce head dim from 32 to 4 (take llama2 7b and v5e-8 as exmple), can we process 4 head in a single process? Is there performance regression if we use multiple head in ragged attention compared with single head attention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a must. By reducing the number of heads dimension to 1 the MHA becomes MQA. That's just for compatibility.

return layer_ref[0], b_next, i_next, 0
return b_next, i_next, 0

def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine with precompute_ragged_block_indices, for a giving decode: start = jnp.asarray([11, 0, 10])
input_pos = jnp.asarray([15, 9, 8]), suppose cache_len = 16
block_size = 4, can you share what are expected kv index map?

jnp.array([layer]),
start,
end,
end, # line_end, not actually used
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying!

@@ -310,15 +586,78 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
return output


def flash_attention(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any up function to call this local attention? If this function is only for the each for loop q_block, v_block and k_block, should we rename it as block_attention?

In generate Flash attention need to dynamic select the max and scale the softmax. Below code are like a flash attention from the ragged_mqa... funciton:

  m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
   m_next = jnp.maximum(m_prev, m_curr)
   alpha = jnp.exp(m_prev - m_next)
   beta = jnp.exp(m_curr - m_next)
   l_next = alpha * l_prev + beta * l_curr
   l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)

Correct me If i'm wrong.

self.input_pos,
)

def update(self, key, value, layer_id: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. My main concerns is current code logic is too complex for read and maintain. The cache manager is very straightforward implementation before, but right now the logic is very complex. Let's only keep the most optimized code in the repo.

@FanhaiLu1
Copy link
Collaborator

Fixed based on your comments, all the unit tests and lint errors. Please let me know if you have any other comment/suggestions. @qihqi @FanhaiLu1

There are new lint error, can you fix it?

@wang2yn84
Copy link
Collaborator Author

Fixed based on your comments, all the unit tests and lint errors. Please let me know if you have any other comment/suggestions. @qihqi @FanhaiLu1

There are new lint error, can you fix it?

Fixed all the lint issues.

@wang2yn84
Copy link
Collaborator Author

I will remove precompute_ragged_block_indices, clear up the ragged attention impl (e.g. remove the one for the ring buffer) and simplify the flags for non ring buffer case therefore simplify the cache manager in the subsequent PR. Will push this PR first since it's been standing alone for a while.

@wang2yn84 wang2yn84 merged commit ee040a4 into main Aug 6, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants