-
Notifications
You must be signed in to change notification settings - Fork 86
Move lmnet/executor/predict.py to blueoil/cmd #838
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,21 +13,160 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================= | ||
import imghdr | ||
import math | ||
import os | ||
from glob import glob | ||
|
||
from executor.predict import run as run_predict | ||
import click | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from blueoil import environment | ||
from blueoil.utils.image import load_image | ||
from blueoil.utils import config as config_util | ||
from blueoil.utils.executor import search_restore_filename | ||
from blueoil.utils.predict_output.writer import OutputWriter | ||
|
||
def predict(input, output, experiment_id, checkpoint=None, save_images=True): | ||
"""Predict input images.""" | ||
DUMMY_FILENAME = "DUMMY_FILE" | ||
|
||
output_dir = os.environ.get("OUTPUT_DIR", "saved") | ||
|
||
if checkpoint is None: | ||
restore_path = None | ||
else: | ||
restore_path = os.path.join( | ||
output_dir, experiment_id, "checkpoints", checkpoint | ||
def _get_images(filenames, pre_processor, data_format): | ||
""" """ | ||
images = [] | ||
raw_images = [] | ||
|
||
for filename in filenames: | ||
if filename == DUMMY_FILENAME: | ||
raw_image = np.zeros((64, 64, 3), dtype=np.uint8) | ||
else: | ||
raw_image = load_image(filename) | ||
|
||
image = pre_processor(image=raw_image)['image'] | ||
if data_format == 'NCHW': | ||
image = np.transpose(image, [2, 0, 1]) | ||
|
||
images.append(image) | ||
raw_images.append(raw_image) | ||
|
||
return np.array(images), np.array(raw_images) | ||
|
||
|
||
def _all_image_files(directory): | ||
all_image_files = [] | ||
for file_path in glob(os.path.join(directory, "*")): | ||
if os.path.isfile(file_path) and imghdr.what(file_path) in {"jpeg", "png"}: | ||
all_image_files.append(os.path.abspath(file_path)) | ||
|
||
return all_image_files | ||
|
||
|
||
def _run(input_dir, output_dir, config, restore_path, save_images): | ||
ModelClass = config.NETWORK_CLASS | ||
network_kwargs = dict((key.lower(), val) for key, val in config.NETWORK.items()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
graph = tf.Graph() | ||
with graph.as_default(): | ||
model = ModelClass( | ||
classes=config.CLASSES, | ||
is_debug=config.IS_DEBUG, | ||
**network_kwargs | ||
) | ||
|
||
run_predict(input, output, experiment_id, None, restore_path, save_images) | ||
is_training = tf.constant(False, name="is_training") | ||
|
||
images_placeholder, _ = model.placeholders() | ||
output_op = model.inference(images_placeholder, is_training) | ||
|
||
init_op = tf.global_variables_initializer() | ||
|
||
saver = tf.compat.v1.train.Saver(max_to_keep=None) | ||
|
||
session_config = tf.ConfigProto() | ||
sess = tf.Session(graph=graph, config=session_config) | ||
sess.run(init_op) | ||
saver.restore(sess, restore_path) | ||
|
||
all_image_files = _all_image_files(input_dir) | ||
|
||
step_size = int(math.ceil(len(all_image_files) / config.BATCH_SIZE)) | ||
|
||
writer = OutputWriter( | ||
task=config.TASK, | ||
classes=config.CLASSES, | ||
image_size=config.IMAGE_SIZE, | ||
data_format=config.DATA_FORMAT | ||
) | ||
|
||
results = [] | ||
for step in range(step_size): | ||
start_index = (step) * config.BATCH_SIZE | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove parenthesis |
||
end_index = (step + 1) * config.BATCH_SIZE | ||
|
||
image_files = all_image_files[start_index:end_index] | ||
|
||
while len(image_files) != config.BATCH_SIZE: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is inefficient, and will loop forever if len(image_files) > config.BATCH_SIZE.
|
||
# add dummy image. | ||
image_files.append(DUMMY_FILENAME) | ||
|
||
images, raw_images = _get_images( | ||
image_files, config.DATASET.PRE_PROCESSOR, config.DATA_FORMAT) | ||
|
||
feed_dict = {images_placeholder: images} | ||
outputs = sess.run(output_op, feed_dict=feed_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put feed_dict in a single line |
||
|
||
if config.POST_PROCESSOR: | ||
outputs = config.POST_PROCESSOR(outputs=outputs)["outputs"] | ||
|
||
results.append(outputs) | ||
|
||
writer.write( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you concatenate into fewer lines? |
||
output_dir, | ||
outputs, | ||
raw_images, | ||
image_files, | ||
step, | ||
save_material=save_images | ||
) | ||
|
||
return results | ||
|
||
|
||
def run(input_dir, output_dir, experiment_id, config_file, restore_path, save_images): | ||
environment.init(experiment_id) | ||
config = config_util.load_from_experiment() | ||
if config_file: | ||
config = config_util.merge(config, config_util.load(config_file)) | ||
|
||
if not os.path.isdir(input_dir): | ||
raise Exception("Input directory {} does not exist.".format(input_dir)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid raising |
||
|
||
if restore_path is None: | ||
restore_file = search_restore_filename(environment.CHECKPOINTS_DIR) | ||
restore_path = os.path.join(environment.CHECKPOINTS_DIR, restore_file) | ||
|
||
print("Restore from {}".format(restore_path)) | ||
|
||
if not os.path.exists("{}.index".format(restore_path)): | ||
raise Exception("restore file {} dont exists.".format(restore_path)) | ||
|
||
print("---- start predict ----") | ||
|
||
_run(input_dir, output_dir, config, restore_path, save_images) | ||
|
||
print("---- end predict ----") | ||
|
||
|
||
def predict(input_dir, output_dir, experiment_id, config_file=None, checkpoint=None, save_images=True): | ||
"""Make predictions from input dir images by using trained model. | ||
Save the predictions npy, json, images results to output dir. | ||
npy: `{output_dir}/npy/{batch number}.npy` | ||
json: `{output_dir}/json/{batch number}.json` | ||
images: `{output_dir}/images/{some type}/{input image file name}` | ||
""" | ||
restore_path = None | ||
if checkpoint: | ||
saved_dir = os.environ.get("OUTPUT_DIR", "saved") | ||
restore_path = os.path.join(saved_dir, experiment_id, "checkpoints", checkpoint) | ||
|
||
run(input_dir, output_dir, experiment_id, config_file, restore_path, save_images) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please write something or remove this line