Skip to content
This repository has been archived by the owner on Dec 1, 2021. It is now read-only.

Refactor predict output #579

Merged
merged 6 commits into from
Oct 31, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 11 additions & 43 deletions lmnet/executor/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lmnet.utils.image import load_image
from lmnet.utils import config as config_util
from lmnet.utils.executor import search_restore_filename
from lmnet.utils.predict_output import ImageFromJson, JsonOutput
from lmnet.utils.predict_output.writer import OutputWriter

DUMMY_FILENAME = "DUMMY_FILE"

Expand Down Expand Up @@ -61,33 +61,6 @@ def _all_image_files(directory):
return all_image_files


def _save_images(output_dir, filename_images, step):

for filename, image in filename_images:
output_file_name = os.path.join(output_dir, "images", "{}".format(step), filename)
os.makedirs(os.path.dirname(output_file_name), exist_ok=True)

image.save(output_file_name)
print("save image: {}".format(output_file_name))


def _save_json(output_dir, json_obj, step):
output_file_name = os.path.join(output_dir, "json", "{}.json".format(step))
os.makedirs(os.path.dirname(output_file_name), exist_ok=True)

with open(output_file_name, "w") as json_file:
json_file.write(json_obj)
print("save json: {}".format(output_file_name))


def _save_outputs(output_dir, outputs, step):
output_file_name = os.path.join(output_dir, "npy", "{}.npy".format(step))
os.makedirs(os.path.dirname(output_file_name), exist_ok=True)

np.save(output_file_name, outputs)
print("save npy: {}".format(output_file_name))


def _run(input_dir, output_dir, config, restore_path, save_images):
ModelClass = config.NETWORK_CLASS
network_kwargs = dict((key.lower(), val) for key, val in config.NETWORK.items())
Expand Down Expand Up @@ -118,17 +91,11 @@ def _run(input_dir, output_dir, config, restore_path, save_images):

step_size = int(math.ceil(len(all_image_files) / config.BATCH_SIZE))

json_output = JsonOutput(
task=config.TASK,
classes=config.CLASSES,
image_size=config.IMAGE_SIZE,
data_format=config.DATA_FORMAT,
)

image_from_json = ImageFromJson(
writer = OutputWriter(
task=config.TASK,
classes=config.CLASSES,
image_size=config.IMAGE_SIZE,
data_format=config.DATA_FORMAT
)

results = []
Expand All @@ -152,14 +119,15 @@ def _run(input_dir, output_dir, config, restore_path, save_images):
outputs = config.POST_PROCESSOR(outputs=outputs)["outputs"]

results.append(outputs)
_save_outputs(output_dir, outputs, step)

json = json_output(outputs, raw_images, image_files)
_save_json(output_dir, json, step)

if save_images:
filename_images = image_from_json(json, raw_images, image_files)
_save_images(output_dir, filename_images, step)
writer.write(
output_dir,
outputs,
raw_images,
image_files,
step,
save_material=save_images
)

return results

Expand Down
103 changes: 103 additions & 0 deletions lmnet/lmnet/utils/predict_output/writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os

import numpy as np

from lmnet.utils.predict_output.output import ImageFromJson
from lmnet.utils.predict_output.output import JsonOutput


class OutputWriter():
def __init__(self, task, classes, image_size, data_format):
self.json_output = JsonOutput(task, classes, image_size, data_format)
self.image_from_json = ImageFromJson(task, classes, image_size)

def write(self, dest, outputs, raw_images, image_files, step, save_material=True):
"""Save predict output to disk.
numpy array, JSON, and images if you want.

Args:
dest (str): paht to save file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in paht .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

outputs (np.ndarray): save ndarray
raw_images (np.ndarray): image ndarray
image_files (list[str]): list of file names.
step (int): value of training step
save_material (bool, optional): save materials or not. Defaults to True.
"""
save_npy(dest, outputs, step)

json = self.json_output(outputs, raw_images, image_files)
save_json(dest, json, step)

if save_material:
materials = self.image_from_json(json, raw_images, image_files)
save_materials(dest, materials, step)


def save_npy(dest, outputs, step):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why these function are not method of OutputWriter ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer "Don't put in a method if you don't need it" mentality.
These functions don't need the state of OutputWriter.
IMO, This makes these functions independent and easier to test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

"""Save numpy array to disk.

Args:
dest (str): path to save file
outputs (np.ndarray): save ndarray
step (int): value of training step

Raises:
PermissionError: If dest dir has no permission to write.
ValueError: If type of step is not int.
"""
if type(step) is not int:
raise ValueError("step must be integer.")

filepath = os.path.join(dest, "npy", "{}.npy".format(step))
os.makedirs(os.path.dirname(filepath), exist_ok=True)

np.save(filepath, outputs)
print("save npy: {}".format(filepath))


def save_json(dest, json, step):
"""Save JSON to disk.

Args:
dest (str): path to save file
json (str): dumped json string
step (int): value of training step

Raises:
PermissionError: If dest dir has no permission to write.
ValueError: If type of step is not int.
"""
if type(step) is not int:
raise ValueError("step must be integer.")

filepath = os.path.join(dest, "json", "{}.json".format(step))
os.makedirs(os.path.dirname(filepath), exist_ok=True)

with open(filepath, "w") as f:
f.write(json)

print("save json: {}".format(filepath))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not print debugging information.



def save_materials(dest, materials, step):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_materials looks general name. save_images is not enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One day we will come to deal with tasks other than images.
Named for that time.
If not, I will fix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. To the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go to the future!

"""Save materials to disk.

Args:
dest (string): path to save file
materials (list[(str, PIL.Image)]): image data, str in tuple is filename.
step (int): value of training step

Raises:
PermissionError: If dest dir has no permission to write.
ValueError: If type of step is not int.
"""
if type(step) is not int:
raise ValueError("step must be integer.")

for filename, content in materials:
filepath = os.path.join(dest, "images", "{}".format(step), filename)
os.makedirs(os.path.dirname(filepath), exist_ok=True)

content.save(filepath)

print("save image: {}".format(filepath))
10 changes: 10 additions & 0 deletions lmnet/tests/lmnet_tests/util_tests/predict_output/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import tempfile

import pytest


@pytest.fixture
def temp_dir():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with TemporaryDirectory() as t:
    yield t.name

temp = tempfile.TemporaryDirectory()
yield temp.name
temp.cleanup()
119 changes: 119 additions & 0 deletions lmnet/tests/lmnet_tests/util_tests/predict_output/test_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import json
import os

import numpy as np
import pytest
from PIL import Image

from lmnet.common import Tasks
from lmnet.utils.predict_output.writer import OutputWriter
from lmnet.utils.predict_output.writer import save_json
from lmnet.utils.predict_output.writer import save_npy
from lmnet.utils.predict_output.writer import save_materials


def test_write(temp_dir):
task = Tasks.CLASSIFICATION
classes = ("aaa", "bbb", "ccc")
image_size = (320, 280)
data_format = "NCHW"

writer = OutputWriter(task, classes, image_size, data_format)
outputs = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]])
raw_images = np.zeros((3, 320, 280, 3), dtype=np.uint8)
image_files = ["dummy1.png", "dummy2.png", "dummy3.png"]

writer.write(temp_dir, outputs, raw_images, image_files, 1)

assert os.path.exists(os.path.join(temp_dir, "npy", "1.npy"))
assert os.path.exists(os.path.join(temp_dir, "json", "1.json"))
assert os.path.exists(os.path.join(temp_dir, "images", "1", "aaa", "dummy3.png"))
assert os.path.exists(os.path.join(temp_dir, "images", "1", "bbb", "dummy1.png"))
assert os.path.exists(os.path.join(temp_dir, "images", "1", "ccc", "dummy2.png"))


def test_save_npy(temp_dir):
"""Test for save npy to existed dir"""
data = np.array([[1, 2, 3], [4, 5, 6]])
save_npy(temp_dir, data, step=1)

assert os.path.exists(os.path.join(temp_dir, "npy", "1.npy"))


def test_save_npy_not_existed_dir(temp_dir):
"""Test for save npy to not existed dir"""
data = np.array([[1, 2, 3], [4, 5, 6]])
dist = os.path.join(temp_dir, 'not_existed')
save_npy(dist, data, step=1)

assert os.path.exists(os.path.join(dist, "npy", "1.npy"))


def test_save_npy_with_invalid_step(temp_dir):
"""Test for save npy with invalid step arg"""
data = np.array([[1, 2, 3], [4, 5, 6]])

with pytest.raises(ValueError):
save_npy(temp_dir, data, step={"invalid": "dict"})


def test_save_json(temp_dir):
"""Test for save json to existed dir"""
data = json.dumps({"k": "v", "list": [1, 2, 3]})
save_json(temp_dir, data, step=1)

assert os.path.exists(os.path.join(temp_dir, "json", "1.json"))


def test_save_json_not_existed_dir(temp_dir):
"""Test for save json to not existed dir"""
data = json.dumps({"k": "v", "list": [1, 2, 3]})
dist = os.path.join(temp_dir, 'not_existed')
save_json(dist, data, step=1)

assert os.path.exists(os.path.join(dist, "json", "1.json"))


def test_save_json_with_invalid_step(temp_dir):
"""Test for save json with invalid step arg"""
data = json.dumps({"k": "v", "list": [1, 2, 3]})

with pytest.raises(ValueError):
save_json(temp_dir, data, step={"invalid": "dict"})


def test_save_materials(temp_dir):
"""Test for save materials"""
image1 = [[[0, 0, 0], [0, 0, 0]], [[255, 255, 255], [255, 255, 255]]]
image2 = [[[0, 0, 0], [255, 255, 255]], [[255, 255, 255], [0, 0, 0]]]
image3 = [[[255, 255, 255], [255, 255, 255]], [[0, 0, 0], [0, 0, 0]]]

data = [
("image1.png", Image.fromarray(np.array(image1, dtype=np.uint8))),
("image2.png", Image.fromarray(np.array(image2, dtype=np.uint8))),
("image3.png", Image.fromarray(np.array(image3, dtype=np.uint8))),
]
save_materials(temp_dir, data, step=1)

assert os.path.exists(os.path.join(temp_dir, "images", "1", "image1.png"))
assert os.path.exists(os.path.join(temp_dir, "images", "1", "image2.png"))
assert os.path.exists(os.path.join(temp_dir, "images", "1", "image3.png"))


def test_save_materials_not_existed_dir(temp_dir):
"""Test for save materials to not existed dir"""
image1 = [[[0, 0, 0], [0, 0, 0]], [[255, 255, 255], [255, 255, 255]]]
image2 = [[[0, 0, 0], [255, 255, 255]], [[255, 255, 255], [0, 0, 0]]]
image3 = [[[255, 255, 255], [255, 255, 255]], [[0, 0, 0], [0, 0, 0]]]

data = [
("image1.png", Image.fromarray(np.array(image1, dtype=np.uint8))),
("image2.png", Image.fromarray(np.array(image2, dtype=np.uint8))),
("image3.png", Image.fromarray(np.array(image3, dtype=np.uint8))),
]
dist = os.path.join(temp_dir, 'not_existed')
save_materials(dist, data, step=1)

assert os.path.exists(os.path.join(dist, "images", "1", "image1.png"))
assert os.path.exists(os.path.join(dist, "images", "1", "image2.png"))
assert os.path.exists(os.path.join(dist, "images", "1", "image3.png"))
2 changes: 1 addition & 1 deletion lmnet/tests/lmnet_tests/util_tests/test_predict_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import PIL.Image

from lmnet.common import Tasks
from lmnet.utils.predict_output import JsonOutput
from lmnet.utils.predict_output.output import JsonOutput


def test_classification_json():
Expand Down