Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling preprocessing with token-classification pipeline #15785

Closed
tkon3 opened this issue Feb 23, 2022 · 7 comments
Closed

Handling preprocessing with token-classification pipeline #15785

tkon3 opened this issue Feb 23, 2022 · 7 comments

Comments

@tkon3
Copy link

tkon3 commented Feb 23, 2022

Hi,

Token classification tasks (e.g NER) usually rely on splitted inputs (a list of words). The tokenizer is then used with is_split_into_words=True argument during training.
However the pipeline for token-classification does not handle this preprocessing and tokenizes the raw input.
This can lead to different predictions if we use some custom preprocessing because tokens are different.

from transformers import AutoTokenizer
from tokenizers import pre_tokenizers
from tokenizers.pre_tokenizers import Punctuation, WhitespaceSplit

MODEL_NAME = "camembert-base"

pre_tokenizer = pre_tokenizers.Sequence([WhitespaceSplit(), Punctuation()])
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

text = "Ceci est un exemple n'est-ce pas ?"
splitted_text = [w[0] for w in pre_tokenizer.pre_tokenize_str(text)]

tokenizer.convert_ids_to_tokens(tokenizer(text).input_ids)
> ['<s>', '▁Ceci', '▁est', '▁un', '▁exemple', '▁n', "'", 'est', '-', 'ce', '▁pas', '▁?', '</s>']

tokenizer.convert_ids_to_tokens(tokenizer(splitted_text, is_split_into_words=True).input_ids)
> ['<s>', '▁Ceci', '▁est', '▁un', '▁exemple', '▁n', '▁', "'", '▁est', '▁-', '▁ce', '▁pas', '▁?', '</s>']

How can you make reliable predictions if you exactly know how the input is preprocessed ?
Pre-tokenizer and tokenizer mapping_offsets have to be merged somehow.

@LysandreJik
Copy link
Member

cc @SaulLu @Narsil

@Narsil
Copy link
Contributor

Narsil commented Feb 23, 2022

Hi @tkon3 ,

Normally this should be taken care of directly by the pre_tokenizer.
It seems camembert-base uses WhitespaceSplit() and Metaspace and you merely want to add Punctuation.

Doing something like

from transformers import AutoTokenizer
from tokenizers import pre_tokenizers
from tokenizers.pre_tokenizers import Punctuation, WhitespaceSplit, Metaspace

MODEL_NAME = "camembert-base"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([WhitespaceSplit(), Punctuation(), Metaspace()])

text = "Ceci est un exemple n'est-ce pas ?"

print(tokenizer.convert_ids_to_tokens(tokenizer(text).input_ids))
print(tokenizer(text).tokens())
# ['<s>', '▁Ceci', '▁est', '▁un', '▁exemple', '▁n', '▁', "'", '▁est', '▁-', '▁ce', '▁pas', '▁?', '</s>']
# ['<s>', '▁Ceci', '▁est', '▁un', '▁exemple', '▁n', '▁', "'", '▁est', '▁-', '▁ce', '▁pas', '▁?', '</s>']

tokenizer.save_pretrained("camembert-punctuation")

Should be enough so that no one forgets how the input was preprocessed at training and make pipelines work along offsets. Does that work ?

There's no real other way to handle custom pretokenization in pipelines, it needs to be included somehow directly in the tokenizer's method.

@tkon3
Copy link
Author

tkon3 commented Feb 23, 2022

hi @Narsil

Thank you for the quick answer, your idea works.
But the pipeline partially returns the expected result (aggregation_strategy="simple"):

  • "start" and "end" indexes are fine
  • the returned "word" is wrong (the pipeline adds additional spaces)

To get the correct word, I have to manually extract it from the input sentence. Not a big deal in my case but can be a problem for someone else:

sentence = "Ceci est un exemple n'est-ce pas ?"

pipe = TokenClassificationPipeline(model, tokenizer, aggregation_strategy="simple")
entities = pipe(sentence)
true_entities = [sentence[e["start"]: e["end"]] for e in entities]

@Narsil
Copy link
Contributor

Narsil commented Feb 23, 2022

Well, "words" don't really exist for this particular brand of tokenizer (Unigram with WhitespaceSplit). If I am not mistaken.

Extracting from the original string like you did is the only way to do it consistently.
The "word" is the decoded version of the token(s) and it's what the model sees. (so there's an extra space since there's no difference between " this" and "this" for this tokenizer.)

Do you mind creating a new issue for this, since it seems like a new problem ?

I do think what you're implying is correct that we should have entity["word"] == sentence[entity["start"]: entity["stop"]].
Unfortunately that seems like a breaking change, so extra caution would have to be taken should we make such a change (or add a new key for instance)

@tkon3
Copy link
Author

tkon3 commented Feb 24, 2022

You are right, this is more general than what I expected. It is similar with uncased tokenizers:

from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

MODEL_NAME = "elastic/distilbert-base-uncased-finetuned-conll03-english"
text = "My name is Clara and I live in Berkeley, California."

model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

pipe = pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
entities = pipe(text)

print([(text[entity["start"]:entity["end"]], entity["word"]) for entity in entities])
# [('Clara', 'clara'), ('Berkeley', 'berkeley'), ('California', 'california')]

Adding a new key is probably the easiest way.
I guess this can implemented inside the postprocess function with a postprocess param:

grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
if some_postprocess_param:
    grouped_entities = [{"raw_word": sentence[e["start"]:e["end"]], **e} for e in grouped_entities]

Want me to do a feature request ?

@Narsil
Copy link
Contributor

Narsil commented Feb 24, 2022

I think we should move to another issue/feature request for this yes (the discussion has diverged since the first message).

Can you ping me, and Lysandre on it and refer to this original issue for people wanting more context ?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants