Skip to content

Commit

Permalink
update demo
Browse files Browse the repository at this point in the history
  • Loading branch information
longcw committed Aug 30, 2017
1 parent 8f5a037 commit 303e512
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,14 @@

# pytorch net
model = torchvision.models.inception_v3(pretrained=True, transform_input=False)

model.eval()
if test_mod:
model = model.cuda()

# random input
image = np.random.randint(0, 255, input_size)
input_data = image.astype(np.float32)

# pytorch forward
input_var = Variable(torch.from_numpy(input_data))
if test_mod:
input_var = input_var.cuda()

if not test_mod:
# generate caffe model
Expand All @@ -54,6 +49,8 @@
net.forward(start=input_name)
caffe_output = net.blobs[output_name].data

model = model.cuda()
input_var = input_var.cuda()
output_var = model(input_var)
pytorch_output = output_var.data.cpu().numpy()

Expand All @@ -62,4 +59,4 @@
print(' caffe: min: {}, max: {}, mean: {}'.format(caffe_output.min(), caffe_output.max(), caffe_output.mean()))

diff = np.abs(pytorch_output - caffe_output)
print(' diff: min: {}, max: {}, mean: {}'.format(diff.min(), diff.max(), diff.mean()))
print(' diff: min: {}, max: {}, mean: {}, median: {}'.format(diff.min(), diff.max(), diff.mean(), np.median(diff)))

0 comments on commit 303e512

Please sign in to comment.