forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_futures.py
134 lines (99 loc) · 3.54 KB
/
test_futures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import threading
import time
import torch
from torch.futures import Future
from torch.testing._internal.common_utils import TestCase, TemporaryFileName
def add_one(fut):
return fut.wait() + 1
class TestFuture(TestCase):
def test_wait(self):
f = Future()
f.set_result(torch.ones(2, 2))
self.assertEqual(f.wait(), torch.ones(2, 2))
def test_wait_multi_thread(self):
def slow_set_future(fut, value):
time.sleep(0.5)
fut.set_result(value)
f = Future()
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
t.start()
self.assertEqual(f.wait(), torch.ones(2, 2))
t.join()
def test_mark_future_twice(self):
fut = Future()
fut.set_result(1)
with self.assertRaisesRegex(
RuntimeError,
"Future can only be marked completed once"
):
fut.set_result(1)
def test_pickle_future(self):
fut = Future()
errMsg = "Can not pickle torch.futures.Future"
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
torch.save(fut, fname)
def test_then(self):
fut = Future()
then_fut = fut.then(lambda x: x.wait() + 1)
fut.set_result(torch.ones(2, 2))
self.assertEqual(fut.wait(), torch.ones(2, 2))
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
def test_chained_then(self):
fut = Future()
futs = []
last_fut = fut
for _ in range(20):
last_fut = last_fut.then(add_one)
futs.append(last_fut)
fut.set_result(torch.ones(2, 2))
for i in range(len(futs)):
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
def _test_error(self, cb, errMsg):
fut = Future()
then_fut = fut.then(cb)
fut.set_result(5)
self.assertEqual(5, fut.wait())
with self.assertRaisesRegex(RuntimeError, errMsg):
then_fut.wait()
def test_then_wrong_arg(self):
def wrong_arg(tensor):
return tensor + 1
self._test_error(wrong_arg, "unsupported operand type.*Future.*int")
def test_then_no_arg(self):
def no_arg():
return True
self._test_error(no_arg, "takes 0 positional arguments but 1 was given")
def test_then_raise(self):
def raise_value_error(fut):
raise ValueError("Expected error")
self._test_error(raise_value_error, "Expected error")
def test_collect_all(self):
fut1 = Future()
fut2 = Future()
fut_all = torch.futures.collect_all([fut1, fut2])
def slow_in_thread(fut, value):
time.sleep(0.1)
fut.set_result(value)
t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
fut2.set_result(2)
t.start()
res = fut_all.wait()
self.assertEqual(res[0].wait(), 1)
self.assertEqual(res[1].wait(), 2)
t.join()
def test_wait_all(self):
fut1 = Future()
fut2 = Future()
# No error version
fut1.set_result(1)
fut2.set_result(2)
res = torch.futures.wait_all([fut1, fut2])
print(res)
self.assertEqual(res, [1, 2])
# Version with an exception
def raise_in_fut(fut):
raise ValueError("Expected error")
fut3 = fut1.then(raise_in_fut)
with self.assertRaisesRegex(RuntimeError, "Expected error"):
torch.futures.wait_all([fut3, fut2])