Skip to content

Commit

Permalink
fix: model input length
Browse files Browse the repository at this point in the history
  • Loading branch information
Kohulan committed Aug 13, 2024
1 parent d45654e commit 61035c5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions STOUT/repack/repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pickle
import helper
import transformer_model_4_repack as nmt_model_transformer

sys.path.append(os.path.dirname(os.path.abspath(__file__)))

print(tf.__version__)
Expand Down
6 changes: 3 additions & 3 deletions STOUT/stout.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# model download location
model_url = "https://zenodo.org/records/12542360/files/models.zip?download=1"
model_path = str(default_path) + "/translator_forward/"

print(model_path)
# download models to a default location
if not os.path.exists(model_path):
helper.download_trained_weights(model_url, default_path)
Expand Down Expand Up @@ -91,7 +91,7 @@ def load_reverse_translation_utils() -> tuple:
inp_lang = pickle.load(
open(default_path.as_posix() + "/assets/tokenizer_target.pkl", "rb")
)
inp_max_length = 1002
inp_max_length = 602
return inp_lang, targ_lang, inp_max_length


Expand Down Expand Up @@ -162,7 +162,7 @@ def translate_reverse(iupacname: str, add_confidence: bool = False) -> str:
splitted_name = helper.split_iupac(iupacname)

decoded = helper.tokenize_input(splitted_name, inp_lang, inp_max_length)
result_predited, confidence_array = reloaded_forward(decoded)
result_predited, confidence_array = reloaded_reverse(decoded)
if add_confidence:
result = helper.detokenize_output_add_confidence(
result_predited, confidence_array, targ_lang
Expand Down

0 comments on commit 61035c5

Please sign in to comment.