Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Syncs repos after new Trainer successful experiment, Data Augmentation and Portfolio sample #11

Merged
merged 307 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
307 commits
Select commit Hold shift + click to select a range
c4bd743
turn argument to option
Jul 18, 2023
1660ef0
revert separate preprocess command
Jul 18, 2023
4c6152c
Revert "revert separate preprocess command"
Jul 18, 2023
8a8d52f
move cache disable inside preproc
Jul 18, 2023
7e3c2d3
add timings
Jul 18, 2023
5ba41fc
Speed improvements on preprocessing
Jul 19, 2023
1628362
Adds benchmarking utility
Jul 19, 2023
669d8e2
Cleaning. Training using preprocessing output
Jul 20, 2023
9d41a3f
Cleaning. Training using preprocessing output
Jul 21, 2023
273c705
Cleaning. Training using preprocessing output
Jul 21, 2023
a774e72
Improving some steps
Jul 21, 2023
f8a10df
Improving some steps
Jul 21, 2023
3c99d25
Improving some steps
Jul 21, 2023
f5af79b
Improving some steps
Jul 21, 2023
256e07c
Improving save_to_disk
Jul 21, 2023
ee18b75
Improving save_to_disk
Jul 21, 2023
daf94fa
Removing save_to_disk
Jul 21, 2023
7e6a2b0
Adapts training to call preprocess
Jul 21, 2023
f804517
Adapts training to call preprocess
Jul 21, 2023
18fb506
Adapts training to call preprocess
Jul 21, 2023
60d05cc
Adapts training to call preprocess
Jul 21, 2023
afe04be
Allows pretraining+saving or just pretraining before train
Jul 21, 2023
6837af7
Cleaning and rearranging
Jul 24, 2023
feb0037
Updates train.sh
Jul 24, 2023
bc3cd22
Adds allMesh_2021 jsonl version
Jul 24, 2023
c383ea1
Ignores .jsonl
Jul 24, 2023
e49366b
Fixes preprocessing unit tests
Jul 24, 2023
c6c3b5b
Adapts preprocessing tests to the style of other tests
Jul 24, 2023
c8dffc2
Fixes tests in train. Adds more tests.
Jul 24, 2023
2501082
Adds some comments
Jul 24, 2023
c963136
Reformatting
josejuanmartinez Jul 24, 2023
b015f49
Reformatting
Jul 24, 2023
13822a8
Reformatting
Jul 24, 2023
3ae6143
Reformatting
Jul 24, 2023
689f323
Reformatting
Jul 24, 2023
4242884
Reformatting
Jul 24, 2023
78aebe7
Reformatting
Jul 24, 2023
55e427b
Reformatting
Jul 24, 2023
d9fd220
Reformatting
Jul 24, 2023
4db6fcf
Updates the Readme
Jul 24, 2023
d3365d0
Merge branch '3-improve-trainer' into 3-improve-trainer-juan
josejuanmartinez Jul 24, 2023
8ddb664
Reformatting
Jul 24, 2023
ac4e025
Merge pull request #12 from MantisAI/11-add-common-terms-column
Jul 26, 2023
7c60c22
Fixes training multigpu
Jul 28, 2023
f3254e9
Fixes train.sh
Jul 28, 2023
c791f8d
Merges duplicated code
Jul 28, 2023
02fcef9
Merge remote-tracking branch 'origin/3-improve-trainer-juan' into 3-i…
Jul 28, 2023
f571f58
Adds label2id to config
Jul 28, 2023
c380c62
Fixes train.sh
josejuanmartinez Jul 28, 2023
8a8aa69
Removes exit
Jul 28, 2023
5e0e8a4
Adds wandb param to train.sh
Jul 28, 2023
2b1191d
Adds report_to none for non-wandb
Jul 28, 2023
0f53dfc
Updates train.sh
josejuanmartinez Jul 28, 2023
df49e1e
Updates README.md
Jul 28, 2023
6fae3bd
Black reformatting
Jul 28, 2023
ecf213b
Modifies training
josejuanmartinez Jul 31, 2023
c31f29f
Merges train.sh
josejuanmartinez Jul 31, 2023
9fe49cd
Saving then evaluating
Jul 31, 2023
b18dd73
Updates train.sh
josejuanmartinez Jul 31, 2023
336bb15
Updates train.sh
Jul 31, 2023
0da2c00
Black reformatting
Aug 1, 2023
0ff389a
Adding training resuming
Aug 2, 2023
e791df7
Adding training resuming
Aug 2, 2023
fc90d25
Adds resume_train.sh
josejuanmartinez Aug 2, 2023
3b43ce5
Parametrization
Aug 2, 2023
303a34a
Updates documentation, script name and adds filter by years and tags …
Aug 4, 2023
5c9876e
Adding train/test split using years
Aug 4, 2023
909cee0
Adds multibatching to filtering train/test by years
Aug 4, 2023
c08da49
Updates message
Aug 4, 2023
3d73606
Changes shards to cpu_count() by default
Aug 4, 2023
07ea25d
Adds `preprocess` example
Aug 4, 2023
d74c2f6
Adds evaluation by epoch
Aug 4, 2023
4db6fb3
Adds evaluation by epoch
Aug 4, 2023
b1a6d19
Train/test split by years
Aug 4, 2023
c2d3a90
Fixes drop out and standardizes calls
Aug 5, 2023
5e04318
Fixes years
Aug 5, 2023
b1a6098
Test size frac vs row
Aug 6, 2023
384436c
Test size frac vs row
Aug 6, 2023
f19fe0f
Test size frac vs row
Aug 6, 2023
4b53208
Test size frac vs row
Aug 6, 2023
fbf941c
Test size frac vs row
Aug 6, 2023
4108f37
Test size frac vs row
Aug 6, 2023
2a95c39
Test size frac vs row
Aug 6, 2023
1097186
Logging model params
Aug 7, 2023
d69553f
Adds hidden size to train params
Aug 7, 2023
d2a5c10
Adds hidden size to train params
Aug 7, 2023
38e9026
Adds hidden size to train params
Aug 7, 2023
f5e2568
Fixes bug with number of rows as test size
Aug 7, 2023
7b1f213
Adds multilabel_attention and freeze_backbone
Aug 8, 2023
1f14e33
Adds multilabel_attention and freeze_backbone
Aug 8, 2023
51c43a8
Adds some debug info
Aug 10, 2023
188ccc1
Removes forward logs
Aug 10, 2023
08193dd
Checking last implementation of BertMesh
Aug 10, 2023
b6268da
Roll back early implementation of BertMesh
Aug 10, 2023
127ab20
Fixing bug with number of rows
Aug 11, 2023
aeaa752
Adds OpenAI augmentation
Aug 14, 2023
9b71dd3
Parallel augmentation prototype
Aug 16, 2023
a95e582
Adds batch size to augment
Aug 16, 2023
460b3da
Fixes bug
Aug 16, 2023
c61d60d
Fixes bug
Aug 16, 2023
f7c9fc0
Fixes bug
Aug 16, 2023
504075a
Fixes bug
Aug 16, 2023
8249ab9
Adds numpy
Aug 16, 2023
1a1309f
Fixes bug
Aug 16, 2023
0aee6b5
Fixes bug
Aug 16, 2023
2858c00
Fixes bug
Aug 16, 2023
612d970
Fixes bug
Aug 16, 2023
e48cc06
Args to kwargs
Aug 16, 2023
24790ae
Args to kwargs
Aug 16, 2023
c098506
Args to kwargs
Aug 16, 2023
5ef18f4
Removes frequence penalty
Aug 16, 2023
a476046
Removes frequence penalty
Aug 16, 2023
e48bfbd
Payload response fix
Aug 16, 2023
3320855
Payload response fix
Aug 16, 2023
770717e
Payload response fix
Aug 16, 2023
c35c1d3
Payload response fix
Aug 16, 2023
7d92c1c
Payload response fix
Aug 16, 2023
4285042
Fixing uuid
Aug 16, 2023
6afce7d
Adds inspiration
Aug 16, 2023
94a8310
Debugs
Aug 16, 2023
6aadf6a
Debugs
Aug 16, 2023
36e4d3e
Asks to reformat the quotes
Aug 16, 2023
f351c6f
Asks to reformat the quotes
Aug 16, 2023
f567903
Moves to a csv format
Aug 16, 2023
c484cee
Moves to a csv format
Aug 16, 2023
7f040c8
Moves to a csv format
Aug 16, 2023
6b1759e
Moves to a csv format
Aug 16, 2023
3a45e0c
Moves to a csv format
Aug 16, 2023
2dbbdcd
Json asking to escape quotes
Aug 16, 2023
3170fc9
Json asking to escape quotes
Aug 16, 2023
841477e
Json asking to escape quotes
Aug 16, 2023
e6abb7c
Json asking to escape quotes
Aug 16, 2023
6083b1b
Json asking to escape quotes
Aug 16, 2023
6bd9f53
Json asking to escape quotes
Aug 16, 2023
3555cb5
Json asking to escape quotes
Aug 16, 2023
0275938
Printing what is generating
Aug 16, 2023
d788e4e
Printing what is generating
Aug 16, 2023
345e388
Printing what is generating
Aug 16, 2023
1669cce
Printing what is generating
Aug 16, 2023
38cfa58
Printing what is generating
Aug 16, 2023
50e1cfb
Printing what is generating
Aug 16, 2023
c607037
Printing what is generating
Aug 16, 2023
ac612b6
Printing what is generating
Aug 16, 2023
ca6643c
Adds gradient clipping and cosine optimizer
Aug 17, 2023
5636545
Adds gradient clipping and cosine optimizer
Aug 17, 2023
2407546
Adds gradient clipping and cosine optimizer
Aug 17, 2023
b89e193
Adds gradient clipping and cosine optimizer
Aug 17, 2023
f2e147d
Adds gradient clipping and cosine optimizer
Aug 17, 2023
c759962
Adds gradient clipping and cosine optimizer
Aug 17, 2023
549b817
Adds gradient clipping and cosine optimizer
Aug 17, 2023
996d167
Adds gradient clipping and cosine optimizer
Aug 17, 2023
8109105
Adds gradient clipping and cosine optimizer
Aug 17, 2023
f1a5370
Adds gradient clipping and cosine optimizer
Aug 17, 2023
c2bc2ab
Adds gradient clipping and cosine optimizer
Aug 17, 2023
c7b631a
Adds gradient clipping and cosine optimizer
Aug 17, 2023
e8e8591
Freezing bias
Aug 18, 2023
00f71d6
Freezing bias
Aug 18, 2023
21c9a66
Freezing bias
Aug 18, 2023
e3f8d9e
Removes cosine scheduler
Aug 19, 2023
a041935
Removes cosine scheduler
Aug 19, 2023
1ed9060
Removes cosine scheduler
Aug 19, 2023
ae4386b
Sending metadata
Aug 19, 2023
a54bd52
Refactors
Aug 19, 2023
7135c1e
Refactors
Aug 19, 2023
f68785b
Adds sleep
Aug 19, 2023
4c06616
Adds sleep
Aug 19, 2023
72d080b
Adds sleep
Aug 19, 2023
60fa413
Adds sleep
Aug 19, 2023
53452f8
Adds sleep
Aug 19, 2023
808e251
Sends one by one
Aug 19, 2023
dd3d763
Sends one by one
Aug 19, 2023
9a1d245
Sends one by one
Aug 19, 2023
d980d8c
Sends one by one
Aug 19, 2023
5830f70
Changes from static to global
Aug 19, 2023
29f32f9
Removes param
Aug 19, 2023
82ef9e9
Removes param
Aug 19, 2023
efb2ef5
Removes param
Aug 19, 2023
9319a9d
Checks error
Aug 19, 2023
22993e1
Fixes bug with metadata field name
Aug 19, 2023
be78b3c
Adds JsonParser
Aug 19, 2023
425c12b
Adds JsonParser
Aug 19, 2023
a45a0f1
Removes sleep
Aug 19, 2023
62f76e2
Prevents locks
Aug 19, 2023
e674dd8
Write > Append
Aug 19, 2023
a81760b
Write > Append
Aug 19, 2023
780475a
Adds different schedulers
Aug 19, 2023
7a2ebe1
Parametrizes temperature
Aug 19, 2023
94809f3
Fixes schedule name bug
Aug 19, 2023
c664a0a
Fixes schedule name bug
Aug 19, 2023
0cd4d1e
Fixes schedule name bug
Aug 19, 2023
4ac7e09
Refactors and adds augment script
Aug 19, 2023
7661575
Adds 25 concurrent calls by default
Aug 19, 2023
4a1571d
Adds 25 concurrent calls by default
Aug 19, 2023
6405184
Adds 25 concurrent calls by default
Aug 19, 2023
be3b7a1
Adds 25 concurrent calls by default
Aug 19, 2023
006c864
Adds more examples
Aug 21, 2023
c0fdc5a
Adds scheduler type
Aug 22, 2023
21b0d6c
Freezes everything except weights
Aug 23, 2023
8b92dae
Changes threshold, evaluation on tags, freezing backbone
Aug 24, 2023
52f1303
Changes threshold, evaluation on tags, freezing backbone
Aug 24, 2023
6ab4876
Changes threshold, evaluation on tags, freezing backbone
Aug 24, 2023
54a754f
Changes threshold, evaluation on tags, freezing backbone
Aug 24, 2023
1fdeada
Changes threshold, evaluation on tags, freezing backbone
Aug 24, 2023
608a9a7
Changes threshold, evaluation on tags, freezing backbone
Aug 24, 2023
063caa4
Changes threshold, evaluation on tags, freezing backbone
Aug 24, 2023
9c518d7
Changes threshold, evaluation on tags, freezing backbone
Aug 24, 2023
25c1a9b
Changes threshold, evaluation on tags, freezing backbone
Aug 24, 2023
a0bc3ed
Adds filtering by tags
Aug 24, 2023
e28b683
Adds edge case for the remaining X end texts to augment
Aug 24, 2023
93d5ad5
Removes metrics from tags absent in training
Aug 24, 2023
1d4cb0f
Removes metrics from tags absent in training
Aug 24, 2023
644d651
Rolls back filtering tags in metrics
Aug 25, 2023
ad40350
Adds back tag filtering
Aug 25, 2023
6cf244b
Adds weight_decay, correct_bias, dropout probs, attention dropout...
Aug 25, 2023
f784566
Adds weight_decay, correct_bias, dropout probs, attention dropout...
Aug 25, 2023
7da5ac4
Adds best params
josejuanmartinez Aug 25, 2023
fbdecd6
Updates resume_train_by_steps.sh
Aug 26, 2023
5758b42
Updates resume_train_by_steps.sh
Aug 26, 2023
9582a6a
Adds tags-based augmentation
Aug 28, 2023
60c9fca
Adds id2label for augmentation
Aug 28, 2023
eb2d0f3
Adds `dataset` folder
Aug 28, 2023
9815027
Decodes id back into labels
Aug 28, 2023
56ef40c
Decodes id back into labels
Aug 28, 2023
d771918
Generates examples also for not-underrepresented
Aug 28, 2023
99e1067
Generates examples also for not-underrepresented
Aug 28, 2023
fa17233
Generates examples also for not-underrepresented
Aug 28, 2023
c811f8a
Adds more columns to preprocessing
Aug 28, 2023
6d58ea9
Adds more columns to preprocessing
Aug 28, 2023
c75c36a
Adds more columns to preprocessing
Aug 28, 2023
2e0474e
Adds more columns to preprocessing
Aug 28, 2023
1c173e1
Adds more columns to preprocessing
Aug 28, 2023
726837f
Prevents crashes
Aug 28, 2023
60a3ec9
Better hyperparams
Aug 28, 2023
4f70bbc
Check if fixes tests
Aug 31, 2023
8b080eb
Tries to fix torch-cpu recent issue
Aug 31, 2023
52231e0
Refactors augmentation
Sep 1, 2023
69b56ea
Refactors augmentation
Sep 1, 2023
77761be
Refactors augmentation
Sep 1, 2023
d6247db
Fixes tests
Sep 1, 2023
8298b5a
Fixes black
Sep 1, 2023
6a440fa
Fixes black
Sep 1, 2023
37c9a2b
Fixes ruff
Sep 1, 2023
3a8fd48
Fixes ruff
Sep 1, 2023
06c5479
Fixes black
Sep 1, 2023
e1fbe2b
Fixes black
Sep 1, 2023
8ae9f7a
Merge pull request #10 from MantisAI/3-improve-trainer-juan
josejuanmartinez Sep 1, 2023
c45c5c1
Merge branch 'wellcometrust:main' into main
josejuanmartinez Sep 4, 2023
35c54f8
Modify script to include active portfolio sample
Sep 6, 2023
d7dde65
Update data
Sep 6, 2023
5c9dd2c
Merge pull request #14 from MantisAI/active-portfolio-sample
nsorros Sep 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,8 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/


# Folder where training outputs are stored
bertmesh_outs/
wandb/
200 changes: 137 additions & 63 deletions README.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions data/grants_comparison/mesh_tree_letters_list.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Information Sources: L
Phenomena and Processes: G
Geographicals: Z
Diseases: C
1 change: 1 addition & 0 deletions data/raw/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/allMeSH_2021.json
/allMeSH_2021.jsonl
/desc2021.xml
/disease_tags_validation_grants.xlsx
4 changes: 4 additions & 0 deletions data/raw/allMeSH_2021.jsonl.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
outs:
- md5: 94f18c3918b180728a553123edb2ee32
size: 27914288461
path: allMeSH_2021.jsonl
3 changes: 3 additions & 0 deletions examples/augment.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [SET_YOUR_OUTPUT_FOLDER_HERE] \
--min-examples 25 \
--concurrent-calls 25
5 changes: 5 additions & 0 deletions examples/augment_specific_tags.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Augments data using a file with 1 label per line and years
grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [SET_YOUR_OUTPUT_FOLDER_HERE] \
--tags-file-path tags_to_augment.txt \
--examples 25 \
--concurrent-calls 25
37 changes: 37 additions & 0 deletions examples/preprocess_and_train_by_epochs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Run on g5.12xlarge instance

# Without saving (on-the-fly)
SOURCE="data/raw/allMeSH_2021.jsonl"

grants-tagger train bertmesh \
"" \
$SOURCE \
--test-size 25000 \
--train-years 2016,2017,2018,2019 \
--test-years 2020,2021 \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 7 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 5000 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy epochs \
--eval_accumulation_steps 20 \
--save_strategy epochs \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
39 changes: 39 additions & 0 deletions examples/preprocess_and_train_by_steps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Run on g5.12xlarge instance

# Without saving (on-the-fly)
SOURCE="data/raw/allMeSH_2021.jsonl"

grants-tagger train bertmesh \
"" \
$SOURCE \
--test-size 25000 \
--train-years 2016,2017,2018,2019 \
--test-years 2020,2021 \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 7 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 5000 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy steps \
--eval_steps 50000 \
--eval_accumulation_steps 20 \
--save_strategy steps \
--save_steps 50000 \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
2 changes: 2 additions & 0 deletions examples/preprocess_splitting_by_fract.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 0.05
2 changes: 2 additions & 0 deletions examples/preprocess_splitting_by_rows.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 25000
4 changes: 4 additions & 0 deletions examples/preprocess_splitting_by_years.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \
--test-size 25000 \
--train-years 2016,2017,2018,2019 \
--test-years 2020,2021
37 changes: 37 additions & 0 deletions examples/resume_train_by_epoch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Run on g5.12xlarge instance

# After preprocessing
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]"

# Checkpoint
CHECKPOINT="checkpoint-100000"

grants-tagger train bertmesh \
bertmesh_outs/pipeline_test/$CHECKPOINT \
$SOURCE \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 3 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 0 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy epoch \
--eval_accumulation_steps 20 \
--save_strategy epoch \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
39 changes: 39 additions & 0 deletions examples/resume_train_by_steps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Run on g5.12xlarge instance

# After preprocessing
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]"

# Checkpoint
CHECKPOINT="checkpoint-100000"

grants-tagger train bertmesh \
bertmesh_outs/pipeline_test/$CHECKPOINT \
$SOURCE \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 3 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 0 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy steps \
--eval_steps 10000 \
--eval_accumulation_steps 20 \
--save_strategy steps \
--save_steps 10000 \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
34 changes: 34 additions & 0 deletions examples/train_by_epochs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Run on g5.12xlarge instance

# After preprocessing
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]"

grants-tagger train bertmesh \
"" \
$SOURCE \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 7 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 5000 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy epoch \
--eval_accumulation_steps 20 \
--save_strategy epoch \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
36 changes: 36 additions & 0 deletions examples/train_by_steps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Run on g5.12xlarge instance

# After preprocessing
SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]"

grants-tagger train bertmesh \
"" \
$SOURCE \
--output_dir bertmesh_outs/pipeline_test/ \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 1 \
--multilabel_attention True \
--freeze_backbone unfreeze \
--num_train_epochs 7 \
--learning_rate 5e-5 \
--dropout 0.1 \
--hidden_size 1024 \
--warmup_steps 5000 \
--max_grad_norm 2.0 \
--scheduler_type cosine_hard_restart \
--weight_decay 0.2 \
--correct_bias True \
--threshold 0.25 \
--prune_labels_in_evaluation True \
--hidden_dropout_prob 0.2 \
--attention_probs_dropout_prob 0.2 \
--fp16 \
--torch_compile \
--evaluation_strategy steps \
--eval_steps 10000 \
--eval_accumulation_steps 20 \
--save_strategy steps \
--save_steps 10000 \
--wandb_project wellcome-mesh \
--wandb_name test-train-all \
--wandb_api_key ${WANDB_API_KEY}
67 changes: 67 additions & 0 deletions grants_tagger_light/augmentation/JsonParser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
From langchain: https://raw.githubusercontent.com/langchain-ai/langchain/master/libs/langchain/langchain/output_parsers/json.py
"""

import json
import re


class JsonParser:
def __init(self):
"""Class to parse json produced by LLMs. Inspiration taken from langchain.
It fixes quotes, it escapes separators, etc."""
pass

@staticmethod
def _replace_new_line(match: re.Match[str]) -> str:
value = match.group(2)
value = re.sub(r"\n", r"\\n", value)
value = re.sub(r"\r", r"\\r", value)
value = re.sub(r"\t", r"\\t", value)
value = re.sub('"', r"\"", value)

return match.group(1) + value + match.group(3)

@staticmethod
def _custom_parser(multiline_string: str) -> str:
"""
The LLM response for `action_input` may be a multiline
string containing unescaped newlines, tabs or quotes. This function
replaces those characters with their escaped counterparts.
(newlines in JSON must be double-escaped: `\\n`)
"""
if isinstance(multiline_string, (bytes, bytearray)):
multiline_string = multiline_string.decode()

multiline_string = re.sub(
r'("action_input"\:\s*")(.*)(")',
JsonParser._replace_new_line,
multiline_string,
flags=re.DOTALL,
)

return multiline_string

@staticmethod
def parse_json(json_string: str) -> dict:
"""
Parse a JSON string from LLM response

Args:
json_string: The Markdown string.

Returns:
The parsed JSON object as a Python dictionary.
"""
json_str = json_string

# Strip whitespace and newlines from the start and end
json_str = json_str.strip()

# handle newlines and other special characters inside the returned value
json_str = JsonParser._custom_parser(json_str)

# Parse the JSON string into a Python dictionary
parsed = json.loads(json_str)

return parsed
8 changes: 8 additions & 0 deletions grants_tagger_light/augmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import typer
from .augment import augment_cli

augment_app = typer.Typer()
augment_app.command(
"mesh",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)(augment_cli)
Loading