-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
113 lines (91 loc) · 3.28 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Mute tensorflow debugging information on console
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from flask import Flask, request, render_template, jsonify
from scipy.misc import imsave, imread, imresize
import numpy as np
import tensorflow as tf
import argparse
import re
import base64
import pickle
app = Flask(__name__)
save_dir = "./ai1"
def recognize(data):
loaded_Graph = tf.Graph()
with tf.Session(graph=loaded_Graph) as sess:
loader = tf.train.import_meta_graph(save_dir +'.meta')
loader.restore(sess, save_dir)
# get tensors
loaded_x = loaded_Graph.get_tensor_by_name('input:0')
loaded_y = loaded_Graph.get_tensor_by_name('label:0')
loaded_prob = loaded_Graph.get_tensor_by_name('probability:0')
prob = sess.run(tf.argmax(loaded_prob,1), feed_dict = {loaded_x: data})
return prob[0]
@app.route("/")
def index():
''' Render index for user connecting to /
'''
return render_template('index.html')
@app.route('/predict/', methods=['GET','POST'])
def predict():
''' Called when user presses the predict button.
Processes the canvas and handles the image.
Passes the loaded image into the neural network and it makes
class prediction.
'''
# Local functions
def crop(x):
# Experimental
_len = len(x) - 1
for index, row in enumerate(x[::-1]):
z_flag = False
for item in row:
if item != 0:
z_flag = True
break
if z_flag == False:
x = np.delete(x, _len - index, 0)
return x
def parseImage(imgData):
# parse canvas bytes and save as output.png
imgstr = re.search(b'base64,(.*)', imgData).group(1)
with open('output.png','wb') as output:
output.write(base64.decodebytes(imgstr))
# get data from drawing canvas and save as image
parseImage(request.get_data())
# read parsed image back in 8-bit, black and white mode (L)
x = imread('output.png', mode='L')
x = np.invert(x)
### Experimental
# Crop on rows
# x = crop(x)
# x = x.T
# Crop on columns
# x = crop(x)
# x = x.T
# Visualize new array
imsave('resized.png', x)
x = imresize(x,(28,28))
# reshape image data for use in neural network
x = x.reshape(1,28*28)
# Convert type to float32
x = x.astype('float32')
# Normalize to prevent issues with model
x /= 255
# Predict from model
out = recognize(x)
# Generate response
response = {'prediction': str(out),
'confidence': "Unkown"}
return jsonify(response)
if __name__ == '__main__':
# Parse optional arguments
parser = argparse.ArgumentParser(description='A webapp for testing models generated from training.py on the EMNIST dataset')
parser.add_argument('--bin', type=str, default='./ai1', help='Directory to the bin containing the model yaml and model h5 files')
parser.add_argument('--host', type=str, default='0.0.0.0', help='The host to run the flask server on')
parser.add_argument('--port', type=int, default=5000, help='The port to run the flask server on')
args = parser.parse_args()
# Overhead
save_dir = args.bin
app.run(host=args.host, port=args.port)