-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support generation code for GPT-3 in static graph. #2188
Conversation
dc03363
to
f296df1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -744,10 +751,23 @@ def forward(self, | |||
embedding_output = self.embeddings( | |||
input_ids=input_ids, position_ids=position_ids) | |||
|
|||
causal_mask = paddle.tensor.triu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wangxicoding 此处是否小心下训练性能下降问题?或者类似上面加一下 if self.training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, refine code
if parallel_output: | ||
return logits | ||
|
||
paddle.distributed.init_parallel_env() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle.distributed.init_parallel_env() | |
paddle.distributed.init_parallel_env() |
init_parallel_env 放这里有额外的考虑吗,可否在模型启动时?
会不会出现调用多次 init_parallel_env?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉了,debug时加的,当前collective._c_concat不支持静态图,后续还需在框架内部修复bug。
sorted_probs, sorted_idx = layers.argsort(probs, descending=True) | ||
cum_sorted_probs = layers.cumsum(sorted_probs, axis=1, exclusive=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle.xxx API 如果可以替换 layers API的话,可以替换一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- argsort没有替换的原因: paddle.argsort只返回排序后索引信息,没有value,但内部计算有返回,只是Python API最后没有输出。 参考代码: https://github.com/PaddlePaddle/Paddle/blob/af79273d97b678c2eefd55b48e3ef3352c15a921/python/paddle/tensor/search.py#L114
如果替换,还需调用调用gather取值。觉得多了一步计算。
- layers.cumsum没有替换原因, paddle. cumsum没有exclusive参数, 按照False的逻辑,我看PaddleNLP里的一些demo也直接用paddle. cumsum,为了效果期间,暂且没换。
@@ -12,7 +12,7 @@ rm -rf main_sharding* | |||
task_name="gpt-mp-sharding" | |||
rm -rf output/$task_name/log | |||
|
|||
python -u -m paddle.distributed.fleet.launch \ | |||
python3.7 -u -m paddle.distributed.fleet.launch \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python3.7 -u -m paddle.distributed.fleet.launch \ | |
python -u -m paddle.distributed.fleet.launch \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
#2190 刚merge了一个 gpt-3 的pr,注意冲突 |
7b15283
to
b4915a3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Models
Description
Support Sampling、TopKSampling、TopPSampling in static graph by While op. The generation results are verified by gpt2-medium-en models.