-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from MantisAI/main
Syncs repos after new Trainer successful experiment, Data Augmentation and Portfolio sample
- Loading branch information
Showing
56 changed files
with
2,942 additions
and
896 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
/comparison.csv | ||
/meshterms_list.txt | ||
/comparison.csv |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Information Sources: L | ||
Phenomena and Processes: G | ||
Geographicals: Z | ||
Diseases: C |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
/allMeSH_2021.json | ||
/allMeSH_2021.jsonl | ||
/desc2021.xml | ||
/disease_tags_validation_grants.xlsx | ||
/active_grants_last_5_years.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
outs: | ||
- md5: d664be2a9000d44bb0325f364ec20e27 | ||
size: 4953477 | ||
path: active_grants_last_5_years.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
outs: | ||
- md5: 94f18c3918b180728a553123edb2ee32 | ||
size: 27914288461 | ||
path: allMeSH_2021.jsonl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [SET_YOUR_OUTPUT_FOLDER_HERE] \ | ||
--min-examples 25 \ | ||
--concurrent-calls 25 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Augments data using a file with 1 label per line and years | ||
grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [SET_YOUR_OUTPUT_FOLDER_HERE] \ | ||
--tags-file-path tags_to_augment.txt \ | ||
--examples 25 \ | ||
--concurrent-calls 25 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Run on g5.12xlarge instance | ||
|
||
# Without saving (on-the-fly) | ||
SOURCE="data/raw/allMeSH_2021.jsonl" | ||
|
||
grants-tagger train bertmesh \ | ||
"" \ | ||
$SOURCE \ | ||
--test-size 25000 \ | ||
--train-years 2016,2017,2018,2019 \ | ||
--test-years 2020,2021 \ | ||
--output_dir bertmesh_outs/pipeline_test/ \ | ||
--per_device_train_batch_size 16 \ | ||
--per_device_eval_batch_size 1 \ | ||
--multilabel_attention True \ | ||
--freeze_backbone unfreeze \ | ||
--num_train_epochs 7 \ | ||
--learning_rate 5e-5 \ | ||
--dropout 0.1 \ | ||
--hidden_size 1024 \ | ||
--warmup_steps 5000 \ | ||
--max_grad_norm 2.0 \ | ||
--scheduler_type cosine_hard_restart \ | ||
--weight_decay 0.2 \ | ||
--correct_bias True \ | ||
--threshold 0.25 \ | ||
--prune_labels_in_evaluation True \ | ||
--hidden_dropout_prob 0.2 \ | ||
--attention_probs_dropout_prob 0.2 \ | ||
--fp16 \ | ||
--torch_compile \ | ||
--evaluation_strategy epochs \ | ||
--eval_accumulation_steps 20 \ | ||
--save_strategy epochs \ | ||
--wandb_project wellcome-mesh \ | ||
--wandb_name test-train-all \ | ||
--wandb_api_key ${WANDB_API_KEY} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Run on g5.12xlarge instance | ||
|
||
# Without saving (on-the-fly) | ||
SOURCE="data/raw/allMeSH_2021.jsonl" | ||
|
||
grants-tagger train bertmesh \ | ||
"" \ | ||
$SOURCE \ | ||
--test-size 25000 \ | ||
--train-years 2016,2017,2018,2019 \ | ||
--test-years 2020,2021 \ | ||
--output_dir bertmesh_outs/pipeline_test/ \ | ||
--per_device_train_batch_size 16 \ | ||
--per_device_eval_batch_size 1 \ | ||
--multilabel_attention True \ | ||
--freeze_backbone unfreeze \ | ||
--num_train_epochs 7 \ | ||
--learning_rate 5e-5 \ | ||
--dropout 0.1 \ | ||
--hidden_size 1024 \ | ||
--warmup_steps 5000 \ | ||
--max_grad_norm 2.0 \ | ||
--scheduler_type cosine_hard_restart \ | ||
--weight_decay 0.2 \ | ||
--correct_bias True \ | ||
--threshold 0.25 \ | ||
--prune_labels_in_evaluation True \ | ||
--hidden_dropout_prob 0.2 \ | ||
--attention_probs_dropout_prob 0.2 \ | ||
--fp16 \ | ||
--torch_compile \ | ||
--evaluation_strategy steps \ | ||
--eval_steps 50000 \ | ||
--eval_accumulation_steps 20 \ | ||
--save_strategy steps \ | ||
--save_steps 50000 \ | ||
--wandb_project wellcome-mesh \ | ||
--wandb_name test-train-all \ | ||
--wandb_api_key ${WANDB_API_KEY} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \ | ||
--test-size 0.05 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \ | ||
--test-size 25000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \ | ||
--test-size 25000 \ | ||
--train-years 2016,2017,2018,2019 \ | ||
--test-years 2020,2021 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Run on g5.12xlarge instance | ||
|
||
# After preprocessing | ||
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]" | ||
|
||
# Checkpoint | ||
CHECKPOINT="checkpoint-100000" | ||
|
||
grants-tagger train bertmesh \ | ||
bertmesh_outs/pipeline_test/$CHECKPOINT \ | ||
$SOURCE \ | ||
--output_dir bertmesh_outs/pipeline_test/ \ | ||
--per_device_train_batch_size 16 \ | ||
--per_device_eval_batch_size 1 \ | ||
--multilabel_attention True \ | ||
--freeze_backbone unfreeze \ | ||
--num_train_epochs 3 \ | ||
--learning_rate 5e-5 \ | ||
--dropout 0.1 \ | ||
--hidden_size 1024 \ | ||
--warmup_steps 0 \ | ||
--max_grad_norm 2.0 \ | ||
--scheduler_type cosine_hard_restart \ | ||
--weight_decay 0.2 \ | ||
--correct_bias True \ | ||
--threshold 0.25 \ | ||
--prune_labels_in_evaluation True \ | ||
--hidden_dropout_prob 0.2 \ | ||
--attention_probs_dropout_prob 0.2 \ | ||
--fp16 \ | ||
--torch_compile \ | ||
--evaluation_strategy epoch \ | ||
--eval_accumulation_steps 20 \ | ||
--save_strategy epoch \ | ||
--wandb_project wellcome-mesh \ | ||
--wandb_name test-train-all \ | ||
--wandb_api_key ${WANDB_API_KEY} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Run on g5.12xlarge instance | ||
|
||
# After preprocessing | ||
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]" | ||
|
||
# Checkpoint | ||
CHECKPOINT="checkpoint-100000" | ||
|
||
grants-tagger train bertmesh \ | ||
bertmesh_outs/pipeline_test/$CHECKPOINT \ | ||
$SOURCE \ | ||
--output_dir bertmesh_outs/pipeline_test/ \ | ||
--per_device_train_batch_size 16 \ | ||
--per_device_eval_batch_size 1 \ | ||
--multilabel_attention True \ | ||
--freeze_backbone unfreeze \ | ||
--num_train_epochs 3 \ | ||
--learning_rate 5e-5 \ | ||
--dropout 0.1 \ | ||
--hidden_size 1024 \ | ||
--warmup_steps 0 \ | ||
--max_grad_norm 2.0 \ | ||
--scheduler_type cosine_hard_restart \ | ||
--weight_decay 0.2 \ | ||
--correct_bias True \ | ||
--threshold 0.25 \ | ||
--prune_labels_in_evaluation True \ | ||
--hidden_dropout_prob 0.2 \ | ||
--attention_probs_dropout_prob 0.2 \ | ||
--fp16 \ | ||
--torch_compile \ | ||
--evaluation_strategy steps \ | ||
--eval_steps 10000 \ | ||
--eval_accumulation_steps 20 \ | ||
--save_strategy steps \ | ||
--save_steps 10000 \ | ||
--wandb_project wellcome-mesh \ | ||
--wandb_name test-train-all \ | ||
--wandb_api_key ${WANDB_API_KEY} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Run on g5.12xlarge instance | ||
|
||
# After preprocessing | ||
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]" | ||
|
||
grants-tagger train bertmesh \ | ||
"" \ | ||
$SOURCE \ | ||
--output_dir bertmesh_outs/pipeline_test/ \ | ||
--per_device_train_batch_size 16 \ | ||
--per_device_eval_batch_size 1 \ | ||
--multilabel_attention True \ | ||
--freeze_backbone unfreeze \ | ||
--num_train_epochs 7 \ | ||
--learning_rate 5e-5 \ | ||
--dropout 0.1 \ | ||
--hidden_size 1024 \ | ||
--warmup_steps 5000 \ | ||
--max_grad_norm 2.0 \ | ||
--scheduler_type cosine_hard_restart \ | ||
--weight_decay 0.2 \ | ||
--correct_bias True \ | ||
--threshold 0.25 \ | ||
--prune_labels_in_evaluation True \ | ||
--hidden_dropout_prob 0.2 \ | ||
--attention_probs_dropout_prob 0.2 \ | ||
--fp16 \ | ||
--torch_compile \ | ||
--evaluation_strategy epoch \ | ||
--eval_accumulation_steps 20 \ | ||
--save_strategy epoch \ | ||
--wandb_project wellcome-mesh \ | ||
--wandb_name test-train-all \ | ||
--wandb_api_key ${WANDB_API_KEY} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Run on g5.12xlarge instance | ||
|
||
# After preprocessing | ||
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]" | ||
|
||
grants-tagger train bertmesh \ | ||
"" \ | ||
$SOURCE \ | ||
--output_dir bertmesh_outs/pipeline_test/ \ | ||
--per_device_train_batch_size 16 \ | ||
--per_device_eval_batch_size 1 \ | ||
--multilabel_attention True \ | ||
--freeze_backbone unfreeze \ | ||
--num_train_epochs 7 \ | ||
--learning_rate 5e-5 \ | ||
--dropout 0.1 \ | ||
--hidden_size 1024 \ | ||
--warmup_steps 5000 \ | ||
--max_grad_norm 2.0 \ | ||
--scheduler_type cosine_hard_restart \ | ||
--weight_decay 0.2 \ | ||
--correct_bias True \ | ||
--threshold 0.25 \ | ||
--prune_labels_in_evaluation True \ | ||
--hidden_dropout_prob 0.2 \ | ||
--attention_probs_dropout_prob 0.2 \ | ||
--fp16 \ | ||
--torch_compile \ | ||
--evaluation_strategy steps \ | ||
--eval_steps 10000 \ | ||
--eval_accumulation_steps 20 \ | ||
--save_strategy steps \ | ||
--save_steps 10000 \ | ||
--wandb_project wellcome-mesh \ | ||
--wandb_name test-train-all \ | ||
--wandb_api_key ${WANDB_API_KEY} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
""" | ||
From langchain: https://raw.githubusercontent.com/langchain-ai/langchain/master/libs/langchain/langchain/output_parsers/json.py | ||
""" | ||
|
||
import json | ||
import re | ||
|
||
|
||
class JsonParser: | ||
def __init(self): | ||
"""Class to parse json produced by LLMs. Inspiration taken from langchain. | ||
It fixes quotes, it escapes separators, etc.""" | ||
pass | ||
|
||
@staticmethod | ||
def _replace_new_line(match: re.Match[str]) -> str: | ||
value = match.group(2) | ||
value = re.sub(r"\n", r"\\n", value) | ||
value = re.sub(r"\r", r"\\r", value) | ||
value = re.sub(r"\t", r"\\t", value) | ||
value = re.sub('"', r"\"", value) | ||
|
||
return match.group(1) + value + match.group(3) | ||
|
||
@staticmethod | ||
def _custom_parser(multiline_string: str) -> str: | ||
""" | ||
The LLM response for `action_input` may be a multiline | ||
string containing unescaped newlines, tabs or quotes. This function | ||
replaces those characters with their escaped counterparts. | ||
(newlines in JSON must be double-escaped: `\\n`) | ||
""" | ||
if isinstance(multiline_string, (bytes, bytearray)): | ||
multiline_string = multiline_string.decode() | ||
|
||
multiline_string = re.sub( | ||
r'("action_input"\:\s*")(.*)(")', | ||
JsonParser._replace_new_line, | ||
multiline_string, | ||
flags=re.DOTALL, | ||
) | ||
|
||
return multiline_string | ||
|
||
@staticmethod | ||
def parse_json(json_string: str) -> dict: | ||
""" | ||
Parse a JSON string from LLM response | ||
Args: | ||
json_string: The Markdown string. | ||
Returns: | ||
The parsed JSON object as a Python dictionary. | ||
""" | ||
json_str = json_string | ||
|
||
# Strip whitespace and newlines from the start and end | ||
json_str = json_str.strip() | ||
|
||
# handle newlines and other special characters inside the returned value | ||
json_str = JsonParser._custom_parser(json_str) | ||
|
||
# Parse the JSON string into a Python dictionary | ||
parsed = json.loads(json_str) | ||
|
||
return parsed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import typer | ||
from .augment import augment_cli | ||
|
||
augment_app = typer.Typer() | ||
augment_app.command( | ||
"mesh", | ||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, | ||
)(augment_cli) |
Oops, something went wrong.