Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fuse batch norm for conv operator #9792

Merged
merged 11 commits into from
Apr 18, 2018
Merged

fuse batch norm for conv operator #9792

merged 11 commits into from
Apr 18, 2018

Conversation

luotao1
Copy link
Contributor

@luotao1 luotao1 commented Apr 9, 2018

related #9629

  • solve the fuse batch norm for conv operator with and without bias
  • after fusing, the elapsed time on resnet (test_inference_image_classification) is from 11.2s to 9.3s, about 10% speedup on inference.

Note that this PR modify the program desc from C++ end, and discussed with @jacquesqiao, it is better to modify the program from Python end.

@luotao1 luotao1 added the 预测 原名Inference,包含Capi预测问题等 label Apr 9, 2018
@luotao1 luotao1 changed the title fuse batch norm for conv operator without bias fuse batch norm for conv operator Apr 10, 2018
@luotao1 luotao1 requested review from Xreki and NHZlX April 11, 2018 02:11
class InferenceTranspiler:
def transpile(self, program, scope, place):
'''
Transpile the program to a inference program by fused batch normalization.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferenceTranspiler以后可能会支持多种变换,所以fuse batch norm op的内容,最好单独写一个函数,或者定义一个类。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

:type place: Place
:return: program by fused batch normalization
:rtype: Program
'''
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

program的变换,目前来看可有两种实现方式:

  • 直接修改input program(memory optimization transplier和distribute transpiler似乎都是使用的这种方式),没有返回值,让用户明确地知道这个program将会被修改。如果用户希望保留一份原来的program,那么在调用transpiler之前就会自己clone一份。
  • 返回一个program,那么最好不要修改输入program。

无论哪种方式,transpile之后最好还是保证原来porgram还是可以正常执行的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

采用第一种方式,直接修改program

current_op = self.block.ops[i]
# TODO(luotao1): consider only conv2d now. fc would be delt later.
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种获取next_op的方式只适用于单链式的网络,如果网络有分支,比如:

             some op
            /       \
     conv2d_1       conv2d_2
            |        |
     batch_norm_1   batch_norm_2
             \      /
                 fc

这种情况,获取到的next_op的信息就是不正确的。不过,batch_norm op一般倒是不会出现在这样的分叉处。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Xreki 可以详细描述一下么, 为什么获取到的next_op 是不正确的呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果ops里的顺序为:

some op
conv2d_1
conv2d_2
batch_norm_1
batch_norm_2
fc

那么用next_op就会出错。不过考虑到bn一般都紧跟在conv2d后面,能否以后再进行处理?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于这种有分支的网络,some op的next op应该有两个,并且这种next op的取法依赖于op的保存顺序。batch_norm op几乎不会出现在这种分叉处,暂时可以先不用处理,加个comment说明一下吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

attrs={"axis": 1}) # dim_start=1
return bias_op

def _fuse_param(self, current_op, bn_op, bias_op, with_bias):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数,我理解,如果current_op是conv2d类型,那么应该会直接修改这个op的参数的值?这样的话,改完后,原来的program执行结果将不对了。
我的建议是,比如原来的参数名为conv2d_w0,那么重新定义一个名为conv2d_fuse_bn_w0的变量,将修改后的参数保存到这个变量里面,然后将目标program对应op的参数变量名rename成conv2d_fuse_bn_w0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Xreki 原来的program是指什么呢? 这里merge bn后,可以不用考虑原来的program是什么样子的吧,只要保证现在merge 后的正确就可以的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. 名字都改成了类似conv2d_w0_fuse_bn.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NHZlX 用户的行为难以预料,一旦用户clone了一个program的备份,并且希望后面继续使用这个program,就可能会出错。

@luotao1
Copy link
Contributor Author

luotao1 commented Apr 13, 2018

已调整单测顺序,保证transpiler后原来的program保持不变。

@NHZlX
Copy link
Contributor

NHZlX commented Apr 13, 2018

LGTM



class InferenceTranspiler:
def transpile(self, program, scope, place):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 调整一下参数的顺序吧,变成:
def transpile(self, program, place, scope=None):
  • 当用户传入参数scopeNone时,可使用默认scope
  • 检查一下三个参数的类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个clone我不确定是不是还是必须的,我特意把这里注释了,单独跑例子也没有出现问题。不过这里多一遍clone也没有关系。

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
预测 原名Inference,包含Capi预测问题等
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants