-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fuse batch norm for conv operator #9792
Conversation
class InferenceTranspiler: | ||
def transpile(self, program, scope, place): | ||
''' | ||
Transpile the program to a inference program by fused batch normalization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InferenceTranspiler以后可能会支持多种变换,所以fuse batch norm op的内容,最好单独写一个函数,或者定义一个类。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
:type place: Place | ||
:return: program by fused batch normalization | ||
:rtype: Program | ||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
program的变换,目前来看可有两种实现方式:
- 直接修改input program(memory optimization transplier和distribute transpiler似乎都是使用的这种方式),没有返回值,让用户明确地知道这个program将会被修改。如果用户希望保留一份原来的program,那么在调用transpiler之前就会自己clone一份。
- 返回一个program,那么最好不要修改输入program。
无论哪种方式,transpile之后最好还是保证原来porgram还是可以正常执行的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
采用第一种方式,直接修改program
current_op = self.block.ops[i] | ||
# TODO(luotao1): consider only conv2d now. fc would be delt later. | ||
if current_op.type in ['conv2d']: | ||
next_op = self.block.ops[i + 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种获取next_op
的方式只适用于单链式的网络,如果网络有分支,比如:
some op
/ \
conv2d_1 conv2d_2
| |
batch_norm_1 batch_norm_2
\ /
fc
这种情况,获取到的next_op
的信息就是不正确的。不过,batch_norm op
一般倒是不会出现在这样的分叉处。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Xreki 可以详细描述一下么, 为什么获取到的next_op 是不正确的呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果ops里的顺序为:
some op
conv2d_1
conv2d_2
batch_norm_1
batch_norm_2
fc
那么用next_op就会出错。不过考虑到bn一般都紧跟在conv2d后面,能否以后再进行处理?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于这种有分支的网络,some op的next op应该有两个,并且这种next op的取法依赖于op的保存顺序。batch_norm op几乎不会出现在这种分叉处,暂时可以先不用处理,加个comment说明一下吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
attrs={"axis": 1}) # dim_start=1 | ||
return bias_op | ||
|
||
def _fuse_param(self, current_op, bn_op, bias_op, with_bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数,我理解,如果current_op是conv2d
类型,那么应该会直接修改这个op的参数的值?这样的话,改完后,原来的program
执行结果将不对了。
我的建议是,比如原来的参数名为conv2d_w0
,那么重新定义一个名为conv2d_fuse_bn_w0
的变量,将修改后的参数保存到这个变量里面,然后将目标program对应op的参数变量名rename成conv2d_fuse_bn_w0
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Xreki 原来的program是指什么呢? 这里merge bn后,可以不用考虑原来的program是什么样子的吧,只要保证现在merge 后的正确就可以的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. 名字都改成了类似conv2d_w0_fuse_bn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NHZlX 用户的行为难以预料,一旦用户clone了一个program的备份,并且希望后面继续使用这个program,就可能会出错。
已调整单测顺序,保证transpiler后原来的program保持不变。 |
LGTM |
|
||
|
||
class InferenceTranspiler: | ||
def transpile(self, program, scope, place): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 调整一下参数的顺序吧,变成:
def transpile(self, program, place, scope=None):
- 当用户传入参数
scope
为None
时,可使用默认scope。 - 检查一下三个参数的类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
# TODO(luotao): use clone() method to flush the program.desc in force, | ||
# since some large program.desc will not be flushed immediately. | ||
# And a better solution will be considered later. | ||
program = program.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个clone
我不确定是不是还是必须的,我特意把这里注释了,单独跑例子也没有出现问题。不过这里多一遍clone
也没有关系。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
related #9629
Note that this PR modify the program desc from C++ end, and discussed with @jacquesqiao, it is better to modify the program from Python end.