Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
move nested member function to outer scope, preventing pickler error
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Jul 12, 2018
1 parent 492917f commit 4fb4a59
Showing 1 changed file with 54 additions and 54 deletions.
108 changes: 54 additions & 54 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,69 +153,69 @@ def __getitem__(self, idx):
def __len__(self):
return 50

def batchify_list(self, data):
"""
return list of ndarray without stack/concat/pad
"""
if isinstance(data, (tuple, list)):
return list(data)
if isinstance(data, mx.nd.NDArray):
return [data]
return data

def batchify(self, data):
"""
Collate data into batch. Use shared memory for stacking.
:param data: a list of array, with layout of 'NTC'.
:return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths
if labels are not supplied.
"""

# input layout is NTC
keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \
[item[2] for item in data]

if len(data) > 1:
max_data_len = max([seq.shape[0] for seq in inputs])
max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels])
else:
max_data_len = inputs[0].shape[0]
max_labels_len = 0 if not labels else labels[0].shape[0]

x_lens = [item.shape[0] for item in inputs]
y_lens = [item.shape[0] for item in labels]

for i, seq in enumerate(inputs):
pad_len = max_data_len - seq.shape[0]
inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0)
labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]),
'constant', constant_values=-1)

inputs = np.asarray(inputs, dtype=np.float32)
if labels is not None:
labels = np.asarray(labels, dtype=np.float32)
inputs = inputs.transpose((1, 0, 2))
labels = labels.transpose((1, 0))

return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \
if labels is None else (
nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0)),
nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))
def batchify_list(data):
"""
return list of ndarray without stack/concat/pad
"""
if isinstance(data, (tuple, list)):
return list(data)
if isinstance(data, mx.nd.NDArray):
return [data]
return data

def batchify(data):
"""
Collate data into batch. Use shared memory for stacking.
:param data: a list of array, with layout of 'NTC'.
:return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths
if labels are not supplied.
"""

# input layout is NTC
keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \
[item[2] for item in data]

if len(data) > 1:
max_data_len = max([seq.shape[0] for seq in inputs])
max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels])
else:
max_data_len = inputs[0].shape[0]
max_labels_len = 0 if not labels else labels[0].shape[0]

x_lens = [item.shape[0] for item in inputs]
y_lens = [item.shape[0] for item in labels]

for i, seq in enumerate(inputs):
pad_len = max_data_len - seq.shape[0]
inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0)
labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]),
'constant', constant_values=-1)

inputs = np.asarray(inputs, dtype=np.float32)
if labels is not None:
labels = np.asarray(labels, dtype=np.float32)
inputs = inputs.transpose((1, 0, 2))
labels = labels.transpose((1, 0))

return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \
if labels is None else (
nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0)),
nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))

# sanity tests
data = Dummy(False)
loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify, num_workers=2)
loader = DataLoader(data, batch_size=40, batchify_fn=batchify, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
pass

# random shape
data = Dummy(True)
loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify_list, num_workers=2)
loader = DataLoader(data, batch_size=40, batchify_fn=batchify_list, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
pass
Expand Down

0 comments on commit 4fb4a59

Please sign in to comment.