Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prompt] add prompt and text classification examples #2894

Merged
merged 55 commits into from
Aug 16, 2022

Conversation

LemonNoel
Copy link
Contributor

@LemonNoel LemonNoel commented Jul 27, 2022

PR types

New features

PR changes

APIs

Description

  • Support PET, P-tuning, WARP and RGL.
  • Add the example of multi-class classification based on WARP.

截屏2022-08-10 21 06 56

How to use

cd application/text_classification/multi_class/few_shot
  • Download dataset
wget https://paddlenlp.bj.bcebos.com/datasets/KUAKE_QIC.tar.gz
tar -zxvf KUAKE_QIC.tar.gz
mv KUAKE_QIC data
  • Run the multi-class example
python train.py --output_dir ./ckpt/ --prompt "问的是" --max_seq_length 128  --learning_rate 6e-5 --num_train_epochs 100 --do_eval --ppt_learning_rate 3e-4 --data_dir ../data/ --do_train --logging_steps 10 --eval_steps 100 --max_steps 1000 --per_device_eval_batch_size 32 --train_sample_per_label 16

@LemonNoel LemonNoel requested review from ZHUI and wawltor July 27, 2022 08:02
@LemonNoel LemonNoel self-assigned this Jul 27, 2022
Comment on lines +83 to +94
ppt_weight_decay: float = field(
default=0.0,
metadata={"help": "Weight decay for the AdamW optimizer of prompt."})
ppt_adam_beta1: float = field(
default=0.9,
metadata={"help": "Beta1 for the AdamW optimizer of prompt."})
ppt_adam_beta2: float = field(
default=0.999,
metadata={"help": "Beta2 for the AdamW optimizer of prompt."})
ppt_adam_epsilon: float = field(
default=1e-8,
metadata={"help": "Epsilon for the AdamW optimizer of prompt."})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些 optimizer 的参数都可以重新设置吗?
跟原始 trainer 的参数,可以都起作用,还是只有 ppt 参数起作用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AdamW代码这些参数都可以按parameters group来设置。prompt_trainer.py里边重新实现了create_optimizerppt_xxx只对prompt相关的参数起作用,原始trainer的参数只对PLM起作用。

Comment on lines 82 to 89
task_type_ids=None,
inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.shape
input_embeddings = self.word_embeddings(input_ids)
else:
input_shape = inputs_embeds.shape[:-1]
input_embeddings = inputs_embeds
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处是为了直接传入input_embeddings, 这样的话,是不是所有支持 prompt的模型都需要适配?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soft template需要PLM支持直接传入inputs_embeds。整体模型的适配我看已经在基础能力建设的排期中了 #2356

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guoshengCS 咱们这块目前有具体的时间节点吗?支持 inputs_embeds

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于 BERT/ERNIE/RoBERTa 的会在在一周内支持作为样板参考,更多支持要在月中了

else:
raise TypeError('InputExample')

def _convert_example(self, example):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原实现中verbalizer的wrap_one_example方法已删除。该方法对应的标签映射功能放在了trainer的convert_example函数中,默认为多分类任务的直接映射,可通过convert_fn参数传递自定义函数。

按照目前的实现这部分逻辑也可以从trainer剥离出去,放在load_local_dataset中。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面的 _map_dataset应该也可以剥离吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_map_dataset这部分是为了把template和tokenizer wrapper的预处理隐藏起来,放在外边每次都要写重复的代码。或者这个剥离出来放在哪里比较合适呢

Comment on lines 23 to 25
from paddlenlp.prompt import (AutoTemplate, SoftVerbalizer, MLMTokenizerWrapper,
PromptTuningArguments, PromptTrainer,
PromptModelForClassification, FewShotSampler)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from paddlenlp.prompt import (AutoTemplate, SoftVerbalizer, MLMTokenizerWrapper,
PromptTuningArguments, PromptTrainer,
PromptModelForClassification, FewShotSampler)
from paddlenlp.prompt import (AutoTemplate, SoftVerbalizer, MLMTokenizerWrapper,
PromptTuningArguments, PromptTrainer,
PromptModelForClassification, FewShotSampler, )

最后加一个逗号,yapf 可自动换行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在类似用法的代码里都增加了逗号来换行。

lr_scheduler = get_scheduler(training_args.lr_scheduler_type,
training_args.ppt_learning_rate,
num_warmup_steps, num_training_steps)
optimizer = paddle.optimizer.AdamW(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加注释说明

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块 prompt_model.verbalizer.non_head_parameters 在不同模型下都是固定的吗?
是否考虑封装到 PromptTrainer

        parameters=[{
            'params': prompt_model.verbalizer.non_head_parameters()
        }, {
            'params':
            prompt_model.verbalizer.head_parameters(),
            'learning_rate':
            training_args.ppt_learning_rate / training_args.learning_rate
        }])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是固定的。这里只有SoftVerbalizer才有non_head_parametershead_parameters这些属性,因此没有封装到PromptTrainer。后续验证non_head_parametershead_parameters分开更新是否有必要后,再决定这段代码是否保留。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分已经封装在PromptTrainer的create_optimizer方法中,通过对self.verbalizer的类型判断来区分是否使用non_head_parametershead_parameters按不同学习率更新参数。

Comment on lines 110 to 122
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
_train_batch_size = training_args.per_device_train_batch_size
_num_train_epochs = training_args.num_train_epochs
num_update_per_epoch = len(train_ds) // _train_batch_size
num_update_per_epoch //= training_args.gradient_accumulation_steps
num_update_per_epoch = max(num_update_per_epoch, 1)
num_training_steps = num_update_per_epoch * _num_train_epochs
if training_args.warmup_steps > 0:
num_warmup_steps = training_args.warmup_steps
else:
num_warmup_steps = int(training_args.warmup_ratio * num_training_steps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的逻辑封装一下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分逻辑封装在了trainer_base.py中的init_num_steps函数(staticmethod)

]
export_path = os.path.join(training_args.output_dir, 'export')
os.makedirs(export_path, exist_ok=True)
export_model(prompt_model, input_spec, export_path,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

验证导出模型的效果

Copy link
Contributor Author

@LemonNoel LemonNoel Aug 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已验证导出模型的效果。

else:
raise TypeError('InputExample')

def _convert_example(self, example):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面的 _map_dataset应该也可以剥离吧

Comment on lines 274 to 275
bce_criterion = nn.CrossEntropyLoss()
cos_criterion = nn.CosineSimilarity(axis=0, eps=1e-6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里改用 functional 的函数?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改为functional函数。

@LemonNoel LemonNoel marked this pull request as ready for review August 11, 2022 05:08
args.use_fp16, args.num_threads)
self._template = AutoTemplate.load_from(
os.path.dirname(args.model_path_prefix), self._tokenizer)
self._wrapper = MLMTokenizerWrapper(self._max_seq_length,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在训练的时候并没有 MLMTokenizerWrapper 概念,为什么在infer的有这个概念了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前TokenizerWrapper的功能已封装在Template中,infer的部分已在最新commit中同步修改,不对外暴露这个概念。

criterion = paddle.nn.CrossEntropyLoss()

# Initialize the prompt model with the above variables.
prompt_model = PromptModelForClassification(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PromptModelForClassification -> PromptModelForSequenceClassification

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改类名为PromptModelForSequenceClassification

├── dev.txt # 验证数据集
├── test.txt # 测试数据集(可选)
├── data.txt # 待预测数据(可选)
└── label.txt # 分类标签集
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的labels具体怎么构造的了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分labels构造需要尽可能与提示(prompt)组成通顺的句子。经测试,当前示例数据集直接按照原label训练有1-17个点的下降。其中17个点为多层次分类数据集,因为标签中有特殊符号##,所以效果下降比较明显。已在README中增加labels构造遵循的规则。

**提示学习(Prompt Learning)**
的主要思想是通过任务转换使得下游任务和预训练任务尽可能相似,充分利用预训练语言模型学习到的特征,从而降低样本需求量。除此之外,我们往往还需要在原有的输入文本上拼接一段“提示”,来引导预训练模型输出期望的结果。

我们以Ernie为例,回顾一下这类预训练语言模型的训练任务。与考试中的完形填空相似,给定一句文本,遮盖掉其中的部分字词,要求语言模型预测出这些遮盖位置原本的字词。因此,我们也将多层次分类任务转换为与完形填空相似的形式。例如影评情感分类任务,标签分为`1-正向`,`0-负向`两类,在经典的微调方式中,需要学习的参数是以`[CLS]`向量为输入,以负向/正向为输出的随机初始化的分类器。而在提示学习中,我们通过构造提示,将原有的分类任务转化为完形填空。如下图所示,通过提示`我[MASK]喜欢。`,原有`1-正向`,`0-负向`的标签被转化为了预测空格是`很`还是`不`。此时的分类器也不再是随机初始化,而是利用了这两个字的预训练向量来初始化,充分利用了预训练模型学习到的参数。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
我们以Ernie为例,回顾一下这类预训练语言模型的训练任务。与考试中的完形填空相似,给定一句文本,遮盖掉其中的部分字词,要求语言模型预测出这些遮盖位置原本的字词。因此,我们也将多层次分类任务转换为与完形填空相似的形式。例如影评情感分类任务,标签分为`1-正向``0-负向`两类,在经典的微调方式中,需要学习的参数是以`[CLS]`向量为输入,以负向/正向为输出的随机初始化的分类器。而在提示学习中,我们通过构造提示,将原有的分类任务转化为完形填空。如下图所示,通过提示`我[MASK]喜欢。`,原有`1-正向``0-负向`的标签被转化为了预测空格是``还是``。此时的分类器也不再是随机初始化,而是利用了这两个字的预训练向量来初始化,充分利用了预训练模型学习到的参数。
我们以Ernie为例,回顾一下这类预训练语言模型的训练任务。
与考试中的完形填空相似,给定一句文本,遮盖掉其中的部分字词,要求语言模型预测出这些遮盖位置原本的字词。
因此,我们也将多层次分类任务转换为与完形填空相似的形式。例如影评情感分类任务,标签分为`1-正向``0-负向`两类。
在经典的微调方式中,需要学习的参数是以`[CLS]`向量为输入,以负向/正向为输出的随机初始化的分类器。而在提示学习中,我们通过构造提示,将原有的分类任务转化为完形填空。
如下图所示,通过提示`我[MASK]喜欢。`,原有`1-正向``0-负向`的标签被转化为了预测空格是``还是``。此时的分类器也不再是随机初始化,而是利用了这两个字的预训练向量来初始化,充分利用了预训练模型学习到的参数。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有点偏长了,看起来很累,建议分段

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已切分为三个段落。


对于训练/验证/测试数据集文件,每行数据表示一条样本,包括文本和标签两部分,由tab符`\t`分隔,多个标签以`,`分隔。例如
```
紫光圣果副总经理李明雷辞职 组织关系,组织关系##辞/离职
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不同层次,类别关系之间使用 ## 字符连接

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加相应的描述。

@dataclass
class ModelArguments:
model_name_or_path: str = field(default="ernie-3.0-base-zh", metadata={"help": "Build-in pretrained model name or the path to local model."})
export_type: str = field(default='paddle', metadata={"help": "The type to export. Support `paddle` and `onnx`."})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认是导出为 paddle 格式?我看infer.py 用的是 onnx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为后续还要接入serving部署,模型格式是静态图,所以这里默认格式设为paddle,统一导出静态图,onnx部署时再进行格式转换。

Comment on lines 28 to 30
PromptModelForSequenceClassification,
)
from utils import load_local_dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
PromptModelForSequenceClassification,
)
from utils import load_local_dataset
PromptModelForSequenceClassification,
)
from utils import load_local_dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不同类型的包import之间已增加空行。


### 模型部署

#### CPU端推理样例
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块默认 onnx 的话,需要增加 onnx runtime的安装文档之类的,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加onnx runtime依赖安装相关的说明。

对于分类标签集文件,存储了数据集中所有的标签集合,每行为一个标签名。如果需要自定义标签映射用于分类器初始化,则每行需要包括标签名和相应的映射词,由`==`分隔。例如
```
news_car'=='汽车
news_culture'=='文化
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以把 tnews 的示例展示在这里,默认 折叠一下。纯粹描述有点抽象

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里例如就是tnews中的例子了。这里是指把全部标签放上去吗?

return [p for p in self.template.parameters()
] + [p for p in self.verbalizer.parameters()]

def get_input_spec(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对 prompt 任务,这里的 InputSpec 基本可以固定死?

Copy link
Contributor Author

@LemonNoel LemonNoel Aug 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的get_input_spec默认是SoftTemplate下的InputSpec,使用MannualTemplate时需要手动传入。

Comment on lines 21 to 22
import paddle
from ..utils.log import logger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import paddle
from ..utils.log import logger
import paddle
from ..utils.log import logger

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。

training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")

paddle.set_device(training_args.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议固定一下随机种子

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认 trainer 里面应该有设置 seed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PromptTrainer继承自Trainer,在初始化时会固定seed

trainer.log_metrics("test", test_ret.metrics)

# Export static model.
if training_args.do_export:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最终dump下来的文件,ckpt -> checkpoints

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_dir默认值已改为 checkpoints

Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LemonNoel LemonNoel merged commit a15da5f into PaddlePaddle:develop Aug 16, 2022
@LemonNoel LemonNoel deleted the ppt branch August 16, 2022 09:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants