-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Benchmark] Optimize bert using fused_ffn and fused_attention #2523
Conversation
custom_white_list=[ | ||
"layer_norm", "softmax", "gelu", | ||
"fused_attention", | ||
"fused_feedforward" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
amp情况下,fused_attention, fused_feedforward
的输入是fp32,那op内部计算走的是fp16还是fp32呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp16的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那非 amp 情况下,算子内部是 fp32吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对的
self.fuse = fuse | ||
if self.fuse: | ||
self.encoder = nn.LayerList([ | ||
FusedTransformerEncoderLayer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里layer变了的话。state_dict中参数的命名是不是也是变了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你是指用非fuse的checkpoint启动fuse的训练?这样应该不支持吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Description
Optimize the benchmark performance.