Skip to content

Commit

Permalink
Merge pull request #105 from alexey-krasnov/build
Browse files Browse the repository at this point in the history
feat(predict_SMILES): add support for numpy array input format.
  • Loading branch information
Kohulan authored Aug 8, 2024
2 parents 6e0e483 + 403f41f commit a344261
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 25 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# Changelog

## [2.7.0](https://github.com/Kohulan/DECIMER-Image_Transformer/compare/v2.6.0...v2.7.0) (2024-06-11)


### Features

* add entry_points ([0cc241f](https://github.com/Kohulan/DECIMER-Image_Transformer/commit/0cc241f0bb1caf28f1ff4528a5e285561a17730f))


### Bug Fixes

* one extra token prediction [#104](https://github.com/Kohulan/DECIMER-Image_Transformer/issues/104) ([6e0e483](https://github.com/Kohulan/DECIMER-Image_Transformer/commit/6e0e483285d9dbe2ed069e3a1c9e25d69dec6974))

## [2.6.0](https://github.com/Kohulan/DECIMER-Image_Transformer/compare/v2.5.0...v2.6.0) (2024-03-08)


Expand Down
80 changes: 62 additions & 18 deletions DECIMER/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from PIL import Image
from PIL import ImageEnhance
from pillow_heif import register_heif_opener
from typing import Union

import DECIMER.Efficient_Net_encoder as Efficient_Net_encoder
import DECIMER.Transformer_decoder as Transformer_decoder
Expand Down Expand Up @@ -95,26 +96,69 @@ def HEIF_to_pillow(image_path: str):
return heif_file


def remove_transparent(image_path: str):
def remove_transparent(image: Union[str, np.ndarray]) -> Image.Image:
"""
Removes the transparent layer from a PNG image with an alpha channel
Args: image_path (str): path of input image
Returns: PIL.Image
Removes the transparent layer from a PNG image with an alpha channel.
Args:
image (Union[str, np.ndarray]): Path of the input image or a numpy array representing the image.
Returns:
PIL.Image.Image: The image with transparency removed.
"""
try:
png = Image.open(image_path).convert("RGBA")
except Exception as e:
if type(e).__name__ == "UnidentifiedImageError":
png = HEIF_to_pillow(image_path)
else:
print(e)
raise Exception
def process_image(png: Image.Image) -> Image.Image:
"""
Helper function to remove transparency from a single image.
Args:
png (PIL.Image.Image): The input PIL image with transparency.
Returns:
PIL.Image.Image: The image with transparency removed.
"""
background = Image.new("RGBA", png.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, png)
return alpha_composite

background = Image.new("RGBA", png.size, (255, 255, 255))
def handle_image_path(image_path: str) -> Image.Image:
"""
Helper function to handle image paths.
Args:
image_path (str): The path to the input image.
Returns:
PIL.Image.Image: The image with transparency removed.
"""
try:
png = Image.open(image_path).convert("RGBA")
except Exception as e:
if type(e).__name__ == "UnidentifiedImageError":
png = HEIF_to_pillow(image_path)
else:
print(e)
raise Exception
return process_image(png)

def handle_numpy_array(array: np.ndarray) -> Image.Image:
"""
Helper function to handle a numpy array.
alpha_composite = Image.alpha_composite(background, png)
Args:
array (np.ndarray): The numpy array representing the image.
return alpha_composite
Returns:
PIL.Image.Image: The image with transparency removed.
"""
png = Image.fromarray(array).convert("RGBA")
return process_image(png)

if isinstance(image, str):
return handle_image_path(image_path=image)
elif isinstance(image, np.ndarray):
return handle_numpy_array(array=image)
else:
raise ValueError("Input should be either a string (image path) or a numpy array.")


def get_bnw_image(image):
Expand Down Expand Up @@ -185,12 +229,12 @@ def increase_brightness(image):
return image


def decode_image(image_path: str):
def decode_image(image_path: Union[str, np.ndarray]):
"""Loads an image and preprocesses the input image in several steps to get
the image ready for DECIMER input.
Args:
image_path (str): path of input image
image_path (Union[str, np.ndarray]): path of input image or numpy array representing the image.
Returns:
Processed image
Expand Down Expand Up @@ -237,7 +281,7 @@ def initialize_encoder_config(
backbone_fn (method): Calls Efficient-Net V2 as backbone for encoder
image_shape (int): Shape of the input image
do_permute (bool, optional): . Defaults to False.
pretrained_weights (keras weights, optional): Use pretrainined efficient net weights or not. Defaults to None.
pretrained_weights (keras weights, optional): Use pretrained efficient net weights or not. Defaults to None.
"""
self.encoder_config = dict(
image_embedding_dim=image_embedding_dim,
Expand Down
13 changes: 7 additions & 6 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List
from typing import Tuple

import numpy as np
import pystow
import tensorflow as tf

Expand Down Expand Up @@ -122,19 +123,19 @@ def detokenize_output_add_confidence(


def predict_SMILES(
image_path: str, confidence: bool = False, hand_drawn: bool = False
image_input: [str, np.ndarray], confidence: bool = False, hand_drawn: bool = False
) -> str:
"""Predicts SMILES representation of a molecule depicted in the given image.
Args:
image_path (str): Path of chemical structure depiction image
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn
image_input (str or np.ndarray): Path of chemical structure depiction image or a numpy array representing the image.
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction.
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn.
Returns:
str: SMILES representation of the molecule in the input image, optionally with confidence values
str: SMILES representation of the molecule in the input image, optionally with confidence values.
"""
chemical_structure = config.decode_image(image_path)
chemical_structure = config.decode_image(image_input)

model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2
predicted_tokens, confidence_values = model(tf.constant(chemical_structure))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setuptools.setup(
name="decimer",
version="2.6.0",
version="2.7.0",
author="Kohulan Rajan",
author_email="[email protected]",
maintainer="Kohulan Rajan, Otto Brinkhaus ",
Expand Down

0 comments on commit a344261

Please sign in to comment.