-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fuse batch norm for conv operator #9792
Changes from 6 commits
17833d3
16e3134
ea0cf6f
5483258
3f320c1
7815cdf
6e735e1
f45818e
ec512cd
01b88f2
81c47b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import os | ||
import shutil | ||
from . import core | ||
|
||
|
||
class InferenceTranspiler: | ||
def transpile(self, program, scope, place): | ||
''' | ||
Transpile the program to a inference program by fused batch normalization. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. InferenceTranspiler以后可能会支持多种变换,所以fuse batch norm op的内容,最好单独写一个函数,或者定义一个类。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
The batch normalization followed the convolution or fully connected layer | ||
can be integrated with them. Doing so will give us a forward acceleration, | ||
especially in environments like mobile or embedded. | ||
|
||
For input X: | ||
- Conv process: X = input * W + bias | ||
- Batch norm process: X' = (X - mean) / std | ||
- Scale Process: Y = a * X' + b | ||
|
||
After fuse into one operation: | ||
|
||
Y = (input * W + bias - mean) / std * a + b | ||
= input * a * W / std + ((bias - mean) / std * a + b) | ||
|
||
The operator transformation is: | ||
- before: | ||
- conv->batch_norm->any_other_op (bias == 0) | ||
- conv->elementwise_add->batch_norm->any_other_op (bias != 0) | ||
- after: | ||
- conv->elementwise_add->any_other_op | ||
|
||
The transpile stages are: | ||
1. insert elementwise_add op when bias == 0. | ||
2. fuse the batch_norm's parameters to conv and elementwise_add operators. | ||
3. remove batch_norm ops which are not used in any other ops. | ||
4. adjust the input of any_other_op to be the output of elementwise_add operator. | ||
5. remove unused variables. | ||
|
||
:param program: program to transpile | ||
:type program: Program | ||
:param scope: inference scope | ||
:type scope: Scope | ||
:param place: inference place | ||
:type place: Place | ||
:return: program by fused batch normalization | ||
:rtype: Program | ||
''' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. program的变换,目前来看可有两种实现方式:
无论哪种方式,transpile之后最好还是保证原来porgram还是可以正常执行的。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 采用第一种方式,直接修改program |
||
self.scope = scope | ||
self.place = place | ||
self.block = program.block(0) | ||
self.input_map = {} # store the input names should be adjusted | ||
|
||
i = 0 | ||
while i < len(self.block.ops): | ||
current_op = self.block.ops[i] | ||
# TODO(luotao1): consider only conv2d now. fc would be delt later. | ||
if current_op.type in ['conv2d']: | ||
next_op = self.block.ops[i + 1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这种获取
这种情况,获取到的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Xreki 可以详细描述一下么, 为什么获取到的next_op 是不正确的呢 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果ops里的顺序为:
那么用next_op就会出错。不过考虑到bn一般都紧跟在conv2d后面,能否以后再进行处理? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 对于这种有分支的网络,some op的next op应该有两个,并且这种next op的取法依赖于op的保存顺序。batch_norm op几乎不会出现在这种分叉处,暂时可以先不用处理,加个comment说明一下吧。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
# conv2d without bias | ||
if (next_op.type == 'batch_norm'): | ||
# insert bias op | ||
bias_op = self._insert_bias_op(i + 1, current_op, next_op) | ||
# fuse batch_norm | ||
self._fuse_param(current_op, next_op, bias_op, 0) | ||
# remove batch_norm_op | ||
self.block.remove_op(i + 2) | ||
i = i + 1 | ||
# conv2d with bias, the next_op.type is elementwise_add | ||
elif (next_op.type == 'elementwise_add'): | ||
next_next_op = self.block.ops[i + 2] | ||
if (next_next_op.type == 'batch_norm'): | ||
# fuse batch_norm | ||
self._fuse_param(current_op, next_next_op, next_op, 1) | ||
# remove batch_norm_op | ||
self.block.remove_op(i + 2) | ||
i = i + 1 | ||
i = i + 1 | ||
|
||
self._adjust_input() | ||
self._remove_unused_var() | ||
# TODO(luotao): use clone() method to flush the program.desc in force, | ||
# since some large program.desc will not be flushed immediately. | ||
# And a better solution will be considered later. | ||
return program.clone() | ||
|
||
# ====================== private transpiler functions ===================== | ||
def _insert_bias_op(self, index, current_op, bn_op): | ||
''' | ||
Construct elementwise_add operator for adding bias | ||
and insert it into program. | ||
|
||
:param index: insert location of bias_op | ||
:type index: Int | ||
:param current_op: current operator (conv or fc) | ||
:type current_op: Operator | ||
:param bn_op: batch norm operator | ||
:type bn_op: Operator | ||
:return: bias_op | ||
:rtype: Operator | ||
''' | ||
# The input of bias_op is current_op's output and Bias of bn_op | ||
# The output of bias_op is bn_op's output | ||
x_var = self.block.var(current_op.output("Output")[0]) | ||
y_var = self.block.var(bn_op.input("Bias")[0]) | ||
out_var = self.block.var(bn_op.output("Y")[0]) | ||
|
||
bias_op = self.block.insert_op( | ||
index, | ||
type="elementwise_add", | ||
inputs={"X": x_var, | ||
"Y": y_var}, | ||
outputs={"Out": out_var}, | ||
attrs={"axis": 1}) # dim_start=1 | ||
return bias_op | ||
|
||
def _fuse_param(self, current_op, bn_op, bias_op, with_bias): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个函数,我理解,如果current_op是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Xreki 原来的program是指什么呢? 这里merge bn后,可以不用考虑原来的program是什么样子的吧,只要保证现在merge 后的正确就可以的。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. 名字都改成了类似conv2d_w0_fuse_bn. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NHZlX 用户的行为难以预料,一旦用户clone了一个program的备份,并且希望后面继续使用这个program,就可能会出错。 |
||
''' | ||
fuse the batch_norm_op' parameters to current_op (conv or fc) | ||
|
||
:param current_op: current operator (conv or fc) | ||
:type current_op: Operator | ||
:param bn_op: batch norm operator | ||
:type bn_op: Operator | ||
:param bias_op: elementwise_add operator for adding bias | ||
:type bias_op: Operator | ||
:param with_bias: If current operator has bias, with_bias = 1; otherwise 0. | ||
:type with_bias: Int | ||
''' | ||
|
||
def _load_tensor(param_name): | ||
return self.scope.find_var(param_name[0]).get_tensor() | ||
|
||
def _load_param(param_name): | ||
return np.array(_load_tensor(param_name)) | ||
|
||
bias_bn = _load_param(bn_op.input("Bias")) #Bias | ||
scale_bn = _load_param(bn_op.input("Scale")) #Scale | ||
mean_bn = _load_param(bn_op.input("Mean")) #Mean | ||
var_bn = _load_param(bn_op.input("Variance")) #Variance | ||
|
||
# TODO(luotao1): consider only conv2d now. fc would be delt later. | ||
current_param = _load_param(current_op.input("Filter")) | ||
current_tensor = _load_tensor(current_op.input("Filter")) | ||
|
||
std_bn = np.float32(np.sqrt(np.add(var_bn, 1e-5))) | ||
tmp = np.float32(np.divide(scale_bn, std_bn)) | ||
|
||
# add bias of batch_norm_op to conv2d | ||
if with_bias: | ||
bias = _load_param(bias_op.input("Y")) | ||
else: | ||
bias = np.zeros(bias_bn.shape) | ||
bias = np.float32( | ||
np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn)) | ||
bias_tensor = _load_tensor(bias_op.input("Y")) | ||
bias_tensor.set(bias, self.place) | ||
|
||
# re-compute weight of conv2d | ||
tmp = tmp.reshape(tmp.shape[0], -1) | ||
dst_param = current_param.reshape((tmp.shape[0], -1)) | ||
dst_param = np.float32(np.multiply(dst_param, tmp)) | ||
dst_param = dst_param.reshape(current_param.shape) | ||
|
||
# set the updated parameters | ||
current_tensor.set(np.array(dst_param), self.place) | ||
|
||
# collect the renamed input | ||
self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0] | ||
|
||
def _adjust_input(self): | ||
for i in range(len(self.block.ops)): | ||
current_op = self.block.ops[i] | ||
for input_arg in current_op.input_arg_names: | ||
if input_arg in self.input_map: | ||
current_op.rename_input(input_arg, | ||
self.input_map[input_arg]) | ||
|
||
def _remove_unused_var(self): | ||
''' | ||
remove unused varibles in program | ||
''' | ||
args = [] | ||
for i in range(len(self.block.ops)): | ||
current_op = self.block.ops[i] | ||
args += current_op.input_arg_names | ||
args += current_op.output_arg_names | ||
args = list(set(args)) # unique the input and output arguments | ||
|
||
for var in self.block.vars.keys(): | ||
if var not in args: | ||
self.block.remove_var(var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scope
为None
时,可使用默认scope。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done