-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[prompt] add prompt and text classification examples #2894
Conversation
ppt_weight_decay: float = field( | ||
default=0.0, | ||
metadata={"help": "Weight decay for the AdamW optimizer of prompt."}) | ||
ppt_adam_beta1: float = field( | ||
default=0.9, | ||
metadata={"help": "Beta1 for the AdamW optimizer of prompt."}) | ||
ppt_adam_beta2: float = field( | ||
default=0.999, | ||
metadata={"help": "Beta2 for the AdamW optimizer of prompt."}) | ||
ppt_adam_epsilon: float = field( | ||
default=1e-8, | ||
metadata={"help": "Epsilon for the AdamW optimizer of prompt."}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些 optimizer 的参数都可以重新设置吗?
跟原始 trainer 的参数,可以都起作用,还是只有 ppt 参数起作用?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看AdamW代码这些参数都可以按parameters group来设置。prompt_trainer.py
里边重新实现了create_optimizer
,ppt_xxx
只对prompt相关的参数起作用,原始trainer的参数只对PLM起作用。
task_type_ids=None, | ||
inputs_embeds=None): | ||
if input_ids is not None: | ||
input_shape = input_ids.shape | ||
input_embeddings = self.word_embeddings(input_ids) | ||
else: | ||
input_shape = inputs_embeds.shape[:-1] | ||
input_embeddings = inputs_embeds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处是为了直接传入input_embeddings
, 这样的话,是不是所有支持 prompt的模型都需要适配?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
soft template需要PLM支持直接传入inputs_embeds
。整体模型的适配我看已经在基础能力建设的排期中了 #2356
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@guoshengCS 咱们这块目前有具体的时间节点吗?支持 inputs_embeds
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于 BERT/ERNIE/RoBERTa 的会在在一周内支持作为样板参考,更多支持要在月中了
else: | ||
raise TypeError('InputExample') | ||
|
||
def _convert_example(self, example): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原实现中verbalizer的wrap_one_example
方法已删除。该方法对应的标签映射功能放在了trainer的convert_example
函数中,默认为多分类任务的直接映射,可通过convert_fn
参数传递自定义函数。
按照目前的实现这部分逻辑也可以从trainer剥离出去,放在load_local_dataset
中。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面的 _map_dataset
应该也可以剥离吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_map_dataset这部分是为了把template和tokenizer wrapper的预处理隐藏起来,放在外边每次都要写重复的代码。或者这个剥离出来放在哪里比较合适呢
from paddlenlp.prompt import (AutoTemplate, SoftVerbalizer, MLMTokenizerWrapper, | ||
PromptTuningArguments, PromptTrainer, | ||
PromptModelForClassification, FewShotSampler) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from paddlenlp.prompt import (AutoTemplate, SoftVerbalizer, MLMTokenizerWrapper, | |
PromptTuningArguments, PromptTrainer, | |
PromptModelForClassification, FewShotSampler) | |
from paddlenlp.prompt import (AutoTemplate, SoftVerbalizer, MLMTokenizerWrapper, | |
PromptTuningArguments, PromptTrainer, | |
PromptModelForClassification, FewShotSampler, ) |
最后加一个逗号,yapf 可自动换行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在类似用法的代码里都增加了逗号来换行。
lr_scheduler = get_scheduler(training_args.lr_scheduler_type, | ||
training_args.ppt_learning_rate, | ||
num_warmup_steps, num_training_steps) | ||
optimizer = paddle.optimizer.AdamW( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加注释说明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块 prompt_model.verbalizer.non_head_parameters
在不同模型下都是固定的吗?
是否考虑封装到 PromptTrainer
parameters=[{
'params': prompt_model.verbalizer.non_head_parameters()
}, {
'params':
prompt_model.verbalizer.head_parameters(),
'learning_rate':
training_args.ppt_learning_rate / training_args.learning_rate
}])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是固定的。这里只有SoftVerbalizer才有non_head_parameters
和head_parameters
这些属性,因此没有封装到PromptTrainer。后续验证non_head_parameters
和head_parameters
分开更新是否有必要后,再决定这段代码是否保留。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分已经封装在PromptTrainer的create_optimizer
方法中,通过对self.verbalizer
的类型判断来区分是否使用non_head_parameters
和head_parameters
按不同学习率更新参数。
if training_args.max_steps > 0: | ||
num_training_steps = training_args.max_steps | ||
else: | ||
_train_batch_size = training_args.per_device_train_batch_size | ||
_num_train_epochs = training_args.num_train_epochs | ||
num_update_per_epoch = len(train_ds) // _train_batch_size | ||
num_update_per_epoch //= training_args.gradient_accumulation_steps | ||
num_update_per_epoch = max(num_update_per_epoch, 1) | ||
num_training_steps = num_update_per_epoch * _num_train_epochs | ||
if training_args.warmup_steps > 0: | ||
num_warmup_steps = training_args.warmup_steps | ||
else: | ||
num_warmup_steps = int(training_args.warmup_ratio * num_training_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的逻辑封装一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分逻辑封装在了trainer_base.py中的init_num_steps
函数(staticmethod)
] | ||
export_path = os.path.join(training_args.output_dir, 'export') | ||
os.makedirs(export_path, exist_ok=True) | ||
export_model(prompt_model, input_spec, export_path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
验证导出模型的效果
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已验证导出模型的效果。
else: | ||
raise TypeError('InputExample') | ||
|
||
def _convert_example(self, example): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面的 _map_dataset
应该也可以剥离吧
paddlenlp/prompt/prompt_trainer.py
Outdated
bce_criterion = nn.CrossEntropyLoss() | ||
cos_criterion = nn.CosineSimilarity(axis=0, eps=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里改用 functional 的函数?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已改为functional函数。
args.use_fp16, args.num_threads) | ||
self._template = AutoTemplate.load_from( | ||
os.path.dirname(args.model_path_prefix), self._tokenizer) | ||
self._wrapper = MLMTokenizerWrapper(self._max_seq_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在训练的时候并没有 MLMTokenizerWrapper 概念,为什么在infer的有这个概念了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前TokenizerWrapper
的功能已封装在Template
中,infer的部分已在最新commit中同步修改,不对外暴露这个概念。
criterion = paddle.nn.CrossEntropyLoss() | ||
|
||
# Initialize the prompt model with the above variables. | ||
prompt_model = PromptModelForClassification( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PromptModelForClassification -> PromptModelForSequenceClassification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改类名为PromptModelForSequenceClassification
。
├── dev.txt # 验证数据集 | ||
├── test.txt # 测试数据集(可选) | ||
├── data.txt # 待预测数据(可选) | ||
└── label.txt # 分类标签集 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的labels具体怎么构造的了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分labels构造需要尽可能与提示(prompt)组成通顺的句子。经测试,当前示例数据集直接按照原label训练有1-17个点的下降。其中17个点为多层次分类数据集,因为标签中有特殊符号##
,所以效果下降比较明显。已在README中增加labels构造遵循的规则。
**提示学习(Prompt Learning)** | ||
的主要思想是通过任务转换使得下游任务和预训练任务尽可能相似,充分利用预训练语言模型学习到的特征,从而降低样本需求量。除此之外,我们往往还需要在原有的输入文本上拼接一段“提示”,来引导预训练模型输出期望的结果。 | ||
|
||
我们以Ernie为例,回顾一下这类预训练语言模型的训练任务。与考试中的完形填空相似,给定一句文本,遮盖掉其中的部分字词,要求语言模型预测出这些遮盖位置原本的字词。因此,我们也将多层次分类任务转换为与完形填空相似的形式。例如影评情感分类任务,标签分为`1-正向`,`0-负向`两类,在经典的微调方式中,需要学习的参数是以`[CLS]`向量为输入,以负向/正向为输出的随机初始化的分类器。而在提示学习中,我们通过构造提示,将原有的分类任务转化为完形填空。如下图所示,通过提示`我[MASK]喜欢。`,原有`1-正向`,`0-负向`的标签被转化为了预测空格是`很`还是`不`。此时的分类器也不再是随机初始化,而是利用了这两个字的预训练向量来初始化,充分利用了预训练模型学习到的参数。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我们以Ernie为例,回顾一下这类预训练语言模型的训练任务。与考试中的完形填空相似,给定一句文本,遮盖掉其中的部分字词,要求语言模型预测出这些遮盖位置原本的字词。因此,我们也将多层次分类任务转换为与完形填空相似的形式。例如影评情感分类任务,标签分为`1-正向`,`0-负向`两类,在经典的微调方式中,需要学习的参数是以`[CLS]`向量为输入,以负向/正向为输出的随机初始化的分类器。而在提示学习中,我们通过构造提示,将原有的分类任务转化为完形填空。如下图所示,通过提示`我[MASK]喜欢。`,原有`1-正向`,`0-负向`的标签被转化为了预测空格是`很`还是`不`。此时的分类器也不再是随机初始化,而是利用了这两个字的预训练向量来初始化,充分利用了预训练模型学习到的参数。 | |
我们以Ernie为例,回顾一下这类预训练语言模型的训练任务。 | |
与考试中的完形填空相似,给定一句文本,遮盖掉其中的部分字词,要求语言模型预测出这些遮盖位置原本的字词。 | |
因此,我们也将多层次分类任务转换为与完形填空相似的形式。例如影评情感分类任务,标签分为`1-正向`,`0-负向`两类。 | |
在经典的微调方式中,需要学习的参数是以`[CLS]`向量为输入,以负向/正向为输出的随机初始化的分类器。而在提示学习中,我们通过构造提示,将原有的分类任务转化为完形填空。 | |
如下图所示,通过提示`我[MASK]喜欢。`,原有`1-正向`,`0-负向`的标签被转化为了预测空格是`很`还是`不`。此时的分类器也不再是随机初始化,而是利用了这两个字的预训练向量来初始化,充分利用了预训练模型学习到的参数。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有点偏长了,看起来很累,建议分段
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已切分为三个段落。
|
||
对于训练/验证/测试数据集文件,每行数据表示一条样本,包括文本和标签两部分,由tab符`\t`分隔,多个标签以`,`分隔。例如 | ||
``` | ||
紫光圣果副总经理李明雷辞职 组织关系,组织关系##辞/离职 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不同层次,类别关系之间使用 ##
字符连接
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已增加相应的描述。
@dataclass | ||
class ModelArguments: | ||
model_name_or_path: str = field(default="ernie-3.0-base-zh", metadata={"help": "Build-in pretrained model name or the path to local model."}) | ||
export_type: str = field(default='paddle', metadata={"help": "The type to export. Support `paddle` and `onnx`."}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认是导出为 paddle 格式?我看infer.py 用的是 onnx?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为后续还要接入serving部署,模型格式是静态图,所以这里默认格式设为paddle,统一导出静态图,onnx部署时再进行格式转换。
PromptModelForSequenceClassification, | ||
) | ||
from utils import load_local_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PromptModelForSequenceClassification, | |
) | |
from utils import load_local_dataset | |
PromptModelForSequenceClassification, | |
) | |
from utils import load_local_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不同类型的包import之间已增加空行。
|
||
### 模型部署 | ||
|
||
#### CPU端推理样例 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块默认 onnx 的话,需要增加 onnx runtime的安装文档之类的,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已增加onnx runtime依赖安装相关的说明。
对于分类标签集文件,存储了数据集中所有的标签集合,每行为一个标签名。如果需要自定义标签映射用于分类器初始化,则每行需要包括标签名和相应的映射词,由`==`分隔。例如 | ||
``` | ||
news_car'=='汽车 | ||
news_culture'=='文化 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以把 tnews 的示例展示在这里,默认 折叠一下。纯粹描述有点抽象
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里例如就是tnews中的例子了。这里是指把全部标签放上去吗?
return [p for p in self.template.parameters() | ||
] + [p for p in self.verbalizer.parameters()] | ||
|
||
def get_input_spec(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对 prompt 任务,这里的 InputSpec 基本可以固定死?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的get_input_spec
默认是SoftTemplate下的InputSpec,使用MannualTemplate时需要手动传入。
paddlenlp/prompt/prompt_utils.py
Outdated
import paddle | ||
from ..utils.log import logger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import paddle | |
from ..utils.log import logger | |
import paddle | |
from ..utils.log import logger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上。
training_args.print_config(model_args, "Model") | ||
training_args.print_config(data_args, "Data") | ||
|
||
paddle.set_device(training_args.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议固定一下随机种子
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认 trainer 里面应该有设置 seed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PromptTrainer继承自Trainer,在初始化时会固定seed
。
trainer.log_metrics("test", test_ret.metrics) | ||
|
||
# Export static model. | ||
if training_args.do_export: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最终dump下来的文件,ckpt -> checkpoints
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_dir
默认值已改为 checkpoints
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Description
How to use