Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(unittest): fix unit test for send recv
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Sep 25, 2019
1 parent 31f53bc commit 4314501
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions tests/test_raw_bytes_send.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import random
import unittest

Expand Down Expand Up @@ -42,12 +43,12 @@ def test_send_recv_raw_bytes(self):
for j in range(random.randint(10, 20)):
d = msg.request.index.docs.add()
d.raw_bytes = b'a' * random.randint(100, 1000)
raw_bytes = [d.raw_bytes for d in msg.request.index.docs]
raw_bytes = copy.deepcopy([d.raw_bytes for d in msg.request.index.docs])
c1.send_message(msg, squeeze_pb=True)
r_msg = c2.recv_message()
for d, r_d in zip(msg.request.index.docs, r_msg.request.index.docs):
for d, o_d, r_d in zip(msg.request.index.docs, raw_bytes, r_msg.request.index.docs):
self.assertEqual(d.raw_bytes, b'')
self.assertEqual(raw_bytes, r_d.raw_bytes)
self.assertEqual(o_d, r_d.raw_bytes)
print('.', end='')
print('checked %d docs' % len(msg.request.index.docs))

Expand Down Expand Up @@ -127,15 +128,15 @@ def test_benchmark2(self):

def test_benchmark3(self):
all_msgs = self.build_msgs()
all_msgs_bak = copy.deepcopy(all_msgs)

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
for m in all_msgs:
raw_bytes = [d.raw_bytes for d in m.request.index.docs]
for m, m1 in zip(all_msgs, all_msgs_bak):
c1.send_message(m, squeeze_pb=True)
r_m = c2.recv_message()
for d, r_d in zip(m.request.index.docs, r_m.request.index.docs):
for d, o_d, r_d in zip(m.request.index.docs, m1.request.index.docs, r_m.request.index.docs):
self.assertEqual(d.raw_bytes, b'')
self.assertEqual(raw_bytes, r_d.raw_bytes)
self.assertEqual(o_d.raw_bytes, r_d.raw_bytes)

def test_benchmark4(self):
all_msgs = self.build_msgs2()
Expand All @@ -154,7 +155,7 @@ def test_benchmark4(self):

def test_benchmark5(self):
all_msgs = self.build_msgs2()
all_msgs_bak = self.build_msgs2()
all_msgs_bak = copy.deepcopy(all_msgs)

with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
with TimeContext('send->recv, squeeze_pb=True'):
Expand All @@ -165,4 +166,4 @@ def test_benchmark5(self):
for d, r_d in zip(m1.request.index.docs, r_m.request.index.docs):
for c, r_c in zip(d.chunks, r_d.chunks):
np.allclose(blob2array(c.embedding), blob2array(r_c.embedding))
np.allclose(blob2array(c.blob), blob2array(r_c.blob))
np.allclose(blob2array(c.blob), blob2array(r_c.blob))

0 comments on commit 4314501

Please sign in to comment.