-
Notifications
You must be signed in to change notification settings - Fork 86
Refactor predict output #579
Changes from 1 commit
6cfa517
75a9491
4206dbf
4d381f2
6b39933
4e5ae98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import os | ||
|
||
import numpy as np | ||
|
||
from lmnet.utils.predict_output.output import ImageFromJson | ||
from lmnet.utils.predict_output.output import JsonOutput | ||
|
||
|
||
class OutputWriter(): | ||
def __init__(self, task, classes, image_size, data_format): | ||
self.json_output = JsonOutput(task, classes, image_size, data_format) | ||
self.image_from_json = ImageFromJson(task, classes, image_size) | ||
|
||
def write(self, dest, outputs, raw_images, image_files, step, save_material=True): | ||
"""Save predict output to disk. | ||
numpy array, JSON, and images if you want. | ||
|
||
Args: | ||
dest (str): paht to save file | ||
outputs (np.ndarray): save ndarray | ||
raw_images (np.ndarray): image ndarray | ||
image_files (list[str]): list of file names. | ||
step (int): value of training step | ||
save_material (bool, optional): save materials or not. Defaults to True. | ||
""" | ||
save_npy(dest, outputs, step) | ||
|
||
json = self.json_output(outputs, raw_images, image_files) | ||
save_json(dest, json, step) | ||
|
||
if save_material: | ||
materials = self.image_from_json(json, raw_images, image_files) | ||
save_materials(dest, materials, step) | ||
|
||
|
||
def save_npy(dest, outputs, step): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why these function are not method of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer "Don't put in a method if you don't need it" mentality. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. |
||
"""Save numpy array to disk. | ||
|
||
Args: | ||
dest (str): path to save file | ||
outputs (np.ndarray): save ndarray | ||
step (int): value of training step | ||
|
||
Raises: | ||
PermissionError: If dest dir has no permission to write. | ||
ValueError: If type of step is not int. | ||
""" | ||
if type(step) is not int: | ||
raise ValueError("step must be integer.") | ||
|
||
filepath = os.path.join(dest, "npy", "{}.npy".format(step)) | ||
os.makedirs(os.path.dirname(filepath), exist_ok=True) | ||
|
||
np.save(filepath, outputs) | ||
print("save npy: {}".format(filepath)) | ||
|
||
|
||
def save_json(dest, json, step): | ||
"""Save JSON to disk. | ||
|
||
Args: | ||
dest (str): path to save file | ||
json (str): dumped json string | ||
step (int): value of training step | ||
|
||
Raises: | ||
PermissionError: If dest dir has no permission to write. | ||
ValueError: If type of step is not int. | ||
""" | ||
if type(step) is not int: | ||
raise ValueError("step must be integer.") | ||
|
||
filepath = os.path.join(dest, "json", "{}.json".format(step)) | ||
os.makedirs(os.path.dirname(filepath), exist_ok=True) | ||
|
||
with open(filepath, "w") as f: | ||
f.write(json) | ||
|
||
print("save json: {}".format(filepath)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should not |
||
|
||
|
||
def save_materials(dest, materials, step): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One day we will come to deal with tasks other than images. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. To the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Go to the future! |
||
"""Save materials to disk. | ||
|
||
Args: | ||
dest (string): path to save file | ||
materials (list[(str, PIL.Image)]): image data, str in tuple is filename. | ||
step (int): value of training step | ||
|
||
Raises: | ||
PermissionError: If dest dir has no permission to write. | ||
ValueError: If type of step is not int. | ||
""" | ||
if type(step) is not int: | ||
raise ValueError("step must be integer.") | ||
|
||
for filename, content in materials: | ||
filepath = os.path.join(dest, "images", "{}".format(step), filename) | ||
os.makedirs(os.path.dirname(filepath), exist_ok=True) | ||
|
||
content.save(filepath) | ||
|
||
print("save image: {}".format(filepath)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import tempfile | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def temp_dir(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
temp = tempfile.TemporaryDirectory() | ||
yield temp.name | ||
temp.cleanup() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import json | ||
import os | ||
|
||
import numpy as np | ||
import pytest | ||
from PIL import Image | ||
|
||
from lmnet.common import Tasks | ||
from lmnet.utils.predict_output.writer import OutputWriter | ||
from lmnet.utils.predict_output.writer import save_json | ||
from lmnet.utils.predict_output.writer import save_npy | ||
from lmnet.utils.predict_output.writer import save_materials | ||
|
||
|
||
def test_write(temp_dir): | ||
task = Tasks.CLASSIFICATION | ||
classes = ("aaa", "bbb", "ccc") | ||
image_size = (320, 280) | ||
data_format = "NCHW" | ||
|
||
writer = OutputWriter(task, classes, image_size, data_format) | ||
outputs = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) | ||
raw_images = np.zeros((3, 320, 280, 3), dtype=np.uint8) | ||
image_files = ["dummy1.png", "dummy2.png", "dummy3.png"] | ||
|
||
writer.write(temp_dir, outputs, raw_images, image_files, 1) | ||
|
||
assert os.path.exists(os.path.join(temp_dir, "npy", "1.npy")) | ||
assert os.path.exists(os.path.join(temp_dir, "json", "1.json")) | ||
assert os.path.exists(os.path.join(temp_dir, "images", "1", "aaa", "dummy3.png")) | ||
assert os.path.exists(os.path.join(temp_dir, "images", "1", "bbb", "dummy1.png")) | ||
assert os.path.exists(os.path.join(temp_dir, "images", "1", "ccc", "dummy2.png")) | ||
|
||
|
||
def test_save_npy(temp_dir): | ||
"""Test for save npy to existed dir""" | ||
data = np.array([[1, 2, 3], [4, 5, 6]]) | ||
save_npy(temp_dir, data, step=1) | ||
|
||
assert os.path.exists(os.path.join(temp_dir, "npy", "1.npy")) | ||
|
||
|
||
def test_save_npy_not_existed_dir(temp_dir): | ||
"""Test for save npy to not existed dir""" | ||
data = np.array([[1, 2, 3], [4, 5, 6]]) | ||
dist = os.path.join(temp_dir, 'not_existed') | ||
save_npy(dist, data, step=1) | ||
|
||
assert os.path.exists(os.path.join(dist, "npy", "1.npy")) | ||
|
||
|
||
def test_save_npy_with_invalid_step(temp_dir): | ||
"""Test for save npy with invalid step arg""" | ||
data = np.array([[1, 2, 3], [4, 5, 6]]) | ||
|
||
with pytest.raises(ValueError): | ||
save_npy(temp_dir, data, step={"invalid": "dict"}) | ||
|
||
|
||
def test_save_json(temp_dir): | ||
"""Test for save json to existed dir""" | ||
data = json.dumps({"k": "v", "list": [1, 2, 3]}) | ||
save_json(temp_dir, data, step=1) | ||
|
||
assert os.path.exists(os.path.join(temp_dir, "json", "1.json")) | ||
|
||
|
||
def test_save_json_not_existed_dir(temp_dir): | ||
"""Test for save json to not existed dir""" | ||
data = json.dumps({"k": "v", "list": [1, 2, 3]}) | ||
dist = os.path.join(temp_dir, 'not_existed') | ||
save_json(dist, data, step=1) | ||
|
||
assert os.path.exists(os.path.join(dist, "json", "1.json")) | ||
|
||
|
||
def test_save_json_with_invalid_step(temp_dir): | ||
"""Test for save json with invalid step arg""" | ||
data = json.dumps({"k": "v", "list": [1, 2, 3]}) | ||
|
||
with pytest.raises(ValueError): | ||
save_json(temp_dir, data, step={"invalid": "dict"}) | ||
|
||
|
||
def test_save_materials(temp_dir): | ||
"""Test for save materials""" | ||
image1 = [[[0, 0, 0], [0, 0, 0]], [[255, 255, 255], [255, 255, 255]]] | ||
image2 = [[[0, 0, 0], [255, 255, 255]], [[255, 255, 255], [0, 0, 0]]] | ||
image3 = [[[255, 255, 255], [255, 255, 255]], [[0, 0, 0], [0, 0, 0]]] | ||
|
||
data = [ | ||
("image1.png", Image.fromarray(np.array(image1, dtype=np.uint8))), | ||
("image2.png", Image.fromarray(np.array(image2, dtype=np.uint8))), | ||
("image3.png", Image.fromarray(np.array(image3, dtype=np.uint8))), | ||
] | ||
save_materials(temp_dir, data, step=1) | ||
|
||
assert os.path.exists(os.path.join(temp_dir, "images", "1", "image1.png")) | ||
assert os.path.exists(os.path.join(temp_dir, "images", "1", "image2.png")) | ||
assert os.path.exists(os.path.join(temp_dir, "images", "1", "image3.png")) | ||
|
||
|
||
def test_save_materials_not_existed_dir(temp_dir): | ||
"""Test for save materials to not existed dir""" | ||
image1 = [[[0, 0, 0], [0, 0, 0]], [[255, 255, 255], [255, 255, 255]]] | ||
image2 = [[[0, 0, 0], [255, 255, 255]], [[255, 255, 255], [0, 0, 0]]] | ||
image3 = [[[255, 255, 255], [255, 255, 255]], [[0, 0, 0], [0, 0, 0]]] | ||
|
||
data = [ | ||
("image1.png", Image.fromarray(np.array(image1, dtype=np.uint8))), | ||
("image2.png", Image.fromarray(np.array(image2, dtype=np.uint8))), | ||
("image3.png", Image.fromarray(np.array(image3, dtype=np.uint8))), | ||
] | ||
dist = os.path.join(temp_dir, 'not_existed') | ||
save_materials(dist, data, step=1) | ||
|
||
assert os.path.exists(os.path.join(dist, "images", "1", "image1.png")) | ||
assert os.path.exists(os.path.join(dist, "images", "1", "image2.png")) | ||
assert os.path.exists(os.path.join(dist, "images", "1", "image3.png")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo in
paht
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed!