Skip to content

Commit

Permalink
Merge pull request #11 from MantisAI/main
Browse files Browse the repository at this point in the history
Syncs repos after new Trainer successful experiment, Data Augmentation and Portfolio sample
  • Loading branch information
nsorros authored Sep 6, 2023
2 parents 09a3b44 + 5c9dd2c commit 5b633ae
Show file tree
Hide file tree
Showing 56 changed files with 2,942 additions and 896 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,8 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/


# Folder where training outputs are stored
bertmesh_outs/
wandb/
200 changes: 137 additions & 63 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion data/grants_comparison/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
/comparison.csv
/meshterms_list.txt
/comparison.csv
4 changes: 0 additions & 4 deletions data/grants_comparison/comparison.csv.dvc

This file was deleted.

4 changes: 4 additions & 0 deletions data/grants_comparison/mesh_tree_letters_list.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Information Sources: L
Phenomena and Processes: G
Geographicals: Z
Diseases: C
2 changes: 2 additions & 0 deletions data/raw/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
/allMeSH_2021.json
/allMeSH_2021.jsonl
/desc2021.xml
/disease_tags_validation_grants.xlsx
/active_grants_last_5_years.csv
4 changes: 4 additions & 0 deletions data/raw/active_grants_last_5_years.csv.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
outs:
- md5: d664be2a9000d44bb0325f364ec20e27
size: 4953477
path: active_grants_last_5_years.csv
4 changes: 4 additions & 0 deletions data/raw/allMeSH_2021.jsonl.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
outs:
- md5: 94f18c3918b180728a553123edb2ee32
size: 27914288461
path: allMeSH_2021.jsonl
3 changes: 3 additions & 0 deletions examples/augment.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [SET_YOUR_OUTPUT_FOLDER_HERE] \
--min-examples 25 \
--concurrent-calls 25
5 changes: 5 additions & 0 deletions examples/augment_specific_tags.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Augments data using a file with 1 label per line and years
grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [SET_YOUR_OUTPUT_FOLDER_HERE] \
--tags-file-path tags_to_augment.txt \
--examples 25 \
--concurrent-calls 25
37 changes: 37 additions & 0 deletions examples/preprocess_and_train_by_epochs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Run on g5.12xlarge instance

# Without saving (on-the-fly)
SOURCE="data/raw/allMeSH_2021.jsonl"

grants-tagger train bertmesh \
"" \
$SOURCE \
--test-size 25000 \
--train-years 2016,2017,2018,2019 \
--test-years 2020,2021 \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 7 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 5000 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy epochs \
--eval_accumulation_steps 20 \
--save_strategy epochs \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
39 changes: 39 additions & 0 deletions examples/preprocess_and_train_by_steps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Run on g5.12xlarge instance

# Without saving (on-the-fly)
SOURCE="data/raw/allMeSH_2021.jsonl"

grants-tagger train bertmesh \
"" \
$SOURCE \
--test-size 25000 \
--train-years 2016,2017,2018,2019 \
--test-years 2020,2021 \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 7 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 5000 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy steps \
--eval_steps 50000 \
--eval_accumulation_steps 20 \
--save_strategy steps \
--save_steps 50000 \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
2 changes: 2 additions & 0 deletions examples/preprocess_splitting_by_fract.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 0.05
2 changes: 2 additions & 0 deletions examples/preprocess_splitting_by_rows.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 25000
4 changes: 4 additions & 0 deletions examples/preprocess_splitting_by_years.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 25000 \
--train-years 2016,2017,2018,2019 \
--test-years 2020,2021
37 changes: 37 additions & 0 deletions examples/resume_train_by_epoch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Run on g5.12xlarge instance

# After preprocessing
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]"

# Checkpoint
CHECKPOINT="checkpoint-100000"

grants-tagger train bertmesh \
bertmesh_outs/pipeline_test/$CHECKPOINT \
$SOURCE \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 3 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 0 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy epoch \
--eval_accumulation_steps 20 \
--save_strategy epoch \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
39 changes: 39 additions & 0 deletions examples/resume_train_by_steps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Run on g5.12xlarge instance

# After preprocessing
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]"

# Checkpoint
CHECKPOINT="checkpoint-100000"

grants-tagger train bertmesh \
bertmesh_outs/pipeline_test/$CHECKPOINT \
$SOURCE \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 3 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 0 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy steps \
--eval_steps 10000 \
--eval_accumulation_steps 20 \
--save_strategy steps \
--save_steps 10000 \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
34 changes: 34 additions & 0 deletions examples/train_by_epochs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Run on g5.12xlarge instance

# After preprocessing
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]"

grants-tagger train bertmesh \
"" \
$SOURCE \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 7 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 5000 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy epoch \
--eval_accumulation_steps 20 \
--save_strategy epoch \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
36 changes: 36 additions & 0 deletions examples/train_by_steps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Run on g5.12xlarge instance

# After preprocessing
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]"

grants-tagger train bertmesh \
"" \
$SOURCE \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 7 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 5000 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy steps \
--eval_steps 10000 \
--eval_accumulation_steps 20 \
--save_strategy steps \
--save_steps 10000 \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
67 changes: 67 additions & 0 deletions grants_tagger_light/augmentation/JsonParser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
From langchain: https://raw.githubusercontent.com/langchain-ai/langchain/master/libs/langchain/langchain/output_parsers/json.py
"""

import json
import re


class JsonParser:
def __init(self):
"""Class to parse json produced by LLMs. Inspiration taken from langchain.
It fixes quotes, it escapes separators, etc."""
pass

@staticmethod
def _replace_new_line(match: re.Match[str]) -> str:
value = match.group(2)
value = re.sub(r"\n", r"\\n", value)
value = re.sub(r"\r", r"\\r", value)
value = re.sub(r"\t", r"\\t", value)
value = re.sub('"', r"\"", value)

return match.group(1) + value + match.group(3)

@staticmethod
def _custom_parser(multiline_string: str) -> str:
"""
The LLM response for `action_input` may be a multiline
string containing unescaped newlines, tabs or quotes. This function
replaces those characters with their escaped counterparts.
(newlines in JSON must be double-escaped: `\\n`)
"""
if isinstance(multiline_string, (bytes, bytearray)):
multiline_string = multiline_string.decode()

multiline_string = re.sub(
r'("action_input"\:\s*")(.*)(")',
JsonParser._replace_new_line,
multiline_string,
flags=re.DOTALL,
)

return multiline_string

@staticmethod
def parse_json(json_string: str) -> dict:
"""
Parse a JSON string from LLM response
Args:
json_string: The Markdown string.
Returns:
The parsed JSON object as a Python dictionary.
"""
json_str = json_string

# Strip whitespace and newlines from the start and end
json_str = json_str.strip()

# handle newlines and other special characters inside the returned value
json_str = JsonParser._custom_parser(json_str)

# Parse the JSON string into a Python dictionary
parsed = json.loads(json_str)

return parsed
8 changes: 8 additions & 0 deletions grants_tagger_light/augmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import typer
from .augment import augment_cli

augment_app = typer.Typer()
augment_app.command(
"mesh",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)(augment_cli)
Loading

0 comments on commit 5b633ae

Please sign in to comment.