Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core tokenization] add_dummy_prefix_space option to help with latest issues #28010

Merged
merged 40 commits into from
Feb 20, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Dec 13, 2023

What does this PR do?

Allows users to use tokenizer.tokenize controlling the addition of prefix space. Let's also update fast!

fixes #28622

@huggingface huggingface deleted a comment from github-actions bot Jan 15, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker mentioned this pull request Jan 16, 2024
@ArthurZucker ArthurZucker marked this pull request as ready for review January 18, 2024 11:05
@gabegrand
Copy link

Just wanted to say this would be hugely helpful for us over at https://github.com/probcomp/hfppl !

@haileyschoelkopf
Copy link
Contributor

Likewise the ability to not include an extra SPIECE_UNDERLINE / Llama token 29871 when encoding a word with a space in front ( <word>) would be huge for https://github.com/EleutherAI/lm-evaluation-harness !

@LysandreJik LysandreJik self-requested a review February 20, 2024 10:51
Copy link
Collaborator Author

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll let @Lysandre decide, but instead of following what we do with bloom I'd rather we convert from slow. Bit slower but at least we are sure we use the correct logic.
This is done with a warning.

src/transformers/models/llama/tokenization_llama_fast.py Outdated Show resolved Hide resolved
src/transformers/models/llama/tokenization_llama_fast.py Outdated Show resolved Hide resolved
src/transformers/models/llama/tokenization_llama_fast.py Outdated Show resolved Hide resolved
src/transformers/models/llama/tokenization_llama_fast.py Outdated Show resolved Hide resolved
@ArthurZucker
Copy link
Collaborator Author

Failing test is unrelated 😉

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this looks good to me

Comment on lines +127 to +131
if add_prefix_space is not None:
logger.warning_once(
"You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers"
)
kwargs["from_slow"] = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does it take to convert the tokenizer from slow? If it's quick we can move it to info

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Around 10 seconds I believe!

@ArthurZucker ArthurZucker merged commit 15cfe38 into main Feb 20, 2024
19 of 21 checks passed
@ArthurZucker ArthurZucker deleted the add-prefix-space branch February 20, 2024 11:50
itazap pushed a commit that referenced this pull request May 14, 2024
…test issues (#28010)

* add add_dummy_prefix_space option to slow

* checking kwargs might be better. Should be there for all spm tokenizer IMO

* nits

* fix copies

* more copied

* nits

* add prefix space

* nit

* nits

* Update src/transformers/convert_slow_tokenizer.py

* fix inti

* revert wrong styling

* fix

* nits

* style

* updates

* make sure we use slow tokenizer for conversion instead of looking for the decoder

* support llama ast well

* update llama tokenizer fast

* nits

* nits nits nits

* update the doc

* update

* update to fix tests

* skip unrelated tailing test

* Update src/transformers/convert_slow_tokenizer.py

* add proper testing

* test decode as well

* more testing

* format

* fix llama test

* Apply suggestions from code review
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Can LlamaTokenizerFast support the argument add_prefix_space = False
7 participants