From 0a5408ff3a5693240c5fe3caeffbed63af183ffa Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Thu, 13 Jul 2017 01:12:27 +0000 Subject: [PATCH] add data partition for libsvm iter --- src/io/iter_libsvm.cc | 16 ++++++++++++++-- tests/python/unittest/test_io.py | 20 ++++++++++++++------ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/io/iter_libsvm.cc b/src/io/iter_libsvm.cc index 04dcf289a020..49c6f7fd08f4 100644 --- a/src/io/iter_libsvm.cc +++ b/src/io/iter_libsvm.cc @@ -23,6 +23,10 @@ struct LibSVMIterParam : public dmlc::Parameter { std::string label_libsvm; /*! \brief label shape */ TShape label_shape; + /*! \brief partition the data into multiple parts */ + int num_parts; + /*! \brief the index of the part will read*/ + int part_index; // declare parameters DMLC_DECLARE_PARAMETER(LibSVMIterParam) { DMLC_DECLARE_FIELD(data_libsvm) @@ -35,6 +39,10 @@ struct LibSVMIterParam : public dmlc::Parameter { index_t shape1[] = {1}; DMLC_DECLARE_FIELD(label_shape).set_default(TShape(shape1, shape1 + 1)) .describe("The shape of one label."); + DMLC_DECLARE_FIELD(num_parts).set_default(1) + .describe("partition the data into multiple parts"); + DMLC_DECLARE_FIELD(part_index).set_default(0) + .describe("the index of the part will read"); } }; @@ -47,11 +55,15 @@ class LibSVMIter: public SparseIIterator { virtual void Init(const std::vector >& kwargs) { param_.InitAllowUnknown(kwargs); CHECK_EQ(param_.data_shape.ndim(), 1) << "dimension of data_shape is expected to be 1"; + CHECK_GT(param_.num_parts, 0) << "number of parts should be positive"; + CHECK_GE(param_.part_index, 0) << "part index should be non-negative"; data_parser_.reset(dmlc::Parser::Create(param_.data_libsvm.c_str(), - 0, 1, "libsvm")); + param_.part_index, + param_.num_parts, "libsvm")); if (param_.label_libsvm != "NULL") { label_parser_.reset(dmlc::Parser::Create(param_.label_libsvm.c_str(), - 0, 1, "libsvm")); + param_.part_index, + param_.num_parts, "libsvm")); CHECK_GT(param_.label_shape.Size(), 1) << "label_shape is not expected to be (1,) when param_.label_libsvm is set."; } else { diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 356afc19de5e..dc609cdb9abf 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -161,29 +161,37 @@ def check_libSVMIter_synthetic(): assert_almost_equal(data_train.getdata().asnumpy(), expected) i += 1 - def check_libSVMIter_news_metadata(): + def check_libSVMIter_news_data(): news_metadata = { 'name': 'news20.t', 'origin_name': 'news20.t.bz2', 'url': "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.t.bz2", - 'shape': 62060, + 'feature_dim': 62060, 'num_classes': 20, + 'num_examples': 3993, } + num_parts = 3 + batch_size = 128 + num_examples = news_metadata['num_examples'] data_dir = os.path.join(os.getcwd(), 'data') get_data(data_dir, news_metadata['name'], news_metadata['url'], news_metadata['origin_name']) path = os.path.join(data_dir, news_metadata['name']) - data_train = mx.io.LibSVMIter(data_libsvm=path, - data_shape=(news_metadata['shape'], ), - batch_size=512) + data_train = mx.io.LibSVMIter(data_libsvm=path, data_shape=(news_metadata['feature_dim'],), + batch_size=batch_size, num_parts=num_parts, part_index=0) + num_batches = 0 iterator = iter(data_train) for batch in iterator: # check the range of labels assert(np.sum(batch.label[0].asnumpy() > 20) == 0) assert(np.sum(batch.label[0].asnumpy() <= 0) == 0) + num_batches += 1 + import math + expected_num_batches = math.ceil(num_examples * 1.0 / batch_size / num_parts) + assert(num_batches == int(expected_num_batches)), (num_batches, expected_num_batches) check_libSVMIter_synthetic() - check_libSVMIter_news_metadata() + check_libSVMIter_news_data() if __name__ == "__main__": test_NDArrayIter()