From 03d508b5f72c7f76c382bff49e7209c843b01b86 Mon Sep 17 00:00:00 2001 From: Kohulan Date: Tue, 5 Mar 2024 13:44:09 +0100 Subject: [PATCH] feat: improved DECIMER hand-drawn model --- DECIMER/decimer.py | 28 ++++++++++++++++++++-------- DECIMER/utils.py | 9 +++------ 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/DECIMER/decimer.py b/DECIMER/decimer.py index c6dfec8..bc50f0c 100644 --- a/DECIMER/decimer.py +++ b/DECIMER/decimer.py @@ -32,9 +32,10 @@ model_urls = { "DECIMER": "https://zenodo.org/record/8300489/files/models.zip", - "DECIMER_HandDrawn": "https://zenodo.org/records/10781330/files/DECIMER_HandDrawn_model.zip" + "DECIMER_HandDrawn": "https://zenodo.org/records/10781330/files/DECIMER_HandDrawn_model.zip", } + def get_models(model_urls: dict) -> Tuple[object, tf.saved_model, tf.saved_model]: """Download and load models from the provided URLs. @@ -54,7 +55,9 @@ def get_models(model_urls: dict) -> Tuple[object, tf.saved_model, tf.saved_model model_paths = utils.ensure_models(default_path=default_path, model_urls=model_urls) # Load tokenizers - tokenizer_path = os.path.join(model_paths["DECIMER"], "assets", "tokenizer_SMILES.pkl") + tokenizer_path = os.path.join( + model_paths["DECIMER"], "assets", "tokenizer_SMILES.pkl" + ) tokenizer = pickle.load(open(tokenizer_path, "rb")) # Load DECIMER models @@ -63,8 +66,10 @@ def get_models(model_urls: dict) -> Tuple[object, tf.saved_model, tf.saved_model return tokenizer, DECIMER_V2, DECIMER_Hand_drawn + tokenizer, DECIMER_V2, DECIMER_Hand_drawn = get_models(model_urls) + def detokenize_output(predicted_array: int) -> str: """This function takes the predited tokens from the DECIMER model and returns the decoded SMILES string. @@ -115,7 +120,10 @@ def detokenize_output_add_confidence( decoded_prediction_with_confidence.append(prediction_with_confidence_[-1]) return decoded_prediction_with_confidence -def predict_SMILES(image_path: str, confidence: bool = False, hand_drawn: bool = False) -> str: + +def predict_SMILES( + image_path: str, confidence: bool = False, hand_drawn: bool = False +) -> str: """Predicts SMILES representation of a molecule depicted in the given image. Args: @@ -127,18 +135,21 @@ def predict_SMILES(image_path: str, confidence: bool = False, hand_drawn: bool = str: SMILES representation of the molecule in the input image, optionally with confidence values """ chemical_structure = config.decode_image(image_path) - + model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2 predicted_tokens, confidence_values = model(tf.constant(chemical_structure)) - + predicted_SMILES = utils.decoder(detokenize_output(predicted_tokens)) - + if confidence: - predicted_SMILES_with_confidence = detokenize_output_add_confidence(predicted_tokens, confidence_values) + predicted_SMILES_with_confidence = detokenize_output_add_confidence( + predicted_tokens, confidence_values + ) return predicted_SMILES, predicted_SMILES_with_confidence - + return predicted_SMILES + def main(): """This function take the path of the image as user input and returns the predicted SMILES as output in CLI. @@ -155,5 +166,6 @@ def main(): SMILES = predict_SMILES(sys.argv[1]) print(SMILES) + if __name__ == "__main__": main() diff --git a/DECIMER/utils.py b/DECIMER/utils.py index 4c92500..d42fdbc 100644 --- a/DECIMER/utils.py +++ b/DECIMER/utils.py @@ -62,10 +62,8 @@ def decoder(predictions): ) return modified -def ensure_models( - default_path: str, - model_urls: dict -) -> dict: + +def ensure_models(default_path: str, model_urls: dict) -> dict: """Function to ensure models are present locally. Convenient function to ensure model downloads before usage @@ -89,9 +87,8 @@ def ensure_models( config.download_trained_weights(model_url, default_path) elif not os.path.exists(model_path): config.download_trained_weights(model_url, default_path) - + # Store the model path model_paths[model_name] = model_path return model_paths -