Skip to content

Commit

Permalink
upgrade net parameter data transformation fields automagically
Browse files Browse the repository at this point in the history
Convert DataParameter and ImageDataParameter data transformation fields
into a TransformationParameter.
  • Loading branch information
shelhamer committed Aug 22, 2014
1 parent af1d065 commit c14c8be
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/caffe/util/upgrade_proto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection,

LayerParameter_LayerType UpgradeV0LayerType(const string& type);

// Return true iff any layer contains deprecated data transformation parameters.
bool NetNeedsDataUpgrade(const NetParameter& net_param);

// Perform all necessary transformations to upgrade old transformation fields
// into a TransformationParameter.
void UpgradeNetDataTransformation(NetParameter* net_param);

// Convert a NetParameter to NetParameterPrettyPrint used for dumping to
// proto text files.
void NetParameterToPrettyPrint(const NetParameter& param,
Expand Down
79 changes: 79 additions & 0 deletions src/caffe/util/upgrade_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,75 @@ LayerParameter_LayerType UpgradeV0LayerType(const string& type) {
}
}

bool NetNeedsDataUpgrade(const NetParameter& net_param) {
for (int i = 0; i < net_param.layers_size(); ++i) {
if (net_param.layers(i).type() == LayerParameter_LayerType_DATA) {
DataParameter layer_param = net_param.layers(i).data_param();
if (layer_param.has_scale()) { return true; }
if (layer_param.has_mean_file()) { return true; }
if (layer_param.has_crop_size()) { return true; }
if (layer_param.has_mirror()) { return true; }
}
if (net_param.layers(i).type() == LayerParameter_LayerType_IMAGE_DATA) {
ImageDataParameter layer_param = net_param.layers(i).image_data_param();
if (layer_param.has_scale()) { return true; }
if (layer_param.has_mean_file()) { return true; }
if (layer_param.has_crop_size()) { return true; }
if (layer_param.has_mirror()) { return true; }
}
}
return false;
}

void UpgradeNetDataTransformation(NetParameter* net_param) {
for (int i = 0; i < net_param->layers_size(); ++i) {
if (net_param->layers(i).type() == LayerParameter_LayerType_DATA) {
DataParameter* layer_param =
net_param->mutable_layers(i)->mutable_data_param();
TransformationParameter* transform_param =
layer_param->mutable_transform_param();
if (layer_param->has_scale()) {
transform_param->set_scale(layer_param->scale());
layer_param->clear_scale();
}
if (layer_param->has_mean_file()) {
transform_param->set_mean_file(layer_param->mean_file());
layer_param->clear_mean_file();
}
if (layer_param->has_crop_size()) {
transform_param->set_crop_size(layer_param->crop_size());
layer_param->clear_crop_size();
}
if (layer_param->has_mirror()) {
transform_param->set_mirror(layer_param->mirror());
layer_param->clear_mirror();
}
}
if (net_param->layers(i).type() == LayerParameter_LayerType_IMAGE_DATA) {
ImageDataParameter* layer_param =
net_param->mutable_layers(i)->mutable_image_data_param();
TransformationParameter* transform_param =
layer_param->mutable_transform_param();
if (layer_param->has_scale()) {
transform_param->set_scale(layer_param->scale());
layer_param->clear_scale();
}
if (layer_param->has_mean_file()) {
transform_param->set_mean_file(layer_param->mean_file());
layer_param->clear_mean_file();
}
if (layer_param->has_crop_size()) {
transform_param->set_crop_size(layer_param->crop_size());
layer_param->clear_crop_size();
}
if (layer_param->has_mirror()) {
transform_param->set_mirror(layer_param->mirror());
layer_param->clear_mirror();
}
}
}
}

void NetParameterToPrettyPrint(const NetParameter& param,
NetParameterPrettyPrint* pretty_param) {
pretty_param->Clear();
Expand Down Expand Up @@ -586,6 +655,16 @@ void UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
<< "prototxt and ./build/tools/upgrade_net_proto_binary for model "
<< "weights upgrade this and any other net protos to the new format.";
}
// NetParameter uses old style data transformation fields; try to upgrade it.
if (NetNeedsDataUpgrade(*param)) {
LOG(ERROR) << "Attempting to upgrade input file specified using deprecated "
<< "transformation parameters: " << param_file;
UpgradeNetDataTransformation(param);
LOG(INFO) << "Successfully upgraded file specified using deprecated "
<< "data transformation parameters.";
LOG(ERROR) << "Note that future Caffe releases will only support "
<< "transform_param messages for transformation fields.";
}
}

void ReadNetParamsFromTextFileOrDie(const string& param_file,
Expand Down
5 changes: 5 additions & 0 deletions tools/upgrade_net_proto_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ int main(int argc, char** argv) {
return 2;
}
bool need_upgrade = NetNeedsUpgrade(net_param);
bool need_data_upgrade = NetNeedsDataUpgrade(net_param);
bool success = true;
if (need_upgrade) {
NetParameter v0_net_param(net_param);
Expand All @@ -37,6 +38,10 @@ int main(int argc, char** argv) {
LOG(ERROR) << "File already in V1 proto format: " << argv[1];
}

if (need_data_upgrade) {
UpgradeNetDataTransformation(&net_param);
}

// Convert to a NetParameterPrettyPrint to print fields in desired
// order.
NetParameterPrettyPrint net_param_pretty;
Expand Down

0 comments on commit c14c8be

Please sign in to comment.