-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unit tests for generation models #3018
Conversation
unk_token=unk_token, | ||
pad_token=pad_token, | ||
mask_token=mask_token, | ||
**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
170--813 这段不用加,会有hook处理
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
bos_token=bos_token, | ||
eos_token=eos_token, | ||
eol_token=eol_token, | ||
**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,这里也不用添加
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
import json | ||
import jieba | ||
import shutil | ||
import sentencepiece as spm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是否需要jieba和sentencepiece
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
generated_summaries = tok.batch_decode( | ||
hypotheses_batch.tolist(), | ||
clean_up_tokenization_spaces=True, | ||
skip_special_tokens=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是否要对生成的结果来check下,看HF的会有这个的判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是因为 bart-large 生成出来的结果是乱的,无意义的句子,因此暂时去掉了 assert,待单测全部搞完之后,需要对模型权重本身做验证
|
||
decoder_start_token_id = ( | ||
decoder_start_token_id if decoder_start_token_id is not None else | ||
getattr(self, pretrained_model_name).config.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretrained_model_name是从model获取的吗,这个又是在哪里设置的呢,感觉作为model的attr不太合适
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, | ||
axis=0) | ||
|
||
kwargs["use_cache"] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面_sample_generate那些不需要设置这个吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里和 hf 对齐设置,有的单测跑 use_cache
为 True
,而有的为 False
,能同时兼顾到不同配置的情况
PR types
Others
PR changes
Others
Description
Add unit tests for generation models.
Done:
Unit tests:
test_tokenizer_common.py
to support morefrom_pretrained_filter
Functions:
TODO: