Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When '--gather_all_token_logits' is enabled, the first token appears to be abnormal." #639

Closed
StarrickLiu opened this issue Dec 12, 2023 · 3 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@StarrickLiu
Copy link

StarrickLiu commented Dec 12, 2023

Problem Desciption:

When building the engine with the '--gather_all_token_logits' option, there seems to be an issue.

If constructed with '--gather_all_token_logits', there is a high probability of garbled characters appearing in the first token.

However, if built without '--gather_all_token_logits' while keeping other parameters consistent, the first token appears normal.

llama1 7B:

With --gather_all_token_logits'

Build Command:

python3 build.py --model_dir=/path/to/llama-7b-hf/ \
                  --dtype bfloat16 \
                  --use_gpt_attention_plugin bfloat16 \
                  --use_gemm_plugin bfloat16 \
                  --output_dir /path/to/llama-7b-trt/0.6.1-cf-pe1-gatl-mb-bf16-8_gpu-8k-2k-bs4 \
                  --world_size 8 \
                  --tp_size 8 \
                  --max_input_len 8192 \
                  --max_output_len 2048 \
                  --max_batch_size 4 \
                  --remove_input_padding \
                  --enable_context_fmha \
                  --parallel_build \
                  --multi_block_mode \
                  --gather_all_token_logits \
                  --use_parallel_embedding \
                  --embedding_sharding_dim 1

image

Without --gather_all_token_logits'

Build Command:

python3 build.py --model_dir=/path/to/llama-7b-hf/ \
                  --dtype bfloat16 \
                  --use_gpt_attention_plugin bfloat16 \
                  --use_gemm_plugin bfloat16 \
                  --output_dir /path/to/llama-7b-trt/0.6.1-cf-pe1-mb-bf16-8_gpu-8k-2k-bs4 \
                  --world_size 8 \
                  --tp_size 8 \
                  --max_input_len 8192 \
                  --max_output_len 2048 \
                  --max_batch_size 4 \
                  --remove_input_padding \
                  --enable_context_fmha \
                  --parallel_build \
                  --multi_block_mode \
                  --use_parallel_embedding \
                  --embedding_sharding_dim 1

image

Worth noting is that this issue has been tested in previous versions as well as in version 0.6.1.

@byshiue byshiue self-assigned this Dec 15, 2023
@byshiue byshiue added the triaged Issue has been triaged by maintainers label Dec 15, 2023
@byshiue
Copy link
Collaborator

byshiue commented Dec 15, 2023

It should be a bug of 0.6.1, and should be fixed in latest main branch. Please take a try.

@StarrickLiu
Copy link
Author

image

In testing with the new version, everything is fine. Thank you.

@salaki
Copy link

salaki commented Feb 20, 2024

@StarrickLiu, wondering if you successfully got the logtis. Is the logits for outputed tokens or all token for vocab?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants