Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert_hf : fix Gemma v1 conversion #8597

Merged
merged 2 commits into from
Jul 21, 2024

Conversation

compilade
Copy link
Collaborator

Should fix #7897 and fix #7923. Conversion for Gemma v1 instruct models was broken by #7827, because the conversion for Gemma originally (and erroneously) relied on duplicated GGUF keys and these are no longer allowed since that PR.

I've also allowed renaming tokens, but with a warning, which allows converting finetunes of Gemma like https://huggingface.co/Columbia-NLP/gemma-2b-zephyr-dpo which change the text of some control tokens.

cc @bartowski1182, @maab19, @TobiasKlapper, since you've noticed that problem before.

* convert_hf : allow renaming tokens, but with a warning
@compilade compilade added bugfix fixes an issue or bug Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix python python script changes labels Jul 19, 2024
@Galunid
Copy link
Collaborator

Galunid commented Jul 20, 2024

@compilade Have you tried running the converted model? It outputs garbage for me.

$ ./llama-cli -m models-local/gemma-7b/ggml-model-Q4_K_M.gguf -p "Write me a poem about Machine Learning" -ngl 99 --ctx-size 128

Resulted in:

Write me a poem about Machine Learning Poetry Machine Poetry Machine Poetry Machine Poetry Machine Poetry Machine Poetry Machine Poetry Poetry Machine Poetry Poetry Machine Poetry Poetry Poetry Machine Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry

full log

main: build = 3422 (57b1d4f)
main: built with cc (GCC) 14.1.1 20240522 for x86_64-pc-linux-gnu
main: seed = 1721438773
llama_model_loader: loaded meta data with 27 key-value pairs and 254 tensors from models-local/gemma-7b/ggml-model-Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = gemma
llama_model_loader: - kv 1: general.type str = model
llama_model_loader: - kv 2: general.name str = Gemma 7b
llama_model_loader: - kv 3: general.basename str = gemma
llama_model_loader: - kv 4: general.size_label str = 7B
llama_model_loader: - kv 5: general.license str = gemma
llama_model_loader: - kv 6: gemma.context_length u32 = 8192
llama_model_loader: - kv 7: gemma.embedding_length u32 = 3072
llama_model_loader: - kv 8: gemma.block_count u32 = 28
llama_model_loader: - kv 9: gemma.feed_forward_length u32 = 24576
llama_model_loader: - kv 10: gemma.attention.head_count u32 = 16
llama_model_loader: - kv 11: gemma.attention.head_count_kv u32 = 16
llama_model_loader: - kv 12: gemma.attention.layer_norm_rms_epsilon f32 = 0,000001
llama_model_loader: - kv 13: gemma.attention.key_length u32 = 256
llama_model_loader: - kv 14: gemma.attention.value_length u32 = 256
llama_model_loader: - kv 15: general.file_type u32 = 15
llama_model_loader: - kv 16: tokenizer.ggml.model str = llama
llama_model_loader: - kv 17: tokenizer.ggml.pre str = default
llama_model_loader: - kv 18: tokenizer.ggml.tokens arr[str,256000] = ["", "", "", "", ...
llama_model_loader: - kv 19: tokenizer.ggml.scores arr[f32,256000] = [-1000,000000, -1000,000000, -1000,00...
llama_model_loader: - kv 20: tokenizer.ggml.token_type arr[i32,256000] = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...
llama_model_loader: - kv 21: tokenizer.ggml.prefix_token_id u32 = 67
llama_model_loader: - kv 22: tokenizer.ggml.suffix_token_id u32 = 69
llama_model_loader: - kv 23: tokenizer.ggml.middle_token_id u32 = 68
llama_model_loader: - kv 24: tokenizer.ggml.eot_token_id u32 = 107
llama_model_loader: - kv 25: tokenizer.ggml.add_space_prefix bool = false
llama_model_loader: - kv 26: general.quantization_version u32 = 2
llama_model_loader: - type f32: 57 tensors
llama_model_loader: - type q4_K: 168 tensors
llama_model_loader: - type q6_K: 29 tensors
llm_load_vocab: special tokens cache size = 187
llm_load_vocab: token to piece cache size = 1,6014 MB
llm_load_print_meta: format = GGUF V3 (latest)
llm_load_print_meta: arch = gemma
llm_load_print_meta: vocab type = SPM
llm_load_print_meta: n_vocab = 256000
llm_load_print_meta: n_merges = 0
llm_load_print_meta: vocab_only = 0
llm_load_print_meta: n_ctx_train = 8192
llm_load_print_meta: n_embd = 3072
llm_load_print_meta: n_layer = 28
llm_load_print_meta: n_head = 16
llm_load_print_meta: n_head_kv = 16
llm_load_print_meta: n_rot = 256
llm_load_print_meta: n_swa = 0
llm_load_print_meta: n_embd_head_k = 256
llm_load_print_meta: n_embd_head_v = 256
llm_load_print_meta: n_gqa = 1
llm_load_print_meta: n_embd_k_gqa = 4096
llm_load_print_meta: n_embd_v_gqa = 4096
llm_load_print_meta: f_norm_eps = 0,0e+00
llm_load_print_meta: f_norm_rms_eps = 1,0e-06
llm_load_print_meta: f_clamp_kqv = 0,0e+00
llm_load_print_meta: f_max_alibi_bias = 0,0e+00
llm_load_print_meta: f_logit_scale = 0,0e+00
llm_load_print_meta: n_ff = 24576
llm_load_print_meta: n_expert = 0
llm_load_print_meta: n_expert_used = 0
llm_load_print_meta: causal attn = 1
llm_load_print_meta: pooling type = 0
llm_load_print_meta: rope type = 2
llm_load_print_meta: rope scaling = linear
llm_load_print_meta: freq_base_train = 10000,0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn = 8192
llm_load_print_meta: rope_finetuned = unknown
llm_load_print_meta: ssm_d_conv = 0
llm_load_print_meta: ssm_d_inner = 0
llm_load_print_meta: ssm_d_state = 0
llm_load_print_meta: ssm_dt_rank = 0
llm_load_print_meta: model type = 7B
llm_load_print_meta: model ftype = Q4_K - Medium
llm_load_print_meta: model params = 8,54 B
llm_load_print_meta: model size = 4,96 GiB (4,99 BPW)
llm_load_print_meta: general.name = Gemma 7b
llm_load_print_meta: BOS token = 1 ''
llm_load_print_meta: EOS token = 2 ''
llm_load_print_meta: UNK token = 0 ''
llm_load_print_meta: LF token = 227 '<0x0A>'
llm_load_print_meta: PRE token = 67 ''
llm_load_print_meta: SUF token = 69 ''
llm_load_print_meta: MID token = 68 ''
llm_load_print_meta: EOT token = 107 '<end_of_turn>'
llm_load_print_meta: max token length = 93
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 2060, compute capability 7.5, VMM: yes
llm_load_tensors: ggml ctx size = 0,24 MiB
llm_load_tensors: offloading 20 repeating layers to GPU
llm_load_tensors: offloaded 20/29 layers to GPU
llm_load_tensors: CPU buffer size = 5077,09 MiB
llm_load_tensors: CUDA0 buffer size = 3168,47 MiB
.............................................................................
llama_new_context_with_model: n_ctx = 1024
llama_new_context_with_model: n_batch = 1024
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 10000,0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CUDA_Host KV buffer size = 128,00 MiB
llama_kv_cache_init: CUDA0 KV buffer size = 320,00 MiB
llama_new_context_with_model: KV self size = 448,00 MiB, K (f16): 224,00 MiB, V (f16): 224,00 MiB
llama_new_context_with_model: CUDA_Host output buffer size = 0,98 MiB
llama_new_context_with_model: CUDA0 compute buffer size = 1127,24 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 18,01 MiB
llama_new_context_with_model: graph nodes = 931
llama_new_context_with_model: graph splits = 92

system_info: n_threads = 6 / 12 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
sampling:
repeat_last_n = 64, repeat_penalty = 1,000, frequency_penalty = 0,000, presence_penalty = 0,000
top_k = 40, tfs_z = 1,000, top_p = 0,950, min_p = 0,050, typical_p = 1,000, temp = 0,800
mirostat = 0, mirostat_lr = 0,100, mirostat_ent = 5,000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature
generate: n_ctx = 1024, n_batch = 2048, n_predict = -1, n_keep = 1

Write me a poem about Machine Learning Poetry Machine Poetry Machine Poetry Machine Poetry Machine Poetry Machine Poetry Machine Poetry Poetry Machine Poetry Poetry Machine Poetry Poetry Poetry Machine Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry Poetry

I tried bigger contexts and flash attention, but it made no difference. I had the same issue in #7992. Model with duplicated keys worked.

With master:

Write me a poem about Machine Learning, and I will reply with a haiku.

I will not reply with a haiku if:

  • you use racist or sexist language.
  • your poem is not in English.
  • your poem is not in the form of a haiku.
  • your poem does not include a specific machine learning term (e.g. "word2vec")
  • your poem contains a misspelling of a machine learning term (e.g. "wrod2vec" instead of "word2vec")
  • your poem contains more than 14 words.
  • your poem contains more than 4 lines.
  • your poem contains more than 7 syllables in total.
  • your poem does not include a haiku form (e.g. the third line does not have 5 syllables).
  • your poem does not rhyme.
  • your poem does not use haiku-specific poetic devices such as shiori or yose.

@mofosyne
Copy link
Collaborator

mofosyne commented Jul 20, 2024

Concept makes sense, but should check G's observation as well.

Would it make sense to error on certain warnings by default but not if --no-warning-error flag is set?

Not sure you can easily fix models that was previously converted incorrectly.

@compilade
Copy link
Collaborator Author

compilade commented Jul 20, 2024

Have you tried running the converted model? It outputs garbage for me.

@Galunid I did not try yet (didn't finish downloading), but now I've tried https://huggingface.co/google/gemma-1.1-2b-it and I can confirm it's outputting garbage. I'll try to figure out what's wrong. Thanks for warning me about this.

Would it make sense to error on certain warnings by default but not if --no-warning-error flag is set?

@mofosyne In this case, the thing which was prevented before was renaming tokens which were not UNUSED. Some finetunes rename control tokens, which I think probably should be allowed, but not silently (so I used logger.warning).

Not sure you can easily fix models that was previously converted incorrectly.

In this case these models don't exist because conversion failed previously. The warnings for token renaming are replacing assertions. Unless you're talking about the changes in #8228 which make the Gemma v1 tokenizer correctly pre-tokenize HTML tags.


I'm turning this PR into a draft until I figure out if the garbage output is caused by something related to the tokenizer or lazy conversion or something else.

@compilade compilade marked this pull request as draft July 20, 2024 02:04
@compilade
Copy link
Collaborator Author

Wait, this seems wrong:

llm_load_print_meta: BOS token        = 1 '<eos>'
llm_load_print_meta: EOS token        = 2 '<bos>'

@Galunid
Copy link
Collaborator

Galunid commented Jul 20, 2024

Also gemma-7b doesn't have tokenizer.chat_template, that's why I thought it worked in #7964 (comment). I tried gemma-1.1-7b-it, that one didn't output anything at all, just used 100% of GPU for a while (which could (?) theoretically be explained by missing bos token).

Wait, this seems wrong:

llm_load_print_meta: BOS token        = 1 '<eos>'
llm_load_print_meta: EOS token        = 2 '<bos>'

Difference between master and PR seems to be tokenizer.ggml.bos_token_id not being set at all (and the same for eos).

$ python gguf-py/examples/reader.py models-local/gemma-7b/ggml-model-Q4_K_M.gguf | grep -i bos

tokenizer.ggml.bos_token_id            : [2]
tokenizer.ggml.add_bos_token           : [ True]

That's the model from master

$ python gguf-py/examples/reader.py models-local/gemma-7b/Gemma-7B-F16.gguf | grep -i bos

outputs nothing (that's for the model from this PR).

Here's a diff between the two (diff broken-model-fields working-model-fields)

4c4
< GGUF.kv_count                          : [27]
---
> GGUF.kv_count                          : [33]
20c20
< general.file_type                      : [1]
---
> general.file_type                      : [15]
25a26,31
> tokenizer.ggml.bos_token_id            : [2]
> tokenizer.ggml.eos_token_id            : [1]
> tokenizer.ggml.unknown_token_id        : [3]
> tokenizer.ggml.padding_token_id        : [0]
> tokenizer.ggml.add_bos_token           : [ True]
> tokenizer.ggml.add_eos_token           : [False]

@compilade
Copy link
Collaborator Author

compilade commented Jul 20, 2024

that one didn't output anything at all, just used 100% of GPU for a while (which could (?) theoretically be explained by missing bos token).

Had this happen for me as well, and from the logs it was because the model was continuously outputting <eos> which wasn't displayed because it's a control token, and wasn't stopping because <bos> was set as the EOS token.

Difference between master and PR seems to be tokenizer.ggml.bos_token_id not being set at all (and the same for eos).

Yes, this is because SpecialVocab doesn't extend the special_token_types list when called, it's either the default types or the provided types. I did not notice this at first.

Running it twice could work, but special_vocab.chat_template would have to be set to None the second time before special_vocab.add_to_gguf(self.gguf_writer).

Alternatively, making the special_token_types be extended instead of replaced in SpecialVocab could be another way.

@compilade compilade marked this pull request as ready for review July 20, 2024 02:47
@compilade
Copy link
Collaborator Author

@Galunid I think the problem has been fixed in 50d1a03. It still runs SpecialVocab twice (same as on master), but the chat template is ignored the second time, so that there are no duplicate keys.

$ ./build/bin/llama-cli --log-disable -m models/gemma-1.1-2B-it-Q8_0.gguf -p "To make pizza," -n 60 --temp 0
To make pizza, you need to:

A. Preheat the oven to 500 degrees Fahrenheit.
B. Spread pizza dough on a baking sheet.
C. Bake the pizza for 10 minutes.
D. Add cheese and toppings to the pizza and bake for an additional 5 minutes.

Copy link
Collaborator

@Galunid Galunid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That did the trick

@compilade compilade merged commit c69c630 into master Jul 21, 2024
12 checks passed
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 27, 2024
* convert_hf : fix Gemma v1 conversion

* convert_hf : allow renaming tokens, but with a warning

* convert_hf : fix Gemma v1 not setting BOS and EOS tokens
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bugfix fixes an issue or bug python python script changes Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix
Projects
None yet
3 participants