Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IQ3_S: a much better alternative to Q3_K #5676

Merged
merged 27 commits into from
Feb 24, 2024
Merged

IQ3_S: a much better alternative to Q3_K #5676

merged 27 commits into from
Feb 24, 2024

Conversation

ikawrakow
Copy link
Contributor

@ikawrakow ikawrakow commented Feb 23, 2024

This PR adds IQ3_S, a 3.4375 bpw quantization (i.e., the exact same size as Q3_K) that has a significantly lower PPL cpompared to Q3_K_S (see below).

In addition

  • The existing Q3_K_XS quantization mix (a mix of Q3_K, Q2_K and Q4_K) is replaced with a simpler and much better mix of IQ3_XXS and IQ3_S with an approximate bpw of 3.25.
  • It adds IQ3_M, a mix between the new IQ3_S and Q4_K. It has basically the same PPL as the existing Q3_K_M at 0.15 bpw less.

The graph shows a summary of PPL results for LLaMA-v1 and LLaMA-v2 models, plus Mistral-7B. Each point represents the ratio of the quantized PPL to the PPL of the base (fp16) model. The x-axis is bpw - bits-per-weight - excluding the bpw added by the higher bit quantization of the output.weight tensor. The magenta circles show the results for Q3_K_S, the orange for the new IQ3_S. The improvement in quantization error (defined as PPL(Q)/PPL(fp16)-1) is 40-70% depending on model. The cyan circles represent the existing Q3_K_M quantization mix. The dark green circles are for the new IQ3_M, showing the ~0.15 bpw saving for essentially the same quantization error. The new Q3_K_XS mix, shown in indigo, is designed to be roughly in the middle between IQ3_XXS and IQ3_S in terms of bpw. The dashed line is for visual guidance (it connects the average of the data points at each bpw).

3bit_PR

Inference performance of the new IQ3_S quants is similar to Q3_K on CUDA (RTX-4080), AVX2 (Ryzen 7950X), and Metal (30-core M2 Max). Performance on the M2 Max CPU with ARM_NEON intrinsics is pathetic - only about 10 t/s for a 7B model compared to 22.5 t/s for Q3_K_S. The IQ series of quants use "codebooks" to encode groups of 4 or 8 weights. For IQ3_S this requires 4 memory loads from a lookup table of 2048 bytes to setup one 128-bit SIMD register. It seems Apple Silicon does not like this very much. Let's hope that someone more knowledgeable than me will be able to optimize.

The extra 0.375 bits per weight spent compared to IQ3_XXS are due to

  • Using a "codebook" with 512 entries for IQ3_S instead of 256 for IQ3_XXS. This adds 1 bit per 4 weights, so 0.25 bpw.
  • Not enforcing an even number of signs in a group of 8 weights. This adds 1 bit per 8 quants, so 0.125 bpw.

If this PR is accepted, one could retire the Q3_K quants. I haven't done that mainly for two reasons:

  • Not break compatibility with the many quantized models floating around the Internet that involve Q3_K quants
  • The bad ARM_NEON performance of IQ3_S.

* Basics (quantize, dequantize)
* CUDA dequantize and dot product
* Slightly faster CUDA dot product (120 t/s)
* Switch to 6-bit scales
* Scalar dot product
* AVX2 dot product
* ARM_NEON dot product
* Works on metal, but still slow
* Slightly better Metal dot product
* Another small Metal improvement
* Metal dot product is getting there
* Faster CUDA dot product
* Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided
* Report the actual bpw
* Add _xs mix that is 4.05 bpw for non-MoE models
* Remove IQ4_XS for now, slightly adjust kvalues_iq4nl
* AVX2 dot product uses Q8_0 instead of Q8_K
* Add to test-backend-ops
* Minor fix
* Also use use Q5_K for attn_output in MoE models
* Fixes after merging latest master
* Switching to blocks of 32
* AVX2 for blocks of 32
* Scaler dot product for blocks of 32
* ARM_NEON dot product for blocks of 32
* Metal kernels for blocks of 32
* Slightly faster Metal kernels
After all the experimentation, nothing was better than this.
Performance is very similar to Q3_K_S
@ikawrakow
Copy link
Contributor Author

I see the ROCm builds failing. It is claiming to not know about __vcmpeq4. But according to this ROCm compatibility guide, __vcmpeq4 is indeed supported. Someone knows what is going on? I don't have an AMD GPU to test.

@Artefact2
Copy link
Collaborator

Artefact2 commented Feb 23, 2024

The page you linked says it isn't supported (HIP column is empty).

@Nexesenex
Copy link
Contributor

Great work!

Btw, wouldn't be more pertinent to rename the new Q3K_XS in.. IQ3_XS, to avoid confusion and considering that it's exactly where a IQ3_XS should be while this naming convention is still available?

@Xonar92
Copy link

Xonar92 commented Feb 23, 2024

Impressive work! Thanks again @ikawrakow

I am agreeing with my the former poster that it might be a good idea to think about naming conventions and possibly even have “v1/v2/v3” etc. once small improvements are made to an existing format.

Would also be cool if we could find a way to optimize for ARM / Apple silicon.

@@ -196,6 +196,17 @@ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
return __vsubss4(a, b);
}

static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 92f9309b..9729ad73 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -172,6 +172,7 @@
 #endif
 
 typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
+typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
 static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
     const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
     const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);

@ggerganov
Copy link
Owner

The build fails with the smaller super blocks (QK_K 64):

make clean && LLAMA_QKK_64=1 make -j tests && ./tests/test-quantize-fns
In file included from ggml-quants.c:1:
./ggml-quants.h:202:1: error: static assertion failed due to requirement 'sizeof(block_iq3_s) == sizeof(unsigned short) + 27 * (64 / 64)': wrong iq3_s block size/padding
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
^             ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/assert.h:113:23: note: expanded from macro 'static_assert'
#define static_assert _Static_assert
                      ^
./ggml-quants.h:202:35: note: expression evaluates to '30 == 29'
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
              ~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1 error generated.
make: *** [ggml-quants.o] Error 1
make: *** Waiting for unfinished jobs....
In file included from ggml.c:5:
./ggml-quants.h:202:1: error: static assertion failed due to requirement 'sizeof(block_iq3_s) == sizeof(unsigned short) + 27 * (64 / 64)': wrong iq3_s block size/padding
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
^             ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/assert.h:113:23: note: expanded from macro 'static_assert'
#define static_assert _Static_assert
                      ^
./ggml-quants.h:202:35: note: expression evaluates to '30 == 29'
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
              ~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1 error generated.
make: *** [ggml.o] Error 1

@ggerganov
Copy link
Owner

ggml-ci does not run test-quantize-fns for QK_K 64. It only downloads OpenLLaMA 3B, quantizes it with different K-quants, runs inference and measures PPL values (the tests in the link in the previous comment). I ran the test-quantize-fns test because of the reported issue on master

@ikawrakow
Copy link
Contributor Author

So, now it compiles and even works on ARM_NEON and AVX2 with QK_K = 64. Metal is not yet supported (just like any of the IQ quants) for QK_K = 64.

Pleasent surprise: the coding was super-block size independent,
so all it took was to delete some QK_K == 256 guards.
@sorasoras
Copy link

rocmforiq3.txt
@ikawrakow
still cannot build for rocm for gfx1100
Hope this help

@sorasoras
Copy link

sorasoras commented Feb 23, 2024

@ikawrakow Rocm work now. good job
ps some test
test_backend_and_fns.txt

Some perplexity for ref

with Imatrix

Q2KS  5.1283 +/- 0.04957 old
Q2KS  5.1090 +/- 0.04962 new with IQ4NL
IQ3XXS4.7257 +/- 0.04595
IQ3M  4.6586 +/- 0.04476
IQ3S  4.6568 +/- 0.04432
Q4KM  4.6321 +/- 0.04518 with AWQ
Q4KM  4.6321
IQ4NL 4.6048 +/- 0.04419
Q6K   4.5787 +/- 0.04407
Q5_KS 4.5761 +/- 0.04412
Q8_0  4.5685


without Imatrix
Q8_0 4.5745
Q6_K 4.5899 +/- 0.04430
Q4KM 4.6769 +/- 0.04636

@PeterReid
Copy link

PeterReid commented Feb 23, 2024

For the NEON performance, have you considered making the table be something you can compute (vectorized) instead of doing lookups?

The bytes in the lookup table only have 8 distinct values, which makes them possibly a product of a shuffle instruction. If you start with x in [0, 512), you can do (x * 123605091) & 0x07070707 and that will give you 512 distinct sets of shuffle indices. (They don't match the exact shuffle indices required to make your table, but I bet you could get somewhat close by searching through values to replace 123605091. Is the exact structure of that table important? About 20% of numbers you plug in there end up with 512 distinct results, so you'd have a lot of potential codebooks that could be chosen from.) And, you could do all 8 of these lookups at once in one big register, I think.

@ikawrakow
Copy link
Contributor Author

@PeterReid The order of the values in the lookup table is not important, but the specific values are. Each uint32_t entry encodes the quantized values of 4 model weights (i.e., each byte in the uint32_t corresponds to one quant). The 8 possible quant values are {4, 12, 20, 28, 36, 44, 52, 62} (so, apart from the last one, basically 8*q + 4, q = 0...7). These are the absolute values of the quants, there is a sign stored separately and applied during de-quantization/dot products. With 8 possible quant values and 4 quants one has 8^4 = 4096 possible combinations. The selected 512 are not the result of a random choice. They are determined by the following procedure

  • Quantize some models using quants that can take values {-62, -52, -44, -36, -28, -20, -12, -4, 4, 12, 20, 28, 36, 44, 52, 62}
  • Remove the signs, collect statistics how often each of the 4096 possible codes occur
  • Pick 512 such that a) The total counts of the points selected is maximized and b) The sum of distances of not selected points to their respective closest point on the grid, weighted with their counts, is minimized

So, these 512 values are totally not random and the quantization error strongly depends on the selection. This is a 3-bit quantization, so a straightforward 3-bit quantization would allow only 8 distinct values. Instead, here we have 16 distinct values (taking into account the sign), but with the restriction that only a subset of the possible combinations of 16 values is allowed.

@JianbangZ
Copy link

Good stuff, the new 3kxs is better.
But with recent constant changes, most ppl can't follow or know what to use anymore. There either needs a table to illustrate difference among all the quant types, or a consolidation or better naming.

@PeterReid
Copy link

PeterReid commented Feb 23, 2024

Do you happen to have those statistics in a way you can send? I would like to see how much worse the closest weighted distance would be if the codebook is constrained to be one of the ~800 million (some of those would be duplicates) that can be generated with this method.

@cebtenzzre
Copy link
Collaborator

But with recent constant changes, most ppl can't follow or know what to use anymore. There either needs a table to illustrate difference among all the quant types, or a consolidation or better naming.

Agreed - I still mostly use the legacy Q4_0 and Q5_0 quants because they run fast on my Tesla P40. It would be nice to have a summary of all of the new quantization types so I know which ones to explore. I got a little discouraged after Q2_K halved in tg speed on my GPU, and since then I haven't experimented much with sub-4-bit quants.

@ikawrakow
Copy link
Contributor Author

@PeterReid Here is one such sample file with statistics. Just binary data containing 4096 ints that are the counts for the 4096 possible combinations. In C++ simply

std::vector<int> H(4096);
std::ifstream in("h8.out", std::ios::binary);
in.read((char *)H.data(), H.size()*sizeof(int));

But I did not mention the last step in the process, which is the most tedious and lengthy: after generating a new set of points (codes), I go, change the lookup table, and run a bunch of perplexity calculations to see how this new codebook performs. I always use the 7B and 13B models of LLaMA-1 and LLaMA-v2 along with Mistral-7B. If the results are looking promising, then I also run LLaMA-1-30B and Mixtral-8x7B. If also this looks promising, then I run LLaMA-v2-70B and possibly LLaMA-v1-65B. Without this verification step it doesn't work. More often than not I made a tweak to the codebook generation, it looked to be better (lower mean squared distance of codes not in the codebook), to only get a worse PPL to the codebook that looked worse on paper.

If you can find a way to encode this exact set of 512 entries via some clever trick, this would be great!

h8.out.gz

@sorasoras
Copy link

But with recent constant changes, most ppl can't follow or know what to use anymore. There either needs a table to illustrate difference among all the quant types, or a consolidation or better naming.

Agreed - I still mostly use the legacy Q4_0 and Q5_0 quants because they run fast on my Tesla P40. It would be nice to have a summary of all of the new quantization types so I know which ones to explore. I got a little discouraged after Q2_K halved in tg speed on my GPU, and since then I haven't experimented much with sub-4-bit quants.

That's strange you got Q2K that slow on P40
JFYI

ggml_init_cublas: GGML_CUDA_FORCE_MMQ: no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
Device 0: Tesla P40, compute capability 6.1, VMM: no

model size params backend ngl test t/s
qwen 13B Q4_K - Medium 8.79 GiB 14.17 B CUDA 99 pp 512 431.84 ± 0.31
qwen 13B Q4_K - Medium 8.79 GiB 14.17 B CUDA 99 tg 128 19.66 ± 0.01
qwen 13B Q2_K - Small 5.34 GiB 14.17 B CUDA 99 pp 512 336.90 ± 0.28
qwen 13B Q2_K - Small 5.34 GiB 14.17 B CUDA 99 tg 128 26.81 ± 0.03
qwen 13B IQ4_NL - 4.5 bpw 7.61 GiB 14.17 B CUDA 99 pp 512 243.57 ± 0.20
qwen 13B IQ4_NL - 4.5 bpw 7.61 GiB 14.17 B CUDA 99 tg 128 18.06 ± 0.00

IQ4NL might be a bit slower but It's 1.2g less with quite a bit better perplexity.
The other fast one i can think of is the 1.5quant.
#5453
most IQ2 aren't fast on older GPU

@Artefact2
Copy link
Collaborator

KL-divergence data for Mistral-7B (over wikitext)

image

Bits per weight KL-divergence median KL-divergence q99 Top tokens differ ln(PPL(Q)/PPL(base))
IQ1_S 1.78 0.5495 5.5174 0.3840 0.9235
IQ2_XXS 2.20 0.1751 2.4983 0.2313 0.2988
IQ2_XS 2.43 0.1146 1.7693 0.1943 0.2046
Q2_K_S 2.79 0.0829 1.5111 0.1735 0.1600
Q2_K 3.00 0.0588 1.0337 0.1492 0.1103
IQ3_XXS 3.21 0.0330 0.5492 0.1137 0.0589
Q3_K_XS 3.32 0.0296 0.4550 0.1071 0.0458
Q3_K_S 3.50 0.0304 0.4481 0.1068 0.0511
IQ3_S 3.52 0.0205 0.3018 0.0895 0.0306
IQ3_M 3.63 0.0186 0.2740 0.0859 0.0268
Q3_K_M 3.89 0.0171 0.2546 0.0839 0.0258
Q3_K_L 4.22 0.0152 0.2202 0.0797 0.0205
IQ4_NL 4.56 0.0085 0.1077 0.0605 0.0074
Q4_K_S 4.57 0.0083 0.1012 0.0600 0.0081
Q4_K_M 4.83 0.0075 0.0885 0.0576 0.0060
Q5_K_S 5.52 0.0045 0.0393 0.0454 0.0005
Q5_K_M 5.67 0.0043 0.0368 0.0444 0.0005
Q6_K 6.57 0.0032 0.0222 0.0394 −0.0008

@PeterReid
Copy link

Thank you for the details on your process, and that file! I am quite certain that I will not be able to generate that exact codebook, but am mildly hopeful that I will be able to generate one as good.

@thorvaldur-arnar
Copy link

thorvaldur-arnar commented Feb 24, 2024

Looking at the 3 bit quants in @Artefact2’s plot, it looks to me that the low 3 bit quants are holding up rather well in terms of top token compared with the high 3 bit quants. The difference is bigger when considering the median divergence though. I would conclude that the difference between low and high quants in the 3 bit range is less at lower temps like one would use for coding and logic.

On the other hand, maybe it’s just a flatter curve overall rather than a true difference in shape.

@ikawrakow ikawrakow merged commit 4c4cb30 into master Feb 24, 2024
57 checks passed
@ikawrakow ikawrakow deleted the ik/iq3_xs_new2 branch February 24, 2024 14:23
@dranger003
Copy link
Contributor

dranger003 commented Feb 25, 2024

EDIT: It appears this bug occurs when using more than 2 threads when using quantize - @ikawrakow Could this point to some thread synchronization issue?

I am getting invalid memory reading kmap_q3xs (looks like u is out of bound)

int grid_index = kmap_q3xs[u];
and reaching the assert below. Now this model is a finetune of mixtral, which actually quantizes fine so I'm wondering if this is an issue in the model? Also, only the new quant reach the assert (i.e. IQ3_M/IQ3_S/Q3_K_XS) and IQ3_XXS works fine. The matrix was generated with:

./build/bin/imatrix -m /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-q8_0.gguf -f /mnt/c/LLM_MODELS/groups_merged.txt -o /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-q8_0-imatrix.dat -ngl 15
assert crash (8 threads)
~/llama.cpp$ ./build/bin/quantize --imatrix /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-q8_0-imatrix.dat /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-f16.gguf /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-iq3_m.gguf IQ3_M 8
load_imatrix: loaded 928 importance matrix entries from /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-q8_0-imatrix.dat
prepare_imatrix: have 928 importance matrix entries
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
main: build = 2259 (930b1780)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: quantizing '/mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-f16.gguf' to '/mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-q3_k_m.gguf' as IQ3_M using 8 threads
llama_model_loader: loaded meta data with 24 key-value pairs and 995 tensors from /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = C:\LLM_MODELS\abacusai
llama_model_loader: - kv   2:                       llama.context_length u32              = 32768
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   9:                         llama.expert_count u32              = 8
llama_model_loader: - kv  10:                    llama.expert_used_count u32              = 2
llama_model_loader: - kv  11:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  12:                       llama.rope.freq_base f32              = 1000000.000000
llama_model_loader: - kv  13:                          general.file_type u32              = 0
llama_model_loader: - kv  14:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  15:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  16:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  17:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  18:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  19:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  20:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  21:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  22:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  23:                    tokenizer.chat_template str              = {{ bos_token }}{% for message in mess...
llama_model_loader: - type  f32:  995 tensors
================================ Have weights data with 928 entries
llama_model_quantize_internal ============ Strange model: n_attention_wv = 32, n_ffn_down = 256, hparams.n_layer = 32
llama_model_quantize_internal: meta size = 780224 bytes
[   1/ 995]                    token_embd.weight - [ 4096, 32000,     1,     1], type =    f32,
====== llama_model_quantize_internal: did not find weights for token_embd.weight
quantizing to iq3_s .. ================================================================= iq3xs_init_impl(grid_size = 512)
iq3xs_init_impl: 24733 neighbours in total
Oops: found point 975 not on grid: 15 0 15 0
GGML_ASSERT: ~/llama.cpp/ggml-quants.c:11145: false
Aborted
segfault crash (3 threads)
~/llama.cpp$ ./build/bin/quantize --imatrix /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-q8_0-2-imatrix.dat /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-f16.gguf /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-iq3_s.gguf IQ3_S 3
load_imatrix: loaded 928 importance matrix entries from /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-q8_0-2-imatrix.dat
prepare_imatrix: have 928 importance matrix entries
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
main: build = 2271 (67fd3313)
main: built with cc (GCC) 13.2.1 20230801 for x86_64-pc-linux-gnu
main: quantizing '/mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-f16.gguf' to '/mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-iq3_s.gguf' as IQ3_S using 3 threads
llama_model_loader: loaded meta data with 24 key-value pairs and 995 tensors from /mnt/c/LLM_MODELS/abacusai/ggml-smaug-mixtral-v0.1-f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = C:\LLM_MODELS\abacusai
llama_model_loader: - kv   2:                       llama.context_length u32              = 32768
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   9:                         llama.expert_count u32              = 8
llama_model_loader: - kv  10:                    llama.expert_used_count u32              = 2
llama_model_loader: - kv  11:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  12:                       llama.rope.freq_base f32              = 1000000.000000
llama_model_loader: - kv  13:                          general.file_type u32              = 0
llama_model_loader: - kv  14:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  15:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  16:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  17:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  18:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  19:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  20:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  21:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  22:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  23:                    tokenizer.chat_template str              = {{ bos_token }}{% for message in mess...
llama_model_loader: - type  f32:  995 tensors
================================ Have weights data with 928 entries
llama_model_quantize_internal ============ Strange model: n_attention_wv = 32, n_ffn_down = 256, hparams.n_layer = 32
llama_model_quantize_internal: meta size = 780224 bytes
[   1/ 995]                    token_embd.weight - [ 4096, 32000,     1,     1], type =    f32,
====== llama_model_quantize_internal: did not find weights for token_embd.weight
quantizing to iq3_s .. ================================================================= iq3xs_init_impl(grid_size = 512)
iq3xs_init_impl: 24733 neighbours in total
Segmentation fault (core dumped)

@ikawrakow
Copy link
Contributor Author

@dranger003 I downloaded your model and imatrices but I cannot reproduce the problem. There is definitely no race in the code (else all quantization types would be affected as the multi-threading mechanism is exactly the same for all quants). Not sure what the issue might be.

@dranger003
Copy link
Contributor

@ikawrakow Thanks for looking into it, appreciate it. For now, I only get the issue with that specific model and I can still get it to work if I run the quantize multiple times until it goes through, so I think this is fine. I'll report back if this becomes a larger issue.

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
* iq4_nl: squash commits for easier rebase

* Basics (quantize, dequantize)
* CUDA dequantize and dot product
* Slightly faster CUDA dot product (120 t/s)
* Switch to 6-bit scales
* Scalar dot product
* AVX2 dot product
* ARM_NEON dot product
* Works on metal, but still slow
* Slightly better Metal dot product
* Another small Metal improvement
* Metal dot product is getting there
* Faster CUDA dot product
* Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided
* Report the actual bpw
* Add _xs mix that is 4.05 bpw for non-MoE models
* Remove IQ4_XS for now, slightly adjust kvalues_iq4nl
* AVX2 dot product uses Q8_0 instead of Q8_K
* Add to test-backend-ops
* Minor fix
* Also use use Q5_K for attn_output in MoE models
* Fixes after merging latest master
* Switching to blocks of 32
* AVX2 for blocks of 32
* Scaler dot product for blocks of 32
* ARM_NEON dot product for blocks of 32
* Metal kernels for blocks of 32
* Slightly faster Metal kernels

* Resurrecting iq3_xs

After all the experimentation, nothing was better than this.

* Minor PPL improvement via a block scale fudge factor

* Minor improvement via 3 neighbours

* iq3_xs: working scalar and AVX2 dot products

* iq3_xs: ARM_NEON dot product - works but extremely slow (10 t/s)

* iq3_xs: working Metal implementation

* Adding IQ3_M - IQ3_XS mix with mostly Q4_K

* iiq3_xs: a 3.4375 bpw variant

* iq3_xs: make CUDA work for new version

* iq3_xs: make scalar and AVX2 work for new version

* iq3_s: make ARM_NEON work with new version

* iq3_xs: make new version work on metal

Performance is very similar to Q3_K_S

* iq3_xs: tiny Metal speed improvement

* iq3_xs: tiny Metal speed improvement

* Fix stupid warning

* Q3_K_XS now uses a mix of IQ3_XS and IQ3_XXS

* iq3_xs: rename to iq3_s

* iq3_s: make tests pass

* Move Q3_K_XS mix to 3.25 bpw

* Attempt to fix failing tests

* Another attempt to fix the Windows builds

* Attempt to fix ROCm

* ROCm again

* iq3_s: partial fix for QK_K = 64

* iq3_s: make it work on metal for QK_K = 64

Pleasent surprise: the coding was super-block size independent,
so all it took was to delete some QK_K == 256 guards.

* Will this fix ROCm?

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* iq4_nl: squash commits for easier rebase

* Basics (quantize, dequantize)
* CUDA dequantize and dot product
* Slightly faster CUDA dot product (120 t/s)
* Switch to 6-bit scales
* Scalar dot product
* AVX2 dot product
* ARM_NEON dot product
* Works on metal, but still slow
* Slightly better Metal dot product
* Another small Metal improvement
* Metal dot product is getting there
* Faster CUDA dot product
* Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided
* Report the actual bpw
* Add _xs mix that is 4.05 bpw for non-MoE models
* Remove IQ4_XS for now, slightly adjust kvalues_iq4nl
* AVX2 dot product uses Q8_0 instead of Q8_K
* Add to test-backend-ops
* Minor fix
* Also use use Q5_K for attn_output in MoE models
* Fixes after merging latest master
* Switching to blocks of 32
* AVX2 for blocks of 32
* Scaler dot product for blocks of 32
* ARM_NEON dot product for blocks of 32
* Metal kernels for blocks of 32
* Slightly faster Metal kernels

* Resurrecting iq3_xs

After all the experimentation, nothing was better than this.

* Minor PPL improvement via a block scale fudge factor

* Minor improvement via 3 neighbours

* iq3_xs: working scalar and AVX2 dot products

* iq3_xs: ARM_NEON dot product - works but extremely slow (10 t/s)

* iq3_xs: working Metal implementation

* Adding IQ3_M - IQ3_XS mix with mostly Q4_K

* iiq3_xs: a 3.4375 bpw variant

* iq3_xs: make CUDA work for new version

* iq3_xs: make scalar and AVX2 work for new version

* iq3_s: make ARM_NEON work with new version

* iq3_xs: make new version work on metal

Performance is very similar to Q3_K_S

* iq3_xs: tiny Metal speed improvement

* iq3_xs: tiny Metal speed improvement

* Fix stupid warning

* Q3_K_XS now uses a mix of IQ3_XS and IQ3_XXS

* iq3_xs: rename to iq3_s

* iq3_s: make tests pass

* Move Q3_K_XS mix to 3.25 bpw

* Attempt to fix failing tests

* Another attempt to fix the Windows builds

* Attempt to fix ROCm

* ROCm again

* iq3_s: partial fix for QK_K = 64

* iq3_s: make it work on metal for QK_K = 64

Pleasent surprise: the coding was super-block size independent,
so all it took was to delete some QK_K == 256 guards.

* Will this fix ROCm?

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
@mofosyne mofosyne added Review Complexity : High Generally require indepth knowledge of LLMs or GPUs Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes labels May 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Review Complexity : High Generally require indepth knowledge of LLMs or GPUs Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet