Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.5 bit quantization #5453

Merged
merged 15 commits into from
Feb 18, 2024
Merged

1.5 bit quantization #5453

merged 15 commits into from
Feb 18, 2024

Conversation

ikawrakow
Copy link
Contributor

@ikawrakow ikawrakow commented Feb 11, 2024

This draft PR is a WIP that demonstrates 1.5 bits-per-weight (bpw) quantization. Only CUDA works, there is no implementation for the other supported back-ends. CUDA, AVX2 and ARM_NEON are implemented, Metal is missing. But given the keen interest in 1-bit quantization, and the recent hype out there (e.g, BiLLM paper, PB-LLM paper), I decided to show what I have so far to see if I should proceed.

Given the ongoing interest in low-bit quantization and recent papers (PB-LLM, BiLLM), this PR adds a 1.5-bit quantization as IQ1_S.

Don't expect literary or otherwise masterpieces. But it is not complete gibberish either.

The table shows a PPL comparison between this work and the two papers linked above. The PPL values for BiLLM and PB-LLM were taker from the BiLLM paper, so LLaMA-v1 and LLaMA-v2 only (Note: I have edited the PR description to put the final version here).

Model PPL PB-LLM PPL BiLLM PPL this PR
LLaMA-v1-7B 102.36 35.04 17.93
LLaMA-v1-13B 36.60 15.14 9.619
LLaMA-v1-30B 33.67 10.52 7.564
LLaMA-v2-7B 69.20 32.48 15.33
LLaMA-v2-13B 151.09 16.77 9.167
LLaMA-v2-70B 28.37 8.41 5.675
Mistral-7B - - 10.80
Mixtral8x7B - - 7.261

This is the PPL comparison as initially posted:

Model PPL PB-LLM PPL BiLLM PPL this PR
LLaMA-v1-7B 102.36 35.04 27.53
LLaMA-v1-13B 36.60 15.14 12.44
LLaMA-v1-30B 33.67 10.52 8.605
LLaMA-v1-65B 12.53 8.49 6.456
LLaMA-v2-7B 69.20 32.48 25.11
LLaMA-v2-13B 151.09 16.77 11.40
LLaMA-v2-70B 28.37 8.41 6.065

Here the responses to the sample prompts from the BiLLM paper using LLaMA-v1-13B:

./bin/main -m q1s.gguf -s 1234 --ignore-eos -n 64  -t 1 -ngl 100 -p "Alaska is a state of"

 Alaska is a state of the USA that lies in the northwest area of the United States. It was discovered by the British explorer William de Hale who had been expelled from the Queen’s College and had no hope to be seen in 1768. The main reason why he was expelled was because he gave an

---

./bin/main -m q1s.gguf -s 1234 --ignore-eos -n 64  -t 1 -ngl 100 -p "Yangtze and Yellow River locate in"

Yangtze and Yellow River locate in the western area of Shanxi Province, which is known as "the first one" among the seven. In the past, the west bank of Yarrow River is called "Yi", while its east river side was named "Zhou".

----
./bin/main -m q1s.gguf -s 1234 --ignore-eos -n 64  -t 1 -ngl 100 -p "After climbing thirty flights of stairs, he was drenched in sweat and didn't feel much like climbing. I advised him"

After climbing thirty flights of stairs, he was drenched in sweat and didn't feel much like climbing. I advised him to walk upstairs, but his legs were not yet ready for that.
 This is why she wants to go back to her house, where it is not yet 1870's and all over a little bit of 20th century. She wants to be in her home. She

PB-LLM is 1.7 bpw, so this PR is massively better. BiLLM claims ~1.1 bpw (but we don't know the final balance after block scales and bits for non-repeating layers have been added), so it is not surprising to see a better result in this PR with 1.5 bpw.

CUDA performance is impressive: 212 t/s for a 7B model, 130 t/s for 13B, and 33.5 t/s for 70B running on an RTX-4080. Oh, LLaMA-v2-70B finally fits on my 16 GB GPU!

The BiLLM approach separates salient and non-salient weights. They use 2 bpw for salient and 1 bpw for non-salient weights (and so, if one declares about 10% of the model weights to be salient, one needs 1.1 bpw). The thing about separating salient and non-salient weights is that this separation already costs 1 bpw, unless one has a better idea. This is the key insight of the BiLLM paper. They basically make a per tensor column separation. This could easily be done here too (one takes the imatrix, which is already per column, multiples with the sum of the model weights in the column squared, and uses this as a measure to pick the top-k percent of columns). Unfortunately ggml lacks the infrastructure to do that sort of thing. Hence, this PR uses the same quantization for all weights. Unlike the quoted papers, which have binary quants (-1, 1), I use 3 allowed values (-1, 0, 1), and squeeze to 1.125 bpw by selecting 512 8D points out of the 3^8 = 6561 possibilities. This is similar to the IQ2_XS quants, but here it is no longer an E8 lattice as I do not impose the condition of the sum of the co-ordinates to be even. With additional 3 bits for an unsigned scale per group of 8, we end up with 1.5 bpw. If we wanted to futher squeeze the model, the salient/no-salient separation will be essential. For this, I would need support from @ggerganov and @slaren to have

  • Per tensor meta-data. Here the column indices of the salient and non-salient columns will be recorded. One can add row-wise scales and such.
  • A new ggml op that takes a tensor holding activations and reorders the columns as per the alient/non-salient separation
  • Fixing all places in ggml where the assumption is being made that tensor rows are made up of a given number of block structs with a fixed size.

@ikawrakow ikawrakow added the demo Demonstrate some concept or idea, not intended to be merged label Feb 11, 2024
@Nexesenex
Copy link
Contributor

I converted Miqu 70b and Kyllene 34b, and tested them on my KoboldCPP and ST
I'm impressed by the coherence despite the hallucinations, even stronger on 34b than on 70b as expected.
Grosso modo, I think that we need 0.2bpw more to have something remotely usable, plus the optimizations you might find along the way.

@ikawrakow
Copy link
Contributor Author

Pushed a small improvement. The perplexities now are like this:

Model PPL PB-LLM PPL BiLLM PPL this PR
LLaMA-v1-7B 102.36 35.04 27.81
LLaMA-v1-13B 36.60 15.14 12.16
LLaMA-v1-30B 33.67 10.52 8.38
LLaMA-v1-65B 12.53 8.49 6.412
LLaMA-v2-7B 69.20 32.48 23.88
LLaMA-v2-13B 151.09 16.77 11.09
LLaMA-v2-70B 28.37 8.41 5.975
Mistral-7B - - 13.49
Mixtral8x7B - - 8.136

If you want to play with this quantization, it is worth experimenting with the f_norm_rms_eps parameter. This is normally defined by the model. But here we are using such an extreme quantization that the epsilon may no longer make sense for the RMS norm operation. You can override the model defined value using

--override-kv llama.attention.layer_norm_rms_epsilon=float:value

For the LLaMA models the best results (as in lowest perplexities) are obtained using

Model f_norm_rms_eps
LLaMA-v1-7B 5e-5
LLaMA-v1-13B 4e-5
LLaMA-v1-30B 1.5e-5
LLaMA-v1-65B 1.5e-5
LLaMA-v2-7B 3e-5
LLaMA-v2-13B 2.5e-5
LLaMA-v2-70B 2.5e-5

One does not gain by modifying f_rms_norm_eps for Mistral/Mixtral. But for Mixtral, PPL goes down to 7.589 by using more than 2 "experts":

./perplexity -m iq1s.gguf -f tests/wiki.test.raw -t 1 -ngl 100 --override-kv llama.expert_used_count=int:3

@Nexesenex
Copy link
Contributor

Nexesenex commented Feb 12, 2024

Superb job. I will test on 70b models your improvements as they come, starting with this update.

I remember your tests about epsilon values a while ago, the 5e06 being better than the 1e05 in some case for Llama 2 models if I recall properly.
Because I'm a bit storage space limited, can I change the epsilon value from a Q8_0 (I use them as base to requant in lower quants), or must I restart always from FP16 weights and change it in the config.json?

@ikawrakow
Copy link
Contributor Author

Because I'm a bit storage space limited, can I change the epsilon value from a Q8_0 (I use them as base), or must I restart always from FP16 weights and change it in the config.json?

You simply change it on the command line when you use the model, not when you quantize. You quantize once without worrying about f_norm_rms_eps. Then, when you run main or perplexity or any other example, you just add

--override-kv llama.attention.layer_norm_rms_epsilon=float:value

to the command arguments, and it will use the value you specified.

@Nexesenex
Copy link
Contributor

Nexesenex commented Feb 12, 2024

On Miqu 70B IQ1_S (quant made this night with commit 2ee4281 ), I get the following error when I test the perplexity with your b2128 merge with your IQ1_S improvement ( 9803f7a )

U:\Lla\LLAMA_CUDA_121>perplexity -m Z:\text-generation-webui\models\miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf -f hellaswag_val_full.txt --hellaswag --hellaswag-tasks 1000 --parallel 2 -ngl 100 -b 128 -mg 0 -ts 1,0
main: build = 2138 (9803f7a)
main: built with Clang 12.0.0 for
main: seed  = 1707750205
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 1: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
llama_model_loader: loaded meta data with 23 key-value pairs and 723 tensors from Z:\text-generation-webui\models\miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = D:\HF
llama_model_loader: - kv   2:                       llama.context_length u32              = 32764
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 8192
llama_model_loader: - kv   4:                          llama.block_count u32              = 80
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 28672
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 64
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 1000000.000000
llama_model_loader: - kv  11:                          general.file_type u32              = 24
llama_model_loader: - kv  12:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  13:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  14:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  15:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  16:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  17:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  18:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  19:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  20:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  21:                    tokenizer.chat_template str              = {{ bos_token }}{% for message in mess...
llama_model_loader: - kv  22:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  161 tensors
llama_model_loader: - type q2_K:   11 tensors
llama_model_loader: - type q4_K:   80 tensors
llama_model_loader: - type q5_K:    1 tensors
llama_model_loader: - type iq1_s:  470 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 32764
llm_load_print_meta: n_embd           = 8192
llm_load_print_meta: n_head           = 64
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_layer          = 80
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 8
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 28672
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 1000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 32764
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 70B
llm_load_print_meta: model ftype      = IQ1_S - 1.5625 bpw
llm_load_print_meta: model params     = 68.98 B
llm_load_print_meta: model size       = 13.22 GiB (1.65 BPW)
llm_load_print_meta: general.name     = D:\HF
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.55 MiB
llm_load_tensors: offloading 80 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 81/81 layers to GPU
llm_load_tensors:        CPU buffer size =    82.03 MiB
llm_load_tensors:      CUDA0 buffer size = 13459.41 MiB
....................................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =   160.00 MiB
llama_new_context_with_model: KV self size  =  160.00 MiB, K (f16):   80.00 MiB, V (f16):   80.00 MiB
llama_new_context_with_model:  CUDA_Host input buffer size   =     4.25 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =    39.88 MiB
llama_new_context_with_model:      CUDA1 compute buffer size =     0.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     4.40 MiB
llama_new_context_with_model: graph splits (measure): 3

system_info: n_threads = 4 / 8 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 |
hellaswag_score : loaded 10042 tasks from prompt.
================================= is_spm = 1
hellaswag_score : selecting 1000 randomized tasks.
hellaswag_score : calculating hellaswag score over selected tasks.

task    acc_norm
ggml_tallocr_alloc: not enough space in the buffer to allocate ffn_up-0 (needed 14680064, largest block available 14483507)
GGML_ASSERT: ..\..\..\ggml-alloc.c:114: !"not enough space in the buffer"

With 34b Yi models, no problem.

@ikawrakow
Copy link
Contributor Author

Nothing has changed with the last commit that would make a difference in the ability to run the model. It is the exact same size. I think the issue is that you are using -b 128. Try without -b 128.

@slaren
Copy link
Collaborator

slaren commented Feb 12, 2024

The PR #5452 merged a few hours ago should have fixed the out of space errors with some batch sizes. You shouldn't get this error anymore if you merge master into this PR.

@Nexesenex
Copy link
Contributor

@ikawrakow : I tested that indeed with -b 512, and I was about to report that it works.

@slaren : I'll test this fix, thanks!

@Nexesenex
Copy link
Contributor

Nexesenex commented Feb 12, 2024

With PR 5452 merge on the IQ1_S branch, I get that 👍

U:\Lla\LLAMA_CUDA_121>perplexity -m Z:\text-generation-webui\models\miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf -f hellaswag_val_full.txt --hellaswag --hellaswag-tasks 1000 --parallel 2 -ngl 100 -b 128 -mg 0 -ts 1,0
main: build = 2139 (f1b6081)
main: built with Clang 12.0.0 for
main: seed  = 1707754581
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 1: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
llama_model_loader: loaded meta data with 23 key-value pairs and 723 tensors from Z:\text-generation-webui\models\miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = D:\HF
llama_model_loader: - kv   2:                       llama.context_length u32              = 32764
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 8192
llama_model_loader: - kv   4:                          llama.block_count u32              = 80
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 28672
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 64
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 1000000.000000
llama_model_loader: - kv  11:                          general.file_type u32              = 24
llama_model_loader: - kv  12:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  13:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  14:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  15:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  16:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  17:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  18:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  19:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  20:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  21:                    tokenizer.chat_template str              = {{ bos_token }}{% for message in mess...
llama_model_loader: - kv  22:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  161 tensors
llama_model_loader: - type q2_K:   11 tensors
llama_model_loader: - type q4_K:   80 tensors
llama_model_loader: - type q5_K:    1 tensors
llama_model_loader: - type iq1_s:  470 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 32764
llm_load_print_meta: n_embd           = 8192
llm_load_print_meta: n_head           = 64
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_layer          = 80
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 8
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 28672
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 1000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 32764
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 70B
llm_load_print_meta: model ftype      = IQ1_S - 1.5625 bpw
llm_load_print_meta: model params     = 68.98 B
llm_load_print_meta: model size       = 13.22 GiB (1.65 BPW)
llm_load_print_meta: general.name     = D:\HF
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.55 MiB
llm_load_tensors: offloading 80 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 81/81 layers to GPU
llm_load_tensors:        CPU buffer size =    82.03 MiB
llm_load_tensors:      CUDA0 buffer size = 13459.41 MiB
....................................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =   160.00 MiB
llama_new_context_with_model: KV self size  =  160.00 MiB, K (f16):   80.00 MiB, V (f16):   80.00 MiB
llama_new_context_with_model:  CUDA_Host input buffer size   =     4.25 MiB
ggml_gallocr_reserve_n: reallocating CUDA0 buffer from size 0.00 MiB to 36.25 MiB
ggml_gallocr_reserve_n: reallocating CUDA_Host buffer from size 0.00 MiB to 4.00 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =    36.25 MiB
llama_new_context_with_model:      CUDA1 compute buffer size =     0.00 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     4.00 MiB
llama_new_context_with_model: graph splits (measure): 3
ggml_gallocr_needs_realloc: graph has different number of nodes
ggml_gallocr_alloc_graph: cannot reallocate multi buffer graph automatically, call reserve
ggml_backend_sched: failed to allocate graph, reserving

system_info: n_threads = 4 / 8 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 |
hellaswag_score : loaded 10042 tasks from prompt.
================================= is_spm = 1
hellaswag_score : selecting 1000 randomized tasks.
hellaswag_score : calculating hellaswag score over selected tasks.

task    acc_norm
ggml_gallocr_needs_realloc: node inp_embd is not valid
ggml_gallocr_alloc_graph: cannot reallocate multi buffer graph automatically, call reserve
ggml_backend_sched: failed to allocate graph, reserving
ggml_gallocr_reserve_n: reallocating CUDA0 buffer from size 36.25 MiB to 40.06 MiB
ggml_gallocr_needs_realloc: node CUDA0#KQ_mask is not valid
ggml_gallocr_alloc_graph: cannot reallocate multi buffer graph automatically, call reserve
ggml_backend_sched: failed to allocate graph, reserving
ggml_gallocr_reserve_n: reallocating CUDA0 buffer from size 40.06 MiB to 40.13 MiB
ggml_gallocr_needs_realloc: node CUDA0#KQ_mask is not valid
ggml_gallocr_alloc_graph: cannot reallocate multi buffer graph automatically, call reserve
ggml_backend_sched: failed to allocate graph, reserving
ggml_gallocr_needs_realloc: node CUDA0#KQ_mask is not valid
ggml_gallocr_alloc_graph: cannot reallocate multi buffer graph automatically, call reserve
ggml_backend_sched: failed to allocate graph, reserving
1       0.00000000
2       0.00000000
ggml_gallocr_needs_realloc: node inp_embd is not valid
ggml_gallocr_alloc_graph: cannot reallocate multi buffer graph automatically, call reserve
ggml_backend_sched: failed to allocate graph, reserving
ggml_gallocr_needs_realloc: node CUDA0#KQ_mask is not valid
ggml_gallocr_alloc_graph: cannot reallocate multi buffer graph automatically, call reserve
ggml_backend_sched: failed to allocate graph, reserving
ggml_gallocr_needs_realloc: node CUDA0#KQ_mask is not valid
ggml_gallocr_alloc_graph: cannot reallocate multi buffer graph automatically, call reserve
ggml_backend_sched: failed to allocate graph, reserving
3       0.00000000
4       0.00000000
5       20.00000000
6       16.66666667
7       28.57142857
8       37.50000000
9       44.44444444
10      40.00000000
ggml_gallocr_needs_realloc: node CUDA0#KQ_mask is not valid
ggml_gallocr_alloc_graph: cannot reallocate multi buffer graph automatically, call reserve
ggml_backend_sched: failed to allocate graph, reserving
11      36.36363636
12      41.66666667

There's still a problem, but it doesn't crash anymore.

@slaren
Copy link
Collaborator

slaren commented Feb 12, 2024

These messages are only printed in debug builds, and do not necessarily indicate a problem, it's just a trace of what's happening in the allocator.

@Nexesenex
Copy link
Contributor

Ok!
Then it works, and the frequency of the allocator messages decreases when I increase the batch size.
Thank you !

@cebtenzzre
Copy link
Collaborator

@Nexesenex For readability, please paste output into code blocks,

```
like this
```

@ikawrakow
Copy link
Contributor Author

ikawrakow commented Feb 12, 2024

OK, another update (apart from merging latest master to pick up #5452): using IQ2_XXS for the attn_output.weight tensors. This adds 0.04 bpw to the quantized model size, but massively improves PPL (relative to size increase):

Model PPL PB-LLM PPL BiLLM PPL this PR
LLaMA-v1-7B 102.36 35.04 17.93
LLaMA-v1-13B 36.60 15.14 9.619
LLaMA-v1-30B 33.67 10.52 7.564
LLaMA-v2-7B 69.20 32.48 15.33
LLaMA-v2-13B 151.09 16.77 9.167
LLaMA-v2-70B 28.37 8.41 5.675
Mistral-7B - - 10.80
Mixtral8x7B - - 7.261

Mixtral8x7B with 3 experts has PPL = 6.902. Model size is 8.9 GiB, so at least theoretically it should fit on a 12 GB GPU (I don't have one to test).

@Nexesenex
Copy link
Contributor

Nexesenex commented Feb 12, 2024

@ikawrakow 👍

The second version of IQ1_S showed massive progresses already. I will test the third tonight.

Noted for the code blocks, but I fail to achieve it properly obviously. :X

@InvincibleDude
Copy link

Is all of this applicable to higher bpw quants?

@Artefact2
Copy link
Collaborator

Model size is 8.9 GiB, so at least theoretically it should fit on a 12 GB GPU (I don't have one to test).

It does. Here are some results on my 6750XT.

model size params backend ngl test t/s
llama 7B IQ1_S - 1.5625 bpw 8.88 GiB 46.70 B ROCm 99 pp 512 267.10 ± 0.69
llama 7B IQ1_S - 1.5625 bpw 8.88 GiB 46.70 B ROCm 99 pp 2048 257.14 ± 0.30
llama 7B IQ1_S - 1.5625 bpw 8.88 GiB 46.70 B ROCm 99 pp 8192 220.46 ± 0.07
llama 7B IQ1_S - 1.5625 bpw 8.88 GiB 46.70 B ROCm 99 tg 128 40.89 ± 0.04

For reference, the PPL for this model (Fish-8x7B-IQ1_S) over wiki.test.raw, at 512 context is 8.1976 +/- 0.05212.

@Nexesenex
Copy link
Contributor

Nexesenex commented Feb 12, 2024

Here are some benchs, @ikawrakow :

- TeeZee_Kyllene-34B-v1.1-b2116-iMat-c32_ch3250-IQ1_S.gguf,-,Hellaswag,31,,400,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2116-iMat-c32_ch3250-IQ1_S.gguf,-,Hellaswag,26.8,,1000,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2116-iMat-c32_ch3250-IQ1_S.gguf,-,Arc-Challenge,20.06688963,,299,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2116-iMat-c32_ch3250-IQ1_S.gguf,-,Arc-Easy,24.73684211,,570,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2116-iMat-c32_ch3250-IQ1_S.gguf,-,MMLU,27.15654952,,313,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2116-iMat-c32_ch3250-IQ1_S.gguf,-,Thruthful-QA,30.23255814,,817,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2116-iMat-c32_ch3250-IQ1_S.gguf,-,Winogrande,47.9084,,1267,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2116-iMat-c32_ch3250-IQ1_S.gguf,-,wikitext,724599.9720,512,512,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,327

- TeeZee_Kyllene-34B-v1.1-b2128-iMat-c32_ch3250-IQ1_S_v2.gguf,-,Hellaswag,62.75,,400,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2128-iMat-c32_ch3250-IQ1_S_v2.gguf,-,Hellaswag,62.9,,1000,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2128-iMat-c32_ch3250-IQ1_S_v2.gguf,-,Arc-Challenge,36.78929766,,299,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2128-iMat-c32_ch3250-IQ1_S_v2.gguf,-,Arc-Easy,56.49122807,,570,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2128-iMat-c32_ch3250-IQ1_S_v2.gguf,-,MMLU,30.67092652,,313,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2128-iMat-c32_ch3250-IQ1_S_v2.gguf,-,Thruthful-QA,27.90697674,,817,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2128-iMat-c32_ch3250-IQ1_S_v2.gguf,-,Winogrande,60.6946,,1267,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2128-iMat-c32_ch3250-IQ1_S_v2.gguf,-,wikitext,12.8712,512,512,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2128-iMat-c32_ch3250-IQ1_S_v2.gguf,-,wikitext,10.0199,4096,4096,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2128-iMat-c32_ch3250-IQ1_S_v2.gguf,-,wikitext,10.0193,8192,8192,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,

- TeeZee_Kyllene-34B-v1.1-b2131-iMat-c32_ch3250-IQ1_S_v3.gguf,-,Hellaswag,63,,400,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2131-iMat-c32_ch3250-IQ1_S_v3.gguf,-,Hellaswag,64,,1000,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2131-iMat-c32_ch3250-IQ1_S_v3.gguf,-,Arc-Challenge,34.44816054,,299,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2131-iMat-c32_ch3250-IQ1_S_v3.gguf,-,Arc-Easy,54.03508772,,570,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2131-iMat-c32_ch3250-IQ1_S_v3.gguf,-,MMLU,32.90734824,,313,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2131-iMat-c32_ch3250-IQ1_S_v3.gguf,-,Thruthful-QA,26.68298654,,817,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2131-iMat-c32_ch3250-IQ1_S_v3.gguf,-,Winogrande,63.6148,,1267,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2131-iMat-c32_ch3250-IQ1_S_v3.gguf,-,wikitext,11.6058,512,512,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,
- TeeZee_Kyllene-34B-v1.1-b2131-iMat-c32_ch3250-IQ1_S_v3.gguf,-,wikitext,8.9842,4096,4096,2024-02-12 00:00:00,,34b,Yi,2000000,,,GGUF,TeeZee,Nexesenex,

- miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf,-,Hellaswag,24.25,400,,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf,-,Hellaswag,22.5,1000,,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf,-,Arc-Challenge,25.08361204,,299,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf,-,Arc-Easy,24.56140351,,570,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf,-,MMLU,24.92012780,,313,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf,-,Thruthful-QA,19.33904529,,817,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf,-,Winogrande,50.8287,,1267,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2116-iMat-c32_ch400-IQ1_S.gguf,-,wikitext,117089.7230,512,512,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,327

- miqu-1-70b-Requant-b2128-iMat-c32_ch400-IQ1_S_v2.gguf,-,Hellaswag,76,400,,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2128-iMat-c32_ch400-IQ1_S_v2.gguf,-,Hellaswag,76.3,1000,,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2128-iMat-c32_ch400-IQ1_S_v2.gguf,-,Arc-Challenge,45.15050167,,299,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2128-iMat-c32_ch400-IQ1_S_v2.gguf,-,Arc-Easy,67.54385965,,570,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2128-iMat-c32_ch400-IQ1_S_v2.gguf,-,MMLU,39.93610224,,313,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2128-iMat-c32_ch400-IQ1_S_v2.gguf,-,Thruthful-QA,29.37576499,,817,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2128-iMat-c32_ch400-IQ1_S_v2.gguf,-,Winogrande,72.6914,,1267,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2128-iMat-c32_ch400-IQ1_S_v2.gguf,-,wikitext,7.0861,512,512,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2128-iMat-c32_ch400-IQ1_S_v2.gguf,-,wikitext,5.8372,4096,4096,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2128-iMat-c32_ch400-IQ1_S_v2.gguf,-,wikitext,5.7746,8192,8192,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,

- miqu-1-70b-Requant-b2131-iMat-c32_ch400-IQ1_S_v3.gguf,-,Hellaswag,78.75,,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2131-iMat-c32_ch400-IQ1_S_v3.gguf,-,Hellaswag,78.1,1000,,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2131-iMat-c32_ch400-IQ1_S_v3.gguf,-,Arc-Challenge,45.15050167,,299,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2131-iMat-c32_ch400-IQ1_S_v3.gguf,-,Arc-Easy,70.70175439,,570,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2131-iMat-c32_ch400-IQ1_S_v3.gguf,-,MMLU,38.97763578,,313,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2131-iMat-c32_ch400-IQ1_S_v3.gguf,-,Thruthful-QA,33.29253366,,817,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2131-iMat-c32_ch400-IQ1_S_v3.gguf,-,Winogrande,72.2178,,1267,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2131-iMat-c32_ch400-IQ1_S_v3.gguf,-,wikitext,6.7606,512,512,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2131-iMat-c32_ch400-IQ1_S_v3.gguf,-,wikitext,5.5886,4096,4096,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,
- miqu-1-70b-Requant-b2131-iMat-c32_ch400-IQ1_S_v3.gguf,-,wikitext,5.5291,8192,8192,2024-02-12 00:00:00,,70b,Mistral_Medium,32768,,,GGUF,Miqudev,Nexesenex,

On Miqu 70b, the progression is neat from one revision to another.

On a Yi 34b model, it's a bit weirder.

As for the output of the Miqu 70b in IQ1_S "v3", it's definitively better than v1. The formatting is more respected, the model can make a relatively sensical and detailled answer to a question, with developments.

Example obtained with my last Frankenstein Kobold CPP (IQ1_V3) :

How to build a pyramid standing on its top?

I'm sorry, but my masters have not taught me how to physically manipulate objects in the real world. However, I can help you with information or advice on how to build a pyramid standing on its tip. In this case, it would be more accurate to say that you are trying to create a pyramid-shaped structure balanced on its tip. Here are some suggestions for creating such a structure:

1. Start by designing your pyramid using a lightweight material like cardboard or foam. This will make it easier to balance the structure on its tip.
2. Create a small pedestal as the base of your pyramid. The pedestal should be slightly wider than the pyramid's base so that it can support the weight of the structure without falling over.
3. Carve a small notch at the tip of the pyramid, which will act as a fulcrum point when you balance the structure.
4. Carefully place the pyramid on the pedals, ensuring that the fulcum point is centered and properly aligned with the base.
5. Gently lift the structure and carefully move it into position, balancing it on its tip.
6. Once in place, use a piece of tape or magnet to hold the structure in place until it settles naturally.
7. If you want to add more stability to your pyramid, consider adding weights to the base of the structure.
8. Finally, take a step back and admire your creation!

But still, I think it needs some slight quality bump (or an intermediary IQ1 quant halfway to iq2_xxs) in order to be really usable, especially if we consider the models below 70b for which the results are weirdly incoherent between the v2 and v3 of the IQ1_S PR.

@ghchris2021
Copy link

@ikawrakow et. al. who are helping implement & test this -- bravo, this is superb to see!

I think there are many use cases for these kinds of "as high quality and speed as possible" 1-4 bit per parameter realm quantizations pragmatically simply to fit next-echelon size range models
into GPUs / CPUs with modest RAM/VRAM limits.

Particularly above ~16-24GBy VRAM many people run out of practical ability / tolerance to
"solve the problem with more money" and buy a bigger GPU ad hoc so being able to run 30B, 70B,
120B and larger models with compellingly useful quality on HW one can access for the next
year or two is quite a boon of facility!

I have no idea to what extent it may be useful for comparison of some of these highly quantized techniques, but there is this newer model out there that may be an interesting larger size test subject and also may be relatively unlikely to receive such contemporary SOTA high factor quantizations any time soon otherwise (until this work reaches mainstream use) but which can be compared with older construction heuristics Q2/Q4 options present.

e.g.

https://huggingface.co/abacusai/TheProfessor-155b

https://huggingface.co/abacusai/TheProfessor-155b-gguf/tree/main

(there is a Q2, Q4 there, so would be a good apples-to-apples comparison perhaps going to lower / better calculated quants)

@benxh1995
Copy link

benxh1995 commented Feb 12, 2024

I want to thank you @ikawrakow for the amazing job! I've rented out a box with A5000, and downloaded @Nexesenex miqu q1_s_v2: here's my perplexity outputs running on WikiText-2 wiki.test.raw:
command: ./perplexity -m models/miqu-1s.gguf -ngl 100 -f wiki.test.raw

perplexity: tokenizing the input ..
perplexity: tokenization took 1270.28 ms
perplexity: calculating perplexity over 655 chunks, batch_size=512
perplexity: 1.26 seconds per pass - ETA 13.78 minutes
[1]5.7467,[2]6.4647,[3]6.7486,[4]7.3598,[5]7.2764,[6]7.1045,[7]7.1943,[8]7.2727,[9]7.5140,[10]7.7969,[11]8.0505,[12]8.1259,[13]8.1811,[14]8.3588,[15]8.7169,[16]8.3215,[17]8.1265,[18]8.2420,[19]7.8164,[20]7.7751,[21]7.7082,[22]7.6068,[23]7.5423,[24]7.4539,[25]7.4459,[26]7.2090,[27]6.9749,[28]6.8829,[29]6.7688,[30]6.6016,[31]6.5263,[32]6.5478,[33]6.4736,[34]6.5062,[35]6.5037,[36]6.5622,[37]6.5595,[38]6.5789,[39]6.6356,[40]6.7092,[41]6.7643,[42]6.8172,[43]6.7586,[44]6.7886,[45]6.8020,[46]6.7732,[47]6.8026,[48]6.7859,[49]6.7834,[50]6.7357,[51]6.7377,[52]6.7239,[53]6.7725,[54]6.7472,[55]6.7216,[56]6.7651,[57]6.7835,[58]6.8047,[59]6.8232,[60]6.8830,[61]6.8851,[62]6.9564,[63]6.9832,[64]6.9838,[65]7.0248,[66]7.0268,[67]7.0326,[68]7.0564,[69]7.0974,[70]7.1408,[71]7.1688,[72]7.2080,[73]7.2686,[74]7.2703,[75]7.2795,[76]7.2869,[77]7.3110,[78]7.3069,[79]7.3344,[80]7.3312,[81]7.3590,[82]7.3690,[83]7.3100,[84]7.3032,[85]7.3032,[86]7.2742,[87]7.2305,[88]7.1987,[89]7.1743,[90]7.1626,[91]7.1888,[92]7.1823,[93]7.1748,[94]7.1690,[95]7.1963,[96]7.1854,[97]7.1773,[98]7.1687,[99]7.1466,[100]7.1431,[101]7.1640,[102]7.1487,[103]7.1586,[104]7.1579,[105]7.1550,[106]7.1640,[107]7.1671,[108]7.1712,[109]7.1598,[110]7.1512,[111]7.1710,[112]7.1967,[113]7.1931,[114]7.1869,[115]7.1934,[116]7.1857,[117]7.1906,[118]7.2183,[119]7.2428,[120]7.2835,[121]7.3054,[122]7.3213,[123]7.3633,[124]7.3837,[125]7.3671,[126]7.4047,[127]7.4437,[128]7.4664,[129]7.4469,[130]7.4522,[131]7.4483,[132]7.4409,[133]7.4350,[134]7.4572,[135]7.4516,[136]7.4453,[137]7.4442,[138]7.4326,[139]7.4312,[140]7.4301,[141]7.4149,[142]7.4218,[143]7.4137,[144]7.3947,[145]7.4002,[146]7.3873,[147]7.3966,[148]7.4020,[149]7.3962,[150]7.4010,[151]7.4086,[152]7.4004,[153]7.3830,[154]7.3795,[155]7.3864,[156]7.3879,[157]7.4010,[158]7.4023,[159]7.4073,[160]7.4111,[161]7.4258,[162]7.3906,[163]7.3711,[164]7.3434,[165]7.3123,[166]7.2773,[167]7.2297,[168]7.1946,[169]7.1736,[170]7.1587,[171]7.1268,[172]7.1061,[173]7.0879,[174]7.0601,[175]7.0335,[176]7.0225,[177]7.0037,[178]6.9837,[179]6.9656,[180]6.9569,[181]6.9364,[182]6.9159,[183]6.8985,[184]6.8954,[185]6.8899,[186]6.8905,[187]6.8967,[188]6.8983,[189]6.9193,[190]6.9250,[191]6.9444,[192]6.9589,[193]6.9804,[194]6.9937,[195]7.0226,[196]7.0370,[197]7.0625,[198]7.0807,[199]7.0836,[200]7.0809,[201]7.0748,[202]7.0949,[203]7.1015,[204]7.1122,[205]7.1213,[206]7.1261,[207]7.1213,[208]7.1313,[209]7.1342,[210]7.1361,[211]7.1472,[212]7.1523,[213]7.1595,[214]7.1663,[215]7.1727,[216]7.1879,[217]7.2055,[218]7.2225,[219]7.2219,[220]7.2153,[221]7.2040,[222]7.1989,[223]7.1870,[224]7.1760,[225]7.1712,[226]7.1935,[227]7.2150,[228]7.2238,[229]7.2326,[230]7.2352,[231]7.2516,[232]7.2444,[233]7.2226,[234]7.2089,[235]7.1989,[236]7.1906,[237]7.1808,[238]7.1874,[239]7.1697,[240]7.1590,[241]7.1653,[242]7.1698,[243]7.1642,[244]7.1547,[245]7.1545,[246]7.1423,[247]7.1334,[248]7.1260,[249]7.1230,[250]7.1275,[251]7.1208,[252]7.1153,[253]7.1070,[254]7.1020,[255]7.0913,[256]7.0767,[257]7.0666,[258]7.0624,[259]7.0640,[260]7.0575,[261]7.0574,[262]7.0551,[263]7.0500,[264]7.0371,[265]7.0365,[266]7.0306,[267]7.0212,[268]7.0262,[269]7.0237,[270]7.0230,[271]7.0288,[272]7.0331,[273]7.0315,[274]7.0326,[275]7.0415,[276]7.0470,[277]7.0629,[278]7.0747,[279]7.0846,[280]7.0920,[281]7.1029,[282]7.1075,[283]7.1230,[284]7.1327,[285]7.1415,[286]7.1559,[287]7.1566,[288]7.1646,[289]7.1568,[290]7.1423,[291]7.1269,[292]7.1088,[293]7.0941,[294]7.0953,[295]7.0933,[296]7.0987,[297]7.0988,[298]7.1008,[299]7.0957,[300]7.0853,[301]7.0844,[302]7.0772,[303]7.0681,[304]7.0566,[305]7.0552,[306]7.0437,[307]7.0421,[308]7.0425,[309]7.0257,[310]7.0205,[311]7.0133,[312]7.0138,[313]7.0056,[314]7.0037,[315]6.9876,[316]6.9905,[317]6.9735,[318]6.9573,[319]6.9681,[320]6.9812,[321]6.9864,[322]6.9785,[323]6.9802,[324]6.9819,[325]6.9971,[326]6.9984,[327]7.0039,[328]7.0093,[329]7.0140,[330]7.0227,[331]7.0352,[332]7.0303,[333]7.0398,[334]7.0338,[335]7.0243,[336]7.0256,[337]7.0218,[338]7.0230,[339]7.0173,[340]7.0108,[341]7.0157,[342]7.0147,[343]7.0174,[344]7.0154,[345]7.0154,[346]7.0124,[347]7.0138,[348]7.0135,[349]7.0151,[350]7.0136,[351]7.0122,[352]7.0118,[353]7.0033,[354]7.0035,[355]7.0088,[356]7.0127,[357]7.0068,[358]7.0170,[359]7.0213,[360]7.0146,[361]7.0107,[362]7.0181,[363]7.0291,[364]7.0367,[365]7.0413,[366]7.0434,[367]7.0514,[368]7.0464,[369]7.0455,[370]7.0451,[371]7.0391,[372]7.0433,[373]7.0477,[374]7.0459,[375]7.0424,[376]7.0511,[377]7.0443,[378]7.0450,[379]7.0493,[380]7.0401,[381]7.0378,[382]7.0329,[383]7.0310,[384]7.0282,[385]7.0257,[386]7.0257,[387]7.0248,[388]7.0195,[389]7.0122,[390]7.0035,[391]6.9951,[392]6.9942,[393]6.9980,[394]7.0015,[395]6.9981,[396]6.9895,[397]6.9981,[398]7.0020,[399]7.0090,[400]7.0069,[401]7.0083,[402]7.0109,[403]7.0128,[404]7.0178,[405]7.0127,[406]7.0077,[407]7.0104,[408]7.0099,[409]7.0216,[410]7.0334,[411]7.0458,[412]7.0630,[413]7.0748,[414]7.0846,[415]7.0893,[416]7.0964,[417]7.1082,[418]7.1114,[419]7.1159,[420]7.1238,[421]7.1339,[422]7.1377,[423]7.1463,[424]7.1567,[425]7.1649,[426]7.1708,[427]7.1730,[428]7.1810,[429]7.1854,[430]7.1928,[431]7.2081,[432]7.2086,[433]7.2048,[434]7.1972,[435]7.1986,[436]7.2009,[437]7.2090,[438]7.2171,[439]7.2106,[440]7.2104,[441]7.2054,[442]7.2037,[443]7.2049,[444]7.2062,[445]7.2034,[446]7.2048,[447]7.2057,[448]7.2072,[449]7.2042,[450]7.2036,[451]7.1984,[452]7.1961,[453]7.1903,[454]7.1857,[455]7.1859,[456]7.1889,[457]7.1905,[458]7.1877,[459]7.1873,[460]7.1960,[461]7.1923,[462]7.1912,[463]7.1959,[464]7.1951,[465]7.1930,[466]7.1876,[467]7.1899,[468]7.1906,[469]7.1914,[470]7.1924,[471]7.1892,[472]7.1942,[473]7.1872,[474]7.1930,[475]7.1939,[476]7.1935,[477]7.1876,[478]7.1872,[479]7.2016,[480]7.2081,[481]7.2101,[482]7.2044,[483]7.2009,[484]7.2046,[485]7.2037,[486]7.1988,[487]7.2021,[488]7.2016,[489]7.1972,[490]7.1965,[491]7.1949,[492]7.1895,[493]7.1858,[494]7.1834,[495]7.1843,[496]7.1821,[497]7.1790,[498]7.1810,[499]7.1745,[500]7.1657,[501]7.1611,[502]7.1628,[503]7.1600,[504]7.1509,[505]7.1544,[506]7.1549,[507]7.1564,[508]7.1517,[509]7.1489,[510]7.1518,[511]7.1569,[512]7.1603,[513]7.1616,[514]7.1677,[515]7.1620,[516]7.1611,[517]7.1635,[518]7.1625,[519]7.1657,[520]7.1668,[521]7.1669,[522]7.1695,[523]7.1709,[524]7.1763,[525]7.1799,[526]7.1812,[527]7.1835,[528]7.1804,[529]7.1837,[530]7.1765,[531]7.1748,[532]7.1824,[533]7.1854,[534]7.1828,[535]7.1862,[536]7.1797,[537]7.1770,[538]7.1818,[539]7.1831,[540]7.1882,[541]7.1902,[542]7.1911,[543]7.1949,[544]7.1948,[545]7.1920,[546]7.1928,[547]7.1874,[548]7.1785,[549]7.1768,[550]7.1737,[551]7.1697,[552]7.1675,[553]7.1628,[554]7.1586,[555]7.1546,[556]7.1534,[557]7.1567,[558]7.1541,[559]7.1554,[560]7.1531,[561]7.1539,[562]7.1521,[563]7.1537,[564]7.1614,[565]7.1640,[566]7.1647,[567]7.1610,[568]7.1611,[569]7.1574,[570]7.1596,[571]7.1614,[572]7.1613,[573]7.1605,[574]7.1569,[575]7.1570,[576]7.1572,[577]7.1556,[578]7.1535,[579]7.1536,[580]7.1455,[581]7.1406,[582]7.1402,[583]7.1394,[584]7.1407,[585]7.1347,[586]7.1285,[587]7.1289,[588]7.1328,[589]7.1399,[590]7.1413,[591]7.1418,[592]7.1397,[593]7.1337,[594]7.1328,[595]7.1295,[596]7.1327,[597]7.1277,[598]7.1268,[599]7.1278,[600]7.1246,[601]7.1220,[602]7.1264,[603]7.1283,[604]7.1289,[605]7.1315,[606]7.1328,[607]7.1333,[608]7.1289,[609]7.1283,[610]7.1334,[611]7.1299,[612]7.1325,[613]7.1286,[614]7.1220,[615]7.1120,[616]7.1154,[617]7.1066,[618]7.0999,[619]7.0928,[620]7.0752,[621]7.0658,[622]7.0627,[623]7.0647,[624]7.0623,[625]7.0630,[626]7.0629,[627]7.0660,[628]7.0666,[629]7.0666,[630]7.0696,[631]7.0740,[632]7.0797,[633]7.0773,[634]7.0798,[635]7.0794,[636]7.0782,[637]7.0770,[638]7.0801,[639]7.0768,[640]7.0771,[641]7.0759,[642]7.0830,[643]7.0842,[644]7.0838,[645]7.0815,[646]7.0858,[647]7.0846,[648]7.0847,[649]7.0825,[650]7.0854,[651]7.0895,[652]7.0896,[653]7.0940,[654]7.0882,[655]7.0865,
Final estimate: PPL = 7.0865 +/- 0.04023

llama_print_timings:        load time =    3595.00 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =  863697.23 ms / 335360 tokens (    2.58 ms per token,   388.28 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =  870279.91 ms / 335361 tokens

I hope this is useful. I have some saved outputs on low temperature with miqu q5_k_m, when compared to the output of iq1_s_v2(with default settings) it's of comparable quality.

Looking forward to being able to run this on M1!

EDIT: Did v3 too:

perplexity: tokenizing the input ..
perplexity: tokenization took 1008.64 ms
perplexity: calculating perplexity over 655 chunks, batch_size=512
perplexity: 1.28 seconds per pass - ETA 13.92 minutes
[1]5.5194,[2]6.1329,[3]6.5139,[4]7.0681,[5]7.0457,[6]6.9202,[7]7.0180,[8]7.0711,[9]7.2671,[10]7.5602,[11]7.8078,[12]7.8416,[13]7.8681,[14]7.9968,[15]8.3418,[16]7.9559,[17]7.8083,[18]7.9101,[19]7.5053,[20]7.4614,[21]7.3718,[22]7.2770,[23]7.2251,[24]7.1246,[25]7.1013,[26]6.8871,[27]6.6712,[28]6.5849,[29]6.4811,[30]6.3286,[31]6.2568,[32]6.2751,[33]6.2022,[34]6.2336,[35]6.2365,[36]6.2778,[37]6.2613,[38]6.2750,[39]6.3268,[40]6.3880,[41]6.4373,[42]6.4834,[43]6.4267,[44]6.4511,[45]6.4653,[46]6.4344,[47]6.4654,[48]6.4467,[49]6.4460,[50]6.4011,[51]6.4050,[52]6.3861,[53]6.4326,[54]6.4056,[55]6.3886,[56]6.4230,[57]6.4427,[58]6.4696,[59]6.4881,[60]6.5401,[61]6.5315,[62]6.5977,[63]6.6231,[64]6.6252,[65]6.6603,[66]6.6617,[67]6.6705,[68]6.6927,[69]6.7363,[70]6.7797,[71]6.8108,[72]6.8495,[73]6.9120,[74]6.9155,[75]6.9244,[76]6.9378,[77]6.9620,[78]6.9622,[79]6.9859,[80]6.9840,[81]7.0098,[82]7.0179,[83]6.9620,[84]6.9582,[85]6.9590,[86]6.9328,[87]6.8941,[88]6.8645,[89]6.8413,[90]6.8294,[91]6.8553,[92]6.8460,[93]6.8420,[94]6.8376,[95]6.8635,[96]6.8548,[97]6.8468,[98]6.8404,[99]6.8174,[100]6.8139,[101]6.8338,[102]6.8196,[103]6.8281,[104]6.8287,[105]6.8280,[106]6.8362,[107]6.8393,[108]6.8459,[109]6.8345,[110]6.8239,[111]6.8422,[112]6.8651,[113]6.8631,[114]6.8586,[115]6.8646,[116]6.8598,[117]6.8611,[118]6.8877,[119]6.9122,[120]6.9522,[121]6.9733,[122]6.9886,[123]7.0303,[124]7.0487,[125]7.0339,[126]7.0686,[127]7.1032,[128]7.1255,[129]7.1080,[130]7.1122,[131]7.1092,[132]7.1052,[133]7.0969,[134]7.1139,[135]7.1103,[136]7.1053,[137]7.1044,[138]7.0953,[139]7.0943,[140]7.0899,[141]7.0745,[142]7.0819,[143]7.0715,[144]7.0549,[145]7.0572,[146]7.0464,[147]7.0574,[148]7.0645,[149]7.0604,[150]7.0630,[151]7.0678,[152]7.0599,[153]7.0430,[154]7.0390,[155]7.0487,[156]7.0518,[157]7.0648,[158]7.0641,[159]7.0698,[160]7.0732,[161]7.0860,[162]7.0519,[163]7.0339,[164]7.0081,[165]6.9786,[166]6.9454,[167]6.9036,[168]6.8698,[169]6.8521,[170]6.8387,[171]6.8092,[172]6.7918,[173]6.7760,[174]6.7497,[175]6.7233,[176]6.7134,[177]6.6940,[178]6.6756,[179]6.6579,[180]6.6500,[181]6.6298,[182]6.6086,[183]6.5922,[184]6.5883,[185]6.5824,[186]6.5843,[187]6.5910,[188]6.5932,[189]6.6130,[190]6.6201,[191]6.6378,[192]6.6524,[193]6.6704,[194]6.6824,[195]6.7092,[196]6.7236,[197]6.7476,[198]6.7651,[199]6.7687,[200]6.7669,[201]6.7595,[202]6.7773,[203]6.7841,[204]6.7946,[205]6.8021,[206]6.8077,[207]6.8045,[208]6.8154,[209]6.8174,[210]6.8200,[211]6.8328,[212]6.8381,[213]6.8451,[214]6.8515,[215]6.8555,[216]6.8717,[217]6.8878,[218]6.9038,[219]6.9021,[220]6.8960,[221]6.8844,[222]6.8783,[223]6.8674,[224]6.8563,[225]6.8518,[226]6.8730,[227]6.8909,[228]6.8995,[229]6.9085,[230]6.9110,[231]6.9268,[232]6.9202,[233]6.9013,[234]6.8875,[235]6.8785,[236]6.8713,[237]6.8614,[238]6.8680,[239]6.8512,[240]6.8403,[241]6.8447,[242]6.8477,[243]6.8436,[244]6.8332,[245]6.8332,[246]6.8204,[247]6.8119,[248]6.8044,[249]6.8024,[250]6.8057,[251]6.7990,[252]6.7936,[253]6.7855,[254]6.7790,[255]6.7684,[256]6.7527,[257]6.7415,[258]6.7349,[259]6.7361,[260]6.7289,[261]6.7270,[262]6.7243,[263]6.7201,[264]6.7057,[265]6.7037,[266]6.6980,[267]6.6893,[268]6.6942,[269]6.6924,[270]6.6910,[271]6.6970,[272]6.7004,[273]6.6987,[274]6.6981,[275]6.7067,[276]6.7112,[277]6.7277,[278]6.7371,[279]6.7474,[280]6.7539,[281]6.7643,[282]6.7686,[283]6.7828,[284]6.7908,[285]6.8000,[286]6.8141,[287]6.8141,[288]6.8196,[289]6.8126,[290]6.7972,[291]6.7815,[292]6.7656,[293]6.7520,[294]6.7522,[295]6.7503,[296]6.7552,[297]6.7543,[298]6.7560,[299]6.7508,[300]6.7418,[301]6.7417,[302]6.7330,[303]6.7243,[304]6.7140,[305]6.7124,[306]6.7012,[307]6.7016,[308]6.7023,[309]6.6875,[310]6.6825,[311]6.6754,[312]6.6749,[313]6.6669,[314]6.6654,[315]6.6491,[316]6.6510,[317]6.6348,[318]6.6178,[319]6.6276,[320]6.6399,[321]6.6437,[322]6.6351,[323]6.6363,[324]6.6379,[325]6.6518,[326]6.6545,[327]6.6596,[328]6.6646,[329]6.6682,[330]6.6770,[331]6.6902,[332]6.6854,[333]6.6947,[334]6.6898,[335]6.6810,[336]6.6821,[337]6.6784,[338]6.6789,[339]6.6739,[340]6.6681,[341]6.6721,[342]6.6719,[343]6.6750,[344]6.6732,[345]6.6734,[346]6.6704,[347]6.6720,[348]6.6734,[349]6.6740,[350]6.6725,[351]6.6709,[352]6.6713,[353]6.6629,[354]6.6626,[355]6.6673,[356]6.6721,[357]6.6665,[358]6.6765,[359]6.6799,[360]6.6746,[361]6.6718,[362]6.6792,[363]6.6895,[364]6.6970,[365]6.7002,[366]6.7029,[367]6.7115,[368]6.7082,[369]6.7078,[370]6.7083,[371]6.7024,[372]6.7064,[373]6.7107,[374]6.7103,[375]6.7077,[376]6.7159,[377]6.7101,[378]6.7108,[379]6.7144,[380]6.7054,[381]6.7035,[382]6.6993,[383]6.6979,[384]6.6955,[385]6.6941,[386]6.6923,[387]6.6919,[388]6.6874,[389]6.6808,[390]6.6731,[391]6.6649,[392]6.6640,[393]6.6670,[394]6.6715,[395]6.6686,[396]6.6604,[397]6.6697,[398]6.6747,[399]6.6817,[400]6.6799,[401]6.6817,[402]6.6849,[403]6.6866,[404]6.6917,[405]6.6860,[406]6.6819,[407]6.6831,[408]6.6820,[409]6.6926,[410]6.7036,[411]6.7165,[412]6.7321,[413]6.7435,[414]6.7529,[415]6.7577,[416]6.7646,[417]6.7756,[418]6.7793,[419]6.7838,[420]6.7910,[421]6.8003,[422]6.8037,[423]6.8109,[424]6.8219,[425]6.8305,[426]6.8367,[427]6.8387,[428]6.8458,[429]6.8506,[430]6.8579,[431]6.8719,[432]6.8739,[433]6.8703,[434]6.8632,[435]6.8645,[436]6.8662,[437]6.8737,[438]6.8814,[439]6.8755,[440]6.8764,[441]6.8716,[442]6.8705,[443]6.8714,[444]6.8734,[445]6.8706,[446]6.8720,[447]6.8724,[448]6.8743,[449]6.8718,[450]6.8716,[451]6.8670,[452]6.8659,[453]6.8611,[454]6.8567,[455]6.8575,[456]6.8604,[457]6.8618,[458]6.8589,[459]6.8589,[460]6.8673,[461]6.8627,[462]6.8620,[463]6.8666,[464]6.8650,[465]6.8628,[466]6.8571,[467]6.8590,[468]6.8582,[469]6.8586,[470]6.8587,[471]6.8553,[472]6.8584,[473]6.8524,[474]6.8569,[475]6.8579,[476]6.8578,[477]6.8524,[478]6.8522,[479]6.8659,[480]6.8725,[481]6.8736,[482]6.8681,[483]6.8644,[484]6.8670,[485]6.8660,[486]6.8610,[487]6.8631,[488]6.8634,[489]6.8579,[490]6.8574,[491]6.8552,[492]6.8503,[493]6.8470,[494]6.8451,[495]6.8446,[496]6.8421,[497]6.8388,[498]6.8399,[499]6.8337,[500]6.8258,[501]6.8203,[502]6.8221,[503]6.8203,[504]6.8114,[505]6.8147,[506]6.8150,[507]6.8163,[508]6.8125,[509]6.8094,[510]6.8120,[511]6.8174,[512]6.8208,[513]6.8218,[514]6.8284,[515]6.8236,[516]6.8228,[517]6.8236,[518]6.8231,[519]6.8263,[520]6.8274,[521]6.8280,[522]6.8305,[523]6.8306,[524]6.8353,[525]6.8385,[526]6.8395,[527]6.8417,[528]6.8394,[529]6.8422,[530]6.8358,[531]6.8343,[532]6.8418,[533]6.8453,[534]6.8433,[535]6.8478,[536]6.8421,[537]6.8400,[538]6.8451,[539]6.8459,[540]6.8503,[541]6.8528,[542]6.8536,[543]6.8573,[544]6.8572,[545]6.8550,[546]6.8559,[547]6.8510,[548]6.8428,[549]6.8416,[550]6.8388,[551]6.8355,[552]6.8337,[553]6.8290,[554]6.8250,[555]6.8218,[556]6.8208,[557]6.8241,[558]6.8220,[559]6.8238,[560]6.8213,[561]6.8224,[562]6.8209,[563]6.8220,[564]6.8292,[565]6.8317,[566]6.8325,[567]6.8288,[568]6.8292,[569]6.8255,[570]6.8279,[571]6.8293,[572]6.8291,[573]6.8289,[574]6.8254,[575]6.8248,[576]6.8249,[577]6.8232,[578]6.8211,[579]6.8205,[580]6.8134,[581]6.8094,[582]6.8083,[583]6.8076,[584]6.8091,[585]6.8032,[586]6.7969,[587]6.7974,[588]6.8010,[589]6.8083,[590]6.8100,[591]6.8106,[592]6.8086,[593]6.8035,[594]6.8024,[595]6.7992,[596]6.8025,[597]6.7980,[598]6.7961,[599]6.7972,[600]6.7948,[601]6.7926,[602]6.7972,[603]6.7989,[604]6.7996,[605]6.8030,[606]6.8042,[607]6.8044,[608]6.8001,[609]6.8003,[610]6.8054,[611]6.8023,[612]6.8047,[613]6.8013,[614]6.7953,[615]6.7862,[616]6.7894,[617]6.7813,[618]6.7748,[619]6.7685,[620]6.7517,[621]6.7431,[622]6.7392,[623]6.7414,[624]6.7397,[625]6.7403,[626]6.7402,[627]6.7436,[628]6.7442,[629]6.7435,[630]6.7465,[631]6.7509,[632]6.7560,[633]6.7543,[634]6.7563,[635]6.7560,[636]6.7547,[637]6.7535,[638]6.7568,[639]6.7525,[640]6.7522,[641]6.7513,[642]6.7576,[643]6.7586,[644]6.7580,[645]6.7559,[646]6.7602,[647]6.7595,[648]6.7595,[649]6.7574,[650]6.7601,[651]6.7636,[652]6.7646,[653]6.7684,[654]6.7628,[655]6.7611,
Final estimate: PPL = 6.7611 +/- 0.03827

llama_print_timings:        load time =    3042.86 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =  863938.68 ms / 335360 tokens (    2.58 ms per token,   388.18 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =  870160.60 ms / 335361 tokens

@Nexesenex
Copy link
Contributor

When I'm looking at an ongoing quant :

2024-02-13 00_19_12-_D__text-generation-webui_logs_evaluations csv - Notepad++

Would attn_k in Q4_K like attn_v, and attn_q in IQ2_XXS (or attn_q in IQ2_XS with attn_output in IQ2_XS as well) give us a viable intermediate "IQ1" quant between IQ1_S and IQ2_XXS?

@ikawrakow
Copy link
Contributor Author

Thank you all for testing and feedback.

So, what do we do with this? Leave it as a demo, or add the missing kernels and merge? @ggerganov

@ggerganov
Copy link
Owner

It's fine to merge it. The bullet points though would have to remain for the future

@ikawrakow ikawrakow added the enhancement New feature or request label Feb 13, 2024
@ikawrakow ikawrakow marked this pull request as ready for review February 13, 2024 13:28
@tsengalb99
Copy link

@ikawrakow Just wondering is this 1.5bpw end to end (including the embedding + lm head) or just the decoder weights? Do you know if BiLLM quantizes the embedding + lm head or just the decoder weights?

@ikawrakow
Copy link
Contributor Author

ikawrakow commented Feb 15, 2024

@ikawrakow Just wondering is this 1.5bpw end to end (including the embedding + lm head) or just the decoder weights? Do you know if BiLLM quantizes the embedding + lm head or just the decoder weights?

For the results reported in this PR:

  • Token embedding tensor (token_embd.weight in llama.cpp) is 2.5 bit (Q2_K)
  • Output tensor (tensor output.weight in llama.cpp) is 5.5 bit (Q5_K)
  • The V tensor in the attention part (tensor attn_v.weight in llama.cpp) is 2.5 bit (Q2_K)
  • The first 1/8 of layers of ffn_down are 2.5 bit (Q2_K)
  • The attention output tensors (attn_output.weight in llama.cpp) are 2 bit (IQ2_XXS).
  • Everything else is 1.5 bit (IQ1_S)

We end up using effectively about 1.8 bpw for 7B models, or about 1.69 bpw for 70B, when everything is counted in. If the approach was implemented in a different framework that does not require blocks of predetermined size as ggml, we would remove 0.0625 bpw from this balance (one fp16 scale per block of 256 weights), and replace it with a negligible tensor row scale (1 fp16 per 4096+ weights).

I don't know what BiLLM does. But my experience with researchers putting papers up on arxiv is that they don't care about token embedding and output tensors and never count them in the balance, so my assumption is that this applies to the BiLLM paper as well. The paper doesn't mention bits spent on block scales (they do have those as per paper), or any other meta data they may need. So, overall, hard to tell what the actual total balance will be when everything is said and done. I did try to test, but gave up after unsuccessfully fighting with their requirements.

One more thing: as I wrote in the PR description above, there is currently no way in llama.cpp to separate salient and no-salient weights as the BiLLM paper does. I'm quite confident that if that was possible, one could significantly shrink the quantized model while maintaining the same quantization error as here.

@ggerganov ggerganov merged commit bd2d4e3 into master Feb 18, 2024
47 checks passed
@Artefact2
Copy link
Collaborator

I can't get CPU (AVX2) to work. GPU (ROCm) works fine.

% ./main -m ~/KoboldCpp/models/Cat-8x7B-IQ1_S.gguf -p "[INST]Write some poetry about typography.[/INST]" -n 128

 [INST]Write some poetry about typography.[/INST]ptonENOptonCalled launchricane forward tri Становништвоchant east called indexENOENOENOENOLaunchENO reinpton launchENO triCLUD Iv triptonENOENOivotcalledpton reinadelphENOENO lic Становништво СтановништвоENO tri called calledCalled tri rein cy СтановништвоENO eastptontanptonCLUD licENOplementsENOENO launch watersENOwy launch >pton Index eastpgfscope >ptonpton rein conformnung>ptonplementsadelphptonENOptonENO tri СтановништвоENO DCHECKpton licCLUD surn licptonLaunch>ENO launch forward licenseENO launch east called specificallyptonENOricaneENOENO>calledykENOcalledchantCalledakhplements reinENO indexpton groundCalled called .. forward

#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this kernel is in a critical path for performance, it may help performance a bit if we ensure that y is aligned on 16-byte boundary (if that's that's the case). Right now all results are stored in y one by one and for this relatively low-compute kernel memory i/o will likely be the bottleneck. Aligned pointer would allow compiler to perform the store as a single 128-bit operation.

https://cuda.godbolt.org/z/a1Eczhv5a

image

@BarfingLemurs
Copy link
Contributor

@Artefact2 for me, with CPU (AVX2), i'm able to get the quantized 8x7B instruct model to produce proper results on ubuntu jammy.

See if the HF converted instruct model works. Tested on this calibration text - #5263 (comment)

I did get the same gibberish as you with an ARM device.

Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Feb 19, 2024
* iq1_s: WIP basics

* iq1_s: CUDA is working

* iq1_s: scalar CPU dot product

* iq1_s: WIP AVX2 dot product - something is not right

* Fix tests

* Fix shadow warnings

* Fix after merge with latest master

* iq1_s: AVX2 finally works

* iq1_s: ARM_NEON dot product. Works, but not very fast

* iq1_s: better grid

* iq1_s: use IQ2_XXS for attn_output

At a cost of 0.04 extra bpw this gives a big improvement in PPL.

* iq1_s: Metal basics

Dequantize works, but not dot product

* iq1_s: Metal works, but quite slow

As usual, Apple Silicon does not like the code I write.

* iq1_s: Tests

* iq1_s: slightly faster dot product

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Feb 19, 2024
@BadisG
Copy link

BadisG commented Feb 20, 2024

@ikawrakow take a look at that paper, they have some insane PPL, maybe you can use their techniques to improve your craft.
https://arxiv.org/abs/2402.11960
image
image

@ikawrakow
Copy link
Contributor Author

@BadisG Normally the fp16 PPL for LLaMA-2 reported in papers is 5.12 (7B), 4.57 (13B), and 3.12 (70B). Go figure what these results are supposed to be. Especially considering that their LLaMA-1 fp16 perplexities match the commonly reported values. Why is this important? Because if they just got confused and put the wrong fp16 values in the LLaMA-v2 table, the results are not particularly impressive (you can find better results for instance in this paper. But if the results are for a smaller context than the 4096 commonly used for LLaMA-2 to report PPL results in papers, then they may have achieved something noteworthy. I tend to assume the former is true based on their LLaMA-1 results being not particularly impressive (e.g, the 2-bit quantization IQ2_XXS in this repo has a PPL of 4.30 for LLaMA-v1-65B).

@tsengalb99
Copy link

tsengalb99 commented Feb 20, 2024 via email

@ikawrakow
Copy link
Contributor Author

@tsengalb99 No I haven't seen your tweet (I'm a relic of the past who does not hang around social networks such as Twitter, err, X). Congratulations! When you say "end-to-end", do you mean really end-to-end, as stored on disk and including output tensor? But either way, 5.9 is higher than this PR even after factoring in the ~3% difference in LLaMA-v2 fp16 PPL between llama.cpp and the Python PPL calculation. Sure, your model is smaller, but if producing basically useless quantized LLM models became the priority of the day, I'm sure the apparatus necessary to shrink the IQ1_S model from currently 13.5 GB to something more in line with your 9 GB will get implemented also here, so there wouldn't be much difference. Oh, wait, there will be one: 3 CPU minutes quantization time for LLaMA-v2-70 vs your 60 GPU hours.

Concerning the quoted paper: thanks for the clarification. If their quantization is 2.25 bpw, then we need to compare to IQ2_XS from this repo, which has a LLaMA-v1-65B PPL of 4.07, so miles ahead of their 4.84 (and I don't feel like running LLaMA-v2 at a context of 2048 to also compare this result).

@tsengalb99
Copy link

tsengalb99 commented Feb 20, 2024 via email

@ikawrakow
Copy link
Contributor Author

@tsengalb99 The ratio PPL(quantized)/PPL(fp16) is not particularly sensitive to how PPL is being calculated, and is nearly independent of context length. This ratio is closely related to KL-divergence, I wrote about that elsewhere. Hence, one can absolutely use that ratio to compare your results to mine.

@Nexesenex
Copy link
Contributor

If you want to play with this quantization, it is worth experimenting with the f_norm_rms_eps parameter. This is normally defined by the model. But here we are using such an extreme quantization that the epsilon may no longer make sense for the RMS norm operation. You can override the model defined value using

@ikawrakow : Indeed, it's worth it, one can decrease perplexity by almost 1% with the right value (9-e05 seems to be pertinent value for the infra 2bpw quants).

By the way, I reiterate my request to be able to set that parameter during quantization, so the correct f_norm_rms_eps parameter will be loaded and used in any 3rd party tool like KoboldCPP which does not support the modification of this parameter in CLI. It's a bit frustrating to not be able to make and share the best quants possible for everyone, and I would greatly appreciate the feature.

Also, do you have in store something smaller than the current GGML_TYPE_IQ1_S, which could be a IQ1_XS? Of course, it could not be efficiently used for all tensors, but it could most probably be used for attn_q.weight, and to some extend for attn_k.weight or even ffn_up / ffn_gate. I don't need that to be integrated in an existing or a new LLAMA_FTYPE_MOSTLY_XXX quant, just such GGML_TYPE_IQ1_XS to be available for usage in a quant strategy.

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
* iq1_s: WIP basics

* iq1_s: CUDA is working

* iq1_s: scalar CPU dot product

* iq1_s: WIP AVX2 dot product - something is not right

* Fix tests

* Fix shadow warnings

* Fix after merge with latest master

* iq1_s: AVX2 finally works

* iq1_s: ARM_NEON dot product. Works, but not very fast

* iq1_s: better grid

* iq1_s: use IQ2_XXS for attn_output

At a cost of 0.04 extra bpw this gives a big improvement in PPL.

* iq1_s: Metal basics

Dequantize works, but not dot product

* iq1_s: Metal works, but quite slow

As usual, Apple Silicon does not like the code I write.

* iq1_s: Tests

* iq1_s: slightly faster dot product

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* iq1_s: WIP basics

* iq1_s: CUDA is working

* iq1_s: scalar CPU dot product

* iq1_s: WIP AVX2 dot product - something is not right

* Fix tests

* Fix shadow warnings

* Fix after merge with latest master

* iq1_s: AVX2 finally works

* iq1_s: ARM_NEON dot product. Works, but not very fast

* iq1_s: better grid

* iq1_s: use IQ2_XXS for attn_output

At a cost of 0.04 extra bpw this gives a big improvement in PPL.

* iq1_s: Metal basics

Dequantize works, but not dot product

* iq1_s: Metal works, but quite slow

As usual, Apple Silicon does not like the code I write.

* iq1_s: Tests

* iq1_s: slightly faster dot product

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
@mofosyne mofosyne added Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs labels May 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
demo Demonstrate some concept or idea, not intended to be merged enhancement New feature or request Review Complexity : High Generally require indepth knowledge of LLMs or GPUs Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet