Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data distillation for UIE #3136

Merged
merged 9 commits into from
Sep 5, 2022
Merged

Conversation

linjieccc
Copy link
Contributor

PR types

New features

PR changes

APIs

Description

  • 新增通过UIE及数据蒸馏的方式训练封闭域信息抽取模型的示例

@linjieccc linjieccc requested a review from wawltor August 24, 2022 14:26
@linjieccc linjieccc self-assigned this Aug 24, 2022
@linjieccc linjieccc added ie Issues related to Information Extraction taskflow Taskflow labels Aug 24, 2022
@@ -0,0 +1,31 @@
# UIE数据蒸馏
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的UIE数据蒸馏需要解释一下,同时整体的文档太过于简单还是需要优化,概念,数据来源,部署方面都有欠缺

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充UIE数据蒸馏相关概念和数据来源,具体示例和部署说明待补充

# yapf: enable

# Define your schema here
schema = ["观点词", {"评价维度": ["观点词", "情感倾向[正向,负向]"]}]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要用户自己来自定义schema,看看有没有将字符串转成python code的方式
通过传参的方式传入

relation2id[child.name] = len(relation2id)
schema_list.append(child)

entity2id['OBJECT'] = len(entity2id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的key为什么要大写了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

for text in tqdm(infer_texts, desc="Predicting: ", leave=False):
infer_results.append(uie(text))

train_synthetic_lines = synthetic2distill(texts, infer_results,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的逻辑没有太搞懂,这里的unlabel的数据的结果是来自己Taskflow的输出结果吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,unlabel数据通过Taskflow load定制模型的形式推理得到合成数据

return schema_tree


def schema2label_maps(task_type, schema=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在schema的分析的需要特别的指出情感分类不在这个数据蒸馏的范围,可以在代码做个提示或者报错

在文档中也要特别的指出

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx


#### UIE数据蒸馏三步

- **Step 1**: 使用UIE模型对标注数据进行fine-tune,得到Teacher Model。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine-tune -> finetune,或者整体查看一下,微调这块的统一英文术语是啥样子的了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -0,0 +1,50 @@
# UIE Slim 数据蒸馏

在UIE强大的抽取能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测。很多工业应用场景对性能要求较高,若不能有效压缩则无法实际应用。因此,我们基于数据蒸馏技术构建了UIE Slim数据蒸馏系统。其原理是通过数据作为桥梁,将UIE模型的知识迁移到小模型,以达到精度损失较小的情况下却能达到大幅度预测速度提升的效果。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在UIE强大的抽取能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测 -> 在UIE强大的抽取能力背后,同样需要较大的算力支持计算

很多工业应用场景对性能要求较高 -> 在一些工业应用场景中对性能的要求较高

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

import paddle.nn as nn


class Criterion(nn.Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的文档说明需要是英文的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

y_true,
y_pred,
mask_zero=False):
"""稀疏版多标签分类的交叉熵
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

y_pred = paddle.concat([-infs, y_pred[..., 1:]], axis=-1)
y_pos_2 = paddle.take_along_axis(y_pred, y_true, axis=-1)

pos_loss = (-y_pos_1).exp().sum(axis=-1).log()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle看起来是有logsumexp这个API的,可以尝试提交一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里之前profile使用logsumexp对四维Tensor输入的计算特别慢,这块我整理下最小复现代码反馈给框架同学



def do_train():
paddle.disable_static()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么和metric那块都要disable_static,默认看起来就是disable的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除


train_ds = load_dataset(reader, data_path=args.train_path, lazy=False)
dev_ds = load_dataset(reader, data_path=args.dev_path, lazy=False)
tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-base-zh")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议的model_name不要默认,可以让用户选择

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

task_type=args.task_type)

encoder = AutoModel.from_pretrained("ernie-3.0-base-zh")
if args.task_type == "entity_extraction":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不太明白的是,单独实体抽取是走这个分支,如果是实体抽取、关系抽取在一起的任务是不是也可以了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实体关系抽取可以指定任务类型为relation_extraction

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不太明白的是,单独实体抽取是走这个分支,如果是实体抽取、关系抽取在一起的任务是不是也可以了?

请问事件抽取也是选择 task_type为“relation_extraction”吗?

# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm", "LayerNorm.weight"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里多了一个LayerNorm.weight,之前的经验是norm已经覆盖了,在GPLinkerForRelationExtraction这些模型增加了LayerNorm是吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除,这里GP没有增加LayerNorm

label_maps = get_label_maps(args.task_type, args.label_maps_path)

tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-base-zh")
encoder = AutoModel.from_pretrained("ernie-3.0-base-zh")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

else:
model = GPLinkerForRelationExtraction(encoder, label_maps)

state_dict = paddle.load(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要判断一下文件是否存在

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gp.py -> globalpoint.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ie Issues related to Information Extraction taskflow Taskflow
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants