diff --git a/keras_preprocessing/image/utils.py b/keras_preprocessing/image/utils.py index bc3e6886..9330abd4 100644 --- a/keras_preprocessing/image/utils.py +++ b/keras_preprocessing/image/utils.py @@ -81,7 +81,7 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None, """Loads an image into PIL format. # Arguments - path: Path to image file. + path: Path or io.BytesIO stream to image file. grayscale: DEPRECATED use `color_mode="grayscale"`. color_mode: The desired image format. One of "grayscale", "rgb", "rgba". "grayscale" supports 8-bit images and 32-bit signed integer images. @@ -110,33 +110,38 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None, if pil_image is None: raise ImportError('Could not import PIL.Image. ' 'The use of `load_img` requires PIL.') - with open(path, 'rb') as f: - img = pil_image.open(io.BytesIO(f.read())) - if color_mode == 'grayscale': - # if image is not already an 8-bit, 16-bit or 32-bit grayscale image - # convert it to an 8-bit grayscale image. - if img.mode not in ('L', 'I;16', 'I'): - img = img.convert('L') - elif color_mode == 'rgba': - if img.mode != 'RGBA': - img = img.convert('RGBA') - elif color_mode == 'rgb': - if img.mode != 'RGB': - img = img.convert('RGB') - else: - raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"') - if target_size is not None: - width_height_tuple = (target_size[1], target_size[0]) - if img.size != width_height_tuple: - if interpolation not in _PIL_INTERPOLATION_METHODS: - raise ValueError( - 'Invalid interpolation method {} specified. Supported ' - 'methods are {}'.format( - interpolation, - ", ".join(_PIL_INTERPOLATION_METHODS.keys()))) - resample = _PIL_INTERPOLATION_METHODS[interpolation] - img = img.resize(width_height_tuple, resample) - return img + if isinstance(path, io.BytesIO): + buf = path + else: + with open(path, 'rb') as f: + buf = io.BytesIO(f.read()) + + img = pil_image.open(buf) + if color_mode == 'grayscale': + # if image is not already an 8-bit, 16-bit or 32-bit grayscale image + # convert it to an 8-bit grayscale image. + if img.mode not in ('L', 'I;16', 'I'): + img = img.convert('L') + elif color_mode == 'rgba': + if img.mode != 'RGBA': + img = img.convert('RGBA') + elif color_mode == 'rgb': + if img.mode != 'RGB': + img = img.convert('RGB') + else: + raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"') + if target_size is not None: + width_height_tuple = (target_size[1], target_size[0]) + if img.size != width_height_tuple: + if interpolation not in _PIL_INTERPOLATION_METHODS: + raise ValueError( + 'Invalid interpolation method {} specified. Supported ' + 'methods are {}'.format( + interpolation, + ", ".join(_PIL_INTERPOLATION_METHODS.keys()))) + resample = _PIL_INTERPOLATION_METHODS[interpolation] + img = img.resize(width_height_tuple, resample) + return img def list_pictures(directory, ext=('jpg', 'jpeg', 'bmp', 'png', 'ppm', 'tif', diff --git a/tests/image/dataframe_iterator_test.py b/tests/image/dataframe_iterator_test.py index cc89fa15..102970dc 100644 --- a/tests/image/dataframe_iterator_test.py +++ b/tests/image/dataframe_iterator_test.py @@ -256,7 +256,7 @@ def test_dataframe_iterator_class_mode_categorical_multi_label(all_test_images, assert isinstance(batch_y, np.ndarray) assert batch_y.shape == (len(batch_x), 2) for labels in batch_y: - assert all(l in {0, 1} for l in labels) + assert all(lbl in {0, 1} for lbl in labels) # on first 3 batches df = pd.DataFrame({ @@ -272,7 +272,7 @@ def test_dataframe_iterator_class_mode_categorical_multi_label(all_test_images, assert isinstance(batch_y, np.ndarray) assert batch_y.shape == (len(batch_x), 3) for labels in batch_y: - assert all(l in {0, 1} for l in labels) + assert all(lbl in {0, 1} for lbl in labels) assert (batch_y[0] == np.array([1, 1, 0])).all() assert (batch_y[1] == np.array([0, 1, 0])).all() assert (batch_y[2] == np.array([0, 0, 1])).all() diff --git a/tests/image/utils_test.py b/tests/image/utils_test.py index d954e1e2..53bda00d 100644 --- a/tests/image/utils_test.py +++ b/tests/image/utils_test.py @@ -1,3 +1,5 @@ +import io + import numpy as np import pytest import resource @@ -192,6 +194,14 @@ def test_load_img(tmpdir): loaded_im_array = utils.img_to_array(loaded_im, dtype='int32') assert loaded_im_array.shape == (25, 25, 1) + with open(filename_grayscale_32bit, 'rb') as f: + buf = io.BytesIO(f.read()) + loaded_im = utils.load_img(buf, + color_mode='grayscale', + target_size=(25, 25), interpolation="nearest") + loaded_im_array = utils.img_to_array(loaded_im, dtype='int32') + assert loaded_im_array.shape == (25, 25, 1) + # Check that exception is raised if interpolation not supported. loaded_im = utils.load_img(filename_rgb, interpolation="unsupported") diff --git a/tests/sequence_test.py b/tests/sequence_test.py index 36fe1d74..f2f30d7b 100644 --- a/tests/sequence_test.py +++ b/tests/sequence_test.py @@ -101,8 +101,8 @@ def test_skipgrams(): categorical=True) for couple in couples: assert couple[0] - couple[1] <= 3 - for l in labels: - assert len(l) == 2 + for lbl in labels: + assert len(lbl) == 2 def test_remove_long_seq():