Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
* extended caffe to mxnet converter and improved converter test (#6822)
Browse files Browse the repository at this point in the history
- added support for networks which uses batch normalization without a scale layer following the batch norm, i.e. gamma is fixed to 1
- extended naming convention used when implementing batch normalization in caffe
- added support for old caffe versions where dilation didn't exist. This is needed to convert models which depends on old caffe
- added support for deconvolution layer
- added support for older version of caffe where kernel_size, pad and stride parameters were not iterable
- fixed crash happening when a bottom layer doesn't exist in the internal top_to_layers dictionary, this can happen if the name of the input is not 'data'
- added ignore-by-design support for converting 'Crop' layers
- fixed batch norm layer comparison to take into account the rescaling factor
- added careful condition in tester to swap (RGB,BGR) input channels only if they are of size 3 or 4, which is the same check the conversion does
- allow comparing layers of models with no mean file
- added support for comparing the parameters of deconvolution layers
arikpoz authored and piiswrong committed Jun 28, 2017
1 parent 09e08f5 commit 8c81ee4
Showing 4 changed files with 90 additions and 31 deletions.
38 changes: 27 additions & 11 deletions tools/caffe_converter/caffe_proto_utils.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,9 @@ def process_network_proto(caffe_root, deploy_proto):


class LayerRecord(object):
"""
A record which describe basic layer parameters
"""

def __init__(self, layer_def):

@@ -35,15 +38,24 @@ def __init__(self, layer_def):

# keep filter, stride and pad
if layer_def.type == 'Convolution':
self.filter = list(layer_def.convolution_param.kernel_size)
if LayerRecord._is_iterable(layer_def.convolution_param.kernel_size):
self.filter = list(layer_def.convolution_param.kernel_size)
else:
self.filter = list([layer_def.convolution_param.kernel_size])
if len(self.filter) == 1:
self.filter *= 2
self.pad = list(layer_def.convolution_param.pad)
if LayerRecord._is_iterable(layer_def.convolution_param.pad):
self.pad = list(layer_def.convolution_param.pad)
else:
self.pad = list([layer_def.convolution_param.pad])
if len(self.pad) == 0:
self.pad = [0, 0]
elif len(self.pad) == 1:
self.pad *= 2
self.stride = list(layer_def.convolution_param.stride)
if LayerRecord._is_iterable(layer_def.convolution_param.stride):
self.stride = list(layer_def.convolution_param.stride)
else:
self.stride = list([layer_def.convolution_param.stride])
if len(self.stride) == 0:
self.stride = [1, 1]
elif len(self.stride) == 1:
@@ -81,6 +93,9 @@ def __init__(self, layer_def):
# list of child layers
self.children = []

@staticmethod
def _is_iterable(obj):
return hasattr(obj, '__iter__')

def read_network_dag(processed_deploy_prototxt):
"""
@@ -123,16 +138,17 @@ def read_network_dag(processed_deploy_prototxt):
top_to_layers[top].append(layer.name)

# find parents and children of all layers
for child_layer_name in layer_name_to_record.keys():
for child_layer_name in layer_name_to_record.keys(): # pylint: disable=too-many-nested-blocks
child_layer_def = layer_name_to_record[child_layer_name]
for bottom in child_layer_def.bottoms:
for parent_layer_name in top_to_layers[bottom]:
if parent_layer_name in layer_name_to_record:
parent_layer_def = layer_name_to_record[parent_layer_name]
if parent_layer_def not in child_layer_def.parents:
child_layer_def.parents.append(parent_layer_def)
if child_layer_def not in parent_layer_def.children:
parent_layer_def.children.append(child_layer_def)
if bottom in top_to_layers:
for parent_layer_name in top_to_layers[bottom]:
if parent_layer_name in layer_name_to_record:
parent_layer_def = layer_name_to_record[parent_layer_name]
if parent_layer_def not in child_layer_def.parents:
child_layer_def.parents.append(parent_layer_def)
if child_layer_def not in parent_layer_def.children:
parent_layer_def.children.append(child_layer_def)

# update filter, strid, pad for maxout "structures"
for layer_name in layer_name_to_record.keys():
39 changes: 29 additions & 10 deletions tools/caffe_converter/compare_layers.py
Original file line number Diff line number Diff line change
@@ -79,6 +79,8 @@ def convert_and_compare_caffe_to_mxnet(image_url, gpu, caffe_prototxt_path, caff

if isinstance(caffe_mean, str):
caffe_mean = read_caffe_mean(caffe_mean)
elif caffe_mean is None:
pass
elif len(caffe_mean) == 3:
# swap channels from Caffe BGR to RGB
caffe_mean = caffe_mean[::-1]
@@ -188,7 +190,8 @@ def _process_layer_parameters(layer):
normalized_layer_name = re.sub('[-/]', '_', layer.name)

# handle weight and bias of convolution and fully-connected layers
if layer.name in caffe_net.params and layer.type in ['Convolution', 'InnerProduct']:
if layer.name in caffe_net.params and layer.type in ['Convolution', 'InnerProduct',
'Deconvolution']:

has_bias = len(caffe_net.params[layer.name]) > 1

@@ -199,8 +202,10 @@ def _process_layer_parameters(layer):
if layer.type == 'Convolution' and compare_layers_from_nets.is_first_convolution:
compare_layers_from_nets.is_first_convolution = False

# swapping BGR of caffe into RGB in mxnet
mx_beta = mx_beta[:, ::-1, :, :]
# if RGB or RGBA
if mx_beta.shape[1] == 3 or mx_beta.shape[1] == 4:
# Swapping BGR of caffe into RGB in mxnet
mx_beta[:, [0, 2], :, :] = mx_beta[:, [2, 0], :, :]

caf_beta = caffe_net.params[layer.name][0].data
_compare_blob(caf_beta, mx_beta, layer.name, mx_name_weight, 'weight', '')
@@ -213,7 +218,13 @@ def _process_layer_parameters(layer):

elif layer.name in caffe_net.params and layer.type == 'Scale':

bn_name = normalized_layer_name.replace('scale', 'bn')
if 'scale' in normalized_layer_name:
bn_name = normalized_layer_name.replace('scale', 'bn')
elif 'sc' in normalized_layer_name:
bn_name = normalized_layer_name.replace('sc', 'bn')
else:
assert False, 'Unknown name convention for bn/scale'

beta_name = '{}_beta'.format(bn_name)
gamma_name = '{}_gamma'.format(bn_name)

@@ -230,17 +241,19 @@ def _process_layer_parameters(layer):
mean_name = '{}_moving_mean'.format(normalized_layer_name)
var_name = '{}_moving_var'.format(normalized_layer_name)

caf_rescale_factor = caffe_net.params[layer.name][2].data

mx_mean = aux_params[mean_name].asnumpy()
caf_mean = caffe_net.params[layer.name][0].data
caf_mean = caffe_net.params[layer.name][0].data / caf_rescale_factor
_compare_blob(caf_mean, mx_mean, layer.name, mean_name, 'mean', '')

mx_var = aux_params[var_name].asnumpy()
caf_var = caffe_net.params[layer.name][1].data
caf_var = caffe_net.params[layer.name][1].data / caf_rescale_factor
_compare_blob(caf_var, mx_var, layer.name, var_name, 'var',
'expect 1e-04 change due to cudnn eps')

elif layer.type in ['Input', 'Pooling', 'ReLU', 'Eltwise', 'Softmax', 'LRN', 'Concat',
'Dropout']:
'Dropout', 'Crop']:
# no parameters to check for these layers
pass

@@ -262,16 +275,22 @@ def _process_layer_output(caffe_blob_name):

# data should change from BGR to RGB
if caffe_blob_name == 'data':
# swapping BGR of caffe into RGB in mxnet
caf_blob = caf_blob[:, ::-1, :, :]

# if RGB or RGBA
if caf_blob.shape[1] == 3 or caf_blob.shape[1] == 4:
# Swapping BGR of caffe into RGB in mxnet
caf_blob[:, [0, 2], :, :] = caf_blob[:, [2, 0], :, :]
mx_name = 'data'

else:
# get last layer name which outputs this blob name
last_layer_name = top_to_layers[caffe_blob_name][-1]
normalized_last_layer_name = re.sub('[-/]', '_', last_layer_name)
mx_name = '{}_output'.format(normalized_last_layer_name)
mx_name = mx_name.replace('scale', 'bn')
if 'scale' in mx_name:
mx_name = mx_name.replace('scale', 'bn')
elif 'sc' in mx_name:
mx_name = mx_name.replace('sc', 'bn')

if mx_name not in exe.output_dict:
logging.error('mxnet blob %s is missing, time to extend the compare tool..', mx_name)
29 changes: 25 additions & 4 deletions tools/caffe_converter/convert_model.py
Original file line number Diff line number Diff line change
@@ -48,8 +48,9 @@ def convert_model(prototxt_fname, caffemodel_fname, output_prefix=None):
layers_proto = caffe_parser.get_layers(caffe_parser.read_prototxt(prototxt_fname))

for layer_name, layer_type, layer_blobs in layer_iter:
if layer_type == 'Convolution' or layer_type == 'InnerProduct' \
or layer_type == 4 or layer_type == 14 or layer_type == 'PReLU':
if layer_type == 'Convolution' or layer_type == 'InnerProduct' \
or layer_type == 4 or layer_type == 14 or layer_type == 'PReLU' \
or layer_type == 'Deconvolution' or layer_type == 39:
if layer_type == 'PReLU':
assert (len(layer_blobs) == 1)
wmat = layer_blobs[0].data
@@ -108,7 +109,13 @@ def convert_model(prototxt_fname, caffemodel_fname, output_prefix=None):
first_conv = False

elif layer_type == 'Scale':
bn_name = layer_name.replace('scale', 'bn')
if 'scale' in layer_name:
bn_name = layer_name.replace('scale', 'bn')
elif 'sc' in layer_name:
bn_name = layer_name.replace('sc', 'bn')
else:
assert False, 'Unknown name convention for bn/scale'

gamma = np.array(layer_blobs[0].data)
beta = np.array(layer_blobs[1].data)
# beta = np.expand_dims(beta, 1)
@@ -154,9 +161,23 @@ def convert_model(prototxt_fname, caffemodel_fname, output_prefix=None):
assert mean.flags['C_CONTIGUOUS'] is True
print('converting batchnorm layer, mean shape = {}, var shape = {}'.format(
mean.shape, var.shape))

fix_gamma = layers_proto[bn_index+1].type != 'Scale'
if fix_gamma:
gamma_name = '{}_gamma'.format(bn_name)
gamma = np.array(np.ones(arg_shape_dic[gamma_name]))
beta_name = '{}_beta'.format(bn_name)
beta = np.array(np.zeros(arg_shape_dic[beta_name]))
arg_params[beta_name] = mx.nd.zeros(beta.shape)
arg_params[gamma_name] = mx.nd.zeros(gamma.shape)
arg_params[beta_name][:] = beta
arg_params[gamma_name][:] = gamma
assert gamma.flags['C_CONTIGUOUS'] is True
assert beta.flags['C_CONTIGUOUS'] is True

else:
assert len(layer_blobs) == 0
print('\tskipping layer {} of type {}'.format(layer_name, layer_type))
assert len(layer_blobs) == 0

if output_prefix is not None:
model = mx.mod.Module(symbol=sym, label_names=['prob_label', ])
15 changes: 9 additions & 6 deletions tools/caffe_converter/convert_symbol.py
Original file line number Diff line number Diff line change
@@ -69,10 +69,11 @@ def _convert_conv_param(param):
param_string += ", stride=(%d,%d)" % (stride, stride)

dilate = 1
if isinstance(param.dilation, int):
dilate = param.dilation
else:
dilate = 1 if len(param.dilation) == 0 else param.dilation[0]
if hasattr(param, 'dilation'):
if isinstance(param.dilation, int):
dilate = param.dilation
else:
dilate = 1 if len(param.dilation) == 0 else param.dilation[0]

param_string += ", no_bias=%s" % (not param.bias_term)

@@ -189,8 +190,10 @@ def _parse_proto(prototxt_fname):
epsilon = param.eps
if (epsilon <= 1e-05):
epsilon = 1e-04
param_string = 'use_global_stats=%s, fix_gamma=False, eps=%f' % (
param.use_global_stats, epsilon)
# if next layer is scale, don't fix gamma
fix_gamma = layers[i+1].type != 'Scale'
param_string = 'use_global_stats=%s, fix_gamma=%s, eps=%f' % (
param.use_global_stats, fix_gamma, epsilon)
need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
if layer.type == 'Scale':
assert layers[i-1].type == 'BatchNorm'

0 comments on commit 8c81ee4

Please sign in to comment.