Skip to content
This repository has been archived by the owner on Dec 1, 2021. It is now read-only.

Move lmnet/executor/predict.py to blueoil/cmd #838

Merged
merged 5 commits into from
Feb 14, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions blueoil/cmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ def convert(experiment_id, checkpoint, template, image_size, project_name):
help='ID of this experiment.',
required=True,
)
@click.option(
'-c',
'--config',
help='Path of config file.',
default=None,
)
@click.option(
'-p',
'--checkpoint',
Expand All @@ -147,8 +153,8 @@ def convert(experiment_id, checkpoint, template, image_size, project_name):
help="Flag of saving images. Default is True.",
default=True,
)
def predict(input, output, experiment_id, checkpoint, save_images):
run_predict(input, output, experiment_id, checkpoint, save_images)
def predict(input, output, experiment_id, config, checkpoint, save_images):
run_predict(input, output, experiment_id, config, checkpoint, save_images)

click.echo('Result files are created: {}'.format(output))

Expand Down
159 changes: 149 additions & 10 deletions blueoil/cmd/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,160 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import imghdr
import math
import os
from glob import glob

from executor.predict import run as run_predict
import click
import numpy as np
import tensorflow as tf

from blueoil import environment
from blueoil.utils.image import load_image
from blueoil.utils import config as config_util
from blueoil.utils.executor import search_restore_filename
from blueoil.utils.predict_output.writer import OutputWriter

def predict(input, output, experiment_id, checkpoint=None, save_images=True):
"""Predict input images."""
DUMMY_FILENAME = "DUMMY_FILE"

output_dir = os.environ.get("OUTPUT_DIR", "saved")

if checkpoint is None:
restore_path = None
else:
restore_path = os.path.join(
output_dir, experiment_id, "checkpoints", checkpoint
def _get_images(filenames, pre_processor, data_format):
""" """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please write something or remove this line

images = []
raw_images = []

for filename in filenames:
if filename == DUMMY_FILENAME:
raw_image = np.zeros((64, 64, 3), dtype=np.uint8)
else:
raw_image = load_image(filename)

image = pre_processor(image=raw_image)['image']
if data_format == 'NCHW':
image = np.transpose(image, [2, 0, 1])

images.append(image)
raw_images.append(raw_image)

return np.array(images), np.array(raw_images)


def _all_image_files(directory):
all_image_files = []
for file_path in glob(os.path.join(directory, "*")):
if os.path.isfile(file_path) and imghdr.what(file_path) in {"jpeg", "png"}:
all_image_files.append(os.path.abspath(file_path))

return all_image_files


def _run(input_dir, output_dir, config, restore_path, save_images):
ModelClass = config.NETWORK_CLASS
network_kwargs = dict((key.lower(), val) for key, val in config.NETWORK.items())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{a: b for a, b in c.items()}


graph = tf.Graph()
with graph.as_default():
model = ModelClass(
classes=config.CLASSES,
is_debug=config.IS_DEBUG,
**network_kwargs
)

run_predict(input, output, experiment_id, None, restore_path, save_images)
is_training = tf.constant(False, name="is_training")

images_placeholder, _ = model.placeholders()
output_op = model.inference(images_placeholder, is_training)

init_op = tf.global_variables_initializer()

saver = tf.compat.v1.train.Saver(max_to_keep=None)

session_config = tf.ConfigProto()
sess = tf.Session(graph=graph, config=session_config)
sess.run(init_op)
saver.restore(sess, restore_path)

all_image_files = _all_image_files(input_dir)

step_size = int(math.ceil(len(all_image_files) / config.BATCH_SIZE))

writer = OutputWriter(
task=config.TASK,
classes=config.CLASSES,
image_size=config.IMAGE_SIZE,
data_format=config.DATA_FORMAT
)

results = []
for step in range(step_size):
start_index = (step) * config.BATCH_SIZE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove parenthesis

end_index = (step + 1) * config.BATCH_SIZE

image_files = all_image_files[start_index:end_index]

while len(image_files) != config.BATCH_SIZE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is inefficient, and will loop forever if len(image_files) > config.BATCH_SIZE.

if len > size:
    files += [DUMMY] * (len - size)

# add dummy image.
image_files.append(DUMMY_FILENAME)

images, raw_images = _get_images(
image_files, config.DATASET.PRE_PROCESSOR, config.DATA_FORMAT)

feed_dict = {images_placeholder: images}
outputs = sess.run(output_op, feed_dict=feed_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put feed_dict in a single line


if config.POST_PROCESSOR:
outputs = config.POST_PROCESSOR(outputs=outputs)["outputs"]

results.append(outputs)

writer.write(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you concatenate into fewer lines?

output_dir,
outputs,
raw_images,
image_files,
step,
save_material=save_images
)

return results


def run(input_dir, output_dir, experiment_id, config_file, restore_path, save_images):
environment.init(experiment_id)
config = config_util.load_from_experiment()
if config_file:
config = config_util.merge(config, config_util.load(config_file))

if not os.path.isdir(input_dir):
raise Exception("Input directory {} does not exist.".format(input_dir))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid raising Exception and use more specific ones. In this case, FileNotFoundError might be a good choice


if restore_path is None:
restore_file = search_restore_filename(environment.CHECKPOINTS_DIR)
restore_path = os.path.join(environment.CHECKPOINTS_DIR, restore_file)

print("Restore from {}".format(restore_path))

if not os.path.exists("{}.index".format(restore_path)):
raise Exception("restore file {} dont exists.".format(restore_path))

print("---- start predict ----")

_run(input_dir, output_dir, config, restore_path, save_images)

print("---- end predict ----")


def predict(input_dir, output_dir, experiment_id, config_file=None, checkpoint=None, save_images=True):
"""Make predictions from input dir images by using trained model.
Save the predictions npy, json, images results to output dir.
npy: `{output_dir}/npy/{batch number}.npy`
json: `{output_dir}/json/{batch number}.json`
images: `{output_dir}/images/{some type}/{input image file name}`
"""
restore_path = None
if checkpoint:
saved_dir = os.environ.get("OUTPUT_DIR", "saved")
restore_path = os.path.join(saved_dir, experiment_id, "checkpoints", checkpoint)

run(input_dir, output_dir, experiment_id, config_file, restore_path, save_images)
Loading