-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add data distillation for UIE #3136
Conversation
model_zoo/uie/data_distill/README.md
Outdated
@@ -0,0 +1,31 @@ | |||
# UIE数据蒸馏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的UIE数据蒸馏需要解释一下,同时整体的文档太过于简单还是需要优化,概念,数据来源,部署方面都有欠缺
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充UIE数据蒸馏相关概念和数据来源,具体示例和部署说明待补充
# yapf: enable | ||
|
||
# Define your schema here | ||
schema = ["观点词", {"评价维度": ["观点词", "情感倾向[正向,负向]"]}] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要用户自己来自定义schema,看看有没有将字符串转成python code的方式
通过传参的方式传入
relation2id[child.name] = len(relation2id) | ||
schema_list.append(child) | ||
|
||
entity2id['OBJECT'] = len(entity2id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的key为什么要大写了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
for text in tqdm(infer_texts, desc="Predicting: ", leave=False): | ||
infer_results.append(uie(text)) | ||
|
||
train_synthetic_lines = synthetic2distill(texts, infer_results, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的逻辑没有太搞懂,这里的unlabel的数据的结果是来自己Taskflow的输出结果吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯,unlabel数据通过Taskflow load定制模型的形式推理得到合成数据
return schema_tree | ||
|
||
|
||
def schema2label_maps(task_type, schema=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里在schema的分析的需要特别的指出情感分类不在这个数据蒸馏的范围,可以在代码做个提示或者报错
在文档中也要特别的指出
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
model_zoo/uie/data_distill/README.md
Outdated
|
||
#### UIE数据蒸馏三步 | ||
|
||
- **Step 1**: 使用UIE模型对标注数据进行fine-tune,得到Teacher Model。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fine-tune -> finetune,或者整体查看一下,微调这块的统一英文术语是啥样子的了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
model_zoo/uie/data_distill/README.md
Outdated
@@ -0,0 +1,50 @@ | |||
# UIE Slim 数据蒸馏 | |||
|
|||
在UIE强大的抽取能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测。很多工业应用场景对性能要求较高,若不能有效压缩则无法实际应用。因此,我们基于数据蒸馏技术构建了UIE Slim数据蒸馏系统。其原理是通过数据作为桥梁,将UIE模型的知识迁移到小模型,以达到精度损失较小的情况下却能达到大幅度预测速度提升的效果。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在UIE强大的抽取能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测 -> 在UIE强大的抽取能力背后,同样需要较大的算力支持计算
很多工业应用场景对性能要求较高 -> 在一些工业应用场景中对性能的要求较高
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
import paddle.nn as nn | ||
|
||
|
||
class Criterion(nn.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的文档说明需要是英文的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
y_true, | ||
y_pred, | ||
mask_zero=False): | ||
"""稀疏版多标签分类的交叉熵 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
y_pred = paddle.concat([-infs, y_pred[..., 1:]], axis=-1) | ||
y_pos_2 = paddle.take_along_axis(y_pred, y_true, axis=-1) | ||
|
||
pos_loss = (-y_pos_1).exp().sum(axis=-1).log() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle看起来是有logsumexp这个API的,可以尝试提交一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里之前profile使用logsumexp对四维Tensor输入的计算特别慢,这块我整理下最小复现代码反馈给框架同学
model_zoo/uie/data_distill/train.py
Outdated
|
||
|
||
def do_train(): | ||
paddle.disable_static() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么和metric那块都要disable_static,默认看起来就是disable的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
model_zoo/uie/data_distill/train.py
Outdated
|
||
train_ds = load_dataset(reader, data_path=args.train_path, lazy=False) | ||
dev_ds = load_dataset(reader, data_path=args.dev_path, lazy=False) | ||
tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-base-zh") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议的model_name不要默认,可以让用户选择
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
task_type=args.task_type) | ||
|
||
encoder = AutoModel.from_pretrained("ernie-3.0-base-zh") | ||
if args.task_type == "entity_extraction": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不太明白的是,单独实体抽取是走这个分支,如果是实体抽取、关系抽取在一起的任务是不是也可以了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
实体关系抽取可以指定任务类型为relation_extraction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不太明白的是,单独实体抽取是走这个分支,如果是实体抽取、关系抽取在一起的任务是不是也可以了?
请问事件抽取也是选择 task_type为“relation_extraction”吗?
model_zoo/uie/data_distill/train.py
Outdated
# All bias and LayerNorm parameters are excluded. | ||
decay_params = [ | ||
p.name for n, p in model.named_parameters() | ||
if not any(nd in n for nd in ["bias", "norm", "LayerNorm.weight"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里多了一个LayerNorm.weight,之前的经验是norm已经覆盖了,在GPLinkerForRelationExtraction这些模型增加了LayerNorm是吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除,这里GP没有增加LayerNorm
label_maps = get_label_maps(args.task_type, args.label_maps_path) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("ernie-3.0-base-zh") | ||
encoder = AutoModel.from_pretrained("ernie-3.0-base-zh") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
else: | ||
model = GPLinkerForRelationExtraction(encoder, label_maps) | ||
|
||
state_dict = paddle.load( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要判断一下文件是否存在
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
paddlenlp/layers/gp.py
Outdated
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gp.py -> globalpoint.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Description