-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathps_server-basic.py
65 lines (53 loc) · 1.79 KB
/
ps_server-basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tensorflow as tf
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
# cluster = tf.train.ClusterSpec({
# "worker": [
# # workers' port numbers start from 9900
# # worker_0 referred as /job:worker/task:0
# "127.0.0.1:9900",
# # worker_1 referred as /job:worker/task:1
# "127.0.0.1:9901",
# # worker_2 referred as /job:worker/task:2
# "127.0.0.1:9902"
# ],
# "ps": [
# # ps has a port number 9910
# # ps referred as /job:ps/task:0
# "127.0.0.1:9910"
# ]
# })
cluster = tf.train.ClusterSpec({
"worker": [
# workers' port numbers start from 9900
# worker_0 referred as /job:worker/task:0
"127.0.0.1:9900",
# worker_1 referred as /job:worker/task:1
"127.0.0.1:9901",
],
"ps": [
# ps has a port number 9910
# ps referred as /job:ps/task:0
"127.0.0.1:9910"
]
})
isps = True
if isps:
# Define the parameter server(ps)
server = tf.train.Server(cluster, job_name='ps', task_index=0)
server.join()
else:
server = tf.train.Server(cluster, job_name='worker', task_index=0)
with tf.device(tf.train.replica_device_setter(worker_device=server, cluster=cluster)):
w = tf.get_variable('w', (2, 2), tf.float32, initializer=tf.constant_initializer(2))
b = tf.get_variable('b', (2, 2), tf.float32, initializer=tf.constant_initializer(5))
addwb = w+b
mutwb = w*b
divwb = w/b
saver = tf.train.Saver()
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()
sv = tf.train.Supervisor(init_op=init_op, summary_op=summary_op, saver=saver)
with sv.managed_session(server.target) as sess:
while not isps:
print(sess.run([addwb, mutwb, divwb]))