-
Notifications
You must be signed in to change notification settings - Fork 18.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
91b0928
commit 85283aa
Showing
10 changed files
with
645 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
1,1,1,1,1,1,1,1,1,1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# The train/test net protocol buffer definition | ||
net: "examples/mnist/lenet_weighted_train_test.prototxt" | ||
# test_iter specifies how many forward passes the test should carry out. | ||
# In the case of MNIST, we have test batch size 100 and 100 test iterations, | ||
# covering the full 10,000 testing images. | ||
test_iter: 100 | ||
# Carry out testing every 500 training iterations. | ||
test_interval: 500 | ||
# The base learning rate, momentum and the weight decay of the network. | ||
base_lr: 0.01 | ||
momentum: 0.9 | ||
weight_decay: 0.0005 | ||
# The learning rate policy | ||
lr_policy: "inv" | ||
gamma: 0.0001 | ||
power: 0.75 | ||
# Display every 100 iterations | ||
display: 100 | ||
# The maximum number of iterations | ||
max_iter: 10000 | ||
# snapshot intermediate results | ||
snapshot: 5000 | ||
snapshot_prefix: "examples/mnist/lenet" | ||
# solver mode: CPU or GPU | ||
solver_mode: CPU |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
name: "LeNet" | ||
layer { | ||
name: "mnist" | ||
type: "Data" | ||
top: "data" | ||
top: "label" | ||
include { | ||
phase: TRAIN | ||
} | ||
transform_param { | ||
scale: 0.00390625 | ||
} | ||
data_param { | ||
source: "examples/mnist/mnist_train_lmdb" | ||
batch_size: 64 | ||
backend: LMDB | ||
} | ||
} | ||
layer { | ||
name: "mnist" | ||
type: "Data" | ||
top: "data" | ||
top: "label" | ||
include { | ||
phase: TEST | ||
} | ||
transform_param { | ||
scale: 0.00390625 | ||
} | ||
data_param { | ||
source: "examples/mnist/mnist_test_lmdb" | ||
batch_size: 100 | ||
backend: LMDB | ||
} | ||
} | ||
layer { | ||
name: "conv1" | ||
type: "Convolution" | ||
bottom: "data" | ||
top: "conv1" | ||
param { | ||
lr_mult: 1 | ||
} | ||
param { | ||
lr_mult: 2 | ||
} | ||
convolution_param { | ||
num_output: 20 | ||
kernel_size: 5 | ||
stride: 1 | ||
weight_filler { | ||
type: "xavier" | ||
} | ||
bias_filler { | ||
type: "constant" | ||
} | ||
} | ||
} | ||
layer { | ||
name: "pool1" | ||
type: "Pooling" | ||
bottom: "conv1" | ||
top: "pool1" | ||
pooling_param { | ||
pool: MAX | ||
kernel_size: 2 | ||
stride: 2 | ||
} | ||
} | ||
layer { | ||
name: "conv2" | ||
type: "Convolution" | ||
bottom: "pool1" | ||
top: "conv2" | ||
param { | ||
lr_mult: 1 | ||
} | ||
param { | ||
lr_mult: 2 | ||
} | ||
convolution_param { | ||
num_output: 50 | ||
kernel_size: 5 | ||
stride: 1 | ||
weight_filler { | ||
type: "xavier" | ||
} | ||
bias_filler { | ||
type: "constant" | ||
} | ||
} | ||
} | ||
layer { | ||
name: "pool2" | ||
type: "Pooling" | ||
bottom: "conv2" | ||
top: "pool2" | ||
pooling_param { | ||
pool: MAX | ||
kernel_size: 2 | ||
stride: 2 | ||
} | ||
} | ||
layer { | ||
name: "ip1" | ||
type: "InnerProduct" | ||
bottom: "pool2" | ||
top: "ip1" | ||
param { | ||
lr_mult: 1 | ||
} | ||
param { | ||
lr_mult: 2 | ||
} | ||
inner_product_param { | ||
num_output: 500 | ||
weight_filler { | ||
type: "xavier" | ||
} | ||
bias_filler { | ||
type: "constant" | ||
} | ||
} | ||
} | ||
layer { | ||
name: "relu1" | ||
type: "ReLU" | ||
bottom: "ip1" | ||
top: "ip1" | ||
} | ||
layer { | ||
name: "ip2" | ||
type: "InnerProduct" | ||
bottom: "ip1" | ||
top: "ip2" | ||
param { | ||
lr_mult: 1 | ||
} | ||
param { | ||
lr_mult: 2 | ||
} | ||
inner_product_param { | ||
num_output: 10 | ||
weight_filler { | ||
type: "xavier" | ||
} | ||
bias_filler { | ||
type: "constant" | ||
} | ||
} | ||
} | ||
layer { | ||
name: "accuracy" | ||
type: "Accuracy" | ||
bottom: "ip2" | ||
bottom: "label" | ||
top: "accuracy" | ||
include { | ||
phase: TEST | ||
} | ||
} | ||
layer { | ||
name: "loss" | ||
type: "WeightedSoftmaxWithLoss" | ||
bottom: "ip2" | ||
bottom: "label" | ||
top: "loss" | ||
softmax_param{ | ||
weights_file: "examples/mnist/label_weights.txt"} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#!/usr/bin/env sh | ||
set -e | ||
|
||
./build/tools/caffe train --solver=examples/mnist/lenet_weighted_solver.prototxt $@ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
#ifndef CAFFE_SOFTMAX_WITH_LOSS_LAYER_HPP_ | ||
#define CAFFE_SOFTMAX_WITH_LOSS_LAYER_HPP_ | ||
|
||
#include <vector> | ||
|
||
#include "caffe/blob.hpp" | ||
#include "caffe/layer.hpp" | ||
#include "caffe/proto/caffe.pb.h" | ||
|
||
#include "caffe/layers/loss_layer.hpp" | ||
#include "caffe/layers/softmax_layer.hpp" | ||
|
||
namespace caffe { | ||
|
||
/** | ||
* @brief Computes the multinomial logistic loss for a one-of-many | ||
* classification task, passing real-valued predictions through a | ||
* softmax to get a probability distribution over classes. | ||
* | ||
* This layer should be preferred over separate | ||
* SoftmaxLayer + MultinomialLogisticLossLayer | ||
* as its gradient computation is more numerically stable. | ||
* At test time, this layer can be replaced simply by a SoftmaxLayer. | ||
* | ||
* @param bottom input Blob vector (length 2) | ||
* -# @f$ (N \times C \times H \times W) @f$ | ||
* the predictions @f$ x @f$, a Blob with values in | ||
* @f$ [-\infty, +\infty] @f$ indicating the predicted score for each of | ||
* the @f$ K = CHW @f$ classes. This layer maps these scores to a | ||
* probability distribution over classes using the softmax function | ||
* @f$ \hat{p}_{nk} = \exp(x_{nk}) / | ||
* \left[\sum_{k'} \exp(x_{nk'})\right] @f$ (see SoftmaxLayer). | ||
* -# @f$ (N \times 1 \times 1 \times 1) @f$ | ||
* the labels @f$ l @f$, an integer-valued Blob with values | ||
* @f$ l_n \in [0, 1, 2, ..., K - 1] @f$ | ||
* indicating the correct class label among the @f$ K @f$ classes | ||
* @param top output Blob vector (length 1) | ||
* -# @f$ (1 \times 1 \times 1 \times 1) @f$ | ||
* the computed cross-entropy classification loss: @f$ E = | ||
* \frac{-1}{N} \sum\limits_{n=1}^N \log(\hat{p}_{n,l_n}) | ||
* @f$, for softmax output class probabilites @f$ \hat{p} @f$ | ||
*/ | ||
template <typename Dtype> | ||
class WeightedSoftmaxWithLossLayer : public LossLayer<Dtype> { | ||
public: | ||
/** | ||
* @param param provides LossParameter loss_param, with options: | ||
* - ignore_label (optional) | ||
* Specify a label value that should be ignored when computing the loss. | ||
* - normalize (optional, default true) | ||
* If true, the loss is normalized by the number of (nonignored) labels | ||
* present; otherwise the loss is simply summed over spatial locations. | ||
*/ | ||
explicit WeightedSoftmaxWithLossLayer(const LayerParameter& param) | ||
: LossLayer<Dtype>(param) {} | ||
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
virtual void Reshape(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
|
||
virtual inline const char* type() const { return "SoftmaxWithLoss"; } | ||
virtual inline int ExactNumTopBlobs() const { return -1; } | ||
virtual inline int MinTopBlobs() const { return 1; } | ||
virtual inline int MaxTopBlobs() const { return 2; } | ||
|
||
protected: | ||
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, | ||
const vector<Blob<Dtype>*>& top); | ||
/** | ||
* @brief Computes the softmax loss error gradient w.r.t. the predictions. | ||
* | ||
* Gradients cannot be computed with respect to the label inputs (bottom[1]), | ||
* so this method ignores bottom[1] and requires !propagate_down[1], crashing | ||
* if propagate_down[1] is set. | ||
* | ||
* @param top output Blob vector (length 1), providing the error gradient with | ||
* respect to the outputs | ||
* -# @f$ (1 \times 1 \times 1 \times 1) @f$ | ||
* This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$, | ||
* as @f$ \lambda @f$ is the coefficient of this layer's output | ||
* @f$\ell_i@f$ in the overall Net loss | ||
* @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence | ||
* @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$. | ||
* (*Assuming that this top Blob is not used as a bottom (input) by any | ||
* other layer of the Net.) | ||
* @param propagate_down see Layer::Backward. | ||
* propagate_down[1] must be false as we can't compute gradients with | ||
* respect to the labels. | ||
* @param bottom input Blob vector (length 2) | ||
* -# @f$ (N \times C \times H \times W) @f$ | ||
* the predictions @f$ x @f$; Backward computes diff | ||
* @f$ \frac{\partial E}{\partial x} @f$ | ||
* -# @f$ (N \times 1 \times 1 \times 1) @f$ | ||
* the labels -- ignored as we can't compute their error gradients | ||
*/ | ||
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, | ||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); | ||
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, | ||
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); | ||
|
||
/// Read the normalization mode parameter and compute the normalizer based | ||
/// on the blob size. If normalization_mode is VALID, the count of valid | ||
/// outputs will be read from valid_count, unless it is -1 in which case | ||
/// all outputs are assumed to be valid. | ||
virtual Dtype get_normalizer( | ||
LossParameter_NormalizationMode normalization_mode, int valid_count); | ||
|
||
/// The internal SoftmaxLayer used to map predictions to a distribution. | ||
shared_ptr<Layer<Dtype> > softmax_layer_; | ||
/// prob stores the output probability predictions from the SoftmaxLayer. | ||
Blob<Dtype> prob_; | ||
/// bottom vector holder used in call to the underlying SoftmaxLayer::Forward | ||
vector<Blob<Dtype>*> softmax_bottom_vec_; | ||
/// top vector holder used in call to the underlying SoftmaxLayer::Forward | ||
vector<Blob<Dtype>*> softmax_top_vec_; | ||
/// Whether to ignore instances with a certain label. | ||
bool has_ignore_label_; | ||
/// The label indicating that an instance should be ignored. | ||
int ignore_label_; | ||
/// How to normalize the output loss. | ||
LossParameter_NormalizationMode normalization_; | ||
|
||
int softmax_axis_, outer_num_, inner_num_; | ||
vector<float> label_weights; | ||
}; | ||
|
||
} // namespace caffe | ||
|
||
#endif // CAFFE_SOFTMAX_WITH_LOSS_LAYER_HPP_ |
Oops, something went wrong.