Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add h5py support to NDArrayIter (#6790)
Browse files Browse the repository at this point in the history
* Support h5py groups as input to NDArrayIter

* Support shuffling indices for h5py data in NDArrayIter

* Make h5py optional

* Install h5py on linux based CI systems

Tests are not run on Windows. I couldn't find the Windows CI system
configuration / a place to define h5py test dependency on Windows.
  • Loading branch information
leezu authored and piiswrong committed Jul 18, 2017
1 parent add7e43 commit bd5df7c
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 29 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ addons:
- python3-numpy
- python3-dev
- python3-nose
- python-h5py
- python3-h5py
- graphviz
- libmouse-perl
- pdl
Expand Down
74 changes: 53 additions & 21 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import ctypes
import logging
import threading
try:
import h5py
except ImportError:
h5py = None
import numpy as np
from .base import _LIB
from .base import c_array, c_str, mx_uint, py_str
Expand Down Expand Up @@ -465,7 +469,8 @@ def _init_data(data, allow_empty, default_name):
if data is None:
data = []

if isinstance(data, (np.ndarray, NDArray)):
if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
if h5py else (np.ndarray, NDArray)):
data = [data]
if isinstance(data, list):
if not allow_empty:
Expand All @@ -476,20 +481,20 @@ def _init_data(data, allow_empty, default_name):
data = OrderedDict( # pylint: disable=redefined-variable-type
[('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
if not isinstance(data, dict):
raise TypeError("Input must be NDArray, numpy.ndarray, " + \
raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \
"a list of them or dict with them as values")
for k, v in data.items():
if not isinstance(v, NDArray):
if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
try:
data[k] = array(v)
except:
raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \
"should be NDArray or numpy.ndarray")
"should be NDArray, numpy.ndarray or h5py.Dataset")

return list(data.items())

class NDArrayIter(DataIter):
"""Returns an iterator for ``mx.nd.NDArray`` or ``numpy.ndarray``.
"""Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray`` or ``h5py.Dataset``.
Example usage:
----------
Expand Down Expand Up @@ -562,6 +567,7 @@ class NDArrayIter(DataIter):
Batch size of data.
shuffle: bool, optional
Whether to shuffle the data.
Only supported if no h5py.Dataset inputs are used.
last_batch_handle : str, optional
How to handle the last batch. This parameter can be 'pad', 'discard' or
'roll_over'. 'roll_over' is intended for training and can cause problems
Expand All @@ -579,30 +585,29 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False,
self.data = _init_data(data, allow_empty=False, default_name=data_name)
self.label = _init_data(label, allow_empty=True, default_name=label_name)

self.idx = np.arange(self.data[0][1].shape[0])
# shuffle data
if shuffle:
idx = np.arange(self.data[0][1].shape[0])
np.random.shuffle(idx)
self.data = [(k, array(v.asnumpy()[idx], v.context)) for k, v in self.data]
self.label = [(k, array(v.asnumpy()[idx], v.context)) for k, v in self.label]
np.random.shuffle(self.idx)
self.data = [(k, array(v.asnumpy()[self.idx], v.context))
if not (isinstance(v, h5py.Dataset)
if h5py else False) else (k, v)
for k, v in self.data]
self.label = [(k, array(v.asnumpy()[self.idx], v.context))
if not (isinstance(v, h5py.Dataset)
if h5py else False) else (k, v)
for k, v in self.label]

# batching
if last_batch_handle == 'discard':
new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size
data_dict = OrderedDict(self.data)
label_dict = OrderedDict(self.label)
for k, _ in self.data:
data_dict[k] = data_dict[k][:new_n]
for k, _ in self.label:
label_dict[k] = label_dict[k][:new_n]
self.data = data_dict.items()
self.label = label_dict.items()
self.idx = self.idx[:new_n]

self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
self.num_source = len(self.data_list)
self.num_data = self.data_list[0].shape[0]
self.num_data = self.idx.shape[0]
assert self.num_data >= batch_size, \
"batch_size need to be smaller than data size."
"batch_size needs to be smaller than data size."
self.cursor = -batch_size
self.batch_size = batch_size
self.last_batch_handle = last_batch_handle
Expand Down Expand Up @@ -648,10 +653,37 @@ def _getdata(self, data_source):
"""Load data from underlying arrays, internal use only."""
assert(self.cursor < self.num_data), "DataIter needs reset."
if self.cursor + self.batch_size <= self.num_data:
return [x[1][self.cursor:self.cursor+self.batch_size] for x in data_source]
return [
# np.ndarray or NDArray case
x[1][self.cursor:self.cursor + self.batch_size]
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[
self.cursor:self.cursor + self.batch_size])][[
list(self.idx[self.cursor:
self.cursor + self.batch_size]).index(i)
for i in sorted(self.idx[
self.cursor:self.cursor + self.batch_size])
]]) for x in data_source
]
else:
pad = self.batch_size - self.num_data + self.cursor
return [concatenate([x[1][self.cursor:], x[1][:pad]]) for x in data_source]
return [
# np.ndarray or NDArray case
concatenate([x[1][self.cursor:], x[1][:pad]])
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
concatenate([
array(x[1][sorted(self.idx[self.cursor:])][[
list(self.idx[self.cursor:]).index(i)
for i in sorted(self.idx[self.cursor:])
]]),
array(x[1][sorted(self.idx[:pad])][[
list(self.idx[:pad]).index(i)
for i in sorted(self.idx[:pad])
]])
]) for x in data_source
]

def getdata(self):
return self._getdata(self.data)
Expand Down
4 changes: 2 additions & 2 deletions tests/ci_build/install/ubuntu_install_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ apt-get update && apt-get install -y python-dev python3-dev
# the version of the pip shipped with ubuntu may be too lower, install a recent version here
cd /tmp && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python2 get-pip.py

pip2 install nose pylint numpy nose-timer requests
pip3 install nose pylint numpy nose-timer requests
pip2 install nose pylint numpy nose-timer requests h5py
pip3 install nose pylint numpy nose-timer requests h5py
62 changes: 56 additions & 6 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import os, gzip
import pickle as pickle
import time
try:
import h5py
except ImportError:
h5py = None
import sys
from common import get_data

Expand Down Expand Up @@ -63,17 +67,17 @@ def test_Cifar10Rec():
assert(labelcount[i] == 5000)

def test_NDArrayIter():
datas = np.ones([1000, 2, 2])
labels = np.ones([1000, 1])
data = np.ones([1000, 2, 2])
label = np.ones([1000, 1])
for i in range(1000):
datas[i] = i / 100
labels[i] = i / 100
dataiter = mx.io.NDArrayIter(datas, labels, 128, True, last_batch_handle='pad')
data[i] = i / 100
label[i] = i / 100
dataiter = mx.io.NDArrayIter(data, label, 128, True, last_batch_handle='pad')
batchidx = 0
for batch in dataiter:
batchidx += 1
assert(batchidx == 8)
dataiter = mx.io.NDArrayIter(datas, labels, 128, False, last_batch_handle='pad')
dataiter = mx.io.NDArrayIter(data, label, 128, False, last_batch_handle='pad')
batchidx = 0
labelcount = [0 for i in range(10)]
for batch in dataiter:
Expand All @@ -88,7 +92,53 @@ def test_NDArrayIter():
else:
assert(labelcount[i] == 100)

def test_NDArrayIter_h5py():
if not h5py:
return

data = np.ones([1000, 2, 2])
label = np.ones([1000, 1])
for i in range(1000):
data[i] = i / 100
label[i] = i / 100

try:
os.remove("ndarraytest.h5")
except OSError:
pass
with h5py.File("ndarraytest.h5") as f:
f.create_dataset("data", data=data)
f.create_dataset("label", data=label)

dataiter = mx.io.NDArrayIter(f["data"], f["label"], 128, True, last_batch_handle='pad')
batchidx = 0
for batch in dataiter:
batchidx += 1
assert(batchidx == 8)

dataiter = mx.io.NDArrayIter(f["data"], f["label"], 128, False, last_batch_handle='pad')
labelcount = [0 for i in range(10)]
for batch in dataiter:
label = batch.label[0].asnumpy().flatten()
assert((batch.data[0].asnumpy()[:,0,0] == label).all())
for i in range(label.shape[0]):
labelcount[int(label[i])] += 1

try:
os.remove("ndarraytest.h5")
except OSError:
pass

for i in range(10):
if i == 0:
assert(labelcount[i] == 124)
else:
assert(labelcount[i] == 100)


if __name__ == "__main__":
test_NDArrayIter()
if h5py:
test_NDArrayIter_h5py()
test_MNISTIter()
test_Cifar10Rec()

0 comments on commit bd5df7c

Please sign in to comment.