Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PaliGemma Support #7553

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft

Conversation

abetlen
Copy link
Collaborator

@abetlen abetlen commented May 27, 2024

Very much still a work in progress however I've been able to convert the weights and update the clip model to support the differences in the PaliGemma implementation of SigLIP (second projector mlp layer is missing).

The next missing piece is that PaliGemma evaluates the prefix part of the prompt (which includes image embeddings) using a fully-connected attention mask as opposed to causal attention (See HF Implementation). I haven't played around too much with llama_set_causal_attn function but I believe this may be sufficient, otherwise it will be necessary to update the API to specify a custom attention mask.

I've created a custom script to generate the f16 ggufs here I've opted to do this in a custom script as the current convert-hf-to-gguf.py is not suited for converting vlms at the moment.

@ggerganov
Copy link
Owner

The next missing piece is that PaliGemma evaluates the prefix part of the prompt (which includes image embeddings) using a fully-connected attention mask as opposed to causal attention

Hm yes, this is not currently supported. Need to figure out what changes would be necessary. Explicitly setting the mask through the API would be possible, but I think it would be too difficult to use. There are other options such as:

  • for each token in the batch, along with it's pos, we provide one more integer of how many "future" token it attends to (defaults to 0)
  • be able to set attention mask rules, such as "positions [p0, p1] are non-causal" and we apply those rules on top of the mask that we normally construct

Not sure yet what would be best

@mofosyne mofosyne added the Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix label May 27, 2024
@abetlen
Copy link
Collaborator Author

abetlen commented May 27, 2024

Explicitly setting the mask through the API would be possible, but I think it would be too difficult to use.

I'm partial to this if it's the most straightforward to implement just because it offers the most general API but I understand the concern.

Otherwise, I think option 1 is the better approach as it's centered around the batch api and would require minimal work to modify existing code.

@abetlen
Copy link
Collaborator Author

abetlen commented May 30, 2024

@ggerganov I'll see what I can come up with along those lines.

We could probably limit complexity by only allowing future attention to work within a batch, otherwise we would need multiple graph compute calls to update previously computed kv positions. I might be mistaken about that though.

@ggerganov
Copy link
Owner

We could probably limit complexity by only allowing future attention to work within a batch

Yes you are correct, future attention can only be done within the batch.

Not sure if option 1 would require less changes though. We have to remember for each token what is the set of tokens that it attends to. So even if we can provide this information for the current batch through llama_batch, then for the next batch we would still need it in order to construct the correct mask, so it would likely require to update llama_kv_cache to store it for example.

On the other hand option 2 might need to just maintain a container with ranges in the llama_context and blindly apply it on top of whatever mask has been constructed normally. It would be also much easier to remove in the future if we think of a better way to support this

@ggerganov
Copy link
Owner

cc @iamlemec, @compilade - bringing your attention to this discussion in case you get some additional ideas how to support this functionality

@iamlemec
Copy link
Collaborator

iamlemec commented Jun 6, 2024

I think that the choice between 1 and 2 might come down to whether the causal/non-causal block positions are generally fixed. It seems like with PaliGemma the image block is of fixed length determined by the image processing model. With something like GritLM the non-causal segments can be of varying size. Also, a 3rd slightly more general implementation would be to add a per-token [p0,p1] to the batch that specifies which positions this token attends to. Keeping in mind that the referenced positions must either be in the current batch or in the kv_cache. Somehow thinking in terms of positions seems more natural to me than future offsets.

That said, there might be a way to do it with the current codebase. Can we just evaluate the image tokens in a first batch with causal_attn=False, then evaluate the prompt in a subsequent batch with causal_attn=True? The first batch will populate the kv_cache appropriately.

@ggerganov
Copy link
Owner

ggerganov commented Jun 7, 2024

Can we just evaluate the image tokens in a first batch with causal_attn=False, then evaluate the prompt in a subsequent batch with causal_attn=True? The first batch will populate the kv_cache appropriately.

The problem I think is that during the second batch, the attention mask will be causal for the entire context, which would lead to incorrect masking of the tokens from the first batch, no?

@iamlemec
Copy link
Collaborator

iamlemec commented Jun 7, 2024

The problem I think is that during the second batch, the attention mask will be causal for the entire context, which would lead to incorrect masking of the tokens from the first batch, no?

But those KV's from the first block are already computed, cached, and wont be recomputed. So any given token in the second batch will be attending to every single token in the first batch, and the KV's they pick up will have been computed non-causally in the first batch execution.

@ggerganov
Copy link
Owner

ggerganov commented Jun 8, 2024

The KV data is cached yes, but the softmax computation in the attention during the new batch still needs the correct mask for the entire context. With a causal mask, the tokens from the new batch would correctly attend to all previous ones, but the previous tokens would not attend correctly to themselves - they still attend to each other via the softmax, even though they are from the previous batch

@iamlemec
Copy link
Collaborator

iamlemec commented Jun 8, 2024

Thanks for the reply @ggerganov. I'm still not 100% convinced, but it seems possible that I'm confused about the premise here. Partially for my own edification, and partially to clear up any ambiguity, I decided to write up a minimal working example in Python with transformers.

Here's a link to the Gist: https://gist.github.com/iamlemec/3febf59b41b7f32a450fcfcb4be0713c. I used RoBERTa because it's still relatively small and allows one to specify full 2D attention matrices, rather than just 1D attention masks like in base BERT. Anyway, those asserts pass! So that's potential validation?

@ggerganov
Copy link
Owner

@iamlemec I stand corrected. Thanks for this example and helping figuring this out!

So this is actually very good news - we should be able to support PaliGemma with the existing API, correct?

@abetlen
Copy link
Collaborator Author

abetlen commented Jun 10, 2024

@iamlemec @ggerganov sounds good, so if I understand correctly the approach would be to update the path for causal_attn == false in decode_internal to also populate the kv cache and then ensure the mask is handled appropriately when causal attention is re-enabled?

This may negatively impact performance of pure embedding models as they don't store anything in the cache afaik but maybe there's a way to distinguish the two scenarios.

@ggerganov
Copy link
Owner

@abetlen Yes, I think this should work

so if I understand correctly the approach would be to update the path for causal_attn == false in decode_internal to also populate the kv cache and then ensure the mask is handled appropriately when causal attention is re-enabled?

Hopefully no change would be needed. Note that we have two parameters:

  • hparams.causal_attn
  • cparams.causal_attn

The latter can be changed via llama_set_causal_attn, while the former is determined during loading the model. In general, hparams.causal_attn is false only for embedding models, so for PaliGemma it should be true and thus the KV cache in llama_decode_internal would be correctly updated

@iamlemec
Copy link
Collaborator

@abetlen @ggerganov Do we need to actually modify llama_decode_internal in llama.cpp for this? I would think that this would be best handled at the examples level, either as a generalization of llava or separately. Like, if this was llava, you should just be able to change (approx lines 186-188 in llava-cli.cpp)

eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);

to

llama_set_causal_attn(ctx_llava->ctx_llama, true);
eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true);
llama_set_causal_attn(ctx_llava->ctx_llama, false);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
llama_set_causal_attn(ctx_llava->ctx_llama, true);
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);

I'm not quite sure how the system prompt works with PaliGemma, but basically just putting the appropriate llama_set_causal_attn before evals should be enough, right?

@ggerganov
Copy link
Owner

Yes, it should work without changes to llama_decode_internal, just as how you demonstrated

@abetlen
Copy link
Collaborator Author

abetlen commented Jun 12, 2024

That's the approach I was initially trying but it caused this assert to fail as the logits aren't reserved when cparams.causal_attn is false.

However I think I was just missing a one line change in llama_output_reserve I'll test it out.

@arseniybelkov
Copy link

Hello, I have cloned abetlen's work. I am trying to run the converting script on this model https://huggingface.co/google/paligemma-3b-pt-224/tree/main, but I keep getting the following error:

Traceback (most recent call last):
  File "/home/belkov.arseniy/paligemma/convert.py", line 305, in <module>
    special_vocab.add_to_gguf(fout)
  File "/home/belkov.arseniy/paligemma/llama.cpp/gguf-py/gguf/vocab.py", line 69, in add_to_gguf
    add_handler(value)
  File "/home/belkov.arseniy/paligemma/llama.cpp/gguf-py/gguf/gguf_writer.py", line 530, in add_add_bos_token
    self.add_bool(Keys.Tokenizer.ADD_BOS, value)
  File "/home/belkov.arseniy/paligemma/llama.cpp/gguf-py/gguf/gguf_writer.py", line 201, in add_bool
    self.add_key_value(key, val, GGUFValueType.BOOL)
  File "/home/belkov.arseniy/paligemma/llama.cpp/gguf-py/gguf/gguf_writer.py", line 166, in add_key_value
    raise ValueError(f'Duplicated key name {key!r}')
ValueError: Duplicated key name 'tokenizer.ggml.add_bos_token'

Can someone please help?

@abetlen abetlen marked this pull request as ready for review August 10, 2024 21:30
@abetlen abetlen marked this pull request as draft August 10, 2024 21:30
@The-TallGuy
Copy link

I don't get it, is the feature meant to have already been implemented or is it a work in progress?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants