diff --git a/CHANGELOG.md b/CHANGELOG.md index 5531884..e6f4621 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## [2.7.0](https://github.com/Kohulan/DECIMER-Image_Transformer/compare/v2.6.0...v2.7.0) (2024-06-11) + + +### Features + +* add entry_points ([0cc241f](https://github.com/Kohulan/DECIMER-Image_Transformer/commit/0cc241f0bb1caf28f1ff4528a5e285561a17730f)) + + +### Bug Fixes + +* one extra token prediction [#104](https://github.com/Kohulan/DECIMER-Image_Transformer/issues/104) ([6e0e483](https://github.com/Kohulan/DECIMER-Image_Transformer/commit/6e0e483285d9dbe2ed069e3a1c9e25d69dec6974)) + ## [2.6.0](https://github.com/Kohulan/DECIMER-Image_Transformer/compare/v2.5.0...v2.6.0) (2024-03-08) diff --git a/DECIMER/config.py b/DECIMER/config.py index c6e0087..3f492eb 100644 --- a/DECIMER/config.py +++ b/DECIMER/config.py @@ -11,6 +11,7 @@ from PIL import Image from PIL import ImageEnhance from pillow_heif import register_heif_opener +from typing import Union import DECIMER.Efficient_Net_encoder as Efficient_Net_encoder import DECIMER.Transformer_decoder as Transformer_decoder @@ -95,26 +96,69 @@ def HEIF_to_pillow(image_path: str): return heif_file -def remove_transparent(image_path: str): +def remove_transparent(image: Union[str, np.ndarray]) -> Image.Image: """ - Removes the transparent layer from a PNG image with an alpha channel - Args: image_path (str): path of input image - Returns: PIL.Image + Removes the transparent layer from a PNG image with an alpha channel. + + Args: + image (Union[str, np.ndarray]): Path of the input image or a numpy array representing the image. + + Returns: + PIL.Image.Image: The image with transparency removed. """ - try: - png = Image.open(image_path).convert("RGBA") - except Exception as e: - if type(e).__name__ == "UnidentifiedImageError": - png = HEIF_to_pillow(image_path) - else: - print(e) - raise Exception + def process_image(png: Image.Image) -> Image.Image: + """ + Helper function to remove transparency from a single image. + + Args: + png (PIL.Image.Image): The input PIL image with transparency. + + Returns: + PIL.Image.Image: The image with transparency removed. + """ + background = Image.new("RGBA", png.size, (255, 255, 255)) + alpha_composite = Image.alpha_composite(background, png) + return alpha_composite - background = Image.new("RGBA", png.size, (255, 255, 255)) + def handle_image_path(image_path: str) -> Image.Image: + """ + Helper function to handle image paths. + + Args: + image_path (str): The path to the input image. + + Returns: + PIL.Image.Image: The image with transparency removed. + """ + try: + png = Image.open(image_path).convert("RGBA") + except Exception as e: + if type(e).__name__ == "UnidentifiedImageError": + png = HEIF_to_pillow(image_path) + else: + print(e) + raise Exception + return process_image(png) + + def handle_numpy_array(array: np.ndarray) -> Image.Image: + """ + Helper function to handle a numpy array. - alpha_composite = Image.alpha_composite(background, png) + Args: + array (np.ndarray): The numpy array representing the image. - return alpha_composite + Returns: + PIL.Image.Image: The image with transparency removed. + """ + png = Image.fromarray(array).convert("RGBA") + return process_image(png) + + if isinstance(image, str): + return handle_image_path(image_path=image) + elif isinstance(image, np.ndarray): + return handle_numpy_array(array=image) + else: + raise ValueError("Input should be either a string (image path) or a numpy array.") def get_bnw_image(image): @@ -185,12 +229,12 @@ def increase_brightness(image): return image -def decode_image(image_path: str): +def decode_image(image_path: Union[str, np.ndarray]): """Loads an image and preprocesses the input image in several steps to get the image ready for DECIMER input. Args: - image_path (str): path of input image + image_path (Union[str, np.ndarray]): path of input image or numpy array representing the image. Returns: Processed image @@ -237,7 +281,7 @@ def initialize_encoder_config( backbone_fn (method): Calls Efficient-Net V2 as backbone for encoder image_shape (int): Shape of the input image do_permute (bool, optional): . Defaults to False. - pretrained_weights (keras weights, optional): Use pretrainined efficient net weights or not. Defaults to None. + pretrained_weights (keras weights, optional): Use pretrained efficient net weights or not. Defaults to None. """ self.encoder_config = dict( image_embedding_dim=image_embedding_dim, diff --git a/DECIMER/decimer.py b/DECIMER/decimer.py index e81cf95..a0831e1 100644 --- a/DECIMER/decimer.py +++ b/DECIMER/decimer.py @@ -5,6 +5,7 @@ from typing import List from typing import Tuple +import numpy as np import pystow import tensorflow as tf @@ -122,19 +123,19 @@ def detokenize_output_add_confidence( def predict_SMILES( - image_path: str, confidence: bool = False, hand_drawn: bool = False + image_input: [str, np.ndarray], confidence: bool = False, hand_drawn: bool = False ) -> str: """Predicts SMILES representation of a molecule depicted in the given image. Args: - image_path (str): Path of chemical structure depiction image - confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction - hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn + image_input (str or np.ndarray): Path of chemical structure depiction image or a numpy array representing the image. + confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction. + hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn. Returns: - str: SMILES representation of the molecule in the input image, optionally with confidence values + str: SMILES representation of the molecule in the input image, optionally with confidence values. """ - chemical_structure = config.decode_image(image_path) + chemical_structure = config.decode_image(image_input) model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2 predicted_tokens, confidence_values = model(tf.constant(chemical_structure)) diff --git a/setup.py b/setup.py index ec77ff8..084ee71 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setuptools.setup( name="decimer", - version="2.6.0", + version="2.7.0", author="Kohulan Rajan", author_email="kohulan.rajan@uni-jena.de", maintainer="Kohulan Rajan, Otto Brinkhaus ",