Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix lint and allow iobytes to load image #294

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 33 additions & 28 deletions keras_preprocessing/image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
"""Loads an image into PIL format.

# Arguments
path: Path to image file.
path: Path or io.BytesIO stream to image file.
grayscale: DEPRECATED use `color_mode="grayscale"`.
color_mode: The desired image format. One of "grayscale", "rgb", "rgba".
"grayscale" supports 8-bit images and 32-bit signed integer images.
Expand Down Expand Up @@ -110,33 +110,38 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
if pil_image is None:
raise ImportError('Could not import PIL.Image. '
'The use of `load_img` requires PIL.')
with open(path, 'rb') as f:
img = pil_image.open(io.BytesIO(f.read()))
if color_mode == 'grayscale':
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ('L', 'I;16', 'I'):
img = img.convert('L')
elif color_mode == 'rgba':
if img.mode != 'RGBA':
img = img.convert('RGBA')
elif color_mode == 'rgb':
if img.mode != 'RGB':
img = img.convert('RGB')
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]
img = img.resize(width_height_tuple, resample)
return img
if isinstance(path, io.BytesIO):
buf = path
else:
with open(path, 'rb') as f:
buf = io.BytesIO(f.read())

img = pil_image.open(buf)
if color_mode == 'grayscale':
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ('L', 'I;16', 'I'):
img = img.convert('L')
elif color_mode == 'rgba':
if img.mode != 'RGBA':
img = img.convert('RGBA')
elif color_mode == 'rgb':
if img.mode != 'RGB':
img = img.convert('RGB')
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]
img = img.resize(width_height_tuple, resample)
return img


def list_pictures(directory, ext=('jpg', 'jpeg', 'bmp', 'png', 'ppm', 'tif',
Expand Down
4 changes: 2 additions & 2 deletions tests/image/dataframe_iterator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def test_dataframe_iterator_class_mode_categorical_multi_label(all_test_images,
assert isinstance(batch_y, np.ndarray)
assert batch_y.shape == (len(batch_x), 2)
for labels in batch_y:
assert all(l in {0, 1} for l in labels)
assert all(lbl in {0, 1} for lbl in labels)

# on first 3 batches
df = pd.DataFrame({
Expand All @@ -272,7 +272,7 @@ def test_dataframe_iterator_class_mode_categorical_multi_label(all_test_images,
assert isinstance(batch_y, np.ndarray)
assert batch_y.shape == (len(batch_x), 3)
for labels in batch_y:
assert all(l in {0, 1} for l in labels)
assert all(lbl in {0, 1} for lbl in labels)
assert (batch_y[0] == np.array([1, 1, 0])).all()
assert (batch_y[1] == np.array([0, 1, 0])).all()
assert (batch_y[2] == np.array([0, 0, 1])).all()
Expand Down
10 changes: 10 additions & 0 deletions tests/image/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import io

import numpy as np
import pytest
import resource
Expand Down Expand Up @@ -192,6 +194,14 @@ def test_load_img(tmpdir):
loaded_im_array = utils.img_to_array(loaded_im, dtype='int32')
assert loaded_im_array.shape == (25, 25, 1)

with open(filename_grayscale_32bit, 'rb') as f:
buf = io.BytesIO(f.read())
loaded_im = utils.load_img(buf,
color_mode='grayscale',
target_size=(25, 25), interpolation="nearest")
loaded_im_array = utils.img_to_array(loaded_im, dtype='int32')
assert loaded_im_array.shape == (25, 25, 1)

# Check that exception is raised if interpolation not supported.

loaded_im = utils.load_img(filename_rgb, interpolation="unsupported")
Expand Down
4 changes: 2 additions & 2 deletions tests/sequence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def test_skipgrams():
categorical=True)
for couple in couples:
assert couple[0] - couple[1] <= 3
for l in labels:
assert len(l) == 2
for lbl in labels:
assert len(lbl) == 2


def test_remove_long_seq():
Expand Down