Skip to content

Commit

Permalink
Updating test code to follow single iterator constraint
Browse files Browse the repository at this point in the history
ghstack-source-id: 764bbd755b2471c51c6f8a3707e8943ff9e9642f
Pull Request resolved: #386
  • Loading branch information
NivekT committed May 19, 2022
1 parent 7555779 commit b4470eb
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
34 changes: 20 additions & 14 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def test_iter_key_zipper_iterdatapipe(self) -> None:

source_dp = IterableWrapper(range(10))
ref_dp = IterableWrapper(range(20))
ref_dp2 = IterableWrapper(range(20))

# Functional Test: Output should be a zip list of tuple
zip_dp = source_dp.zip_with_iter(
Expand All @@ -114,7 +115,7 @@ def test_iter_key_zipper_iterdatapipe(self) -> None:

# Functional Test: keep_key=True, and key should show up as the first element
zip_dp_w_key = source_dp.zip_with_iter(
ref_datapipe=ref_dp, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=True, buffer_size=10
ref_datapipe=ref_dp2, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=True, buffer_size=10
)
self.assertEqual([(i, (i, i)) for i in range(10)], list(zip_dp_w_key))

Expand Down Expand Up @@ -145,13 +146,13 @@ def merge_to_string(item1, item2):

# Without a custom merge function, there will be nested tuples
zip_dp2 = zip_dp.zip_with_iter(
ref_datapipe=ref_dp, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, buffer_size=100
ref_datapipe=ref_dp2, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, buffer_size=100
)
self.assertEqual([((i, i), i) for i in range(10)], list(zip_dp2))

# With a custom merge function, nesting can be prevented
zip_dp2_w_merge = zip_dp.zip_with_iter(
ref_datapipe=ref_dp,
ref_datapipe=ref_dp2,
key_fn=lambda x: x[0],
ref_key_fn=lambda x: x,
keep_key=False,
Expand Down Expand Up @@ -524,10 +525,11 @@ def test_sample_multiplexer_iterdatapipe(self) -> None:

def test_in_batch_shuffler_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(10)).batch(3)
source_dp2 = IterableWrapper(range(10)).batch(3)

# Functional Test: drop last reduces length
filtered_dp = source_dp.in_batch_shuffle()
for ret_batch, exp_batch in zip(filtered_dp, source_dp):
for ret_batch, exp_batch in zip(filtered_dp, source_dp2):
ret_batch.sort()
self.assertEqual(ret_batch, exp_batch)

Expand Down Expand Up @@ -762,28 +764,32 @@ def test_unzipper_iterdatapipe(self):
with self.assertRaises(BufferError):
list(dp2)

# Reset Test: reset the DataPipe after reading part of it
# Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
dp1, dp2 = source_dp.unzip(sequence_length=2)
i1, i2 = iter(dp1), iter(dp2)
_ = iter(dp1)
output2 = []
for i, n2 in enumerate(i2):
output2.append(n2)
if i == 4:
i1 = iter(dp1) # Doesn't reset because i1 hasn't been read
self.assertEqual(list(range(10, 20)), output2)
with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"):
for i, n2 in enumerate(dp2):
output2.append(n2)
if i == 4:
_ = iter(dp1) # This will reset all child DataPipes
self.assertEqual(list(range(10, 15)), output2)

# Reset Test: DataPipe reset when some of it have been read
dp1, dp2 = source_dp.unzip(sequence_length=2)
i1, i2 = iter(dp1), iter(dp2)
output1, output2 = [], []
for i, (n1, n2) in enumerate(zip(i1, i2)):
for i, (n1, n2) in enumerate(zip(dp1, dp2)):
output1.append(n1)
output2.append(n2)
if i == 4:
with warnings.catch_warnings(record=True) as wa:
i1 = iter(dp1) # Reset both all child DataPipe
_ = iter(dp1) # Reset both all child DataPipe
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
break
for i, (n1, n2) in enumerate(zip(dp1, dp2)):
output1.append(n1)
output2.append(n2)
self.assertEqual(list(range(5)) + list(range(10)), output1)
self.assertEqual(list(range(10, 15)) + list(range(10, 20)), output2)

Expand Down
11 changes: 7 additions & 4 deletions test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,16 @@ def fill_hash_dict():
datapipe2 = FileOpener(datapipe1, mode="b")
hash_check_dp = HashChecker(datapipe2, hash_dict)

expected_res = list(datapipe2)

# Functional Test: Ensure the DataPipe values are unchanged if the hashes are the same
for (expected_path, expected_stream), (actual_path, actual_stream) in zip(datapipe2, hash_check_dp):
for (expected_path, expected_stream), (actual_path, actual_stream) in zip(expected_res, hash_check_dp):
self.assertEqual(expected_path, actual_path)
self.assertEqual(expected_stream.read(), actual_stream.read())

# Functional Test: Ensure the rewind option works, and the stream is empty when there is no rewind
hash_check_dp_no_reset = HashChecker(datapipe2, hash_dict, rewind=False)
for (expected_path, _), (actual_path, actual_stream) in zip(datapipe2, hash_check_dp_no_reset):
for (expected_path, _), (actual_path, actual_stream) in zip(expected_res, hash_check_dp_no_reset):
self.assertEqual(expected_path, actual_path)
self.assertEqual(b"", actual_stream.read())

Expand Down Expand Up @@ -458,7 +460,7 @@ def test_xz_archive_reader_iterdatapipe(self):
self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset)

# Reset Test: Ensure the order is consistent between iterations
for r1, r2 in zip(xz_loader_dp, xz_loader_dp):
for r1, r2 in zip(list(xz_loader_dp), list(xz_loader_dp)):
self.assertEqual(r1[0], r2[0])

# __len__ Test: doesn't have valid length
Expand Down Expand Up @@ -497,7 +499,8 @@ def test_bz2_archive_reader_iterdatapipe(self):
self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset)

# Reset Test: Ensure the order is consistent between iterations
for r1, r2 in zip(bz2_loader_dp, bz2_loader_dp):

for r1, r2 in zip(list(bz2_loader_dp), list(bz2_loader_dp)):
self.assertEqual(r1[0], r2[0])

# __len__ Test: doesn't have valid length
Expand Down
3 changes: 3 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _serialization_test_for_single_dp(self, dp, use_dill, is_dataframe=False):
_ = next(it)
test_helper_fn(dp, use_dill)
# 3. Testing for serialization after DataPipe is fully read
it = iter(dp)
_ = list(it)
test_helper_fn(dp, use_dill)

Expand All @@ -146,10 +147,12 @@ def _serialization_test_for_dp_with_children(self, dp1, dp2, use_dill):
self._serialization_test_helper(dp2, use_dill=use_dill)
# 2.5. Testing for serialization after one child DataPipe is fully read
# (Only for DataPipes with children DataPipes)
it1 = iter(dp1)
_ = list(it1) # fully read one child
self._serialization_test_helper(dp1, use_dill=use_dill)
self._serialization_test_helper(dp2, use_dill=use_dill)
# 3. Testing for serialization after DataPipe is fully read
it2 = iter(dp2)
_ = list(it2) # fully read the other child
self._serialization_test_helper(dp1, use_dill=use_dill)
self._serialization_test_helper(dp2, use_dill=use_dill)
Expand Down

0 comments on commit b4470eb

Please sign in to comment.