-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add data distillation for UIE #3136
Changes from 3 commits
7ee1828
1504d41
c0f91c9
66cacee
f9dbd5c
77a7974
d83a9a8
c2b6a26
0ff0ca2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# UIE Slim 数据蒸馏 | ||
|
||
在UIE强大的抽取能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测。很多工业应用场景对性能要求较高,若不能有效压缩则无法实际应用。因此,我们基于数据蒸馏技术构建了UIE Slim数据蒸馏系统。其原理是通过数据作为桥梁,将UIE模型的知识迁移到小模型,以达到精度损失较小的情况下却能达到大幅度预测速度提升的效果。 | ||
|
||
#### UIE数据蒸馏三步 | ||
|
||
- **Step 1**: 使用UIE模型对标注数据进行fine-tune,得到Teacher Model。 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fine-tune -> finetune,或者整体查看一下,微调这块的统一英文术语是啥样子的了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
|
||
- **Step 2**: 用户提供大规模无标注数据,需与标注数据同源。使用Taskflow UIE对无监督数据进行预测。 | ||
|
||
- **Step 3**: 使用标注数据以及步骤2得到的合成数据训练出Student Model。 | ||
|
||
|
||
## 1.数据蒸馏 | ||
|
||
- 合成封闭域训练及评估数据 | ||
|
||
```shell | ||
python data_generate.py --data_dir ../law_data --output_dir law_distill --task_type relation_extraction --synthetic_ratio 1 | ||
``` | ||
|
||
## 2.封闭域训练 | ||
|
||
```shell | ||
python train.py --task_type relation_extraction --train_path law_distill/train_data.json --dev_path law_distill/dev_data.json --label_maps_path law_distill/label_maps.json | ||
``` | ||
|
||
## 3.Taskflow装载 | ||
|
||
```shell | ||
from paddlenlp import Taskflow | ||
|
||
ie = Taskflow("information_extraction", model="uie-data-distill-gp", task_path="checkpoint/model_best/") | ||
``` | ||
|
||
## 实验效果 | ||
|
||
5-shot | ||
|
||
5-shot + UIE数据蒸馏 | ||
|
||
full-shot | ||
|
||
# References | ||
|
||
- **[GlobalPointer](https://kexue.fm/search/globalpointer/)** | ||
|
||
- **[GPLinker](https://kexue.fm/archives/8888)** | ||
|
||
- **[JunnYu/GPLinker_pytorch](https://github.com/JunnYu/GPLinker_pytorch)** |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import paddle | ||
import paddle.nn as nn | ||
|
||
|
||
class Criterion(nn.Layer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块的文档说明需要是英文的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
'''Criterion for TPLinkerPlus''' | ||
|
||
def __init__(self, mask_zero=True): | ||
self.mask_zero = mask_zero | ||
|
||
def _sparse_multilabel_categorical_crossentropy(self, | ||
y_true, | ||
y_pred, | ||
mask_zero=False): | ||
"""稀疏版多标签分类的交叉熵 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, thx |
||
说明: | ||
1. y_true.shape=[..., num_positive], | ||
y_pred.shape=[..., num_classes]; | ||
2. 请保证y_pred的值域是全体实数,换言之一般情况下 | ||
y_pred不用加激活函数,尤其是不能加sigmoid或者 | ||
softmax; | ||
3. 预测阶段则输出y_pred大于0的类; | ||
4. 详情请看:https://kexue.fm/archives/7359 。 | ||
""" | ||
paddle.disable_static() | ||
zeros = paddle.zeros_like(y_pred[..., :1]) | ||
y_pred = paddle.concat([y_pred, zeros], axis=-1) | ||
if mask_zero: | ||
infs = zeros + 1e12 | ||
y_pred = paddle.concat([infs, y_pred[..., 1:]], axis=-1) | ||
y_pos_2 = paddle.take_along_axis(y_pred, y_true, axis=-1) | ||
y_pos_1 = paddle.concat([y_pos_2, zeros], axis=-1) | ||
if mask_zero: | ||
y_pred = paddle.concat([-infs, y_pred[..., 1:]], axis=-1) | ||
y_pos_2 = paddle.take_along_axis(y_pred, y_true, axis=-1) | ||
|
||
pos_loss = (-y_pos_1).exp().sum(axis=-1).log() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. paddle看起来是有logsumexp这个API的,可以尝试提交一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里之前profile使用logsumexp对四维Tensor输入的计算特别慢,这块我整理下最小复现代码反馈给框架同学 |
||
all_loss = y_pred.exp().sum(axis=-1).log() | ||
aux_loss = y_pos_2.exp().sum(axis=-1).log() - all_loss | ||
aux_loss = paddle.clip(1 - paddle.exp(aux_loss), min=1e-10, max=1) | ||
neg_loss = all_loss + paddle.log(aux_loss) | ||
return pos_loss + neg_loss | ||
|
||
def __call__(self, y_pred, y_true): | ||
shape = y_pred.shape | ||
y_true = y_true[..., 0] * shape[2] + y_true[..., 1] | ||
# bs, nclass, seqlen * seqlen | ||
y_pred = paddle.reshape(y_pred, | ||
shape=[shape[0], -1, | ||
np.prod(shape[2:])]) | ||
|
||
loss = self._sparse_multilabel_categorical_crossentropy( | ||
y_true, y_pred, self.mask_zero) | ||
return loss.sum(axis=1).mean() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Dict, List, Optional, Union | ||
from dataclasses import dataclass | ||
|
||
import paddle | ||
from paddlenlp.transformers.tokenizer_utils_base import PretrainedTokenizerBase, PaddingStrategy | ||
|
||
ignore_list = ["offset_mapping", "text"] | ||
|
||
|
||
@dataclass | ||
class DataCollator: | ||
tokenizer: PretrainedTokenizerBase | ||
padding: Union[bool, str, PaddingStrategy] = True | ||
max_length: Optional[int] = None | ||
label_maps: Optional[dict] = None | ||
task_type: Optional[str] = None | ||
|
||
def __call__( | ||
self, features: List[Dict[str, Union[List[int], paddle.Tensor]]] | ||
) -> Dict[str, paddle.Tensor]: | ||
labels = ([feature["labels"] for feature in features] | ||
if "labels" in features[0].keys() else None) | ||
new_features = [{ | ||
k: v | ||
for k, v in f.items() if k not in ["labels"] + ignore_list | ||
} for f in features] | ||
|
||
batch = self.tokenizer.pad( | ||
new_features, | ||
padding=self.padding, | ||
) | ||
|
||
batch = [paddle.to_tensor(batch[k]) for k in batch.keys()] | ||
|
||
if labels is None: # for test | ||
if "offset_mapping" in features[0].keys(): | ||
batch.append( | ||
[feature["offset_mapping"] for feature in features]) | ||
if "text" in features[0].keys(): | ||
batch.append([feature["text"] for feature in features]) | ||
return batch | ||
|
||
bs = batch[0].shape[0] | ||
if self.task_type == "entity_extraction": | ||
max_ent_num = max([len(lb["ent_labels"]) for lb in labels]) | ||
num_ents = len(self.label_maps["entity2id"]) | ||
batch_entity_labels = paddle.zeros( | ||
shape=[bs, num_ents, max_ent_num, 2], dtype="int64") | ||
for i, lb in enumerate(labels): | ||
for eidx, (l, eh, et) in enumerate(lb['ent_labels']): | ||
batch_entity_labels[i, l, | ||
eidx, :] = paddle.to_tensor([eh, et]) | ||
|
||
batch.append([batch_entity_labels]) | ||
else: | ||
max_ent_num = max([len(lb["ent_labels"]) for lb in labels]) | ||
max_spo_num = max([len(lb["rel_labels"]) for lb in labels]) | ||
num_ents = len(self.label_maps["entity2id"]) | ||
if "relation2id" in self.label_maps.keys(): | ||
num_rels = len(self.label_maps["relation2id"]) | ||
else: | ||
num_rels = len(self.label_maps["sentiment2id"]) | ||
batch_entity_labels = paddle.zeros( | ||
shape=[bs, num_ents, max_ent_num, 2], dtype="int64") | ||
batch_head_labels = paddle.zeros( | ||
shape=[bs, num_rels, max_spo_num, 2], dtype="int64") | ||
batch_tail_labels = paddle.zeros( | ||
shape=[bs, num_rels, max_spo_num, 2], dtype="int64") | ||
|
||
for i, lb in enumerate(labels): | ||
for eidx, (l, eh, et) in enumerate(lb['ent_labels']): | ||
batch_entity_labels[i, l, | ||
eidx, :] = paddle.to_tensor([eh, et]) | ||
for spidx, (sh, st, p, oh, ot) in enumerate(lb['rel_labels']): | ||
batch_head_labels[i, p, | ||
spidx, :] = paddle.to_tensor([sh, oh]) | ||
batch_tail_labels[i, p, | ||
spidx, :] = paddle.to_tensor([st, ot]) | ||
batch.append( | ||
[batch_entity_labels, batch_head_labels, batch_tail_labels]) | ||
return batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在UIE强大的抽取能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测 -> 在UIE强大的抽取能力背后,同样需要较大的算力支持计算
很多工业应用场景对性能要求较高 -> 在一些工业应用场景中对性能的要求较高
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx