-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestPythonNoBug.py
49 lines (42 loc) · 1.55 KB
/
testPythonNoBug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import mxnet as mx
print(mx.__version__)
import cv2
import numpy as np
from collections import namedtuple
import time
def get_image(fname):
# download and show the image
img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB)
if img is None:
return None
# convert into format (batch, RGB, width, height)
img = cv2.resize(img, (224, 224))
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
print(img.shape)
return img
img = get_image('tabby.tiff')
Batch = namedtuple('Batch', ['data'])
def predict(fname, batchSize):
batched_img = np.tile(img, (batchSize,1,1,1))
# compute the predict probabilities
mod.forward(Batch([mx.nd.array(batched_img)]))
probs = mod.get_outputs()[0].asnumpy()
for i in range(0,probs.shape[0]):
prob = probs[i, :]
assert(prob.shape[0] == 1000)
j = np.argmax(prob)
assert(j == 281 or j == 282) # its a cat
max_batch_size = 32
ctx = mx.cpu(0)
sym, arg_params, aux_params = mx.model.load_checkpoint('squeezenet-v1.1', 0)
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (max_batch_size,3,224,224))],
label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)
for i in range(1000):
random_batch_size = np.random.randint(low=1,high=max_batch_size+1)
start = time.time() * 1000
predict('tabby.tiff', random_batch_size)
if i % 1 == 0:
print("Batch size {} Pred time {} ms".format(random_batch_size, (time.time() * 1000) - start))