Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add prefetcher #95

Merged
merged 19 commits into from
Sep 18, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@ ifndef DMLC_CORE
DMLC_CORE = dmlc-core
endif


ifneq ($(USE_OPENMP_ITER), 1)
export NO_OPENMP = 1
endif

ifneq ($(USE_OPENMP_ITER), 1)
ifneq ($(USE_OPENMP), 1)
export NO_OPENMP = 1
endif

Expand Down
9 changes: 6 additions & 3 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,21 @@ def Update(grad, weight, mom):
mean_img="data/cifar/cifar_mean.bin",
rand_crop=True,
rand_mirror=True,
shuffle=False,
input_shape=(3,28,28),
batch_size=batch_size,
nthread=1)
nthread=4,
prefetch_capacity=6)
test_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/test.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=False,
rand_mirror=False,
shuffle=False,
input_shape=(3,28,28),
batch_size=batch_size,
nthread=1)

nthread=4,
prefetch_capacity=6)

def progress(count, total, epoch, toc):
bar_len = 50
Expand Down
37 changes: 19 additions & 18 deletions example/cifar10/cifar10_multi_gpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,24 +148,25 @@ def momentum_update(key, grad, weight):
get_data.GetCifar10()

train_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=True,
rand_mirror=True,
shuffle=True,
input_shape=(3,28,28),
batch_size=batch_size,
nthread=1)

val_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/test.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=False,
rand_mirror=False,
input_shape=(3,28,28),
batch_size=batch_size,
nthread=1)

path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=True,
rand_mirror=True,
shuffle=False,
input_shape=(3,28,28),
batch_size=batch_size,
nthread=4,
prefetch_capacity=6)
test_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/test.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=False,
rand_mirror=False,
shuffle=False,
input_shape=(3,28,28),
batch_size=batch_size,
nthread=4,
prefetch_capacity=6)

def progress(count, total, epoch, tic):
bar_len = 50
Expand Down
3 changes: 2 additions & 1 deletion example/mnist/mlp_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ def Update(grad, weight):

#check data
get_data.GetMNIST_ubyte()

train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
input_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
input_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False)

tmp_label = mx.nd.zeros(name2shape["sm_label"])
Expand Down
2 changes: 2 additions & 0 deletions example/mnist/mlp_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ def updater(key, grad, weight):
train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
input_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
input_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False)

def cal_acc(out, label):
Expand Down
44 changes: 27 additions & 17 deletions include/mxnet/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
#include <vector>
#include <string>
#include <utility>
#include <queue>
#include "./base.h"
#include "./ndarray.h"

namespace mxnet {
/*!
Expand Down Expand Up @@ -59,27 +61,16 @@ struct DataInst {
* data and label, how we use them is to see the DNN implementation.
*/
struct DataBatch {
public:
/*! \brief unique id for instance, can be NULL, sometimes is useful */
unsigned *inst_index;
/*! \brief number of instance */
mshadow::index_t batch_size;
/*! \brief number of padding elements in this batch,
this is used to indicate the last elements in the batch are only padded up to match the batch, and should be discarded */
mshadow::index_t num_batch_padd;
public:
/*! \brief content of dense data, if this DataBatch is dense */
std::vector<TBlob> data;
std::vector<NDArray> data;
/*! \brief extra data to be fed to the network */
std::string extra_data;
public:
/*! \brief constructor */
DataBatch(void) {
inst_index = NULL;
batch_size = 0; num_batch_padd = 0;
}
/*! \brief giving name to the data */
void Naming(std::vector<std::string> names);
DataBatch(void) {}
/*! \brief destructor */
~DataBatch() {}
}; // struct DataBatch

/*! \brief typedef the factory function of data iterator */
Expand Down Expand Up @@ -121,10 +112,29 @@ struct DataIteratorReg
* \endcode
*/
#define MXNET_REGISTER_IO_CHAINED_ITER(name, ChainedDataIterType, HoldingDataIterType) \
static ::mxnet::IIterator<DataBatch>* __create__ ## ChainedDataIteratorType ## __() { \
static ::mxnet::IIterator<DataBatch>* __create__ ## ChainedDataIterType ## __() { \
return new HoldingDataIterType(new ChainedDataIterType); \
} \
DMLC_REGISTRY_REGISTER(::mxnet::DataIteratorReg, DataIteratorReg, name) \
.set_body(__create__ ## ChainedDataIteratorType ## __)
.set_body(__create__ ## ChainedDataIterType ## __)
/*!
* \brief Macro to register three chained Iterators
*
* \code
* // example of registering a imagerec iterator
* MXNET_REGISTER_IO_CHAINED_ITERATOR(ImageRecordIter,
* ImageRecordIter, ImageRecBatchLoader, Prefetcher)
* .describe("batched image record data iterator");
*
* \endcode
*/
#define MXNET_REGISTER_IO_THREE_CHAINED_ITER(\
name, FirstIterType, SecondIterType, ThirdIterType) \
static ::mxnet::IIterator<DataBatch>* __create__ ## ThirdIterType ## __() { \
return new FirstIterType(new SecondIterType(new ThirdIterType)); \
} \
DMLC_REGISTRY_REGISTER(::mxnet::DataIteratorReg, DataIteratorReg, name) \
.set_body(__create__ ## ThirdIterType ## __)

} // namespace mxnet
#endif // MXNET_IO_H_
4 changes: 2 additions & 2 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ def getdata(self):
"""
hdl = NDArrayHandle()
check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl)))
return NDArray(hdl)
return NDArray(hdl, False)

def getlabel(self):
"""get label from batch

"""
hdl = NDArrayHandle()
check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl)))
return NDArray(hdl)
return NDArray(hdl, False)

def _make_io_iterator(handle):
"""Create an io iterator by handle."""
Expand Down
7 changes: 6 additions & 1 deletion python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class NDArray(object):
NDArray is basic ndarray/Tensor like data structure in mxnet.
"""
# pylint: disable= no-member
def __init__(self, handle):
def __init__(self, handle, writable=True):
"""initialize a new NDArray

Parameters
Expand All @@ -61,6 +61,7 @@ def __init__(self, handle):
"""
assert isinstance(handle, NDArrayHandle)
self.handle = handle
self.writable = writable

def __del__(self):
check_call(_LIB.MXNDArrayFree(self.handle))
Expand Down Expand Up @@ -555,6 +556,8 @@ def binary_ndarray_function(lhs, rhs, out=None):
if out:
if isinstance(out, NDArray) == False:
raise TypeError('out must be NDArray')
if not out.writable:
raise TypeError('out must be writable')
else:
if not accept_empty_mutate:
raise TypeError('argument out is required to call %s' % func_name)
Expand All @@ -570,6 +573,8 @@ def unary_ndarray_function(src, out=None):
if out:
if isinstance(out, NDArray) == False:
raise TypeError('out must be NDArray')
if not out.writable:
raise TypeError('out must be writable')
else:
if not accept_empty_mutate:
raise TypeError('argument out is required to call %s' % func_name)
Expand Down
12 changes: 8 additions & 4 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -872,15 +872,19 @@ int MXDataIterNext(DataIterHandle handle, int *out) {

int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
DataBatch db = static_cast<IIterator<DataBatch>* >(handle)->Value();
*out = new NDArray(db.data[1], 0);
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
NDArray* pndarray = new NDArray();
*pndarray = db.data[1];
*out = pndarray;
API_END();
}

int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
DataBatch db = static_cast<IIterator<DataBatch>* >(handle)->Value();
*out = new NDArray(db.data[0], 0);
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
NDArray* pndarray = new NDArray();
*pndarray = db.data[0];
*out = pndarray;
API_END();
}

Expand Down
45 changes: 18 additions & 27 deletions src/io/image_augmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
* Copyright (c) 2015 by Contributors
* \file image_augmenter_opencv.hpp
* \brief threaded version of page iterator
* \author Naiyan Wang, Tianqi Chen, Tianjun Xiao
*/
#ifndef MXNET_IO_IMAGE_AUGMENTER_H_
#define MXNET_IO_IMAGE_AUGMENTER_H_

#if MXNET_USE_OPENCV
#include <opencv2/opencv.hpp>
#endif
#include <utility>
#include <string>
#include <algorithm>
Expand Down Expand Up @@ -122,6 +123,8 @@ struct ImageAugmentParam : public dmlc::Parameter<ImageAugmentParam> {
.describe("Maximum ratio of contrast variation");
DMLC_DECLARE_FIELD(max_random_illumination).set_default(0.0f)
.describe("Maximum value of illumination variation");
DMLC_DECLARE_FIELD(silent).set_default(true)
.describe("Whether to print augmentor info");
}
};

Expand All @@ -130,8 +133,10 @@ class ImageAugmenter {
public:
// contructor
ImageAugmenter(void)
: tmpres_(false),
rotateM_(2, 3, CV_32F) {
: tmpres_(false) {
#if MXNET_USE_OPENCV
rotateM_ = cv::Mat(2, 3, CV_32F);
#endif
}
virtual ~ImageAugmenter() {
}
Expand Down Expand Up @@ -164,6 +169,7 @@ class ImageAugmenter {
}
}
}
#if MXNET_USE_OPENCV
/*!
* \brief augment src image, store result into dst
* this function is not thread safe, and will only be called by one thread
Expand All @@ -174,6 +180,7 @@ class ImageAugmenter {
*/
virtual cv::Mat OpencvProcess(const cv::Mat &src,
common::RANDOM_ENGINE *prnd) {
if (!NeedOpencvProcess()) return src;
// shear
float s = NextDouble(prnd) * param_.max_shear_ratio * 2 - param_.max_shear_ratio;
// rotate
Expand Down Expand Up @@ -276,8 +283,9 @@ class ImageAugmenter {
}
return tmpres_;
}

void TensorProcess(mshadow::TensorContainer<cpu, 3> *p_data,
#endif
void TensorProcess(mshadow::Tensor<cpu, 3> *p_data,
mshadow::TensorContainer<cpu, 3> *dst_data,
common::RANDOM_ENGINE *prnd) {
// Check Newly Created mean image
if (meanfile_ready_ == false && param_.mean_img.length() != 0) {
Expand All @@ -291,7 +299,8 @@ class ImageAugmenter {
meanfile_ready_ = true;
}
}
img_.Resize(mshadow::Shape3((*p_data).shape_[0], param_.input_shape[1], param_.input_shape[2]));
img_.Resize(mshadow::Shape3((*p_data).shape_[0],
param_.input_shape[1], param_.input_shape[2]));
if (param_.input_shape[1] == 1) {
img_ = (*p_data) * param_.scale;
} else {
Expand Down Expand Up @@ -355,27 +364,7 @@ class ImageAugmenter {
}
}
}
(*p_data) = img_;
}

virtual void Process(unsigned char *dptr, size_t sz,
mshadow::TensorContainer<cpu, 3> *p_data,
common::RANDOM_ENGINE *prnd) {
cv::Mat buf(1, sz, CV_8U, dptr);
cv::Mat res = cv::imdecode(buf, 1);
if (NeedOpencvProcess())
res = this->OpencvProcess(res, prnd);
p_data->Resize(mshadow::Shape3(3, res.rows, res.cols));
for (index_t i = 0; i < p_data->size(1); ++i) {
for (index_t j = 0; j < p_data->size(2); ++j) {
cv::Vec3b bgr = res.at<cv::Vec3b>(i, j);
(*p_data)[0][i][j] = bgr[2];
(*p_data)[1][i][j] = bgr[1];
(*p_data)[2][i][j] = bgr[0];
}
}
res.release();
this->TensorProcess(p_data, prnd);
(*dst_data) = img_;
}

private:
Expand All @@ -392,11 +381,13 @@ class ImageAugmenter {
mshadow::TensorContainer<cpu, 3> meanimg_;
/*! \brief temp space */
mshadow::TensorContainer<cpu, 3> img_;
#if MXNET_USE_OPENCV
// temporal space
cv::Mat temp_;
// rotation param
cv::Mat rotateM_;
// whether the mean file is ready
#endif
bool meanfile_ready_;
// parameters
ImageAugmentParam param_;
Expand Down
Loading