diff --git a/model_zoo/uie/data_distill/README.md b/model_zoo/uie/data_distill/README.md new file mode 100644 index 000000000000..dea0c17c8c84 --- /dev/null +++ b/model_zoo/uie/data_distill/README.md @@ -0,0 +1,140 @@ +# UIE Slim 数据蒸馏 + +在UIE强大的抽取能力背后,同样需要较大的算力支持计算。在一些工业应用场景中对性能的要求较高,若不能有效压缩则无法实际应用。因此,我们基于数据蒸馏技术构建了UIE Slim数据蒸馏系统。其原理是通过数据作为桥梁,将UIE模型的知识迁移到封闭域信息抽取小模型,以达到精度损失较小的情况下却能达到大幅度预测速度提升的效果。 + +#### UIE数据蒸馏三步 + +- **Step 1**: 使用UIE模型对标注数据进行finetune,得到Teacher Model。 + +- **Step 2**: 用户提供大规模无标注数据,需与标注数据同源。使用Taskflow UIE对无监督数据进行预测。 + +- **Step 3**: 使用标注数据以及步骤2得到的合成数据训练出封闭域Student Model。 + +## 数据准备 + +本项目中从CMeIE数据集中采样少量数据展示了UIE数据蒸馏流程,[示例数据下载](https://bj.bcebos.com/paddlenlp/datasets/uie/doccano_ext.json),解压后放在``../data``目录下。 + +示例数据包含以下两部分: + +| 名称 | 数量 | +| :---: | :-----: | +| 标注数据(doccano格式) | 200 | +| 无标注数据 | 1277 | + +## UIE Finetune + +参考[UIE主文档](../README.md)完成UIE模型微调。 + +```shell +python finetune.py --train_path ./data/train.txt --dev_path ./data/dev.txt --learning_rate 5e-6 --batch_size 2 +``` + +## 离线蒸馏 + +#### 通过训练好的UIE定制模型预测无监督数据的标签 + +```shell +python data_distill.py --data_path ../data --save_dir student_data --task_type relation_extraction --synthetic_ratio 10 --model_path ../checkpoint/model_best +``` + +可配置参数说明: + +- `data_path`: 标注数据(`doccano_ext.json`)及无监督文本(`unlabeled_data.txt`)路径。 +- `model_path`: 训练好的UIE定制模型路径。 +- `save_dir`: 学生模型训练数据保存路径。 +- `synthetic_ratio`: 控制合成数据的比例。最大合成数据数量=synthetic_ratio*标注数据数量。 +- `task_type`: 选择任务类型,可选有`entity_extraction`,`relation_extraction`,`event_extraction`和`opinion_extraction`。因为是封闭域信息抽取,需指定任务类型。 +- `seed`: 随机种子,默认为1000。 + +#### 老师模型评估 + +UIE微调阶段针对UIE训练格式数据评估模型效果(该评估方式非端到端评估,不适合关系、事件等任务),可通过以下评估脚本针对原始标注格式数据评估模型效果 + +```shell +python evaluate_teacher.py \ + --task_type relation_extraction \ + --test_path ./student_data/dev_data.json \ + --label_maps_path ./student_data/label_maps.json \ + --model_path ../checkpoint/model_best +``` + +可配置参数说明: + +- `model_path`: 训练好的UIE定制模型路径。 +- `test_path`: 测试数据集路径。 +- `label_maps_path`: 学生模型标签字典。 +- `batch_size`: 批处理大小,默认为8。 +- `max_seq_len`: 最大文本长度,默认为256。 +- `task_type`: 选择任务类型,可选有`entity_extraction`,`relation_extraction`,`event_extraction`和`opinion_extraction`。因为是封闭域信息抽取的评估,需指定任务类型。 + + +#### 学生模型训练 + +```shell +python train.py \ + --task_type relation_extraction \ + --train_path student_data/train_data.json \ + --dev_path student_data/dev_data.json \ + --label_maps_path student_data/label_maps.json \ + --num_epochs 200 \ + --encoder ernie-3.0-mini-zh +``` + +可配置参数说明: + +- `train_path`: 训练集文件路径。 +- `dev_path`: 验证集文件路径。 +- `batch_size`: 批处理大小,默认为16。 +- `learning_rate`: 学习率,默认为3e-5。 +- `save_dir`: 模型存储路径,默认为`./checkpoint`。 +- `max_seq_len`: 最大文本长度,默认为256。 +- `weight_decay`: 表示AdamW优化器中使用的 weight_decay 的系数。 +- `warmup_proportion`: 学习率warmup策略的比例,如果0.1,则学习率会在前10%训练step的过程中从0慢慢增长到learning_rate, 而后再缓慢衰减,默认为0.0。 +- `num_epochs`: 训练轮数,默认为100。 +- `seed`: 随机种子,默认为1000。 +- `encoder`: 选择学生模型的模型底座,默认为`ernie-3.0-mini-zh`。 +- `task_type`: 选择任务类型,可选有`entity_extraction`,`relation_extraction`,`event_extraction`和`opinion_extraction`。因为是封闭域信息抽取,需指定任务类型。 +- `logging_steps`: 日志打印的间隔steps数,默认10。 +- `valid_steps`: evaluate的间隔steps数,默认200。 +- `device`: 选用什么设备进行训练,可选cpu或gpu。 +- `init_from_ckpt`: 可选,模型参数路径,热启动模型训练;默认为None。 + + +## Taskflow部署学生模型 + +- 通过Taskflow一键部署封闭域信息抽取模型,`task_path`为学生模型路径。 + +```python +>>> from pprint import pprint +>>> from paddlenlp import Taskflow + +>>> ie = Taskflow("information_extraction", model="uie-data-distill-gp", task_path="checkpoint/model_best/") # Schema is fixed in closed-domain information extraction +>>> pprint(ie("登革热@结果 升高 ### 血清白蛋白水平 检查 结果 检查 在资源匮乏地区和富足地区,对有症状患者均应早期检测。")) +[{'疾病': [{'end': 3, + 'probability': 0.99952424, + 'relations': {'实验室检查': [{'end': 21, + 'probability': 0.994445, + 'relations': {}, + 'start': 14, + 'text': '血清白蛋白水平'}]}, + 'start': 0, + 'text': '登革热'}]}] +``` + +## 效果验证 + +| 模型 | Entity-F1 | SPO-F1 | +| :---: | :--------: | :--------: | +| UIE-Finetune | 78.57 | 56.25 | +| GPLinker-ernie-3.0-mini-zh | 68.18 | 47.06 | +| GPLinker-ernie-3.0-mini-zh + UIE数据蒸馏 | 76.38 | 50.42 | + +# References + +- **[GlobalPointer](https://kexue.fm/search/globalpointer/)** + +- **[GPLinker](https://kexue.fm/archives/8888)** + +- **[JunnYu/GPLinker_pytorch](https://github.com/JunnYu/GPLinker_pytorch)** + +- **[CBLUE](https://github.com/CBLUEbenchmark/CBLUE)** diff --git a/model_zoo/uie/data_distill/criterion.py b/model_zoo/uie/data_distill/criterion.py new file mode 100644 index 000000000000..d23aee3576dd --- /dev/null +++ b/model_zoo/uie/data_distill/criterion.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn + + +class Criterion(nn.Layer): + '''Criterion for GPNet''' + + def __init__(self, mask_zero=True): + self.mask_zero = mask_zero + + def _sparse_multilabel_categorical_crossentropy(self, + y_true, + y_pred, + mask_zero=False): + """Sparse multi-label categorical cross entropy + reference to "https://kexue.fm/archives/7359". + """ + zeros = paddle.zeros_like(y_pred[..., :1]) + y_pred = paddle.concat([y_pred, zeros], axis=-1) + if mask_zero: + infs = zeros + 1e12 + y_pred = paddle.concat([infs, y_pred[..., 1:]], axis=-1) + y_pos_2 = paddle.take_along_axis(y_pred, y_true, axis=-1) + y_pos_1 = paddle.concat([y_pos_2, zeros], axis=-1) + if mask_zero: + y_pred = paddle.concat([-infs, y_pred[..., 1:]], axis=-1) + y_pos_2 = paddle.take_along_axis(y_pred, y_true, axis=-1) + + pos_loss = (-y_pos_1).exp().sum(axis=-1).log() + all_loss = y_pred.exp().sum(axis=-1).log() + aux_loss = y_pos_2.exp().sum(axis=-1).log() - all_loss + aux_loss = paddle.clip(1 - paddle.exp(aux_loss), min=0.1, max=1) + neg_loss = all_loss + paddle.log(aux_loss) + return pos_loss + neg_loss + + def __call__(self, y_pred, y_true): + shape = y_pred.shape + y_true = y_true[..., 0] * shape[2] + y_true[..., 1] + # bs, nclass, seqlen * seqlen + y_pred = paddle.reshape(y_pred, + shape=[shape[0], -1, + np.prod(shape[2:])]) + + loss = self._sparse_multilabel_categorical_crossentropy( + y_true, y_pred, self.mask_zero) + return loss.sum(axis=1).mean() diff --git a/model_zoo/uie/data_distill/data_collator.py b/model_zoo/uie/data_distill/data_collator.py new file mode 100644 index 000000000000..2cde572e324e --- /dev/null +++ b/model_zoo/uie/data_distill/data_collator.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Union +from dataclasses import dataclass + +import paddle +from paddlenlp.transformers.tokenizer_utils_base import PretrainedTokenizerBase, PaddingStrategy + +ignore_list = ["offset_mapping", "text"] + + +@dataclass +class DataCollator: + tokenizer: PretrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + label_maps: Optional[dict] = None + task_type: Optional[str] = None + + def __call__( + self, features: List[Dict[str, Union[List[int], paddle.Tensor]]] + ) -> Dict[str, paddle.Tensor]: + labels = ([feature["labels"] for feature in features] + if "labels" in features[0].keys() else None) + new_features = [{ + k: v + for k, v in f.items() if k not in ["labels"] + ignore_list + } for f in features] + + batch = self.tokenizer.pad( + new_features, + padding=self.padding, + ) + + batch = [paddle.to_tensor(batch[k]) for k in batch.keys()] + + if labels is None: # for test + if "offset_mapping" in features[0].keys(): + batch.append( + [feature["offset_mapping"] for feature in features]) + if "text" in features[0].keys(): + batch.append([feature["text"] for feature in features]) + return batch + + bs = batch[0].shape[0] + if self.task_type == "entity_extraction": + max_ent_num = max([len(lb["ent_labels"]) for lb in labels]) + num_ents = len(self.label_maps["entity2id"]) + batch_entity_labels = paddle.zeros( + shape=[bs, num_ents, max_ent_num, 2], dtype="int64") + for i, lb in enumerate(labels): + for eidx, (l, eh, et) in enumerate(lb['ent_labels']): + batch_entity_labels[i, l, + eidx, :] = paddle.to_tensor([eh, et]) + + batch.append([batch_entity_labels]) + else: + max_ent_num = max([len(lb["ent_labels"]) for lb in labels]) + max_spo_num = max([len(lb["rel_labels"]) for lb in labels]) + num_ents = len(self.label_maps["entity2id"]) + if "relation2id" in self.label_maps.keys(): + num_rels = len(self.label_maps["relation2id"]) + else: + num_rels = len(self.label_maps["sentiment2id"]) + batch_entity_labels = paddle.zeros( + shape=[bs, num_ents, max_ent_num, 2], dtype="int64") + batch_head_labels = paddle.zeros( + shape=[bs, num_rels, max_spo_num, 2], dtype="int64") + batch_tail_labels = paddle.zeros( + shape=[bs, num_rels, max_spo_num, 2], dtype="int64") + + for i, lb in enumerate(labels): + for eidx, (l, eh, et) in enumerate(lb['ent_labels']): + batch_entity_labels[i, l, + eidx, :] = paddle.to_tensor([eh, et]) + for spidx, (sh, st, p, oh, ot) in enumerate(lb['rel_labels']): + batch_head_labels[i, p, + spidx, :] = paddle.to_tensor([sh, oh]) + batch_tail_labels[i, p, + spidx, :] = paddle.to_tensor([st, ot]) + batch.append( + [batch_entity_labels, batch_head_labels, batch_tail_labels]) + return batch diff --git a/model_zoo/uie/data_distill/data_distill.py b/model_zoo/uie/data_distill/data_distill.py new file mode 100644 index 000000000000..550b9fcc64c7 --- /dev/null +++ b/model_zoo/uie/data_distill/data_distill.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import math +import random +import argparse +from tqdm import tqdm + +import numpy as np +import paddle +from paddlenlp import Taskflow +from paddlenlp.taskflow.utils import SchemaTree +from paddlenlp.utils.log import logger + +from utils import set_seed, build_tree, schema2label_maps, doccano2distill, synthetic2distill + + +def do_data_distill(): + set_seed(args.seed) + + # Generate closed-domain label maps + if not os.path.exists(args.save_dir): + os.mkdir(args.save_dir) + label_maps = schema2label_maps(args.task_type, schema=args.schema) + label_maps_path = os.path.join(args.save_dir, "label_maps.json") + + # Save closed-domain label maps file + with open(label_maps_path, "w") as fp: + fp.write(json.dumps(label_maps, ensure_ascii=False)) + + # Load doccano file and convert to distill format + sample_index = json.loads( + open(os.path.join(args.data_path, "sample_index.json")).readline()) + + train_ids = sample_index["train_ids"] + dev_ids = sample_index["dev_ids"] + test_ids = sample_index["test_ids"] + + json_lines = [] + with open(os.path.join(args.data_path, "doccano_ext.json")) as fp: + for line in fp: + json_lines.append(json.loads(line)) + + train_lines = [json_lines[i] for i in train_ids] + train_lines = doccano2distill(train_lines, args.task_type, label_maps) + + dev_lines = [json_lines[i] for i in dev_ids] + dev_lines = doccano2distill(dev_lines, args.task_type, label_maps) + + test_lines = [json_lines[i] for i in test_ids] + test_lines = doccano2distill(test_lines, args.task_type, label_maps) + + # Load trained UIE model + uie = Taskflow("information_extraction", + schema=args.schema, + task_path=args.model_path) + + if args.synthetic_ratio > 0: + # Generate synthetic data + texts = open(os.path.join(args.data_path, + "unlabeled_data.txt")).readlines() + + actual_ratio = math.ceil(len(texts) / len(train_lines)) + if actual_ratio <= args.synthetic_ratio or args.synthetic_ratio == -1: + infer_texts = texts + else: + idxs = random.sample(range(0, len(texts)), + args.synthetic_ratio * len(train_lines)) + infer_texts = [texts[i] for i in idxs] + + infer_results = [] + for text in tqdm(infer_texts, desc="Predicting: ", leave=False): + infer_results.append(uie(text)) + + train_synthetic_lines = synthetic2distill(texts, infer_results, + args.task_type) + + # Concat origin and synthetic data + train_lines.extend(train_synthetic_lines) + + def _save_examples(save_dir, file_name, examples): + count = 0 + save_path = os.path.join(save_dir, file_name) + with open(save_path, "w", encoding="utf-8") as f: + for example in examples: + f.write(json.dumps(example, ensure_ascii=False) + "\n") + count += 1 + logger.info("Save %d examples to %s." % (count, save_path)) + + _save_examples(args.save_dir, "train_data.json", train_lines) + _save_examples(args.save_dir, "dev_data.json", dev_lines) + _save_examples(args.save_dir, "test_data.json", test_lines) + + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser() + + parser.add_argument("--data_path", default="../data", type=str, help="The directory for labeled data with doccano format and the large scale unlabeled data.") + parser.add_argument("--model_path", type=str, default="../checkpoint/model_best", help="The path of saved model that you want to load.") + parser.add_argument("--save_dir", default="./distill_task", type=str, help="The path of data that you wanna save.") + parser.add_argument("--synthetic_ratio", default=10, type=int, help="The ratio of labeled and synthetic samples.") + parser.add_argument("--task_type", choices=['relation_extraction', 'event_extraction', 'entity_extraction', 'opinion_extraction'], default="entity_extraction", type=str, help="Select the training task type.") + parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization") + + args = parser.parse_args() + # yapf: enable + + # Define your schema here + schema = {"疾病": ["手术治疗", "实验室检查", "影像学检查"]} + + args.schema = schema + + do_data_distill() diff --git a/model_zoo/uie/data_distill/evaluate.py b/model_zoo/uie/data_distill/evaluate.py new file mode 100644 index 000000000000..b664c4fc65ca --- /dev/null +++ b/model_zoo/uie/data_distill/evaluate.py @@ -0,0 +1,98 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +from tqdm import tqdm + +import paddle +from paddlenlp.datasets import load_dataset +from paddlenlp.transformers import AutoTokenizer, AutoModel +from paddlenlp.utils.log import logger +from paddlenlp.layers import GlobalPointerForEntityExtraction, GPLinkerForRelationExtraction + +from utils import postprocess, create_dataloader, reader, get_label_maps +from metric import get_eval + + +@paddle.no_grad() +def evaluate(model, dataloader, label_maps, task_type="relation_extraction"): + model.eval() + all_preds = ([], []) if task_type in [ + "opinion_extraction", "relation_extraction", "event_extraction" + ] else [] + for batch in tqdm(dataloader, desc="Evaluating: ", leave=False): + input_ids, attention_masks, offset_mappings, texts = batch + logits = model(input_ids, attention_masks) + batch_outputs = postprocess(logits, offset_mappings, texts, label_maps, + task_type) + if isinstance(batch_outputs, tuple): + all_preds[0].extend(batch_outputs[0]) # Entity output + all_preds[1].extend(batch_outputs[1]) # Relation output + else: + all_preds.extend(batch_outputs) + eval_results = get_eval(all_preds, dataloader.dataset.raw_data, task_type) + model.train() + return eval_results + + +def do_eval(): + label_maps = get_label_maps(args.task_type, args.label_maps_path) + + tokenizer = AutoTokenizer.from_pretrained(args.encoder) + encoder = AutoModel.from_pretrained(args.encoder) + if args.task_type == "entity_extraction": + model = GlobalPointerForEntityExtraction(encoder, label_maps) + else: + model = GPLinkerForRelationExtraction(encoder, label_maps) + + if args.model_path: + state_dict = paddle.load( + os.path.join(args.model_path, "model_state.pdparams")) + model.set_dict(state_dict) + + test_ds = load_dataset(reader, data_path=args.test_path, lazy=False) + + test_dataloader = create_dataloader(test_ds, + tokenizer, + max_seq_len=args.max_seq_len, + batch_size=args.batch_size, + label_maps=label_maps, + mode="test", + task_type=args.task_type) + + eval_result = evaluate(model, + test_dataloader, + label_maps, + task_type=args.task_type) + logger.info("Evaluation precision: " + str(eval_result)) + + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", type=str, default=None, help="The path of saved model that you want to load.") + parser.add_argument("--test_path", type=str, default=None, help="The path of test set.") + parser.add_argument("--encoder", default="ernie-3.0-base-zh", type=str, help="Select the pretrained encoder model for GP.") + parser.add_argument("--label_maps_path", default="./ner_data/label_maps.json", type=str, help="The file path of the labels dictionary.") + parser.add_argument("--batch_size", type=int, default=16, help="Batch size per GPU/CPU for training.") + parser.add_argument("--max_seq_len", type=int, default=128, help="The maximum total input sequence length after tokenization.") + parser.add_argument("--task_type", choices=['relation_extraction', 'event_extraction', 'entity_extraction', 'opinion_extraction'], default="entity_extraction", type=str, help="Select the training task type.") + + args = parser.parse_args() + # yapf: enable + + do_eval() diff --git a/model_zoo/uie/data_distill/metric.py b/model_zoo/uie/data_distill/metric.py new file mode 100644 index 000000000000..57cf075e62b8 --- /dev/null +++ b/model_zoo/uie/data_distill/metric.py @@ -0,0 +1,73 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_eval(all_preds, raw_data, task_type): + if task_type == "entity_extraction": + ex, ey, ez = 1e-10, 1e-10, 1e-10 + for ent_preds, data in zip(all_preds, raw_data): + pred_ent_set = set([tuple(p.values()) for p in ent_preds]) + gold_ent_set = set([tuple(g.values()) for g in data["entity_list"]]) + ex += len(pred_ent_set & gold_ent_set) + ey += len(pred_ent_set) + ez += len(gold_ent_set) + ent_f1 = round(2 * ex / (ey + ez), 5) if ex != 1e-10 else 0. + ent_precision = round(ex / ey, 5) if ey != 1e-10 else 0. + ent_recall = round(ex / ez, 5) if ez != 1e-10 else 0. + + return { + "entity_f1": ent_f1, + "entity_precision": ent_precision, + "entity_recall": ent_recall, + } + else: + all_ent_preds, all_rel_preds = all_preds + + ex, ey, ez = 1e-10, 1e-10, 1e-10 + for ent_preds, data in zip(all_ent_preds, raw_data): + pred_ent_set = set([tuple(p.values()) for p in ent_preds]) + gold_ent_set = set([tuple(g.values()) for g in data["entity_list"]]) + ex += len(pred_ent_set & gold_ent_set) + ey += len(pred_ent_set) + ez += len(gold_ent_set) + ent_f1 = round(2 * ex / (ey + ez), 5) if ex != 1e-10 else 0. + ent_precision = round(ex / ey, 5) if ey != 1e-10 else 0. + ent_recall = round(ex / ez, 5) if ez != 1e-10 else 0. + + rx, ry, rz = 1e-10, 1e-10, 1e-10 + + for rel_preds, raw_data in zip(all_rel_preds, raw_data): + pred_rel_set = set([tuple(p.values()) for p in rel_preds]) + if task_type == "opinion_extraction": + gold_rel_set = set( + [tuple(g.values()) for g in raw_data["aso_list"]]) + else: + gold_rel_set = set( + [tuple(g.values()) for g in raw_data["spo_list"]]) + rx += len(pred_rel_set & gold_rel_set) + ry += len(pred_rel_set) + rz += len(gold_rel_set) + + rel_f1 = round(2 * rx / (ry + rz), 5) if rx != 1e-10 else 0. + rel_precision = round(rx / ry, 5) if ry != 1e-10 else 0. + rel_recall = round(rx / rz, 5) if rz != 1e-10 else 0. + + return { + "entity_f1": ent_f1, + "entity_precision": ent_precision, + "entity_recall": ent_recall, + "relation_f1": rel_f1, + "relation_precision": rel_precision, + "relation_recall": rel_recall + } diff --git a/model_zoo/uie/data_distill/train.py b/model_zoo/uie/data_distill/train.py new file mode 100644 index 000000000000..d5a3e7ee4fbe --- /dev/null +++ b/model_zoo/uie/data_distill/train.py @@ -0,0 +1,192 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import time +import os +from functools import partial + +import paddle +import paddle.nn as nn +from paddlenlp.datasets import load_dataset +from paddlenlp.transformers import AutoTokenizer, AutoModel +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.utils.log import logger +from paddlenlp.layers import GlobalPointerForEntityExtraction, GPLinkerForRelationExtraction + +from evaluate import evaluate +from criterion import Criterion +from utils import reader, set_seed, get_label_maps, create_dataloader, criteria_map, save_model_config + + +def do_train(): + paddle.set_device(args.device) + rank = paddle.distributed.get_rank() + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + set_seed(args.seed) + + label_maps = get_label_maps(args.task_type, args.label_maps_path) + + train_ds = load_dataset(reader, data_path=args.train_path, lazy=False) + dev_ds = load_dataset(reader, data_path=args.dev_path, lazy=False) + tokenizer = AutoTokenizer.from_pretrained(args.encoder) + + train_dataloader = create_dataloader(train_ds, + tokenizer, + max_seq_len=args.max_seq_len, + batch_size=args.batch_size, + label_maps=label_maps, + mode="train", + task_type=args.task_type) + + dev_dataloader = create_dataloader(dev_ds, + tokenizer, + max_seq_len=args.max_seq_len, + batch_size=args.batch_size, + label_maps=label_maps, + mode="dev", + task_type=args.task_type) + + encoder = AutoModel.from_pretrained(args.encoder) + if args.task_type == "entity_extraction": + model = GlobalPointerForEntityExtraction(encoder, label_maps) + else: + model = GPLinkerForRelationExtraction(encoder, label_maps) + + model_config = { + "task_type": args.task_type, + "label_maps": label_maps, + "encoder": args.encoder + } + + num_training_steps = len(train_dataloader) * args.num_epochs + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, + args.warmup_proportion) + + # Generate parameter names needed to perform weight decay. + # All bias and LayerNorm parameters are excluded. + decay_params = [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_params) + + if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt): + state_dict = paddle.load(args.init_from_ckpt) + model.set_dict(state_dict) + + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + criterion = Criterion() + + global_step, best_f1 = 1, 0. + tr_loss, logging_loss = 0.0, 0.0 + tic_train = time.time() + for epoch in range(1, args.num_epochs + 1): + for batch in train_dataloader: + input_ids, attention_masks, labels = batch + + logits = model(input_ids, attention_masks) + + loss = sum([criterion(o, l) for o, l in zip(logits, labels)]) / 3 + + loss.backward() + + tr_loss += loss.item() + + lr_scheduler.step() + optimizer.step() + optimizer.clear_grad() + + if global_step % args.logging_steps == 0 and rank == 0: + time_diff = time.time() - tic_train + loss_avg = (tr_loss - logging_loss) / args.logging_steps + logger.info( + "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s" + % (global_step, epoch, loss_avg, + args.logging_steps / time_diff)) + logging_loss = tr_loss + tic_train = time.time() + + if global_step % args.valid_steps == 0 and rank == 0: + save_dir = os.path.join(args.save_dir, "model_%d" % global_step) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + save_param_path = os.path.join(save_dir, "model_state.pdparams") + paddle.save(model.state_dict(), save_param_path) + save_model_config(save_dir, model_config) + logger.disable() + tokenizer.save_pretrained(save_dir) + logger.enable() + + eval_result = evaluate(model, + dev_dataloader, + label_maps, + task_type=args.task_type) + logger.info("Evaluation precision: " + str(eval_result)) + + f1 = eval_result[criteria_map[args.task_type]] + if f1 > best_f1: + logger.info( + f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}" + ) + best_f1 = f1 + save_dir = os.path.join(args.save_dir, "model_best") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + save_param_path = os.path.join(save_dir, + "model_state.pdparams") + paddle.save(model.state_dict(), save_param_path) + save_model_config(save_dir, model_config) + logger.disable() + tokenizer.save_pretrained(save_dir) + logger.enable() + tic_train = time.time() + + global_step += 1 + + +if __name__ == "__main__": + # yapf: disable + parser = argparse.ArgumentParser() + + parser.add_argument("--train_path", default=None, type=str, help="The path of train set.") + parser.add_argument("--dev_path", default=None, type=str, help="The path of dev set.") + parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.") + parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument("--save_dir", default='./checkpoint', type=str, help="The output directory where the model checkpoints will be written.") + parser.add_argument("--max_seq_len", default=256, type=int, help="The maximum input sequence length.") + parser.add_argument("--label_maps_path", default="./ner_data/label_maps.json", type=str, help="The file path of the labels dictionary.") + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay rate for L2 regularizer.") + parser.add_argument("--warmup_proportion", default=0.0, type=float, help="Linear warmup proption over the training process.") + parser.add_argument("--num_epochs", default=100, type=int, help="Number of epoches for training.") + parser.add_argument("--seed", default=1000, type=int, help="Random seed for initialization") + parser.add_argument("--encoder", default="ernie-3.0-mini-zh", type=str, help="Select the pretrained encoder model for GP.") + parser.add_argument("--task_type", choices=['relation_extraction', 'event_extraction', 'entity_extraction', 'opinion_extraction'], default="entity_extraction", type=str, help="Select the training task type.") + parser.add_argument("--logging_steps", default=10, type=int, help="The interval steps to logging.") + parser.add_argument("--valid_steps", default=200, type=int, help="The interval steps to evaluate model performance.") + parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.") + parser.add_argument("--init_from_ckpt", default=None, type=str, help="The path of model parameters for initialization.") + + args = parser.parse_args() + # yapf: enable + + do_train() diff --git a/model_zoo/uie/data_distill/utils.py b/model_zoo/uie/data_distill/utils.py new file mode 100644 index 000000000000..e6f0c8f54984 --- /dev/null +++ b/model_zoo/uie/data_distill/utils.py @@ -0,0 +1,569 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import copy +import json +import os +import random +from itertools import groupby +from tqdm import tqdm + +import numpy as np +import paddle +from paddlenlp.utils.log import logger + +from data_collator import DataCollator + +criteria_map = { + "entity_extraction": "entity_f1", + "opinion_extraction": "relation_f1", # (Aspect, Sentiment, Opinion) + "relation_extraction": "relation_f1", # (Subject, Predicate, Object) + "event_extraction": "relation_f1" # (Trigger, Role, Argument) +} + + +def set_seed(seed): + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + + +def reader(data_path): + with open(data_path, 'r', encoding='utf-8') as f: + for line in f: + json_line = json.loads(line) + yield json_line + + +def save_model_config(save_dir, model_config): + model_config_file = os.path.join(save_dir, "model_config.json") + with open(model_config_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(model_config, ensure_ascii=False, indent=2)) + + +def map_offset(ori_offset, offset_mapping): + """ + map ori offset to token offset + """ + for index, span in enumerate(offset_mapping): + if span[0] <= ori_offset < span[1]: + return index + return -1 + + +def get_label_maps(task_type="relation_extraction", label_maps_path=None): + with open(label_maps_path, 'r', encoding='utf-8') as fp: + label_maps = json.load(fp) + if task_type == "entity_extraction": + entity2id = label_maps['entity2id'] + id2entity = {idx: t for t, idx in entity2id.items()} + label_maps['id2entity'] = id2entity + else: + entity2id = label_maps['entity2id'] + relation2id = label_maps['relation2id'] if task_type in [ + "relation_extraction", "event_extraction" + ] else label_maps['sentiment2id'] + id2entity = {idx: t for t, idx in entity2id.items()} + id2relation = {idx: t for t, idx in relation2id.items()} + label_maps['id2entity'] = id2entity + label_maps['id2relation'] = id2relation + return label_maps + + +def create_dataloader(dataset, + tokenizer, + max_seq_len=128, + batch_size=1, + label_maps=None, + mode="train", + task_type="relation_extraction"): + + def tokenize_and_align_train_labels(example): + tokenized_inputs = tokenizer( + example['text'], + max_length=max_seq_len, + padding=False, + truncation=True, + return_attention_mask=True, + return_token_type_ids=False, + return_offsets_mapping=True, + ) + offset_mapping = tokenized_inputs["offset_mapping"] + + ent_labels = [] + for e in example["entity_list"]: + _start, _end = e['start_index'], e['start_index'] + len( + e['text']) - 1 + start = map_offset(_start, offset_mapping) + end = map_offset(_end, offset_mapping) + if start == -1 or end == -1: + continue + label = label_maps['entity2id'][e['type']] + ent_labels.append([label, start, end]) + + outputs = { + "input_ids": tokenized_inputs["input_ids"], + "attention_mask": tokenized_inputs["attention_mask"], + "labels": { + "ent_labels": ent_labels, + "rel_labels": [] + } + } + + if task_type in ["relation_extraction", "event_extraction"]: + rel_labels = [] + for r in example["spo_list"]: + _sh, _oh = r["subject_start_index"], r["object_start_index"] + _st, _ot = _sh + len(r["subject"]) - 1, _oh + len( + r["object"]) - 1 + sh = map_offset(_sh, offset_mapping) + st = map_offset(_st, offset_mapping) + oh = map_offset(_oh, offset_mapping) + ot = map_offset(_ot, offset_mapping) + if sh == -1 or st == -1 or oh == -1 or ot == -1: + continue + p = label_maps["relation2id"][r["predicate"]] + rel_labels.append([sh, st, p, oh, ot]) + outputs['labels']['rel_labels'] = rel_labels + elif task_type == "opinion_extraction": + rel_labels = [] + for r in example["aso_list"]: + _ah, _oh = r["aspect_start_index"], r["opinion_start_index"] + _at, _ot = _ah + len(r["aspect"]) - 1, _oh + len( + r["opinion"]) - 1 + ah = map_offset(_ah, offset_mapping) + at = map_offset(_at, offset_mapping) + oh = map_offset(_oh, offset_mapping) + ot = map_offset(_ot, offset_mapping) + if ah == -1 or at == -1 or oh == -1 or ot == -1: + continue + + s = label_maps["sentiment2id"][r["sentiment"]] + rel_labels.append([ah, at, s, oh, ot]) + outputs['labels']['rel_labels'] = rel_labels + return outputs + + def tokenize(example): + tokenized_inputs = tokenizer( + example['text'], + max_length=max_seq_len, + padding=False, + truncation=True, + return_attention_mask=True, + return_offsets_mapping=True, + return_token_type_ids=False, + ) + tokenized_inputs['text'] = example['text'] + return tokenized_inputs + + if mode == "train": + dataset = dataset.map(tokenize_and_align_train_labels) + else: + dataset_copy = copy.deepcopy(dataset) + dataset = dataset.map(tokenize) + + data_collator = DataCollator(tokenizer, + label_maps=label_maps, + task_type=task_type) + + shuffle = True if mode == "train" else False + batch_sampler = paddle.io.BatchSampler(dataset=dataset, + batch_size=batch_size, + shuffle=shuffle) + dataloader = paddle.io.DataLoader(dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=data_collator, + num_workers=0, + return_list=True) + if mode != "train": + dataloader.dataset.raw_data = dataset_copy + return dataloader + + +def postprocess(batch_outputs, + offset_mappings, + texts, + label_maps, + task_type="relation_extraction"): + if task_type == "entity_extraction": + batch_ent_results = [] + for entity_output, offset_mapping, text in zip(batch_outputs[0].numpy(), + offset_mappings, texts): + entity_output[:, [0, -1]] -= np.inf + entity_output[:, :, [0, -1]] -= np.inf + ent_list = [] + for l, start, end in zip(*np.where(entity_output > 0.)): + start, end = (offset_mapping[start][0], offset_mapping[end][-1]) + ent = { + "text": text[start:end], + "type": label_maps['id2entity'][l], + "start_index": start + } + ent_list.append(ent) + batch_ent_results.append(ent_list) + return batch_ent_results + else: + batch_ent_results = [] + batch_rel_results = [] + for entity_output, head_output, tail_output, offset_mapping, text in zip( + batch_outputs[0].numpy(), + batch_outputs[1].numpy(), + batch_outputs[2].numpy(), + offset_mappings, + texts, + ): + entity_output[:, [0, -1]] -= np.inf + entity_output[:, :, [0, -1]] -= np.inf + ents = set() + ent_list = [] + for l, start, end in zip(*np.where(entity_output > 0.)): + ents.add((start, end)) + start, end = (offset_mapping[start][0], offset_mapping[end][-1]) + ent = { + "text": text[start:end], + "type": label_maps['id2entity'][l], + "start_index": start + } + ent_list.append(ent) + batch_ent_results.append(ent_list) + + rel_list = [] + for sh, st in ents: + for oh, ot in ents: + p1s = np.where(head_output[:, sh, oh] > 0.)[0] + p2s = np.where(tail_output[:, st, ot] > 0.)[0] + ps = set(p1s) & set(p2s) + for p in ps: + if task_type in [ + "relation_extraction", "event_extraction" + ]: + rel = { + "subject": + text[offset_mapping[sh][0]:offset_mapping[st] + [1]], + "predicate": + label_maps['id2relation'][p], + "object": + text[offset_mapping[oh][0]:offset_mapping[ot] + [1]], + "subject_start_index": + offset_mapping[sh][0], + "object_start_index": + offset_mapping[oh][0] + } + else: + rel = { + "aspect": + text[offset_mapping[sh][0]:offset_mapping[st] + [1]], + "sentiment": + label_maps['id2relation'][p], + "opinion": + text[offset_mapping[oh][0]:offset_mapping[ot] + [1]], + "aspect_start_index": + offset_mapping[sh][0], + "opinion_start_index": + offset_mapping[oh][0] + } + rel_list.append(rel) + batch_rel_results.append(rel_list) + return (batch_ent_results, batch_rel_results) + + +def build_tree(schema, name='root'): + """ + Build the schema tree. + """ + schema_tree = SchemaTree(name) + for s in schema: + if isinstance(s, str): + schema_tree.add_child(SchemaTree(s)) + elif isinstance(s, dict): + for k, v in s.items(): + if isinstance(v, str): + child = [v] + elif isinstance(v, list): + child = v + else: + raise TypeError( + "Invalid schema, value for each key:value pairs should be list or string" + "but {} received".format(type(v))) + schema_tree.add_child(build_tree(child, name=k)) + else: + raise TypeError("Invalid schema, element should be string or dict, " + "but {} received".format(type(s))) + return schema_tree + + +def schema2label_maps(task_type, schema=None): + if schema and isinstance(schema, dict): + schema = [schema] + + label_maps = {} + if task_type == "entity_extraction": + entity2id = {} + for s in schema: + entity2id[s] = len(entity2id) + + label_maps["entity2id"] = entity2id + elif task_type == "opinion_extraction": + schema = ["观点词", {"评价维度": ["观点词", "情感倾向[正向,负向]"]}] + logger.info( + "Opinion extracion does not support custom schema, the schema is default to %s." + % schema) + label_maps["entity2id"] = {"评价维度": 0, "观点词": 1} + label_maps["sentiment2id"] = {"正向": 0, "负向": 1} + else: + entity2id = {} + relation2id = {} + schema_tree = build_tree(schema) + schema_list = schema_tree.children[:] + while len(schema_list) > 0: + node = schema_list.pop(0) + + if node.name not in entity2id.keys() and len(node.children) != 0: + entity2id[node.name] = len(entity2id) + + for child in node.children: + if child.name not in relation2id.keys(): + relation2id[child.name] = len(relation2id) + schema_list.append(child) + + entity2id['object'] = len(entity2id) + label_maps["entity2id"] = entity2id + label_maps["relation2id"] = relation2id + + label_maps["schema"] = schema + return label_maps + + +def doccano2distill(json_lines, task_type, label_maps=None): + """Convert doccano to distill format""" + if task_type == "opinion_extraction": + outputs = [] + for json_line in json_lines: + id2ent = {} + text = json_line['text'] + output = {"text": text} + + entity_list = [] + entities = json_line['entities'] + for entity in entities: + ent_text = text[entity['start_offset']:entity['end_offset']] + + ent_type_gather = entity['label'].split("##") + if len(ent_type_gather) == 2: + ent_type, ent_senti = ent_type_gather + else: + ent_type = ent_type_gather[0] + ent_senti = None + + ent_start_idx = entity['start_offset'] + + id2ent[entity['id']] = { + "text": ent_text, + "type": ent_type, + "start_index": ent_start_idx, + "sentiment": ent_senti + } + + ent = { + "text": ent_text, + "type": ent_type, + "start_index": ent_start_idx + } + + entity_list.append(ent) + output["entity_list"] = entity_list + + aso_list = [] + relations = json_line['relations'] + for relation in relations: + _aspect = id2ent[relation["from_id"]] + if _aspect['sentiment']: + _opinion = id2ent[relation["to_id"]] + rel = { + "aspect": _aspect['text'], + "sentiment": _aspect['sentiment'], + "opinion": _opinion['text'], + "aspect_start_index": _aspect["start_index"], + "opinion_start_index": _opinion["start_index"] + } + aso_list.append(rel) + output["aso_list"] = aso_list + outputs.append(output) + else: + outputs = [] + for json_line in json_lines: + id2ent = {} + text = json_line['text'] + output = {"text": text} + + entity_list = [] + entities = json_line['entities'] + for entity in entities: + ent_text = text[entity['start_offset']:entity['end_offset']] + ent_type = "object" if entity['label'] not in label_maps[ + 'entity2id'].keys() else entity['label'] + ent_start_idx = entity['start_offset'] + + id2ent[entity['id']] = { + "text": ent_text, + "type": ent_type, + "start_index": ent_start_idx + } + + ent = { + "text": ent_text, + "type": ent_type, + "start_index": ent_start_idx + } + + entity_list.append(ent) + output["entity_list"] = entity_list + + spo_list = [] + relations = json_line['relations'] + for relation in relations: + _subject = id2ent[relation["from_id"]] + _object = id2ent[relation["to_id"]] + rel = { + "subject": _subject['text'], + "predicate": relation['type'], + "object": _object['text'], + "subject_start_index": _subject["start_index"], + "object_start_index": _object["start_index"] + } + spo_list.append(rel) + output["spo_list"] = spo_list + outputs.append(output) + return outputs + + +def synthetic2distill(texts, infer_results, task_type, label_maps=None): + """Convert synthetic data to distill format""" + if task_type == "opinion_extraction": + outputs = [] + for i, line in enumerate(infer_results): + pred = line + output = {"text": texts[i]} + + entity_list = [] + aso_list = [] + for key1 in pred.keys(): + for s in pred[key1]: + ent = { + "text": s["text"], + "type": key1, + "start_index": s["start"] + } + entity_list.append(ent) + + if "relations" in s.keys() and "观点词" in s["relations"].keys( + ) and "情感倾向[正向,负向]" in s["relations"].keys(): + for o in s["relations"]["观点词"]: + rel = { + "aspect": + s["text"], + "sentiment": + s["relations"]["情感倾向[正向,负向]"][0]["text"], + "opinion": + o["text"], + "aspect_start_index": + s["start"], + "opinion_start_index": + o["start"] + } + aso_list.append(rel) + + ent = { + "text": o["text"], + "type": "观点词", + "start_index": o["start"] + } + entity_list.append(ent) + output["entity_list"] = entity_list + output["aso_list"] = aso_list + outputs.append(output) + else: + outputs = [] + for i, line in enumerate(infer_results): + pred = line + output = {"text": texts[i]} + + entity_list = [] + spo_list = [] + for key1 in pred.keys(): + for s in pred[key1]: + ent = { + "text": s['text'], + "type": key1, + "start_index": s['start'] + } + entity_list.append(ent) + if "relations" in s.keys(): + for key2 in s['relations'].keys(): + for o1 in s['relations'][key2]: + if 'start' in o1.keys(): + rel = { + "subject": s['text'], + "predicate": key2, + "object": o1['text'], + "subject_start_index": s['start'], + "object_start_index": o1['start'] + } + spo_list.append(rel) + + if 'relations' not in o1.keys(): + ent = { + "text": o1['text'], + "type": "object", + "start_index": o1['start'] + } + entity_list.append(ent) + else: + ent = { + "text": o1['text'], + "type": key2, + "start_index": o1['start'] + } + entity_list.append(ent) + for key3 in o1['relations'].keys(): + for o2 in o1['relations'][key3]: + ent = { + "text": o2['text'], + "type": "object", + "start_index": o2['start'] + } + entity_list.append(ent) + + rel = { + "subject": + o1['text'], + "predicate": + key3, + "object": + o2['text'], + "subject_start_index": + o1['start'], + "object_start_index": + o2['start'] + } + spo_list.append(rel) + output["entity_list"] = entity_list + output["spo_list"] = spo_list + outputs.append(output) + return outputs diff --git a/model_zoo/uie/doccano.py b/model_zoo/uie/doccano.py index c0e55f5765ba..8f417b4eacca 100644 --- a/model_zoo/uie/doccano.py +++ b/model_zoo/uie/doccano.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -92,12 +93,25 @@ def _save_examples(save_dir, file_name, examples): else: if args.is_shuffle: indexes = np.random.permutation(len(raw_examples)) + index_list = indexes.tolist() raw_examples = [raw_examples[i] for i in indexes] i1, i2, _ = args.splits p1 = int(len(raw_examples) * i1) p2 = int(len(raw_examples) * (i1 + i2)) + train_ids = index_list[:p1] + dev_ids = index_list[p1:p2] + test_ids = index_list[p2:] + + with open(os.path.join(args.save_dir, "sample_index.json"), "w") as fp: + maps = { + "train_ids": train_ids, + "dev_ids": dev_ids, + "test_ids": test_ids + } + fp.write(json.dumps(maps)) + if args.task_type == "ext": train_examples = _create_ext_examples(raw_examples[:p1], args.negative_ratio, diff --git a/model_zoo/uie/utils.py b/model_zoo/uie/utils.py index b35456d06500..ab220e81ae6f 100644 --- a/model_zoo/uie/utils.py +++ b/model_zoo/uie/utils.py @@ -267,33 +267,33 @@ def unify_prompt_name(prompt): return prompt -def add_negative_example(examples, texts, prompts, label_set, negative_ratio): +def add_entity_negative_example(examples, texts, prompts, label_set, + negative_ratio): negative_examples = [] positive_examples = [] with tqdm(total=len(prompts)) as pbar: for i, prompt in enumerate(prompts): - negative_sample = [] - redundants_list = list(set(label_set) ^ set(prompt)) - redundants_list.sort() + redundants = list(set(label_set) ^ set(prompt)) + redundants.sort() num_positive = len(examples[i]) if num_positive != 0: - actual_ratio = math.ceil(len(redundants_list) / num_positive) + actual_ratio = math.ceil(len(redundants) / num_positive) else: # Set num_positive to 1 for text without positive example num_positive, actual_ratio = 1, 0 if actual_ratio <= negative_ratio or negative_ratio == -1: - idxs = [k for k in range(len(redundants_list))] + idxs = [k for k in range(len(redundants))] else: - idxs = random.sample(range(0, len(redundants_list)), + idxs = random.sample(range(0, len(redundants)), negative_ratio * num_positive) for idx in idxs: negative_result = { "content": texts[i], "result_list": [], - "prompt": redundants_list[idx] + "prompt": redundants[idx] } negative_examples.append(negative_result) positive_examples.extend(examples[i]) @@ -301,6 +301,43 @@ def add_negative_example(examples, texts, prompts, label_set, negative_ratio): return positive_examples, negative_examples +def add_relation_negative_example(redundants, text, num_positive, ratio): + added_example = [] + rest_example = [] + + if num_positive != 0: + actual_ratio = math.ceil(len(redundants) / num_positive) + else: + # Set num_positive to 1 for text without positive example + num_positive, actual_ratio = 1, 0 + + all_idxs = [k for k in range(len(redundants))] + if actual_ratio <= ratio or ratio == -1: + idxs = all_idxs + rest_idxs = [] + else: + idxs = random.sample(range(0, len(redundants)), ratio * num_positive) + rest_idxs = list(set(all_idxs) ^ set(idxs)) + + for idx in idxs: + negative_result = { + "content": text, + "result_list": [], + "prompt": redundants[idx] + } + added_example.append(negative_result) + + for rest_idx in rest_idxs: + negative_result = { + "content": text, + "result_list": [], + "prompt": redundants[rest_idx] + } + rest_example.append(negative_result) + + return added_example, rest_example + + def add_full_negative_example(examples, texts, relation_prompts, predicate_set, subject_goldens): with tqdm(total=len(relation_prompts)) as pbar: @@ -323,17 +360,6 @@ def add_full_negative_example(examples, texts, relation_prompts, predicate_set, return examples -def construct_relation_prompt_set(entity_name_set, predicate_set): - relation_prompt_set = set() - for entity_name in entity_name_set: - for predicate in predicate_set: - # The relation prompt is constructed as follows: - # subject + "的" + predicate - relation_prompt = entity_name + "的" + predicate - relation_prompt_set.add(relation_prompt) - return sorted(list(relation_prompt_set)) - - def generate_cls_example(text, labels, prompt_prefix, options): random.shuffle(options) cls_options = ",".join(options) @@ -342,7 +368,7 @@ def generate_cls_example(text, labels, prompt_prefix, options): result_list = [] example = {"content": text, "result_list": result_list, "prompt": prompt} for label in labels: - start = prompt.rfind(label[0]) - len(prompt) - 1 + start = prompt.rfind(label) - len(prompt) - 1 end = start + len(label) result = {"text": label, "start": start, "end": end} example["result_list"].append(result) @@ -386,21 +412,6 @@ def _sep_cls_label(label, separator): return label_list[0], None return label_list[0], label_list[1:] - def _concat_examples(positive_examples, negative_examples, negative_ratio): - examples = [] - if math.ceil(len(negative_examples) / - len(positive_examples)) <= negative_ratio: - examples = positive_examples + negative_examples - else: - # Random sampling the negative examples to ensure overall negative ratio unchanged. - idxs = random.sample(range(0, len(negative_examples)), - negative_ratio * len(positive_examples)) - negative_examples_sampled = [] - for idx in idxs: - negative_examples_sampled.append(negative_examples[idx]) - examples = positive_examples + negative_examples_sampled - return examples - texts = [] entity_examples = [] relation_examples = [] @@ -411,6 +422,8 @@ def _concat_examples(positive_examples, negative_examples, negative_ratio): entity_name_set = [] predicate_set = [] subject_goldens = [] + inverse_relation_list = [] + predicate_list = [] logger.info(f"Converting doccano data...") with tqdm(total=len(raw_examples)) as pbar: @@ -524,6 +537,8 @@ def _concat_examples(positive_examples, negative_examples, negative_ratio): relation_example = [] relation_prompt = [] relation_example_map = {} + inverse_relation = [] + predicates = [] for relation in relations: predicate = relation["type"] subject_id = relation["from_id"] @@ -538,6 +553,12 @@ def _concat_examples(positive_examples, negative_examples, negative_ratio): "start": entity_map[object_id]["start"], "end": entity_map[object_id]["end"] } + + inverse_negative = entity_map[object_id][ + "name"] + "的" + predicate + inverse_relation.append(inverse_negative) + predicates.append(predicate) + if prompt not in relation_example_map.keys(): relation_example_map[prompt] = { "content": text, @@ -557,35 +578,86 @@ def _concat_examples(positive_examples, negative_examples, negative_ratio): relation_examples.append(relation_example) relation_prompts.append(relation_prompt) subject_goldens.append(subject_golden) + inverse_relation_list.append(inverse_relation) + predicate_list.append(predicates) pbar.update(1) logger.info(f"Adding negative samples for first stage prompt...") - positive_examples, negative_examples = add_negative_example( + positive_examples, negative_examples = add_entity_negative_example( entity_examples, texts, entity_prompts, entity_label_set, negative_ratio) if len(positive_examples) == 0: all_entity_examples = [] - elif is_train: - all_entity_examples = _concat_examples(positive_examples, - negative_examples, - negative_ratio) else: all_entity_examples = positive_examples + negative_examples all_relation_examples = [] if len(predicate_set) != 0: + logger.info(f"Adding negative samples for second stage prompt...") if is_train: - logger.info(f"Adding negative samples for second stage prompt...") - relation_prompt_set = construct_relation_prompt_set( - entity_name_set, predicate_set) - positive_examples, negative_examples = add_negative_example( - relation_examples, texts, relation_prompts, relation_prompt_set, - negative_ratio) - all_relation_examples = _concat_examples(positive_examples, - negative_examples, - negative_ratio) + + positive_examples = [] + negative_examples = [] + per_n_ratio = negative_ratio // 3 + + with tqdm(total=len(texts)) as pbar: + for i, text in enumerate(texts): + negative_example = [] + collects = [] + num_positive = len(relation_examples[i]) + + # 1. inverse_relation_list + redundants1 = inverse_relation_list[i] + + # 2. entity_name_set ^ subject_goldens[i] + nonentity_list = list( + set(entity_name_set) ^ set(subject_goldens[i])) + nonentity_list.sort() + + redundants2 = [ + nonentity + "的" + predicate_list[i][random.randrange( + len(predicate_list[i]))] + for nonentity in nonentity_list + ] + + # 3. entity_label_set ^ entity_prompts[i] + non_ent_label_list = list( + set(entity_label_set) ^ set(entity_prompts[i])) + non_ent_label_list.sort() + + redundants3 = [ + subject_goldens[i][random.randrange( + len(subject_goldens[i]))] + "的" + non_ent_label + for non_ent_label in non_ent_label_list + ] + + redundants_list = [redundants1, redundants2, redundants3] + + for redundants in redundants_list: + added, rest = add_relation_negative_example( + redundants, + texts[i], + num_positive, + per_n_ratio, + ) + negative_example.extend(added) + collects.extend(rest) + + num_sup = num_positive * negative_ratio - len( + negative_example) + if num_sup > 0 and collects: + if num_sup > len(collects): + idxs = [k for k in range(len(collects))] + else: + idxs = random.sample(range(0, len(collects)), + num_sup) + for idx in idxs: + negative_example.append(collects[idx]) + + positive_examples.extend(relation_examples[i]) + negative_examples.extend(negative_example) + pbar.update(1) else: - logger.info(f"Adding negative samples for second stage prompt...") relation_examples = add_full_negative_example( relation_examples, texts, relation_prompts, predicate_set, subject_goldens) diff --git a/paddlenlp/layers/__init__.py b/paddlenlp/layers/__init__.py index e52ffdb35031..337ffac62364 100644 --- a/paddlenlp/layers/__init__.py +++ b/paddlenlp/layers/__init__.py @@ -15,3 +15,4 @@ from .sequence import sequence_mask from .tcn import TCN, TemporalBlock from .crf import LinearChainCrf, LinearChainCrfLoss, ViterbiDecoder +from .globalpointer import GlobalPointerForEntityExtraction, GPLinkerForRelationExtraction, GPLinkerForEventExtraction \ No newline at end of file diff --git a/paddlenlp/layers/globalpointer.py b/paddlenlp/layers/globalpointer.py new file mode 100644 index 000000000000..aef86cc42fc8 --- /dev/null +++ b/paddlenlp/layers/globalpointer.py @@ -0,0 +1,178 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import numpy as np +import paddle +import paddle.nn as nn + + +class RotaryPositionEmbedding(nn.Layer): + + def __init__(self, dim, max_seq_len=512): + super().__init__() + inv_freq = 1.0 / (10000 + **(paddle.arange(0, dim, 2, dtype='float32') / dim)) + t = paddle.arange(max_seq_len, dtype=inv_freq.dtype) + freqs = paddle.matmul(t.unsqueeze(1), inv_freq.unsqueeze(0)) + self.register_buffer("sin", freqs.sin(), persistable=False) + self.register_buffer("cos", freqs.cos(), persistable=False) + + def forward(self, x, offset=0): + seqlen = paddle.shape(x)[-2] + sin, cos = ( + self.sin[offset:offset + seqlen, :], + self.cos[offset:offset + seqlen, :], + ) + x1, x2 = x[..., 0::2], x[..., 1::2] + # 奇偶交错 + return paddle.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], + axis=-1).flatten(-2, -1) + + +class GlobalPointer(nn.Layer): + + def __init__(self, + hidden_size, + heads, + head_size=64, + RoPE=True, + tril_mask=True, + max_length=512): + super().__init__() + self.heads = heads + self.head_size = head_size + self.RoPE = RoPE + self.tril_mask = tril_mask + self.dense1 = nn.Linear(hidden_size, head_size * 2) + self.dense2 = nn.Linear(head_size * 2, heads * 2) + if RoPE: + self.rotary = RotaryPositionEmbedding(head_size, max_length) + + def forward(self, inputs, attention_mask=None): + inputs = self.dense1(inputs) + qw, kw = inputs[..., ::2], inputs[..., 1::2] + # RoPE编码 + if self.RoPE: + qw, kw = self.rotary(qw), self.rotary(kw) + + # 计算内积 + logits = paddle.einsum("bmd,bnd->bmn", qw, kw) / self.head_size**0.5 + bias = paddle.transpose(self.dense2(inputs), [0, 2, 1]) / 2 + logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] + + # 排除padding + attn_mask = ( + 1 - + attention_mask[:, None, None, :] * attention_mask[:, None, :, None]) + logits = logits - attn_mask * 1e12 + + # # 排除下三角 + if self.tril_mask: + mask = paddle.tril(paddle.ones_like(logits), diagonal=-1) + + logits = logits - mask * 1e12 + + return logits + + +class GlobalPointerForEntityExtraction(nn.Layer): + + def __init__(self, encoder, label_maps, head_size=64): + super().__init__() + self.encoder = encoder + hidden_size = encoder.config["hidden_size"] + gpcls = GlobalPointer + self.entity_output = gpcls(hidden_size, + len(label_maps['entity2id']), + head_size=head_size) + + def forward(self, input_ids, attention_mask): + # input_ids, attention_mask, token_type_ids: (batch_size, seq_len) + context_outputs = self.encoder(input_ids, attention_mask=attention_mask) + # last_hidden_state: (batch_size, seq_len, hidden_size) + last_hidden_state = context_outputs[0] + + entity_output = self.entity_output(last_hidden_state, attention_mask) + return [entity_output] + + +class GPLinkerForRelationExtraction(nn.Layer): + + def __init__(self, encoder, label_maps, head_size=64): + super().__init__() + self.encoder = encoder + hidden_size = encoder.config["hidden_size"] + num_ents = len(label_maps['entity2id']) + if 'relation2id' in label_maps.keys(): + num_rels = len(label_maps['relation2id']) + else: + num_rels = len(label_maps['sentiment2id']) + gpcls = GlobalPointer + + self.entity_output = gpcls(hidden_size, num_ents, head_size=head_size) + self.head_output = gpcls(hidden_size, + num_rels, + head_size=head_size, + RoPE=False, + tril_mask=False) + self.tail_output = gpcls(hidden_size, + num_rels, + head_size=head_size, + RoPE=False, + tril_mask=False) + + def forward(self, input_ids, attention_mask): + # input_ids, attention_mask, token_type_ids: (batch_size, seq_len) + context_outputs = self.encoder(input_ids, attention_mask=attention_mask) + # last_hidden_state: (batch_size, seq_len, hidden_size) + last_hidden_state = context_outputs[0] + + entity_output = self.entity_output(last_hidden_state, attention_mask) + head_output = self.head_output(last_hidden_state, attention_mask) + tail_output = self.tail_output(last_hidden_state, attention_mask) + spo_output = [entity_output, head_output, tail_output] + return spo_output + + +class GPLinkerForEventExtraction(nn.Layer): + + def __init__(self, encoder, label_maps, head_size=64): + super().__init__() + self.encoder = encoder + hidden_size = encoder.config["hidden_size"] + num_labels = len(label_maps['label2id']) + gpcls = GlobalPointer + + self.argu_output = gpcls(hidden_size, num_labels, head_size=head_size) + self.head_output = gpcls(hidden_size, + 1, + head_size=head_size, + RoPE=False) + self.tail_output = gpcls(hidden_size, + 1, + head_size=head_size, + RoPE=False) + + def forward(self, input_ids, attention_mask): + # input_ids, attention_mask, token_type_ids: (batch_size, seq_len) + context_outputs = self.encoder(input_ids, attention_mask=attention_mask) + # last_hidden_state: (batch_size, seq_len, hidden_size) + last_hidden_state = context_outputs[0] + + argu_output = self.argu_output(last_hidden_state, attention_mask) + head_output = self.head_output(last_hidden_state, attention_mask) + tail_output = self.tail_output(last_hidden_state, attention_mask) + aht_output = (argu_output, head_output, tail_output) + return aht_output diff --git a/paddlenlp/taskflow/information_extraction.py b/paddlenlp/taskflow/information_extraction.py index c504097b7246..ba15bef6254d 100755 --- a/paddlenlp/taskflow/information_extraction.py +++ b/paddlenlp/taskflow/information_extraction.py @@ -14,13 +14,16 @@ # limitations under the License. import re +import os +import json import numpy as np import paddle from ..datasets import load_dataset -from ..transformers import AutoTokenizer +from ..transformers import AutoTokenizer, AutoModel +from ..layers import GlobalPointerForEntityExtraction, GPLinkerForRelationExtraction from .models import UIE from .task import Task -from .utils import SchemaTree, get_span, get_id_and_prob, get_bool_ids_greater_than, dbc2sbc +from .utils import SchemaTree, get_span, get_id_and_prob, get_bool_ids_greater_than, dbc2sbc, gp_decode, DataCollatorGP usage = r""" from paddlenlp import Taskflow @@ -706,3 +709,306 @@ def _postprocess(self, inputs): This function will convert the model output to raw text. """ return inputs['result'] + + +class GPTask(Task): + """ + Global Pointer for closed-domain information extraction Task. + Args: + task(string): The name of task. + model(string): The model name in the task. + kwargs (dict, optional): Additional keyword arguments passed along to the specific task. + """ + resource_files_names = { + "model_state": "model_state.pdparams", + "model_config": "model_config.json", + "vocab_file": "vocab.txt", + "special_tokens_map": "special_tokens_map.json", + "tokenizer_config": "tokenizer_config.json" + } + + def __init__(self, task, model, **kwargs): + super().__init__(task=task, model=model, **kwargs) + self._schema_tree = None + self._load_config() + self._construct_tokenizer() + self._get_inference_model() + self._max_seq_len = self.kwargs[ + 'max_seq_len'] if 'max_seq_len' in self.kwargs else 256 + self._batch_size = self.kwargs[ + 'batch_size'] if 'batch_size' in self.kwargs else 64 + self._lazy_load = self.kwargs[ + 'lazy_load'] if 'lazy_load' in self.kwargs else False + self._num_workers = self.kwargs[ + 'num_workers'] if 'num_workers' in self.kwargs else 0 + + def _load_config(self): + model_config_file = os.path.join( + self._task_path, self.resource_files_names["model_config"]) + with open(model_config_file, encoding="utf-8") as f: + model_config = json.load(f) + self._label_maps = model_config["label_maps"] + self._task_type = model_config["task_type"] + self._encoder = model_config["encoder"] + schema = model_config["label_maps"]["schema"] + self._set_schema(schema) + + def _set_schema(self, schema): + if isinstance(schema, dict) or isinstance(schema, str): + schema = [schema] + self._schema_tree = self._build_tree(schema) + + def _construct_input_spec(self): + """ + Construct the input spec for the convert dygraph model to static model. + """ + self._input_spec = [ + paddle.static.InputSpec(shape=[None, None], + dtype="int64", + name='input_ids'), + paddle.static.InputSpec(shape=[None, None], + dtype="int64", + name='att_mask'), + ] + + def _construct_model(self, model): + """ + Construct the inference model for the predictor. + """ + encoder = AutoModel.from_pretrained(self._encoder) + if self._task_type == "entity_extraction": + model_instance = GlobalPointerForEntityExtraction( + encoder, self._label_maps) + else: + model_instance = GPLinkerForRelationExtraction( + encoder, self._label_maps) + model_path = os.path.join(self._task_path, "model_state.pdparams") + state_dict = paddle.load(model_path) + model_instance.set_dict(state_dict) + self._model = model_instance + self._model.eval() + + def _construct_tokenizer(self): + """ + Construct the tokenizer for the predictor. + """ + self._tokenizer = AutoTokenizer.from_pretrained(self._task_path) + + def _preprocess(self, inputs): + """ + Transform the raw text to the model inputs, two steps involved: + 1) Transform the raw text to token ids. + 2) Generate the other model inputs from the raw text and token ids. + """ + inputs = self._check_input_text(inputs) + + def read(inputs): + for x in inputs: + tokenized_inputs = self._tokenizer( + x, + max_length=self._max_seq_len, + padding=False, + truncation=True, + return_attention_mask=True, + return_offsets_mapping=True, + return_token_type_ids=False, + ) + tokenized_inputs['text'] = x + yield tokenized_inputs + + infer_ds = load_dataset(read, inputs=inputs, lazy=self._lazy_load) + + data_collator = DataCollatorGP(self._tokenizer, + label_maps=self._label_maps, + task_type=self._task_type) + + batch_sampler = paddle.io.BatchSampler(dataset=infer_ds, + batch_size=self._batch_size, + shuffle=False) + + infer_data_loader = paddle.io.DataLoader(dataset=infer_ds, + batch_sampler=batch_sampler, + collate_fn=data_collator, + num_workers=self._num_workers, + return_list=True) + outputs = {} + outputs['data_loader'] = infer_data_loader + outputs['input_texts'] = inputs + return outputs + + def _run_model(self, inputs): + all_preds = ([], []) if self._task_type in [ + "opinion_extraction", "relation_extraction" + ] else [] + for batch in inputs['data_loader']: + input_ids, attention_masks, offset_mappings, texts = batch + self.input_handles[0].copy_from_cpu( + input_ids.numpy().astype('int64')) + self.input_handles[1].copy_from_cpu( + attention_masks.numpy().astype('int64')) + self.predictor.run() + logits = [ + paddle.to_tensor(self.output_handle[i].copy_to_cpu()) + for i in range(len(self.output_handle)) + ] + batch_outputs = gp_decode(logits, offset_mappings, texts, + self._label_maps, self._task_type) + if isinstance(batch_outputs, tuple): + all_preds[0].extend(batch_outputs[0]) # Entity output + all_preds[1].extend(batch_outputs[1]) # Relation output + else: + all_preds.extend(batch_outputs) + inputs['result'] = all_preds + return inputs + + @classmethod + def _build_tree(cls, schema, name='root'): + """ + Build the schema tree. + """ + schema_tree = SchemaTree(name) + for s in schema: + if isinstance(s, str): + schema_tree.add_child(SchemaTree(s)) + elif isinstance(s, dict): + for k, v in s.items(): + if isinstance(v, str): + child = [v] + elif isinstance(v, list): + child = v + else: + raise TypeError( + "Invalid schema, value for each key:value pairs should be list or string" + "but {} received".format(type(v))) + schema_tree.add_child(cls._build_tree(child, name=k)) + else: + raise TypeError( + "Invalid schema, element should be string or dict, " + "but {} received".format(type(s))) + return schema_tree + + def _postprocess(self, inputs): + if self._task_type == "entity_extraction": + results = self._postprocess_entity_extraction(inputs) + elif self._task_type == "opinion_extraction": + results = self._postprocess_opinion_extraction(inputs) + else: + results = self._postprocess_relation_extraction(inputs) + return results + + def _postprocess_opinion_extraction(self, inputs): + all_ent_preds, all_rel_preds = inputs['result'] + results = [] + for i in range(len(inputs['input_texts'])): + result = {} + aspect_maps = {} + for ent in all_ent_preds[i]: + ent_res = { + 'text': ent['text'], + 'start': ent['start_index'], + 'end': ent['start_index'] + len(ent['text']), + 'probability': ent['probability'] + } + result.setdefault(ent['type'], []).append(ent_res) + if ent['type'] == "评价维度": + for r in result["评价维度"]: + if ent['text'] == r['text'] and ent['start_index'] == r[ + 'start']: + aspect_maps[(ent['text'], ent['start_index'])] = r + break + + for rel in all_rel_preds[i]: + r = aspect_maps[(rel['aspect'], rel['aspect_start_index'])] + r['relations'] = {} + sentiment = { + 'probability': rel['probability'], + 'text': rel['sentiment'] + } + opinion = { + 'text': rel['opinion'], + 'start': rel['opinion_start_index'], + 'end': rel['opinion_start_index'] + len(rel['opinion']), + 'probability': rel['probability'] + } + r['relations'].setdefault('情感倾向[正向,负向]', []).append(sentiment) + r['relations'].setdefault('观点词', []).append(opinion) + results.append(result) + return results + + def _postprocess_relation_extraction(self, inputs): + all_ent_preds, all_rel_preds = inputs['result'] + results = [] + for input_text_idx in range(len(inputs['input_texts'])): + result = {} + schema_list = self._schema_tree.children[:] + while len(schema_list) > 0: + node = schema_list.pop(0) + if node.parent_relations is None: + prefix = [] + relations = [[]] + cnt = -1 + for ent in all_ent_preds[input_text_idx]: + if node.name == ent['type']: + ent_res = { + 'text': ent['text'], + 'start': ent['start_index'], + 'end': ent['start_index'] + len(ent['text']), + 'probability': ent['probability'] + } + result.setdefault(node.name, []).append(ent_res) + cnt += 1 + result[node.name][cnt]['relations'] = {} + relations[0].append(result[node.name][cnt]) + else: + relations = [[] for _ in range(len(node.parent_relations))] + for i, rs in enumerate(node.parent_relations): + for r in rs: + cnt = -1 + for rel in all_rel_preds[input_text_idx]: + if r['text'] == rel['subject'] and r['start'] == rel[ + 'subject_start_index'] and node.name == rel[ + 'predicate']: + rel_res = { + 'text': + rel['object'], + 'start': + rel['object_start_index'], + 'end': + rel['object_start_index'] + + len(rel['object']), + 'probability': + rel['probability'] + } + r['relations'].setdefault( + node.name, []).append(rel_res) + cnt += 1 + r['relations'][ + node.name][cnt]['relations'] = {} + relations[i].append( + r['relations'][node.name][cnt]) + for child in node.children: + child.prefix = prefix + child.parent_relations = relations + schema_list.append(child) + results.append(result) + return results + + def _postprocess_entity_extraction(self, inputs): + all_preds = inputs['result'] + results = [] + for input_text_idx in range(len(inputs['input_texts'])): + result = {} + schema_list = self._schema_tree.children[:] + while len(schema_list) > 0: + node = schema_list.pop(0) + for ent in all_preds[input_text_idx]: + if node.name == ent['type']: + ent_res = { + 'text': ent['text'], + 'start': ent['start_index'], + 'end': ent['start_index'] + len(ent['text']), + 'probability': ent['probability'] + } + result.setdefault(node.name, []).append(ent_res) + results.append(result) + return results diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index 83270831ab45..e57ce9107f29 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -35,7 +35,7 @@ from .text_correction import CSCTask from .text_similarity import TextSimilarityTask from .dialogue import DialogueTask -from .information_extraction import UIETask +from .information_extraction import UIETask, GPTask from .code_generation import CodeGenerationTask from .text2image_generation import Text2ImageGenerationTask @@ -263,6 +263,10 @@ "hidden_size": 768, "task_flag": "information_extraction-uie-base-en" }, + "uie-data-distill-gp": { + "task_class": GPTask, + "task_flag": "information_extraction-uie-data-distill-gp" + } }, "default": { "model": "uie-base" diff --git a/paddlenlp/taskflow/utils.py b/paddlenlp/taskflow/utils.py index f189e76d04e3..8ca4623e4645 100644 --- a/paddlenlp/taskflow/utils.py +++ b/paddlenlp/taskflow/utils.py @@ -21,13 +21,16 @@ import pickle import warnings import contextlib +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import paddle +import paddle.nn.functional as F from paddle.dataset.common import md5file from ..utils.log import logger from ..utils.downloader import get_path_from_url, DownloaderCheck +from ..transformers.tokenizer_utils_base import PretrainedTokenizerBase, PaddingStrategy DOC_FORMAT = r""" Examples: @@ -770,6 +773,7 @@ def __init__(self, name='root', children=None): self.children = [] self.prefix = None self.parent_relations = None + self.parent = None if children is not None: for child in children: self.add_child(child) @@ -1398,3 +1402,137 @@ def extract_spo(self, all_items): res_cand.append(tmp.copy()) continue return res_cand + + +@dataclass +class DataCollatorGP: + tokenizer: PretrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + label_maps: Optional[dict] = None + task_type: Optional[str] = None + + def __call__( + self, features: List[Dict[str, Union[List[int], paddle.Tensor]]] + ) -> Dict[str, paddle.Tensor]: + new_features = [{ + k: v + for k, v in f.items() if k not in ["offset_mapping", "text"] + } for f in features] + + batch = self.tokenizer.pad( + new_features, + padding=self.padding, + ) + + batch = [paddle.to_tensor(batch[k]) for k in batch.keys()] + batch.append([feature["offset_mapping"] for feature in features]) + batch.append([feature["text"] for feature in features]) + return batch + + +def gp_decode(batch_outputs, + offset_mappings, + texts, + label_maps, + task_type="relation_extraction"): + if task_type == "entity_extraction": + batch_ent_results = [] + for entity_output, offset_mapping, text in zip(batch_outputs[0].numpy(), + offset_mappings, texts): + entity_output[:, [0, -1]] -= np.inf + entity_output[:, :, [0, -1]] -= np.inf + entity_probs = F.softmax(paddle.to_tensor(entity_output), + axis=1).numpy() + ent_list = [] + for l, start, end in zip(*np.where(entity_output > 0.)): + ent_prob = entity_probs[l, start, end] + start, end = (offset_mapping[start][0], offset_mapping[end][-1]) + ent = { + "text": text[start:end], + "type": label_maps['id2entity'][str(l)], + "start_index": start, + "probability": ent_prob + } + ent_list.append(ent) + batch_ent_results.append(ent_list) + return batch_ent_results + else: + batch_ent_results = [] + batch_rel_results = [] + for entity_output, head_output, tail_output, offset_mapping, text in zip( + batch_outputs[0].numpy(), + batch_outputs[1].numpy(), + batch_outputs[2].numpy(), + offset_mappings, + texts, + ): + entity_output[:, [0, -1]] -= np.inf + entity_output[:, :, [0, -1]] -= np.inf + entity_probs = F.softmax(paddle.to_tensor(entity_output), + axis=1).numpy() + head_probs = F.softmax(paddle.to_tensor(head_output), + axis=1).numpy() + tail_probs = F.softmax(paddle.to_tensor(tail_output), + axis=1).numpy() + + ents = set() + ent_list = [] + for l, start, end in zip(*np.where(entity_output > 0.)): + ent_prob = entity_probs[l, start, end] + ents.add((start, end)) + start, end = (offset_mapping[start][0], offset_mapping[end][-1]) + ent = { + "text": text[start:end], + "type": label_maps['id2entity'][str(l)], + "start_index": start, + "probability": ent_prob + } + ent_list.append(ent) + batch_ent_results.append(ent_list) + + rel_list = [] + for sh, st in ents: + for oh, ot in ents: + p1s = np.where(head_output[:, sh, oh] > 0.)[0] + p2s = np.where(tail_output[:, st, ot] > 0.)[0] + ps = set(p1s) & set(p2s) + for p in ps: + rel_prob = head_probs[p, sh, oh] * tail_probs[p, st, ot] + if task_type == "relation_extraction": + rel = { + "subject": + text[offset_mapping[sh][0]:offset_mapping[st] + [1]], + "predicate": + label_maps['id2relation'][str(p)], + "object": + text[offset_mapping[oh][0]:offset_mapping[ot] + [1]], + "subject_start_index": + offset_mapping[sh][0], + "object_start_index": + offset_mapping[oh][0], + "probability": + rel_prob, + } + else: + rel = { + "aspect": + text[offset_mapping[sh][0]:offset_mapping[st] + [1]], + "sentiment": + label_maps['id2relation'][str(p)], + "opinion": + text[offset_mapping[oh][0]:offset_mapping[ot] + [1]], + "aspect_start_index": + offset_mapping[sh][0], + "opinion_start_index": + offset_mapping[oh][0], + "probability": + rel_prob, + } + rel_list.append(rel) + batch_rel_results.append(rel_list) + return (batch_ent_results, batch_rel_results)