Paper: TaCL: Improve BERT Pre-training with Token-aware Contrastive Learning
Authors: Yixuan Su, Fangyu Liu, Zaiqiao Meng, Lei Shu, Ehsan Shareghi, and Nigel Collier
Main Paper Repo: [https://github.com/yxuansu/TaCL]
如果我们提供的资源对你有帮助,请考虑引用我们的文章。
@article{DBLP:journals/corr/abs-2111-04198,
author = {Yixuan Su and
Fangyu Liu and
Zaiqiao Meng and
Lei Shu and
Ehsan Shareghi and
Nigel Collier},
title = {TaCL: Improving {BERT} Pre-training with Token-aware Contrastive Learning},
journal = {CoRR},
volume = {abs/2111.04198},
year = {2021},
url = {https://arxiv.org/abs/2111.04198},
eprinttype = {arXiv},
eprint = {2111.04198},
timestamp = {Wed, 10 Nov 2021 16:07:30 +0100},
biburl = {https://dblp.org/rec/journals/corr/abs-2111-04198.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
python version == 3.8
pip install -r requirements.txt
Chinese TaCL BERT + CRF
Model Name | Model Address |
---|---|
Chinese (cambridgeltl/tacl-bert-base-chinese) | link |
import torch
# initialize model
from transformers import AutoModel, AutoTokenizer
model_name = 'cambridgeltl/tacl-bert-base-chinese'
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# create input ids
text = "中文TaCL-BERT模型真强大!"
text = "[CLS] " + text + " [SEP]"
tokenized_token_list = tokenizer.tokenize(text)
input_ids = torch.LongTensor(tokenizer.convert_tokens_to_ids(tokenized_token_list)).view(1, -1)
# compute hidden states
representation = model(input_ids).last_hidden_state # [1, seqlen, embed_dim]
(1). 命名实体识别: (1) MSRA (2) OntoNotes (3) Resume (4) Weibo
(2). 中文分词: (1) PKU (2) CityU (3) AS
chmod +x ./download_benchmark_data.sh
./download_benchmark_data.sh
chmod +x ./download_checkpoints.sh
./download_checkpoints.sh
cd ./sh_folder/inference/
chmod +x ./inference_{}.sh
./inference_{}.sh
对于不同的数据集{}的取值为['msra', 'ontonotes', 'weibo', 'resume', 'pku', 'cityu', 'as'],相关参数的含义为:
--saved_ckpt_path: 训练好的模型位置
--train_path: 训练集数据路径
--dev_path: 验证集数据路径
--test_path: 测试集数据路径
--label_path: 数据标签路径
--batch_size: inference时的batch size
使用提供的模型进行inference后,可以得到如下结果。
Dataset | Precision | Recall | F1 |
---|---|---|---|
MSRA | 95.41 | 95.47 | 95.44 |
OntoNotes | 81.88 | 82.98 | 82.42 |
Resume | 96.48 | 96.42 | 96.45 |
68.40 | 70.73 | 69.54 | |
PKU | 97.04 | 96.46 | 96.75 |
CityU | 98.16 | 98.19 | 98.18 |
AS | 96.51 | 96.99 | 96.75 |
cd ./sh_folder/train/
chmod +x ./{}.sh
./{}.sh
对于不同的数据集{}的取值为['msra', 'ontonotes', 'weibo', 'resume', 'pku', 'cityu', 'as'],相关参数的含义为:
--model_name: 中文TaCL BERT的模型名称(cambridgeltl/tacl-bert-base-chinese)
--train_path: 训练集数据路径
--dev_path: 验证集数据路径
--test_path: 测试集数据路径
--label_path: 数据标签路径
--learning_rate: 学习率
--number_of_gpu: 可使用的GPU数量
--number_of_runs: 重复试验次数
--save_path_prefix: 模型存储路径
[Note 1] 我们没有对模型进行任何和学习率调参,2e-5只是默认值。通过调整学习率也许可以获得更好的结果。
[Note 2] 实际的batch size等于gradient_accumulation_steps x number_of_gpu x batch_size_per_gpu。我们推荐将其设置为128。
Inference: 使用在./sh_folder/inference/路径中的sh进行inference。将--saved_ckpt_path设置为自己重新训练好的模型的路径。
以下我们使用MSRA数据集作为范例。
(使用以下代码前,请先下载我们提供的训练好的模型以及数据集。具体的指导请见以上章节)
# 载入数据
from dataclass import Data
from transformers import AutoTokenizer
model_name = 'cambridgeltl/tacl-bert-base-chinese'
tokenizer = AutoTokenizer.from_pretrained(model_name)
data_path = r'./benchmark_data/NER/MSRANER/MSRA.test.char.txt'
label_path = r'./benchmark_data/NER/MSRANER/MSRA_NER_Label.txt'
max_len = 128
data = Data(tokenizer, data_path, data_path, data_path, label_path, max_len)
# 载入模型
import torch
from model import NERModel
model = NERModel(model_name, data.num_class)
ckpt_path = r'./pretrained_ckpt/msra/msra_ckpt'
model_ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
model_parameters = model_ckpt['model']
model.load_state_dict(model_parameters)
model.eval()
# 提供输入
text = "中 共 中 央 致 中 国 致 公 党 十 一 大 的 贺 词"
text = "[CLS] " + text + " [SEP]"
tokens = tokenizer.tokenize(text)
# process token input
input_id = tokenizer.convert_tokens_to_ids(tokens)
input_id = torch.LongTensor(input_id).view(1, -1)
attn_mask = ~input_id.eq(data.pad_idx)
tgt_mask = [1.0] * len(tokens)
tgt_mask = torch.tensor(tgt_mask, dtype=torch.uint8).contiguous().view(1,-1)
# 使用模型进行解码
x = model.decode(input_id, attn_mask, tgt_mask)[0][1:-1] # remove [CLS] and [SEP] tokens.
res = ' '.join([data.id2label_dict[tag] for tag in x])
print (res)
# 模型输出结果:
# B-NT M-NT M-NT E-NT O B-NT M-NT M-NT M-NT M-NT M-NT M-NT E-NT O O O
# 标准预测结果:
# B-NT M-NT M-NT E-NT O B-NT M-NT M-NT M-NT M-NT M-NT M-NT E-NT O O O
如果有任何的问题,以下是我的联系方式([email protected])。