Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IQ4_NL: 4-bit non-linear quants with blocks of 32 #5590

Merged
merged 6 commits into from
Feb 21, 2024

Conversation

ikawrakow
Copy link
Contributor

@ikawrakow ikawrakow commented Feb 19, 2024

TL;DR

The main purpose of this PR is to provide a 4-bit quantization type that can be used when k- and i-quants that use blocks of 256 are not available (because the number of columns in some tensors are not a multiple of 256).

In short

  • IQ4_NL uses blocks of 32 weights with a fp16 block scales exactly like Q4_0, so models quantized with IQ4_NL are the exact same size as Q4_0 and Q4_K_S.
  • IQ4_NL uses a non-linear mapping to convert quants to weights (more on this below)
  • Quantization quality as measured by perplexity is much better compared to Q4_0 and almost on par with Q4_K_S.
  • Inference performance is almost the same as Q4_0 except on Metal, where it is 8% (prompt processing) or 20% (token generation) slower than Q4_0.
  • If implemented row-wise, where the fp16 block scales are replaced with int8_t block scales (plus one floating point scale per row, which adds a negligible amount of bits), this would be a 4.25 bpw quantization, which has the same quantization error as the 4.5 bpw IQ4_NL added by this PR.

PPL comparisons

The following tables show PPL comparisons between Q4_0, Q4_K_S, and IQ4_NL. We start with the case of not using an importance matrix (I find this to be an important use case as at 4-bit quantization ideally one should not worry too much about having a suitable imatrix to quantize a model).

Table 1 PPL comparison without imatrix for context of 512 tokens

Model PPL (fp16) PPL (Q4_K_S) PPL (Q4_0) PPL (IQ4_NL)
LLaMA-v1-7B 5.9066 6.0078 6.1161 6.0173
LLaMA-v2-7B 5.7891 5.8855 5.9633 5.8842
Mistral-7B 5.6924 5.7764 5.8189 5.7743
LLaMA-v1-13B 5.2551 5.3162 5.3639 5.3229
LLaMA-v2-13B 5.1001 5.1852 5.1993 5.1925

The next table is with an imatrix created from wiki.train.raw

Table 2 PPL comparison with imatrix for context of 512 tokens

Model PPL (fp16) PPL (Q4_K_S) PPL (Q4_0) PPL (IQ4_NL)
LLaMA-v1-7B 5.9066 5.9734 6.0245 5.9976
LLaMA-v2-7B 5.7891 5.8675 5.9158 5.8737
Mistral-7B 5.6924 5.7374 5.7957 5.7462
LLaMA-v1-13B 5.2551 5.2955 5.3103 5.3168
LLaMA-v2-13B 5.1001 5.1674 5.1874 5.1710

Just in case researchers working on quantization happen to see this PR, here are some PPL results for a context of 4096 (LLaMA-v2 and Mistral) or 2048 (LLaMA-v1)

Table 3 PPL comparison with imatrix for context of 4096/2048 tokens

Model PPL (fp16) PPL (Q4_K_S) PPL (Q4_0) PPL (IQ4_NL)
LLaMA-v1-7B 5.2351 5.2923 5.3364 5.3061
LLaMA-v2-7B 4.9352 4.9831 5.0259 4.9894
Mistral-7B 4.7920 4.8291 4.8570 4.8338
LLaMA-v1-13B 4.6646 4.6995 4.7194 4.7160
LLaMA-v2-13B 4.4173 4.4526 4.4733 4.4641
LLaMA-v2-70B 3.0262 3.0616 3.0699

To make the comparison with the approaches that are currently claiming to be SOTA, the next table shows the quantization error defined as QErr = PPL(Qunatized)/PPL(fp16) - 1. I took the values for AQLM and QuIP# from the latest QuIP# paper.

Table 4 Quantization error comparisons

Model QErr (Q4_K_S) QErr (Q4_0) QErr (IQ4_NL) QErr (AQLM) QErr (QuIP#)
LLaMA-v2-7B 0.97% 1.84% 1.10% 1.76% 1.37%
LLaMA-v2-13B 0.80% 1.27% 1.06% 1.53 1.31%
LLaMA-v2-70B 1.17% 1.44% 1.60% 1.92%

Performance comparisons

Table 5 shows PP-512 and TG-128 values for a 7B LLaMA on various platforms

  • Metal is on an M2-Max 30-core GPU
  • ARM_NEON is on an M2-Max CPU using the 8 performance cores
  • CUDA is on an RTX-4080
  • AVX2 is on a Ryzen-7950X CPU using 16 (PP-512) or 8 (TG-128) threads.
model backend test t/s
llama 7B IQ4_NL - 4.5 bpw Metal pp 512 508.75 ± 0.39
llama 7B Q4_0 Metal pp 512 547.27 ± 0.66
llama 7B IQ4_NL - 4.5 bpw Metal tg 128 51.01 ± 0.45
llama 7B Q4_0 Metal tg 128 61.84 ± 0.09
llama 7B IQ4_NL - 4.5 bpw ARM_NEON pp 512 94.35 ± 0.67
llama 7B Q4_0 ARM_NEON pp 512 96.81 ± 0.20
llama 7B IQ4_NL - 4.5 bpw ARM_NEON tg 128 27.48 ± 0.21
llama 7B Q4_0 ARM_NEON tg 128 28.09 ± 0.20
llama 7B IQ4_NL - 4.5 bpw CUDA pp 512 5698.79 ± 14.63
llama 7B Q4_0 CUDA pp 512 5707.98 ± 18.64
llama 7B IQ4_NL - 4.5 bpw CUDA tg 128 129.14 ± 0.12
llama 7B Q4_0 CUDA tg 128 129.31 ± 0.39
llama 7B IQ4_NL - 4.5 bpw AVX2 pp 512 62.59 ± 0.37
llama 7B Q4_0 AVX2 pp 512 67.54 ± 0.41
llama 7B IQ4_NL - 4.5 bpw AVX2 tg 128 14.82 ± 0.01
llama 7B Q4_0 AVX2 tg 128 14.65 ± 0.06

Additional details

It all comes down to this set of 16 magic values

static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};

Where do they come from? I had implemented a K-means clustering based quantization in my private repository (similar to what, e.g., SqeezeLLM does), with clustering done per tensor row. Although I was getting similar or even slightly better results than SqeezeLLM, I was not particularly happy with the quantization quality, so decided to see what happens if I apply block-wise scaling before clustering. It turned out that the cluster means end up being (nearly) independent of the tensor/tensor row. I collected statistics of the cluster means from a few quantized model, and saw that the 16 means of the cluster means can be fit with a 3rd order polynomial that maps quant index to a (scaled) model weight. Using the polynomial fit directly results in a very decent performance on CUDA, acceptable performance on Metal, but is a no-go for CPU SIMD instructions. On the CPU the only thing that gives a good performance is a lookup table containing int8_t values. So, after scaling the polynomial fit to the full int8_t range and rounding to the nearest integer, we end up with the above 16 values.

The initial work on this was done before I implemented the importance matrix. Without imatrix, the non-linear quantization was basically on par with Q4_K in terms of quantization error (see Table 1), while using ~7% fewer bits (if implemented row-wise with blocks of 32). But after the imatrix was added, Q4_K became again slightly better (Tables 2 and 3). The non-linear quantization outperforms Q4_K with blocks of 16. If implemented using super-blocks of 256 with 6-bit block scales, this would be a 4.4375 bpw SOTA quantization (SOTA in the sense that I'm not aware of a quantization approach that achieves a lower quantization error with less than 5 bpw).

* Basics (quantize, dequantize)
* CUDA dequantize and dot product
* Slightly faster CUDA dot product (120 t/s)
* Switch to 6-bit scales
* Scalar dot product
* AVX2 dot product
* ARM_NEON dot product
* Works on metal, but still slow
* Slightly better Metal dot product
* Another small Metal improvement
* Metal dot product is getting there
* Faster CUDA dot product
* Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided
* Report the actual bpw
* Add _xs mix that is 4.05 bpw for non-MoE models
* Remove IQ4_XS for now, slightly adjust kvalues_iq4nl
* AVX2 dot product uses Q8_0 instead of Q8_K
* Add to test-backend-ops
* Minor fix
* Also use use Q5_K for attn_output in MoE models
* Fixes after merging latest master
* Switching to blocks of 32
* AVX2 for blocks of 32
* Scaler dot product for blocks of 32
* ARM_NEON dot product for blocks of 32
* Metal kernels for blocks of 32
* Slightly faster Metal kernels
@sorasoras
Copy link

sorasoras commented Feb 19, 2024

looking good.
Q4KM
llama_model_loader: - type f32: 121 tensors
llama_model_loader: - type q5_0: 20 tensors
llama_model_loader: - type q8_0: 20 tensors
llama_model_loader: - type q4_K: 121 tensors
llama_model_loader: - type q5_K: 40 tensors
llama_model_loader: - type q6_K: 1 tensors
IQ4
llama_model_loader: - type f32: 121 tensors
llama_model_loader: - type q6_K: 1 tensors
llama_model_loader: - type iq4_nl: 201 tensors

so basically,
llama_model_loader: - type q5_0: 20 tensors
llama_model_loader: - type q8_0: 20 tensors
llama_model_loader: - type q4_K: 121 tensors
turn into
llama_model_loader: - type iq4_nl: 201 tensors

the question is can i have a mix between Q4_K super-block of 256 mixing with 32block of IQ4_nl to get even bigger space saving.
P.S
for your reference,9005.66 MiB to 7794.73 MiB is not a small saving

@ikawrakow
Copy link
Contributor Author

@sorasoras IQ1_S, IQ2_XXS, IQ2_XS, Q2_K and IQ3_XXS will all become IQ4_NL. Previously they were changed to Q4_0. Q4_K will become Q5_0 (this is how it currently is, one may consider also using IQ4_NL for Q4_K). The other two are unchanged (Q5_K becomes Q5_1 and Q6_K becomes Q8_0)

@sorasoras
Copy link

@sorasoras IQ1_S, IQ2_XXS, IQ2_XS, Q2_K and IQ3_XXS will all become IQ4_NL. Previously they were changed to Q4_0. Q4_K will become Q5_0 (this is how it currently is, one may consider also using IQ4_NL for Q4_K). The other two are unchanged (Q5_K becomes Q5_1 and Q6_K becomes Q8_0)

hmm, Could we expect a even denser version IQ4 in the future?
A 4-bit non-linear quants with blocks of 256 should combine blocks of 32 to reduce overall size even further.

@Artefact2
Copy link
Collaborator

KL divergence data over wikitext for bagel-dpo-34b-v0.2

image

ROCm benchmarks
model size params backend ngl n_batch test t/s
llama 13B Q4_K - Small 6.91 GiB 13.02 B ROCm 99 128 pp 1024 372.45 ± 0.31
llama 13B Q4_K - Small 6.91 GiB 13.02 B ROCm 99 256 pp 1024 385.95 ± 3.38
llama 13B Q4_K - Small 6.91 GiB 13.02 B ROCm 99 512 pp 1024 397.44 ± 0.24
llama 13B Q4_0 6.88 GiB 13.02 B ROCm 99 128 pp 1024 447.96 ± 0.44
llama 13B Q4_0 6.88 GiB 13.02 B ROCm 99 256 pp 1024 469.20 ± 0.12
llama 13B Q4_0 6.88 GiB 13.02 B ROCm 99 512 pp 1024 481.67 ± 0.14
llama 13B IQ4_NL - 4.5 bpw 6.86 GiB 13.02 B ROCm 99 128 pp 1024 286.10 ± 0.39
llama 13B IQ4_NL - 4.5 bpw 6.86 GiB 13.02 B ROCm 99 256 pp 1024 396.42 ± 1.30
llama 13B IQ4_NL - 4.5 bpw 6.86 GiB 13.02 B ROCm 99 512 pp 1024 398.01 ± 0.19

@sorasoras
Copy link

sorasoras commented Feb 20, 2024

7900XTX at 400W TGP

model size params backend ngl test t/s
qwen 13B IQ4_NL - 4.5 bpw 7.61 GiB 14.17 B ROCm 99 pp 512 1472.16 ± 11.37
qwen 13B IQ4_NL - 4.5 bpw 7.61 GiB 14.17 B ROCm 99 tg 128 71.99 ± 0.27
qwen 13B Q4_K - Medium 8.79 GiB 14.17 B ROCm 99 pp 512 1551.11 ± 19.85
qwen 13B Q4_K - Medium 8.79 GiB 14.17 B ROCm 99 tg 128 61.20 ± 0.16
qwen 13B IQ3_XXS - 3.0625 bpw 6.02 GiB 14.17 B ROCm 99 pp 512 1455.30 ± 12.56
qwen 13B IQ3_XXS - 3.0625 bpw 6.02 GiB 14.17 B ROCm 99 tg 128 75.06 ± 0.82
qwen 13B Q4_1 8.39 GiB 14.17 B ROCm 99 pp 512 1499.75 ± 14.64
qwen 13B Q4_1 8.39 GiB 14.17 B ROCm 99 tg 128 71.33 ± 0.25
qwen 13B Q2_K - Small 5.34 GiB 14.17 B ROCm 99 pp 512 1499.88 ± 15.80
qwen 13B Q2_K - Small 5.34 GiB 14.17 B ROCm 99 tg 128 86.23 ± 0.49

It's surprise that NL offer comparable performance to Q4_1

@JianbangZ
Copy link

JianbangZ commented Feb 20, 2024

Tested on QWEN1.5-14B, saved about 150MB file size on 3K_X_S (3.71 BPW --> 3.63 BPW) with roughly the same PPL. Thanks for the contribution.

@sorasoras
Copy link

with change introduce by IQ4_NL, IQ2_XS can beat the mainline Q2_K_S in term of PPL with the same imatrix
PPL = 5.1274 to PPL = 5.0616
5227.47 MiB to 4721.82 MiB

@ikawrakow ikawrakow merged commit a14679c into master Feb 21, 2024
57 checks passed
@ikawrakow ikawrakow deleted the ik/iq4_nl_no_superblock branch February 21, 2024 09:39
@JianbangZ
Copy link

@ikawrakow due to the recent big changes and new implementations of k-quant, could you help compile a table showing the difference among all quant types?

@EinhartStratos
Copy link

EinhartStratos commented Feb 27, 2024

Can not run IQ4_NL with mmq on 4070ti
GGML_ASSERT: C:\llama.cpp\ggml-cuda.cu:9539: false
the GGML_ASSERT is in ggml_cuda_op_dequantize_mul_mat_vec

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
* iq4_nl: squash commits for easier rebase

* Basics (quantize, dequantize)
* CUDA dequantize and dot product
* Slightly faster CUDA dot product (120 t/s)
* Switch to 6-bit scales
* Scalar dot product
* AVX2 dot product
* ARM_NEON dot product
* Works on metal, but still slow
* Slightly better Metal dot product
* Another small Metal improvement
* Metal dot product is getting there
* Faster CUDA dot product
* Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided
* Report the actual bpw
* Add _xs mix that is 4.05 bpw for non-MoE models
* Remove IQ4_XS for now, slightly adjust kvalues_iq4nl
* AVX2 dot product uses Q8_0 instead of Q8_K
* Add to test-backend-ops
* Minor fix
* Also use use Q5_K for attn_output in MoE models
* Fixes after merging latest master
* Switching to blocks of 32
* AVX2 for blocks of 32
* Scaler dot product for blocks of 32
* ARM_NEON dot product for blocks of 32
* Metal kernels for blocks of 32
* Slightly faster Metal kernels

* iq4_nl: Fix after merging with master

* iq4_nl: another fix after merging with master

* Use IQ4_NL instead of Q4_K when using k-quants is not possible

* Fix typo that makes several tests fail

* It was the ggml_vdotq thing missed inside the brackets

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* iq4_nl: squash commits for easier rebase

* Basics (quantize, dequantize)
* CUDA dequantize and dot product
* Slightly faster CUDA dot product (120 t/s)
* Switch to 6-bit scales
* Scalar dot product
* AVX2 dot product
* ARM_NEON dot product
* Works on metal, but still slow
* Slightly better Metal dot product
* Another small Metal improvement
* Metal dot product is getting there
* Faster CUDA dot product
* Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided
* Report the actual bpw
* Add _xs mix that is 4.05 bpw for non-MoE models
* Remove IQ4_XS for now, slightly adjust kvalues_iq4nl
* AVX2 dot product uses Q8_0 instead of Q8_K
* Add to test-backend-ops
* Minor fix
* Also use use Q5_K for attn_output in MoE models
* Fixes after merging latest master
* Switching to blocks of 32
* AVX2 for blocks of 32
* Scaler dot product for blocks of 32
* ARM_NEON dot product for blocks of 32
* Metal kernels for blocks of 32
* Slightly faster Metal kernels

* iq4_nl: Fix after merging with master

* iq4_nl: another fix after merging with master

* Use IQ4_NL instead of Q4_K when using k-quants is not possible

* Fix typo that makes several tests fail

* It was the ggml_vdotq thing missed inside the brackets

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
mishig25 pushed a commit to huggingface/huggingface.js that referenced this pull request Apr 11, 2024
@netrunnereve netrunnereve mentioned this pull request Jun 22, 2024
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants