Skip to content

Commit

Permalink
Merge pull request #82 from Kohulan/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
Kohulan authored Jan 5, 2024

Verified

This commit was signed with the committer’s verified signature.
vbudhram Vijay Budhram
2 parents 3db6954 + 208752a commit 8183754
Showing 2 changed files with 28 additions and 14 deletions.
16 changes: 2 additions & 14 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
@@ -28,25 +28,13 @@
# Set path
default_path = pystow.join("DECIMER-V2")

# model download location
model_url = "https://zenodo.org/record/8300489/files/models.zip"
model_path = str(default_path) + "/DECIMER_model/"

# download models to a default location
if (
os.path.exists(model_path)
and os.stat(model_path + "/saved_model.pb").st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)
utils.ensure_model(default_path=default_path)

# Load important pickle files which consists the tokenizers and the maxlength setting

tokenizer = pickle.load(
open(
default_path.as_posix() + "/DECIMER_model/assets/tokenizer_SMILES.pkl",
os.path.join(default_path.as_posix(), "DECIMER_model", "assets", "tokenizer_SMILES.pkl"),
"rb",
)
)
26 changes: 26 additions & 0 deletions DECIMER/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
import os
import DECIMER.config as config

pattern = "R([0-9]*)|X([0-9]*)|Y([0-9]*)|Z([0-9]*)"
add_space_re = "^(\W+)|(\W+)$"
@@ -57,3 +59,27 @@ def decoder(predictions):
.replace("§", "0")
)
return modified

def ensure_model(
default_path: str,
model_url: str = "https://zenodo.org/record/8300489/files/models.zip"
):
"""Function to ensure model is present locally
Convenient function to ensure model download before usage
Args:
default path (str): default path for DECIMER data
model_url (str): trained model url for downloading
"""

model_path = os.path.join(default_path.as_posix(), "DECIMER_model")

if (
os.path.exists(model_path)
and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)

0 comments on commit 8183754

Please sign in to comment.