Skip to content
This repository has been archived by the owner on Dec 1, 2021. It is now read-only.

Improve readability of the predict script #942

Merged
merged 2 commits into from
Mar 27, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 13 additions & 23 deletions blueoil/cmd/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@


def _get_images(filenames, pre_processor, data_format):
""" """
images = []
raw_images = []

Expand All @@ -52,17 +51,16 @@ def _get_images(filenames, pre_processor, data_format):


def _all_image_files(directory):
all_image_files = []
for file_path in glob(os.path.join(directory, "*")):
if os.path.isfile(file_path) and imghdr.what(file_path) in {"jpeg", "png"}:
all_image_files.append(os.path.abspath(file_path))

return all_image_files
return [
os.path.abspath(file_path)
for file_path in glob(os.path.join(directory, "*"))
if os.path.isfile(file_path) and imghdr.what(file_path) in {"jpeg", "png"}
]


def _run(input_dir, output_dir, config, restore_path, save_images):
ModelClass = config.NETWORK_CLASS
network_kwargs = dict((key.lower(), val) for key, val in config.NETWORK.items())
network_kwargs = {key.lower(): val for key, val in config.NETWORK.items()}

graph = tf.Graph()
with graph.as_default():
Expand Down Expand Up @@ -99,34 +97,26 @@ def _run(input_dir, output_dir, config, restore_path, save_images):

results = []
for step in range(step_size):
start_index = (step) * config.BATCH_SIZE
start_index = step * config.BATCH_SIZE
end_index = (step + 1) * config.BATCH_SIZE

image_files = all_image_files[start_index:end_index]

while len(image_files) != config.BATCH_SIZE:
if len(image_files) < config.BATCH_SIZE:
# add dummy image.
image_files.append(DUMMY_FILENAME)
image_files += [DUMMY_FILENAME] * (config.BATCH_SIZE - len(image_files))

images, raw_images = _get_images(
image_files, config.DATASET.PRE_PROCESSOR, config.DATA_FORMAT)

feed_dict = {images_placeholder: images}
outputs = sess.run(output_op, feed_dict=feed_dict)
outputs = sess.run(output_op, feed_dict={images_placeholder: images})

if config.POST_PROCESSOR:
outputs = config.POST_PROCESSOR(outputs=outputs)["outputs"]

results.append(outputs)

writer.write(
output_dir,
outputs,
raw_images,
image_files,
step,
save_material=save_images
)
writer.write(output_dir, outputs, raw_images, image_files, step, save_material=save_images)

return results

Expand All @@ -138,7 +128,7 @@ def run(input_dir, output_dir, experiment_id, config_file, restore_path, save_im
config = config_util.merge(config, config_util.load(config_file))

if not os.path.isdir(input_dir):
raise Exception("Input directory {} does not exist.".format(input_dir))
raise FileNotFoundError("Input directory not found: '{}'".format(input_dir))

if restore_path is None:
restore_file = search_restore_filename(environment.CHECKPOINTS_DIR)
Expand All @@ -147,7 +137,7 @@ def run(input_dir, output_dir, experiment_id, config_file, restore_path, save_im
print("Restore from {}".format(restore_path))

if not os.path.exists("{}.index".format(restore_path)):
raise Exception("restore file {} dont exists.".format(restore_path))
raise FileNotFoundError("Checkpoint file not found: '{}'".format(restore_path))

print("---- start predict ----")

Expand Down