From c7cfb4622e1acdf20596159e9240487d97c117d5 Mon Sep 17 00:00:00 2001 From: Taketoshi Fujiwara Date: Wed, 25 Mar 2020 10:59:27 +0900 Subject: [PATCH] Improve readability of the predict script --- blueoil/cmd/predict.py | 36 +++++++++++++----------------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/blueoil/cmd/predict.py b/blueoil/cmd/predict.py index 3db97e939..fcd2b52b8 100644 --- a/blueoil/cmd/predict.py +++ b/blueoil/cmd/predict.py @@ -31,7 +31,6 @@ def _get_images(filenames, pre_processor, data_format): - """ """ images = [] raw_images = [] @@ -52,17 +51,16 @@ def _get_images(filenames, pre_processor, data_format): def _all_image_files(directory): - all_image_files = [] - for file_path in glob(os.path.join(directory, "*")): - if os.path.isfile(file_path) and imghdr.what(file_path) in {"jpeg", "png"}: - all_image_files.append(os.path.abspath(file_path)) - - return all_image_files + return [ + os.path.abspath(file_path) + for file_path in glob(os.path.join(directory, "*")) + if os.path.isfile(file_path) and imghdr.what(file_path) in {"jpeg", "png"} + ] def _run(input_dir, output_dir, config, restore_path, save_images): ModelClass = config.NETWORK_CLASS - network_kwargs = dict((key.lower(), val) for key, val in config.NETWORK.items()) + network_kwargs = {key.lower(): val for key, val in config.NETWORK.items()} graph = tf.Graph() with graph.as_default(): @@ -99,34 +97,26 @@ def _run(input_dir, output_dir, config, restore_path, save_images): results = [] for step in range(step_size): - start_index = (step) * config.BATCH_SIZE + start_index = step * config.BATCH_SIZE end_index = (step + 1) * config.BATCH_SIZE image_files = all_image_files[start_index:end_index] - while len(image_files) != config.BATCH_SIZE: + if len(image_files) < config.BATCH_SIZE: # add dummy image. - image_files.append(DUMMY_FILENAME) + image_files += [DUMMY_FILENAME] * (config.BATCH_SIZE - len(image_files)) images, raw_images = _get_images( image_files, config.DATASET.PRE_PROCESSOR, config.DATA_FORMAT) - feed_dict = {images_placeholder: images} - outputs = sess.run(output_op, feed_dict=feed_dict) + outputs = sess.run(output_op, feed_dict={images_placeholder: images}) if config.POST_PROCESSOR: outputs = config.POST_PROCESSOR(outputs=outputs)["outputs"] results.append(outputs) - writer.write( - output_dir, - outputs, - raw_images, - image_files, - step, - save_material=save_images - ) + writer.write(output_dir, outputs, raw_images, image_files, step, save_material=save_images) return results @@ -138,7 +128,7 @@ def run(input_dir, output_dir, experiment_id, config_file, restore_path, save_im config = config_util.merge(config, config_util.load(config_file)) if not os.path.isdir(input_dir): - raise Exception("Input directory {} does not exist.".format(input_dir)) + raise FileNotFoundError("Input directory not found: '{}'".format(input_dir)) if restore_path is None: restore_file = search_restore_filename(environment.CHECKPOINTS_DIR) @@ -147,7 +137,7 @@ def run(input_dir, output_dir, experiment_id, config_file, restore_path, save_im print("Restore from {}".format(restore_path)) if not os.path.exists("{}.index".format(restore_path)): - raise Exception("restore file {} dont exists.".format(restore_path)) + raise FileNotFoundError("Checkpoint file not found: '{}'".format(restore_path)) print("---- start predict ----")