From 2797fcb37ec1f68f2aa7c9ef982b6a65b966c030 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 26 Jun 2023 17:46:22 +0300 Subject: [PATCH 001/300] remove upload stage (not needed anymore) --- pipelines/generate_grants/dvc.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pipelines/generate_grants/dvc.yaml b/pipelines/generate_grants/dvc.yaml index 8718757d..6b2a09dd 100644 --- a/pipelines/generate_grants/dvc.yaml +++ b/pipelines/generate_grants/dvc.yaml @@ -7,7 +7,3 @@ stages: cmd: python ${scripts_location}/create_grants_sample.py --s3-url ${s3-url} --num-parquet-files-to-consider 10 --num-samples-per-cat 10 --pre-annotate True outs: - grants_sample.jsonl - upload: - cmd: python ${scripts_location}/upload_grants_data_to_argilla.py --path grants_sample.jsonl --project ${argilla_project_name} - deps: - - grants_sample.jsonl From 19847be58ac4e94cd595c140083c13709e442749 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 13:00:22 +0300 Subject: [PATCH 002/300] add lockfile for grants sample --- pipelines/generate_grants/.gitignore | 1 + pipelines/generate_grants/dvc.lock | 9 +++++++++ 2 files changed, 10 insertions(+) create mode 100644 pipelines/generate_grants/.gitignore create mode 100644 pipelines/generate_grants/dvc.lock diff --git a/pipelines/generate_grants/.gitignore b/pipelines/generate_grants/.gitignore new file mode 100644 index 00000000..df930645 --- /dev/null +++ b/pipelines/generate_grants/.gitignore @@ -0,0 +1 @@ +/grants_sample.jsonl diff --git a/pipelines/generate_grants/dvc.lock b/pipelines/generate_grants/dvc.lock new file mode 100644 index 00000000..db87cd92 --- /dev/null +++ b/pipelines/generate_grants/dvc.lock @@ -0,0 +1,9 @@ +schema: '2.0' +stages: + generate: + cmd: python ../../scripts/create_grants_sample.py --s3-url s3://datalabs-data/dimensions/grants/grants + --num-parquet-files-to-consider 10 --num-samples-per-cat 10 --pre-annotate True + outs: + - path: grants_sample.jsonl + md5: 76bbfd9043e20866382ff9713cba7483 + size: 387951 From 56ebad3294bbf92e45350ffa2b6625fcba287a9a Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 13:31:31 +0300 Subject: [PATCH 003/300] add documentation, move dataloader to own folder --- grants_tagger_light/training/dataloaders.py | 88 +++++++++++++++++++++ grants_tagger_light/training/train.py | 77 +++--------------- 2 files changed, 97 insertions(+), 68 deletions(-) create mode 100644 grants_tagger_light/training/dataloaders.py diff --git a/grants_tagger_light/training/dataloaders.py b/grants_tagger_light/training/dataloaders.py new file mode 100644 index 00000000..d1aef374 --- /dev/null +++ b/grants_tagger_light/training/dataloaders.py @@ -0,0 +1,88 @@ +import json +from transformers import AutoTokenizer +from datasets import Dataset + + +def load_grants_sample( + data_path: str, + tokenizer: AutoTokenizer, + label2id: dict, + test_size: float = 0.1, + num_proc: int = 1, +): + """ + Code that loads a grants sample. + The data should be a jsonl file where each line contains an abstract + and mesh_terms field. + The dvc pipeline in pipelines/generate_grants can be used for this. + It will populate the mesh_terms field with predictions + from Wellcome/WellcomeBertMesh. This can be used to generate a + dummy dataset (i.e. train the model on its own predictions + for development / sanity check purposes). + """ + + def _datagen(data_path: str): + """ + Loads the data from the given path. The data should be in jsonl format, + with each line containing a text and tags field. + The tags field should be a list of strings. + """ + with open(data_path, "r") as f: + for line in f: + sample = json.loads(line) + yield sample + + def _tokenize(batch): + return tokenizer( + batch["abstract"], + padding="max_length", + truncation=True, + max_length=512, + ) + + def _label_encode(batch): + import pdb + + pdb.set_trace() + batch["labels"] = [ + [label2id[tag] for tag in tags[0] if tag in label2id] + for tags in batch["mesh_terms"] + ] + return batch + + def _one_hot(batch): + batch["labels"] = [ + [1 if i in labels else 0 for i in range(len(label2id))] + for labels in batch["labels"] + ] + return batch + + dset = Dataset.from_generator(_datagen, gen_kwargs={"data_path": data_path}) + dset = dset.map( + _tokenize, + batched=True, + batch_size=32, + num_proc=num_proc, + desc="Tokenizing", + ) + + dset = dset.map( + _label_encode, + batched=True, + batch_size=32, + num_proc=num_proc, + desc="Encoding labels", + ) + + dset = dset.map( + _one_hot, + batched=True, + batch_size=32, + num_proc=num_proc, + desc="One-hot labels", + ) + + # Split into train and test + dset = dset.train_test_split(test_size=test_size) + + return dset["train"], dset["test"] diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index f88d54c2..59c739a3 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -1,82 +1,23 @@ -from transformers import AutoTokenizer, Trainer, TrainingArguments, EvalPrediction -from datasets import Dataset +from transformers import ( + AutoTokenizer, + Trainer, + TrainingArguments, + EvalPrediction, +) from grants_tagger_light.models.bert_mesh import BertMesh -import json +from grants_tagger_light.training.dataloaders import load_grants_sample import typer import numpy as np from sklearn.metrics import classification_report -def load_data( - data_path: str, - tokenizer: AutoTokenizer, - label2id: dict, - test_size: float = 0.1, - num_proc: int = 8, -): - def _datagen(data_path: str): - """ - Loads the data from the given path. The data should be in jsonl format, - with each line containing a text and tags field. - The tags field should be a list of strings. - """ - with open(data_path, "r") as f: - for line in f: - sample = json.loads(line) - yield sample - - def _tokenize(batch): - return tokenizer( - batch["text"], padding="max_length", truncation=True, max_length=512 - ) - - def _label_encode(batch): - batch["labels"] = [ - [label2id[tag] for tag in tags if tag in label2id] for tags in batch["tags"] - ] - return batch - - def _one_hot(batch): - batch["labels"] = [ - [1 if i in labels else 0 for i in range(len(label2id))] - for labels in batch["labels"] - ] - return batch - - dset = Dataset.from_generator(_datagen, gen_kwargs={"data_path": data_path}) - dset = dset.map( - _tokenize, batched=True, batch_size=32, num_proc=num_proc, desc="Tokenizing" - ) - - dset = dset.map( - _label_encode, - batched=True, - batch_size=32, - num_proc=num_proc, - desc="Encoding labels", - ) - - dset = dset.map( - _one_hot, - batched=True, - batch_size=32, - num_proc=num_proc, - desc="One-hot labels", - ) - - # Split into train and test - dset = dset.train_test_split(test_size=test_size) - - return dset["train"], dset["test"] - - def train_bertmesh(model_key: str, data_path: str, **user_args): model = BertMesh.from_pretrained(model_key, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_key) label2id = {v: k for k, v in model.id2label.items()} - train_dset, val_dset = load_data(data_path, tokenizer, label2id=label2id) + train_dset, val_dset = load_grants_sample(data_path, tokenizer, label2id=label2id) training_args = { "output_dir": "model_output", @@ -148,7 +89,7 @@ def train_bertmesh_cli( ), model_save_path: str = typer.Argument(..., help="Path to save model to"), ): - train_bertmesh(model_key, data_path, model_save_path) + train_bertmesh(model_key, data_path) if __name__ == "__main__": From efb1f6ce16fabba0db647a334e474bf680d3a512 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 15:52:47 +0300 Subject: [PATCH 004/300] update lockfile --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index c464b22f..4db467c9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3893,4 +3893,4 @@ test = ["zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "405d67996eae5299b1ca3e60712389ce8b793fc64d74d936b28d747e20315ce5" +content-hash = "b5e71835a8b1d87d8ab7d899ae5da9af9e856fbb65f3be51c4ce1729e83bf04f" From c1e6034a792d1d3ba48773fe6d22b87154288c4e Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 16:01:56 +0300 Subject: [PATCH 005/300] add arg system for trainer --- grants_tagger_light/training/__init__.py | 6 ++- grants_tagger_light/training/train.py | 32 ++++-------- grants_tagger_light/training/train_args.py | 60 ++++++++++++++++++++++ 3 files changed, 73 insertions(+), 25 deletions(-) create mode 100644 grants_tagger_light/training/train_args.py diff --git a/grants_tagger_light/training/__init__.py b/grants_tagger_light/training/__init__.py index b56c3c2a..e47c1bd0 100644 --- a/grants_tagger_light/training/__init__.py +++ b/grants_tagger_light/training/__init__.py @@ -1,6 +1,8 @@ import typer - from .train import train_bertmesh_cli train_app = typer.Typer() -train_app.command("bertmesh")(train_bertmesh_cli) +train_app.command( + "bertmesh", + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, +)(train_bertmesh_cli) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 59c739a3..582c5b7c 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -3,15 +3,17 @@ Trainer, TrainingArguments, EvalPrediction, + HfArgumentParser, ) from grants_tagger_light.models.bert_mesh import BertMesh +from grants_tagger_light.training.train_args import BertMeshTrainingArguments from grants_tagger_light.training.dataloaders import load_grants_sample +from sklearn.metrics import classification_report import typer import numpy as np -from sklearn.metrics import classification_report -def train_bertmesh(model_key: str, data_path: str, **user_args): +def train_bertmesh(model_key: str, data_path: str, training_args: TrainingArguments): model = BertMesh.from_pretrained(model_key, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_key) @@ -19,25 +21,6 @@ def train_bertmesh(model_key: str, data_path: str, **user_args): train_dset, val_dset = load_grants_sample(data_path, tokenizer, label2id=label2id) - training_args = { - "output_dir": "model_output", - "overwrite_output_dir": True, - "num_train_epochs": 1, - "per_device_train_batch_size": 4, - "per_device_eval_batch_size": 4, - "warmup_steps": 500, - "weight_decay": 0.01, - "learning_rate": 1e-5, - "evaluation_strategy": "steps", - "eval_steps": 100, - "do_eval": True, - "label_names": ["labels"], - } - - training_args.update(user_args) - - training_args = TrainingArguments(**training_args) - def sklearn_metrics(prediction: EvalPrediction): y_pred = prediction.predictions y_true = prediction.label_ids @@ -80,6 +63,7 @@ def sklearn_metrics(prediction: EvalPrediction): @train_app.command() def train_bertmesh_cli( + ctx: typer.Context, model_key: str = typer.Argument( ..., help="Pretrained model key. Local path or HF location" ), @@ -87,9 +71,11 @@ def train_bertmesh_cli( ..., help="Path to data in jsonl format. Must contain text and tags field", ), - model_save_path: str = typer.Argument(..., help="Path to save model to"), ): - train_bertmesh(model_key, data_path) + parser = HfArgumentParser((BertMeshTrainingArguments,)) + (training_args,) = parser.parse_args_into_dataclasses(ctx.args) + + train_bertmesh(model_key, data_path, training_args) if __name__ == "__main__": diff --git a/grants_tagger_light/training/train_args.py b/grants_tagger_light/training/train_args.py new file mode 100644 index 00000000..acb075fb --- /dev/null +++ b/grants_tagger_light/training/train_args.py @@ -0,0 +1,60 @@ +from transformers import TrainingArguments +from dataclasses import dataclass, field + + +@dataclass +class BertMeshTrainingArguments(TrainingArguments): + """ + This class inherits from transformers.TrainingArguments + and implements some better defaults for convenience. + """ + + output_dir: str = field(default="bertmesh-outs/default") + overwrite_output_dir: bool = field(default=True) + + evaluation_strategy: str = field(default="epoch") # no | epoch | steps + # eval_steps: int = 1 + + save_strategy: str = field(default="epoch") # no | epoch | steps + # save_steps: int = 1 + save_total_limit: int = field(default=5) + + metric_for_best_model: str = field( + default="eval_loss" + ) # can switch later to micro-f1 + greater_is_better: bool = field(default=False) + load_best_model_at_end: bool = field(default=True) + + per_device_train_batch_size: int = field( + default=8 + ) # set to 256 in grants-tagger repo + per_device_eval_batch_size: int = field(default=16) + gradient_accumulation_steps: int = field(default=1) + group_by_length: bool = field(default=False) # TODO test this + + # Learning rate & num train epochs are taken from grants_tagger repo + # (BertMeSH paper does not specify these hparams) + num_train_epochs: int = field(default=5) + learning_rate: float = field(default=1e-4) + + seed: int = field(default=42) + data_seed: int = field(default=42) + + optim: str = field( + default="adamw_torch_fused" + ) # TODO add support for adamw_apex_fused + + fp16: bool = field(default=False) # TODO test if micro-f1 is maintained + + dataloader_num_workers: int = field(default=8) + dataloader_pin_memory: bool = field(default=True) + + gradient_checkpointing: bool = field(default=False) + + auto_find_batch_size: bool = field(default=False) # TODO test this + + torch_compile: bool = field(default=False) # TODO make compilation + torch_compile_backend: str = field(default="inductor") + torch_compile_mode: str = field( + default="default" + ) # default | reduce-overhead | max-autotune From 7e1c5f57374d3521c35f0bf5e7b49fe593221ce4 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 16:04:32 +0300 Subject: [PATCH 006/300] turn compilation off by default (for now), reduce batch size --- grants_tagger_light/training/train_args.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/training/train_args.py b/grants_tagger_light/training/train_args.py index acb075fb..405bf618 100644 --- a/grants_tagger_light/training/train_args.py +++ b/grants_tagger_light/training/train_args.py @@ -28,7 +28,7 @@ class BertMeshTrainingArguments(TrainingArguments): per_device_train_batch_size: int = field( default=8 ) # set to 256 in grants-tagger repo - per_device_eval_batch_size: int = field(default=16) + per_device_eval_batch_size: int = field(default=8) gradient_accumulation_steps: int = field(default=1) group_by_length: bool = field(default=False) # TODO test this @@ -54,7 +54,7 @@ class BertMeshTrainingArguments(TrainingArguments): auto_find_batch_size: bool = field(default=False) # TODO test this torch_compile: bool = field(default=False) # TODO make compilation - torch_compile_backend: str = field(default="inductor") - torch_compile_mode: str = field( - default="default" - ) # default | reduce-overhead | max-autotune + # torch_compile_backend: str = field(default="inductor") + # torch_compile_mode: str = field( + # default="default" + # ) # default | reduce-overhead | max-autotune From 08c9d0580d320d645d03e6c550a4c28f03f9bfa3 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 16:08:52 +0300 Subject: [PATCH 007/300] update gitignore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 68bc17f9..da1beb48 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + + +# Folder where training outputs are stored +bertmesh-outs/ From 0c35957d53b1f22e53a21c47f62d9cd0c734f33b Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 16:37:17 +0300 Subject: [PATCH 008/300] add loguru, prettify printing --- grants_tagger_light/training/train.py | 9 +++++-- poetry.lock | 34 ++++++++++++++++++++++++++- pyproject.toml | 1 + 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 582c5b7c..0601750a 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -9,8 +9,11 @@ from grants_tagger_light.training.train_args import BertMeshTrainingArguments from grants_tagger_light.training.dataloaders import load_grants_sample from sklearn.metrics import classification_report +from loguru import logger +from pprint import pformat import typer import numpy as np +import os def train_bertmesh(model_key: str, data_path: str, training_args: TrainingArguments): @@ -53,9 +56,9 @@ def sklearn_metrics(prediction: EvalPrediction): metrics = trainer.evaluate(eval_dataset=val_dset) - print(metrics) + logger.info(pformat(metrics)) - trainer.save_model(training_args.output_dir) + trainer.save_model(os.path.join(training_args.output_dir, "best")) train_app = typer.Typer() @@ -75,6 +78,8 @@ def train_bertmesh_cli( parser = HfArgumentParser((BertMeshTrainingArguments,)) (training_args,) = parser.parse_args_into_dataclasses(ctx.args) + logger.info("Training args: {}".format(pformat(training_args))) + train_bertmesh(model_key, data_path, training_args) diff --git a/poetry.lock b/poetry.lock index 4db467c9..19b33c88 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1845,6 +1845,24 @@ files = [ {file = "lit-16.0.6.tar.gz", hash = "sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a"}, ] +[[package]] +name = "loguru" +version = "0.7.0" +description = "Python logging made (stupidly) simple" +optional = false +python-versions = ">=3.5" +files = [ + {file = "loguru-0.7.0-py3-none-any.whl", hash = "sha256:b93aa30099fa6860d4727f1b81f8718e965bb96253fa190fab2077aaad6d15d3"}, + {file = "loguru-0.7.0.tar.gz", hash = "sha256:1612053ced6ae84d7959dd7d5e431a0532642237ec21f7fd83ac73fe539e03e1"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v0.990)", "pre-commit (==3.2.1)", "pytest (==6.1.2)", "pytest (==7.2.1)", "pytest-cov (==2.12.1)", "pytest-cov (==4.0.0)", "pytest-mypy-plugins (==1.10.1)", "pytest-mypy-plugins (==1.9.3)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.2.0)", "tox (==3.27.1)", "tox (==4.4.6)"] + [[package]] name = "markupsafe" version = "2.1.3" @@ -3606,6 +3624,20 @@ files = [ {file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"}, ] +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +optional = false +python-versions = ">=3.5" +files = [ + {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, + {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, +] + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + [[package]] name = "wrapt" version = "1.14.1" @@ -3893,4 +3925,4 @@ test = ["zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "b5e71835a8b1d87d8ab7d899ae5da9af9e856fbb65f3be51c4ce1729e83bf04f" +content-hash = "b2dac419b1d08941b073965ed5cc843a76399c3c6305c232e6f54717d482c76a" diff --git a/pyproject.toml b/pyproject.toml index 7f431e5d..69408bbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dvc = {extras = ["s3"], version = "^2.58.2"} torch = {version = "2.0.1", source = "torch-cpu"} transformers = "4.29.2" libpecos = "^1.0.0" +loguru = "^0.7.0" [tool.poetry.group.dev] From 9379765b178a62d73e8a8b4d0e77bf236cba6678 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 16:53:54 +0300 Subject: [PATCH 009/300] add integration with wandb --- .gitignore | 1 + grants_tagger_light/training/train_args.py | 2 + poetry.lock | 211 ++++++++++++++++++++- pyproject.toml | 1 + 4 files changed, 214 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index da1beb48..4fa1d780 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,4 @@ cython_debug/ # Folder where training outputs are stored bertmesh-outs/ +wandb/ diff --git a/grants_tagger_light/training/train_args.py b/grants_tagger_light/training/train_args.py index 405bf618..2bb1dbd0 100644 --- a/grants_tagger_light/training/train_args.py +++ b/grants_tagger_light/training/train_args.py @@ -40,6 +40,8 @@ class BertMeshTrainingArguments(TrainingArguments): seed: int = field(default=42) data_seed: int = field(default=42) + report_to: str = field(default="wandb") + optim: str = field( default="adamw_torch_fused" ) # TODO add support for adamw_apex_fused diff --git a/poetry.lock b/poetry.lock index 19b33c88..a33a6aed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1024,6 +1024,20 @@ files = [ {file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"}, ] +[[package]] +name = "docker-pycreds" +version = "0.4.0" +description = "Python bindings for the docker credentials store API" +optional = false +python-versions = "*" +files = [ + {file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"}, + {file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"}, +] + +[package.dependencies] +six = ">=1.4.0" + [[package]] name = "dpath" version = "2.1.6" @@ -2288,6 +2302,16 @@ files = [ {file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"}, ] +[[package]] +name = "pathtools" +version = "0.1.2" +description = "File system general utilities" +optional = false +python-versions = "*" +files = [ + {file = "pathtools-0.1.2.tar.gz", hash = "sha256:7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0"}, +] + [[package]] name = "platformdirs" version = "3.6.0" @@ -2350,6 +2374,28 @@ files = [ [package.dependencies] wcwidth = "*" +[[package]] +name = "protobuf" +version = "4.23.3" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "protobuf-4.23.3-cp310-abi3-win32.whl", hash = "sha256:514b6bbd54a41ca50c86dd5ad6488afe9505901b3557c5e0f7823a0cf67106fb"}, + {file = "protobuf-4.23.3-cp310-abi3-win_amd64.whl", hash = "sha256:cc14358a8742c4e06b1bfe4be1afbdf5c9f6bd094dff3e14edb78a1513893ff5"}, + {file = "protobuf-4.23.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:2991f5e7690dab569f8f81702e6700e7364cc3b5e572725098215d3da5ccc6ac"}, + {file = "protobuf-4.23.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:08fe19d267608d438aa37019236db02b306e33f6b9902c3163838b8e75970223"}, + {file = "protobuf-4.23.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:3b01a5274ac920feb75d0b372d901524f7e3ad39c63b1a2d55043f3887afe0c1"}, + {file = "protobuf-4.23.3-cp37-cp37m-win32.whl", hash = "sha256:aca6e86a08c5c5962f55eac9b5bd6fce6ed98645d77e8bfc2b952ecd4a8e4f6a"}, + {file = "protobuf-4.23.3-cp37-cp37m-win_amd64.whl", hash = "sha256:0149053336a466e3e0b040e54d0b615fc71de86da66791c592cc3c8d18150bf8"}, + {file = "protobuf-4.23.3-cp38-cp38-win32.whl", hash = "sha256:84ea0bd90c2fdd70ddd9f3d3fc0197cc24ecec1345856c2b5ba70e4d99815359"}, + {file = "protobuf-4.23.3-cp38-cp38-win_amd64.whl", hash = "sha256:3bcbeb2bf4bb61fe960dd6e005801a23a43578200ea8ceb726d1f6bd0e562ba1"}, + {file = "protobuf-4.23.3-cp39-cp39-win32.whl", hash = "sha256:5cb9e41188737f321f4fce9a4337bf40a5414b8d03227e1d9fbc59bc3a216e35"}, + {file = "protobuf-4.23.3-cp39-cp39-win_amd64.whl", hash = "sha256:29660574cd769f2324a57fb78127cda59327eb6664381ecfe1c69731b83e8288"}, + {file = "protobuf-4.23.3-py3-none-any.whl", hash = "sha256:447b9786ac8e50ae72cae7a2eec5c5df6a9dbf9aa6f908f1b8bda6032644ea62"}, + {file = "protobuf-4.23.3.tar.gz", hash = "sha256:7a92beb30600332a52cdadbedb40d33fd7c8a0d7f549c440347bc606fb3fe34b"}, +] + [[package]] name = "psutil" version = "5.9.5" @@ -3131,6 +3177,132 @@ files = [ {file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"}, ] +[[package]] +name = "sentry-sdk" +version = "1.26.0" +description = "Python client for Sentry (https://sentry.io)" +optional = false +python-versions = "*" +files = [ + {file = "sentry-sdk-1.26.0.tar.gz", hash = "sha256:760e4fb6d01c994110507133e08ecd4bdf4d75ee4be77f296a3579796cf73134"}, + {file = "sentry_sdk-1.26.0-py2.py3-none-any.whl", hash = "sha256:0c9f858337ec3781cf4851972ef42bba8c9828aea116b0dbed8f38c5f9a1896c"}, +] + +[package.dependencies] +certifi = "*" +urllib3 = {version = ">=1.26.11", markers = "python_version >= \"3.6\""} + +[package.extras] +aiohttp = ["aiohttp (>=3.5)"] +arq = ["arq (>=0.23)"] +beam = ["apache-beam (>=2.12)"] +bottle = ["bottle (>=0.12.13)"] +celery = ["celery (>=3)"] +chalice = ["chalice (>=1.16.0)"] +django = ["django (>=1.8)"] +falcon = ["falcon (>=1.4)"] +fastapi = ["fastapi (>=0.79.0)"] +flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"] +grpcio = ["grpcio (>=1.21.1)"] +httpx = ["httpx (>=0.16.0)"] +huey = ["huey (>=2)"] +loguru = ["loguru (>=0.5)"] +opentelemetry = ["opentelemetry-distro (>=0.35b0)"] +pure-eval = ["asttokens", "executing", "pure-eval"] +pymongo = ["pymongo (>=3.1)"] +pyspark = ["pyspark (>=2.4.4)"] +quart = ["blinker (>=1.1)", "quart (>=0.16.1)"] +rq = ["rq (>=0.6)"] +sanic = ["sanic (>=0.8)"] +sqlalchemy = ["sqlalchemy (>=1.2)"] +starlette = ["starlette (>=0.19.1)"] +starlite = ["starlite (>=1.48)"] +tornado = ["tornado (>=5)"] + +[[package]] +name = "setproctitle" +version = "1.3.2" +description = "A Python module to customize the process title" +optional = false +python-versions = ">=3.7" +files = [ + {file = "setproctitle-1.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:288943dec88e178bb2fd868adf491197cc0fc8b6810416b1c6775e686bab87fe"}, + {file = "setproctitle-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:630f6fe5e24a619ccf970c78e084319ee8be5be253ecc9b5b216b0f474f5ef18"}, + {file = "setproctitle-1.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c877691b90026670e5a70adfbcc735460a9f4c274d35ec5e8a43ce3f8443005"}, + {file = "setproctitle-1.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7a55fe05f15c10e8c705038777656fe45e3bd676d49ad9ac8370b75c66dd7cd7"}, + {file = "setproctitle-1.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ab45146c71ca6592c9cc8b354a2cc9cc4843c33efcbe1d245d7d37ce9696552d"}, + {file = "setproctitle-1.3.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e00c9d5c541a2713ba0e657e0303bf96ddddc412ef4761676adc35df35d7c246"}, + {file = "setproctitle-1.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:265ecbe2c6eafe82e104f994ddd7c811520acdd0647b73f65c24f51374cf9494"}, + {file = "setproctitle-1.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c2c46200656280a064073447ebd363937562debef329482fd7e570c8d498f806"}, + {file = "setproctitle-1.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:fa2f50678f04fda7a75d0fe5dd02bbdd3b13cbe6ed4cf626e4472a7ccf47ae94"}, + {file = "setproctitle-1.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7f2719a398e1a2c01c2a63bf30377a34d0b6ef61946ab9cf4d550733af8f1ef1"}, + {file = "setproctitle-1.3.2-cp310-cp310-win32.whl", hash = "sha256:e425be62524dc0c593985da794ee73eb8a17abb10fe692ee43bb39e201d7a099"}, + {file = "setproctitle-1.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:e85e50b9c67854f89635a86247412f3ad66b132a4d8534ac017547197c88f27d"}, + {file = "setproctitle-1.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2a97d51c17d438cf5be284775a322d57b7ca9505bb7e118c28b1824ecaf8aeaa"}, + {file = "setproctitle-1.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:587c7d6780109fbd8a627758063d08ab0421377c0853780e5c356873cdf0f077"}, + {file = "setproctitle-1.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d17c8bd073cbf8d141993db45145a70b307385b69171d6b54bcf23e5d644de"}, + {file = "setproctitle-1.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e932089c35a396dc31a5a1fc49889dd559548d14cb2237adae260382a090382e"}, + {file = "setproctitle-1.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e4f8f12258a8739c565292a551c3db62cca4ed4f6b6126664e2381acb4931bf"}, + {file = "setproctitle-1.3.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:570d255fd99c7f14d8f91363c3ea96bd54f8742275796bca67e1414aeca7d8c3"}, + {file = "setproctitle-1.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a8e0881568c5e6beff91ef73c0ec8ac2a9d3ecc9edd6bd83c31ca34f770910c4"}, + {file = "setproctitle-1.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4bba3be4c1fabf170595b71f3af46c6d482fbe7d9e0563999b49999a31876f77"}, + {file = "setproctitle-1.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:37ece938110cab2bb3957e3910af8152ca15f2b6efdf4f2612e3f6b7e5459b80"}, + {file = "setproctitle-1.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db684d6bbb735a80bcbc3737856385b55d53f8a44ce9b46e9a5682c5133a9bf7"}, + {file = "setproctitle-1.3.2-cp311-cp311-win32.whl", hash = "sha256:ca58cd260ea02759238d994cfae844fc8b1e206c684beb8f38877dcab8451dfc"}, + {file = "setproctitle-1.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:88486e6cce2a18a033013d17b30a594f1c5cb42520c49c19e6ade40b864bb7ff"}, + {file = "setproctitle-1.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:92c626edc66169a1b09e9541b9c0c9f10488447d8a2b1d87c8f0672e771bc927"}, + {file = "setproctitle-1.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:710e16fa3bade3b026907e4a5e841124983620046166f355bbb84be364bf2a02"}, + {file = "setproctitle-1.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f29b75e86260b0ab59adb12661ef9f113d2f93a59951373eb6d68a852b13e83"}, + {file = "setproctitle-1.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c8d9650154afaa86a44ff195b7b10d683c73509d085339d174e394a22cccbb9"}, + {file = "setproctitle-1.3.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0452282258dfcc01697026a8841258dd2057c4438b43914b611bccbcd048f10"}, + {file = "setproctitle-1.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:e49ae693306d7624015f31cb3e82708916759d592c2e5f72a35c8f4cc8aef258"}, + {file = "setproctitle-1.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1ff863a20d1ff6ba2c24e22436a3daa3cd80be1dfb26891aae73f61b54b04aca"}, + {file = "setproctitle-1.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:55ce1e9925ce1765865442ede9dca0ba9bde10593fcd570b1f0fa25d3ec6b31c"}, + {file = "setproctitle-1.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7fe9df7aeb8c64db6c34fc3b13271a363475d77bc157d3f00275a53910cb1989"}, + {file = "setproctitle-1.3.2-cp37-cp37m-win32.whl", hash = "sha256:e5c50e164cd2459bc5137c15288a9ef57160fd5cbf293265ea3c45efe7870865"}, + {file = "setproctitle-1.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:a499fff50387c1520c085a07578a000123f519e5f3eee61dd68e1d301659651f"}, + {file = "setproctitle-1.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5b932c3041aa924163f4aab970c2f0e6b4d9d773f4d50326e0ea1cd69240e5c5"}, + {file = "setproctitle-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f4bfc89bd33ebb8e4c0e9846a09b1f5a4a86f5cb7a317e75cc42fee1131b4f4f"}, + {file = "setproctitle-1.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fcd3cf4286a60fdc95451d8d14e0389a6b4f5cebe02c7f2609325eb016535963"}, + {file = "setproctitle-1.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5fb4f769c02f63fac90989711a3fee83919f47ae9afd4758ced5d86596318c65"}, + {file = "setproctitle-1.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5194b4969f82ea842a4f6af2f82cd16ebdc3f1771fb2771796e6add9835c1973"}, + {file = "setproctitle-1.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f0cde41857a644b7353a0060b5f94f7ba7cf593ebde5a1094da1be581ac9a31"}, + {file = "setproctitle-1.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9124bedd8006b0e04d4e8a71a0945da9b67e7a4ab88fdad7b1440dc5b6122c42"}, + {file = "setproctitle-1.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:c8a09d570b39517de10ee5b718730e171251ce63bbb890c430c725c8c53d4484"}, + {file = "setproctitle-1.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:8ff3c8cb26afaed25e8bca7b9dd0c1e36de71f35a3a0706b5c0d5172587a3827"}, + {file = "setproctitle-1.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:589be87172b238f839e19f146b9ea47c71e413e951ef0dc6db4218ddacf3c202"}, + {file = "setproctitle-1.3.2-cp38-cp38-win32.whl", hash = "sha256:4749a2b0c9ac52f864d13cee94546606f92b981b50e46226f7f830a56a9dc8e1"}, + {file = "setproctitle-1.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:e43f315c68aa61cbdef522a2272c5a5b9b8fd03c301d3167b5e1343ef50c676c"}, + {file = "setproctitle-1.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:de3a540cd1817ede31f530d20e6a4935bbc1b145fd8f8cf393903b1e02f1ae76"}, + {file = "setproctitle-1.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4058564195b975ddc3f0462375c533cce310ccdd41b80ac9aed641c296c3eff4"}, + {file = "setproctitle-1.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c5d5dad7c28bdd1ec4187d818e43796f58a845aa892bb4481587010dc4d362b"}, + {file = "setproctitle-1.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ffc61a388a5834a97953d6444a2888c24a05f2e333f9ed49f977a87bb1ad4761"}, + {file = "setproctitle-1.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fa1a0fbee72b47dc339c87c890d3c03a72ea65c061ade3204f285582f2da30f"}, + {file = "setproctitle-1.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe8a988c7220c002c45347430993830666e55bc350179d91fcee0feafe64e1d4"}, + {file = "setproctitle-1.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bae283e85fc084b18ffeb92e061ff7ac5af9e183c9d1345c93e178c3e5069cbe"}, + {file = "setproctitle-1.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:fed18e44711c5af4b681c2b3b18f85e6f0f1b2370a28854c645d636d5305ccd8"}, + {file = "setproctitle-1.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:b34baef93bfb20a8ecb930e395ccd2ae3268050d8cf4fe187de5e2bd806fd796"}, + {file = "setproctitle-1.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7f0bed90a216ef28b9d227d8d73e28a8c9b88c0f48a082d13ab3fa83c581488f"}, + {file = "setproctitle-1.3.2-cp39-cp39-win32.whl", hash = "sha256:4d8938249a7cea45ab7e1e48b77685d0f2bab1ebfa9dde23e94ab97968996a7c"}, + {file = "setproctitle-1.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:a47d97a75fd2d10c37410b180f67a5835cb1d8fdea2648fd7f359d4277f180b9"}, + {file = "setproctitle-1.3.2-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:dad42e676c5261eb50fdb16bdf3e2771cf8f99a79ef69ba88729aeb3472d8575"}, + {file = "setproctitle-1.3.2-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c91b9bc8985d00239f7dc08a49927a7ca1ca8a6af2c3890feec3ed9665b6f91e"}, + {file = "setproctitle-1.3.2-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8579a43eafd246e285eb3a5b939e7158073d5087aacdd2308f23200eac2458b"}, + {file = "setproctitle-1.3.2-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:2fbd8187948284293f43533c150cd69a0e4192c83c377da837dbcd29f6b83084"}, + {file = "setproctitle-1.3.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:faec934cfe5fd6ac1151c02e67156c3f526e82f96b24d550b5d51efa4a5527c6"}, + {file = "setproctitle-1.3.2-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1aafc91cbdacc9e5fe712c52077369168e6b6c346f3a9d51bf600b53eae56bb"}, + {file = "setproctitle-1.3.2-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b617f12c9be61e8f4b2857be4a4319754756845dbbbd9c3718f468bbb1e17bcb"}, + {file = "setproctitle-1.3.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:b2c9cb2705fc84cb8798f1ba74194f4c080aaef19d9dae843591c09b97678e98"}, + {file = "setproctitle-1.3.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a149a5f7f2c5a065d4e63cb0d7a4b6d3b66e6e80f12e3f8827c4f63974cbf122"}, + {file = "setproctitle-1.3.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e3ac25bfc4a0f29d2409650c7532d5ddfdbf29f16f8a256fc31c47d0dc05172"}, + {file = "setproctitle-1.3.2-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65d884e22037b23fa25b2baf1a3316602ed5c5971eb3e9d771a38c3a69ce6e13"}, + {file = "setproctitle-1.3.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7aa0aac1711fadffc1d51e9d00a3bea61f68443d6ac0241a224e4d622489d665"}, + {file = "setproctitle-1.3.2.tar.gz", hash = "sha256:b9fb97907c830d260fa0658ed58afd48a86b2b88aac521135c352ff7fd3477fd"}, +] + +[package.extras] +test = ["pytest"] + [[package]] name = "setuptools" version = "67.8.0" @@ -3599,6 +3771,43 @@ files = [ {file = "voluptuous-0.13.1.tar.gz", hash = "sha256:e8d31c20601d6773cb14d4c0f42aee29c6821bbd1018039aac7ac5605b489723"}, ] +[[package]] +name = "wandb" +version = "0.15.4" +description = "A CLI and library for interacting with the Weights and Biases API." +optional = false +python-versions = ">=3.6" +files = [ + {file = "wandb-0.15.4-py3-none-any.whl", hash = "sha256:9018565177e1be14d7d0dd470c583206031c6027c32a98c57fa3bb83955143d7"}, + {file = "wandb-0.15.4.tar.gz", hash = "sha256:472daaaa1a4e29a46407a85fd77aadb724c91d87dfe2c37cd82ef77be2257011"}, +] + +[package.dependencies] +appdirs = ">=1.4.3" +Click = ">=7.0,<8.0.0 || >8.0.0" +docker-pycreds = ">=0.4.0" +GitPython = ">=1.0.0,<3.1.29 || >3.1.29" +pathtools = "*" +protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5", markers = "python_version > \"3.9\" or sys_platform != \"linux\""} +psutil = ">=5.0.0" +PyYAML = "*" +requests = ">=2.0.0,<3" +sentry-sdk = ">=1.0.0" +setproctitle = "*" +setuptools = "*" + +[package.extras] +async = ["httpx (>=0.22.0)"] +aws = ["boto3"] +azure = ["azure-identity", "azure-storage-blob"] +gcp = ["google-cloud-storage"] +grpc = ["grpcio (>=1.27.2)"] +kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] +launch = ["awscli", "boto3", "botocore", "chardet", "google-auth", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "nbconvert", "nbformat", "optuna", "typing-extensions"] +media = ["bokeh", "moviepy", "numpy", "pillow", "plotly", "rdkit-pypi", "soundfile"] +models = ["cloudpickle"] +sweeps = ["sweeps (>=0.2.0)"] + [[package]] name = "wasabi" version = "1.1.1" @@ -3925,4 +4134,4 @@ test = ["zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "b2dac419b1d08941b073965ed5cc843a76399c3c6305c232e6f54717d482c76a" +content-hash = "ab53e38eeea4ec8285330167a2ee7fb0bef4007cc5af9718549c0eb1fd2a5ebf" diff --git a/pyproject.toml b/pyproject.toml index 69408bbe..3cecaf21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ torch = {version = "2.0.1", source = "torch-cpu"} transformers = "4.29.2" libpecos = "^1.0.0" loguru = "^0.7.0" +wandb = "^0.15.4" [tool.poetry.group.dev] From a7969d18aa2740309669b1dff54807db57d4481a Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 16:57:21 +0300 Subject: [PATCH 010/300] remove pdb call --- grants_tagger_light/training/dataloaders.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/grants_tagger_light/training/dataloaders.py b/grants_tagger_light/training/dataloaders.py index d1aef374..af56738b 100644 --- a/grants_tagger_light/training/dataloaders.py +++ b/grants_tagger_light/training/dataloaders.py @@ -41,9 +41,6 @@ def _tokenize(batch): ) def _label_encode(batch): - import pdb - - pdb.set_trace() batch["labels"] = [ [label2id[tag] for tag in tags[0] if tag in label2id] for tags in batch["mesh_terms"] From 7ed1d724cff0749de3cad553017e1b5ae02887b4 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 16:59:23 +0300 Subject: [PATCH 011/300] refactor --- grants_tagger_light/training/cli_args/__init__.py | 3 +++ grants_tagger_light/training/{ => cli_args}/train_args.py | 0 grants_tagger_light/training/dataloaders/__init__.py | 3 +++ grants_tagger_light/training/{ => dataloaders}/dataloaders.py | 0 grants_tagger_light/training/train.py | 2 +- 5 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 grants_tagger_light/training/cli_args/__init__.py rename grants_tagger_light/training/{ => cli_args}/train_args.py (100%) create mode 100644 grants_tagger_light/training/dataloaders/__init__.py rename grants_tagger_light/training/{ => dataloaders}/dataloaders.py (100%) diff --git a/grants_tagger_light/training/cli_args/__init__.py b/grants_tagger_light/training/cli_args/__init__.py new file mode 100644 index 00000000..bc5c954c --- /dev/null +++ b/grants_tagger_light/training/cli_args/__init__.py @@ -0,0 +1,3 @@ +from .train_args import BertMeshTrainingArguments + +__all__ = ["BertMeshTrainingArguments"] diff --git a/grants_tagger_light/training/train_args.py b/grants_tagger_light/training/cli_args/train_args.py similarity index 100% rename from grants_tagger_light/training/train_args.py rename to grants_tagger_light/training/cli_args/train_args.py diff --git a/grants_tagger_light/training/dataloaders/__init__.py b/grants_tagger_light/training/dataloaders/__init__.py new file mode 100644 index 00000000..09fc8c9a --- /dev/null +++ b/grants_tagger_light/training/dataloaders/__init__.py @@ -0,0 +1,3 @@ +from .dataloaders import load_grants_sample + +__all__ = ["load_grants_sample"] diff --git a/grants_tagger_light/training/dataloaders.py b/grants_tagger_light/training/dataloaders/dataloaders.py similarity index 100% rename from grants_tagger_light/training/dataloaders.py rename to grants_tagger_light/training/dataloaders/dataloaders.py diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 0601750a..18016f06 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -6,7 +6,7 @@ HfArgumentParser, ) from grants_tagger_light.models.bert_mesh import BertMesh -from grants_tagger_light.training.train_args import BertMeshTrainingArguments +from grants_tagger_light.training.cli_args import BertMeshTrainingArguments from grants_tagger_light.training.dataloaders import load_grants_sample from sklearn.metrics import classification_report from loguru import logger From 55b6c53e94a78749f01a7c8e6a74effcd413854f Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 17:24:28 +0300 Subject: [PATCH 012/300] add support for wandb arguments (avoid passing env variables) --- .../training/cli_args/__init__.py | 3 +- .../training/cli_args/wandb_args.py | 44 +++++++++++++++++++ grants_tagger_light/training/train.py | 10 +++-- 3 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 grants_tagger_light/training/cli_args/wandb_args.py diff --git a/grants_tagger_light/training/cli_args/__init__.py b/grants_tagger_light/training/cli_args/__init__.py index bc5c954c..31771867 100644 --- a/grants_tagger_light/training/cli_args/__init__.py +++ b/grants_tagger_light/training/cli_args/__init__.py @@ -1,3 +1,4 @@ from .train_args import BertMeshTrainingArguments +from .wandb_args import WandbArguments -__all__ = ["BertMeshTrainingArguments"] +__all__ = ["BertMeshTrainingArguments", "WandbArguments"] diff --git a/grants_tagger_light/training/cli_args/wandb_args.py b/grants_tagger_light/training/cli_args/wandb_args.py new file mode 100644 index 00000000..74c6fed9 --- /dev/null +++ b/grants_tagger_light/training/cli_args/wandb_args.py @@ -0,0 +1,44 @@ +import os +from dataclasses import dataclass, field, fields + + +@dataclass +class WandbArguments: + """ + Wandb arguments for training. Will set according env variables if not set. + Each field is a lowercase version of the env variable name. + For all wandb envs, see: https://docs.wandb.ai/guides/track/environment-variables + """ + + wandb_api_key: str = field( + default=None, + metadata={"help": "Wandb API key"}, + ) + + wandb_project: str = field( + default=None, + metadata={"help": "Wandb project name"}, + ) + + wandb_name: str = field( + default=None, + metadata={"help": "Wandb run name"}, + ) + + wandb_notes: str = field( + default=None, + metadata={"help": "Wandb run notes. Markdown allowed."}, + ) + + wandb_tags: list[str] = field( + default_factory=list, + metadata={"help": "Wandb run tags. Comma separated."}, + ) + + def __post_init__(self): + # Check if env variables are set, and if not set them to this class' values + for field_ in fields(self): + env_var_name = field_.name.upper() + env_var = os.environ.get(env_var_name) + if not env_var: + os.environ[env_var_name] = str(getattr(self, field_.name)) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 18016f06..7b8685cb 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -6,7 +6,10 @@ HfArgumentParser, ) from grants_tagger_light.models.bert_mesh import BertMesh -from grants_tagger_light.training.cli_args import BertMeshTrainingArguments +from grants_tagger_light.training.cli_args import ( + BertMeshTrainingArguments, + WandbArguments, +) from grants_tagger_light.training.dataloaders import load_grants_sample from sklearn.metrics import classification_report from loguru import logger @@ -75,10 +78,11 @@ def train_bertmesh_cli( help="Path to data in jsonl format. Must contain text and tags field", ), ): - parser = HfArgumentParser((BertMeshTrainingArguments,)) - (training_args,) = parser.parse_args_into_dataclasses(ctx.args) + parser = HfArgumentParser((BertMeshTrainingArguments, WandbArguments)) + (training_args, wandb_args) = parser.parse_args_into_dataclasses(ctx.args) logger.info("Training args: {}".format(pformat(training_args))) + logger.info("Wandb args: {}".format(pformat(wandb_args))) train_bertmesh(model_key, data_path, training_args) From a258d4e7ac603eb4887cf917c02c9314f3cdbf0f Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 27 Jun 2023 17:30:09 +0300 Subject: [PATCH 013/300] add proper tag support --- grants_tagger_light/training/cli_args/wandb_args.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/cli_args/wandb_args.py b/grants_tagger_light/training/cli_args/wandb_args.py index 74c6fed9..21529af7 100644 --- a/grants_tagger_light/training/cli_args/wandb_args.py +++ b/grants_tagger_light/training/cli_args/wandb_args.py @@ -36,9 +36,13 @@ class WandbArguments: ) def __post_init__(self): + # Postprocess the # Check if env variables are set, and if not set them to this class' values for field_ in fields(self): env_var_name = field_.name.upper() env_var = os.environ.get(env_var_name) if not env_var: - os.environ[env_var_name] = str(getattr(self, field_.name)) + if isinstance(getattr(self, field_.name), list): + os.environ[env_var_name] = ",".join(getattr(self, field_.name)) + else: + os.environ[env_var_name] = str(getattr(self, field_.name)) From 775544465fd342a2489b58478df5e0f6063054fe Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 28 Jun 2023 15:47:38 +0300 Subject: [PATCH 014/300] upgrade datasets to 2.13.1 --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index a33a6aed..873c01b6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -903,13 +903,13 @@ test-randomorder = ["pytest-randomly"] [[package]] name = "datasets" -version = "2.13.0" +version = "2.13.1" description = "HuggingFace community-driven open-source library of datasets" optional = false python-versions = ">=3.7.0" files = [ - {file = "datasets-2.13.0-py3-none-any.whl", hash = "sha256:26671d474990ad8fd7388e8c67cde4d72f6c1f0e87af685fc09af5d9a5992274"}, - {file = "datasets-2.13.0.tar.gz", hash = "sha256:b8c3bcf9c3d0c74f101c7645e42231de9f45206a2e742df15799da9bfa625608"}, + {file = "datasets-2.13.1-py3-none-any.whl", hash = "sha256:844d8dbc1759e0b6b8e5063af019dc95d6af07ea075002b03323a280bf8d53f6"}, + {file = "datasets-2.13.1.tar.gz", hash = "sha256:bacb7750b1a434417312b4281a55225a3f7e0163abdd12a2a3e2d700310d5221"}, ] [package.dependencies] @@ -4134,4 +4134,4 @@ test = ["zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "ab53e38eeea4ec8285330167a2ee7fb0bef4007cc5af9718549c0eb1fd2a5ebf" +content-hash = "1cd7926938b8bd8bcfec710786ecc7b248c23c72c6fe01d87b28dc861b12f445" diff --git a/pyproject.toml b/pyproject.toml index 3cecaf21..25ae415f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ scikit-learn = "1.2.2" pandas = "1.5.3" wasabi = "1.1.1" typer = "^0.9.0" -datasets = "^2.12.0" +datasets = "2.13.1" accelerate = "^0.19.0" dvc = {extras = ["s3"], version = "^2.58.2"} torch = {version = "2.0.1", source = "torch-cpu"} From a43953b8a5b1acb407a2951a695c86a254b81c9f Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 28 Jun 2023 16:02:05 +0300 Subject: [PATCH 015/300] add pipeline to output jsonl allmesh_2021 --- data/raw/.gitignore | 1 + pipelines/allmesh_json_to_jsonl/dvc.lock | 9 +++++++ pipelines/allmesh_json_to_jsonl/dvc.yaml | 9 +++++++ scripts/mesh_to_jsonl.py | 30 ++++++++++++++++++++++++ 4 files changed, 49 insertions(+) create mode 100644 pipelines/allmesh_json_to_jsonl/dvc.lock create mode 100644 pipelines/allmesh_json_to_jsonl/dvc.yaml create mode 100644 scripts/mesh_to_jsonl.py diff --git a/data/raw/.gitignore b/data/raw/.gitignore index e17ccf76..fa95c7dc 100644 --- a/data/raw/.gitignore +++ b/data/raw/.gitignore @@ -1,3 +1,4 @@ /allMeSH_2021.json /desc2021.xml /disease_tags_validation_grants.xlsx +/allMeSH_2021.jsonl diff --git a/pipelines/allmesh_json_to_jsonl/dvc.lock b/pipelines/allmesh_json_to_jsonl/dvc.lock new file mode 100644 index 00000000..6252b40e --- /dev/null +++ b/pipelines/allmesh_json_to_jsonl/dvc.lock @@ -0,0 +1,9 @@ +schema: '2.0' +stages: + convert: + cmd: python ../../scripts/mesh_to_jsonl.py --json-path ../../data/raw/allMeSH_2021.json + --jsonl-path ../../data/raw/allMeSH_2021.jsonl + outs: + - path: ../../data/raw/allMeSH_2021.jsonl + md5: 94f18c3918b180728a553123edb2ee32 + size: 27914288461 diff --git a/pipelines/allmesh_json_to_jsonl/dvc.yaml b/pipelines/allmesh_json_to_jsonl/dvc.yaml new file mode 100644 index 00000000..ad714941 --- /dev/null +++ b/pipelines/allmesh_json_to_jsonl/dvc.yaml @@ -0,0 +1,9 @@ +vars: + - scripts_location: "../../scripts" + - json-path: "../../data/raw/allMeSH_2021.json" + - jsonl-path: "../../data/raw/allMeSH_2021.jsonl" +stages: + convert: + cmd: python ${scripts_location}/mesh_to_jsonl.py --json-path ${json-path} --jsonl-path ${jsonl-path} + outs: + - ${jsonl-path} diff --git a/scripts/mesh_to_jsonl.py b/scripts/mesh_to_jsonl.py new file mode 100644 index 00000000..870dbf89 --- /dev/null +++ b/scripts/mesh_to_jsonl.py @@ -0,0 +1,30 @@ +import json +import argparse +from tqdm import tqdm + + +def convert_json_to_jsonl(json_path, jsonl_path): + # First count the lines + with open(json_path, "r", encoding="latin1") as input_file: + num_lines = sum(1 for line in input_file) + + # Skip 1st line + with open(json_path, "r", encoding="latin1") as input_file, open( + jsonl_path, "w" + ) as output_file: + pbar = tqdm(total=num_lines) + for idx, line in enumerate(input_file): + if idx == 0: + continue + sample = json.loads(line[:-2]) + output_file.write(json.dumps(sample) + "\n") + pbar.update(1) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--json-path", type=str, default="data/raw/allMeSH_2021.json") + parser.add_argument("--jsonl-path", type=str, default="data/raw/allMeSH_2021.jsonl") + args = parser.parse_args() + + convert_json_to_jsonl(args.json_path, args.jsonl_path) From f6fc7bcaed43fb5344d7c074953e4cc4513c3d55 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 28 Jun 2023 17:00:13 +0300 Subject: [PATCH 016/300] remove pipeline and add more efficient dataset from generator --- data/raw/.gitignore | 1 - .../training/dataloaders/dataloaders.py | 123 ++++++++++++++---- pipelines/allmesh_json_to_jsonl/dvc.lock | 9 -- pipelines/allmesh_json_to_jsonl/dvc.yaml | 9 -- 4 files changed, 100 insertions(+), 42 deletions(-) delete mode 100644 pipelines/allmesh_json_to_jsonl/dvc.lock delete mode 100644 pipelines/allmesh_json_to_jsonl/dvc.yaml diff --git a/data/raw/.gitignore b/data/raw/.gitignore index fa95c7dc..e17ccf76 100644 --- a/data/raw/.gitignore +++ b/data/raw/.gitignore @@ -1,4 +1,3 @@ /allMeSH_2021.json /desc2021.xml /disease_tags_validation_grants.xlsx -/allMeSH_2021.jsonl diff --git a/grants_tagger_light/training/dataloaders/dataloaders.py b/grants_tagger_light/training/dataloaders/dataloaders.py index af56738b..b855ad8f 100644 --- a/grants_tagger_light/training/dataloaders/dataloaders.py +++ b/grants_tagger_light/training/dataloaders/dataloaders.py @@ -2,13 +2,40 @@ from transformers import AutoTokenizer from datasets import Dataset +# TODO refactor the two load funcs into a class + + +def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): + return tokenizer( + batch[x_col], + padding="max_length", + truncation=True, + max_length=512, + ) + + +def _label_encode(batch, mesh_terms_column: str, label2id: dict): + batch["labels"] = [ + [label2id[tag] for tag in tags[0] if tag in label2id] + for tags in batch[mesh_terms_column] + ] + return batch + + +def _one_hot(batch, label2id: dict): + batch["labels"] = [ + [1 if i in labels else 0 for i in range(len(label2id))] + for labels in batch["labels"] + ] + return batch + def load_grants_sample( data_path: str, tokenizer: AutoTokenizer, label2id: dict, test_size: float = 0.1, - num_proc: int = 1, + num_proc: int = 8, ): """ Code that loads a grants sample. @@ -32,28 +59,6 @@ def _datagen(data_path: str): sample = json.loads(line) yield sample - def _tokenize(batch): - return tokenizer( - batch["abstract"], - padding="max_length", - truncation=True, - max_length=512, - ) - - def _label_encode(batch): - batch["labels"] = [ - [label2id[tag] for tag in tags[0] if tag in label2id] - for tags in batch["mesh_terms"] - ] - return batch - - def _one_hot(batch): - batch["labels"] = [ - [1 if i in labels else 0 for i in range(len(label2id))] - for labels in batch["labels"] - ] - return batch - dset = Dataset.from_generator(_datagen, gen_kwargs={"data_path": data_path}) dset = dset.map( _tokenize, @@ -61,6 +66,58 @@ def _one_hot(batch): batch_size=32, num_proc=num_proc, desc="Tokenizing", + fn_kwargs={"tokenizer": tokenizer, "x_col": "abstract"}, + ) + + dset = dset.map( + _label_encode, + batched=True, + batch_size=32, + num_proc=num_proc, + desc="Encoding labels", + fn_kwargs={"mesh_terms_column": "mesh_terms", "label2id": label2id}, + ) + + dset = dset.map( + _one_hot, + batched=True, + batch_size=32, + num_proc=num_proc, + desc="One-hot labels", + fn_kwargs={"label2id": label2id}, + ) + + # Split into train and test + dset = dset.train_test_split(test_size=test_size) + + return dset["train"], dset["test"] + + +def load_mesh_json( + data_path: str, + tokenizer: AutoTokenizer, + label2id: dict, + test_size: float = 0.1, + num_proc: int = 8, +): + def _datagen(mesh_json_path: str): + with open(mesh_json_path, "r", encoding="latin1") as f: + for idx, line in enumerate(f): + # Skip 1st line + if idx == 0: + continue + sample = json.loads(line[:-2]) + yield sample + + dset = Dataset.from_generator(_datagen, gen_kwargs={"mesh_json_path": data_path}) + + dset = dset.map( + _tokenize, + batched=True, + batch_size=32, + num_proc=num_proc, + desc="Tokenizing", + fn_kwargs={"tokenizer": tokenizer}, ) dset = dset.map( @@ -69,6 +126,7 @@ def _one_hot(batch): batch_size=32, num_proc=num_proc, desc="Encoding labels", + fn_kwargs={"mesh_terms_column": "meshMajor", "label2id": label2id}, ) dset = dset.map( @@ -77,9 +135,28 @@ def _one_hot(batch): batch_size=32, num_proc=num_proc, desc="One-hot labels", + fn_kwargs={"label2id": label2id}, ) # Split into train and test dset = dset.train_test_split(test_size=test_size) return dset["train"], dset["test"] + + +if __name__ == "__main__": + from transformers import AutoModel + + model = AutoModel.from_pretrained( + "Wellcome/WellcomeBertMesh", trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained("Wellcome/WellcomeBertMesh") + + dset_train, dset_val = load_mesh_json( + data_path="data/raw/allMeSH_2021.json", + tokenizer=tokenizer, + label2id=model.config.label2id, + ) + import pdb + + pdb.set_trace() diff --git a/pipelines/allmesh_json_to_jsonl/dvc.lock b/pipelines/allmesh_json_to_jsonl/dvc.lock deleted file mode 100644 index 6252b40e..00000000 --- a/pipelines/allmesh_json_to_jsonl/dvc.lock +++ /dev/null @@ -1,9 +0,0 @@ -schema: '2.0' -stages: - convert: - cmd: python ../../scripts/mesh_to_jsonl.py --json-path ../../data/raw/allMeSH_2021.json - --jsonl-path ../../data/raw/allMeSH_2021.jsonl - outs: - - path: ../../data/raw/allMeSH_2021.jsonl - md5: 94f18c3918b180728a553123edb2ee32 - size: 27914288461 diff --git a/pipelines/allmesh_json_to_jsonl/dvc.yaml b/pipelines/allmesh_json_to_jsonl/dvc.yaml deleted file mode 100644 index ad714941..00000000 --- a/pipelines/allmesh_json_to_jsonl/dvc.yaml +++ /dev/null @@ -1,9 +0,0 @@ -vars: - - scripts_location: "../../scripts" - - json-path: "../../data/raw/allMeSH_2021.json" - - jsonl-path: "../../data/raw/allMeSH_2021.jsonl" -stages: - convert: - cmd: python ${scripts_location}/mesh_to_jsonl.py --json-path ${json-path} --jsonl-path ${jsonl-path} - outs: - - ${jsonl-path} From c72a7ddca16c6871837d68fcf17d057308b1d0d6 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 28 Jun 2023 17:02:03 +0300 Subject: [PATCH 017/300] add missing arg for tokenizer --- grants_tagger_light/training/dataloaders/dataloaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/training/dataloaders/dataloaders.py b/grants_tagger_light/training/dataloaders/dataloaders.py index b855ad8f..d547b39a 100644 --- a/grants_tagger_light/training/dataloaders/dataloaders.py +++ b/grants_tagger_light/training/dataloaders/dataloaders.py @@ -117,7 +117,7 @@ def _datagen(mesh_json_path: str): batch_size=32, num_proc=num_proc, desc="Tokenizing", - fn_kwargs={"tokenizer": tokenizer}, + fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, ) dset = dset.map( From 5110d743a087b8bc356164f966d99c2a2c17da7c Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 28 Jun 2023 17:40:53 +0300 Subject: [PATCH 018/300] add func in train cli --- grants_tagger_light/training/dataloaders/__init__.py | 4 ++-- grants_tagger_light/training/dataloaders/dataloaders.py | 3 --- grants_tagger_light/training/train.py | 6 ++++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/grants_tagger_light/training/dataloaders/__init__.py b/grants_tagger_light/training/dataloaders/__init__.py index 09fc8c9a..63f7aadd 100644 --- a/grants_tagger_light/training/dataloaders/__init__.py +++ b/grants_tagger_light/training/dataloaders/__init__.py @@ -1,3 +1,3 @@ -from .dataloaders import load_grants_sample +from .dataloaders import load_grants_sample, load_mesh_json -__all__ = ["load_grants_sample"] +__all__ = ["load_grants_sample", "load_mesh_json"] diff --git a/grants_tagger_light/training/dataloaders/dataloaders.py b/grants_tagger_light/training/dataloaders/dataloaders.py index d547b39a..cec60ff4 100644 --- a/grants_tagger_light/training/dataloaders/dataloaders.py +++ b/grants_tagger_light/training/dataloaders/dataloaders.py @@ -157,6 +157,3 @@ def _datagen(mesh_json_path: str): tokenizer=tokenizer, label2id=model.config.label2id, ) - import pdb - - pdb.set_trace() diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 7b8685cb..1bd47fc3 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -10,7 +10,9 @@ BertMeshTrainingArguments, WandbArguments, ) -from grants_tagger_light.training.dataloaders import load_grants_sample +from grants_tagger_light.training.dataloaders import ( + load_mesh_json, +) from sklearn.metrics import classification_report from loguru import logger from pprint import pformat @@ -25,7 +27,7 @@ def train_bertmesh(model_key: str, data_path: str, training_args: TrainingArgume label2id = {v: k for k, v in model.id2label.items()} - train_dset, val_dset = load_grants_sample(data_path, tokenizer, label2id=label2id) + train_dset, val_dset = load_mesh_json(data_path, tokenizer, label2id=label2id) def sklearn_metrics(prediction: EvalPrediction): y_pred = prediction.predictions From d86473c0723a7e4d0e5a1c091e610d92203a5004 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 28 Jun 2023 17:53:56 +0300 Subject: [PATCH 019/300] replace load func with all mesh one --- grants_tagger_light/training/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 1bd47fc3..ff372daf 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -27,7 +27,9 @@ def train_bertmesh(model_key: str, data_path: str, training_args: TrainingArgume label2id = {v: k for k, v in model.id2label.items()} - train_dset, val_dset = load_mesh_json(data_path, tokenizer, label2id=label2id) + train_dset, val_dset = load_mesh_json( + data_path, tokenizer=tokenizer, label2id=label2id + ) def sklearn_metrics(prediction: EvalPrediction): y_pred = prediction.predictions From 1fe6544e0a801f2910363842cf0fc2599ca0c20c Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 13:29:55 +0300 Subject: [PATCH 020/300] add ability to subsample dataset --- .../training/dataloaders/dataloaders.py | 28 +++++++++++++++---- grants_tagger_light/training/train.py | 18 ++++++++++-- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/grants_tagger_light/training/dataloaders/dataloaders.py b/grants_tagger_light/training/dataloaders/dataloaders.py index cec60ff4..5e27ff97 100644 --- a/grants_tagger_light/training/dataloaders/dataloaders.py +++ b/grants_tagger_light/training/dataloaders/dataloaders.py @@ -1,4 +1,5 @@ import json +import numpy as np from transformers import AutoTokenizer from datasets import Dataset @@ -36,6 +37,7 @@ def load_grants_sample( label2id: dict, test_size: float = 0.1, num_proc: int = 8, + max_samples: int = np.inf, ): """ Code that loads a grants sample. @@ -48,18 +50,26 @@ def load_grants_sample( for development / sanity check purposes). """ - def _datagen(data_path: str): + def _datagen(data_path: str, max_samples: int = np.inf): """ Loads the data from the given path. The data should be in jsonl format, with each line containing a text and tags field. The tags field should be a list of strings. """ with open(data_path, "r") as f: - for line in f: + for idx, line in f: sample = json.loads(line) + + if idx > max_samples: + break + yield sample - dset = Dataset.from_generator(_datagen, gen_kwargs={"data_path": data_path}) + dset = Dataset.from_generator( + _datagen, + gen_kwargs={"data_path": data_path, "max_samples": max_samples}, + ) + dset = dset.map( _tokenize, batched=True, @@ -99,17 +109,25 @@ def load_mesh_json( label2id: dict, test_size: float = 0.1, num_proc: int = 8, + max_samples: int = np.inf, ): - def _datagen(mesh_json_path: str): + def _datagen(mesh_json_path: str, max_samples: int = np.inf): with open(mesh_json_path, "r", encoding="latin1") as f: for idx, line in enumerate(f): # Skip 1st line if idx == 0: continue sample = json.loads(line[:-2]) + + if idx > max_samples: + break + yield sample - dset = Dataset.from_generator(_datagen, gen_kwargs={"mesh_json_path": data_path}) + dset = Dataset.from_generator( + _datagen, + gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, + ) dset = dset.map( _tokenize, diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index ff372daf..32baf852 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -21,14 +21,22 @@ import os -def train_bertmesh(model_key: str, data_path: str, training_args: TrainingArguments): +def train_bertmesh( + model_key: str, + data_path: str, + max_samples: int, + training_args: TrainingArguments, +): model = BertMesh.from_pretrained(model_key, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_key) label2id = {v: k for k, v in model.id2label.items()} train_dset, val_dset = load_mesh_json( - data_path, tokenizer=tokenizer, label2id=label2id + data_path, + tokenizer=tokenizer, + label2id=label2id, + max_samples=max_samples, ) def sklearn_metrics(prediction: EvalPrediction): @@ -81,6 +89,10 @@ def train_bertmesh_cli( ..., help="Path to data in jsonl format. Must contain text and tags field", ), + max_samples: int = typer.Option( + np.inf, + help="Maximum number of samples to use for training. Useful for dev/debugging", + ), ): parser = HfArgumentParser((BertMeshTrainingArguments, WandbArguments)) (training_args, wandb_args) = parser.parse_args_into_dataclasses(ctx.args) @@ -88,7 +100,7 @@ def train_bertmesh_cli( logger.info("Training args: {}".format(pformat(training_args))) logger.info("Wandb args: {}".format(pformat(wandb_args))) - train_bertmesh(model_key, data_path, training_args) + train_bertmesh(model_key, data_path, max_samples, training_args) if __name__ == "__main__": From 7295d3f7256fbd768a7500b6063789c9751bdd2e Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 14:01:54 +0300 Subject: [PATCH 021/300] fix issue with label encoding --- .../training/dataloaders/dataloaders.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/training/dataloaders/dataloaders.py b/grants_tagger_light/training/dataloaders/dataloaders.py index 5e27ff97..a993c2f6 100644 --- a/grants_tagger_light/training/dataloaders/dataloaders.py +++ b/grants_tagger_light/training/dataloaders/dataloaders.py @@ -16,10 +16,16 @@ def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): def _label_encode(batch, mesh_terms_column: str, label2id: dict): - batch["labels"] = [ - [label2id[tag] for tag in tags[0] if tag in label2id] - for tags in batch[mesh_terms_column] - ] + batch_labels = [] + for sample_tags in batch[mesh_terms_column]: + sample_labels = [] + for tag in sample_tags: + if tag in label2id: + sample_labels.append(label2id[tag]) + batch_labels.append(sample_labels) + + batch["labels"] = batch_labels + return batch @@ -83,7 +89,7 @@ def _datagen(data_path: str, max_samples: int = np.inf): _label_encode, batched=True, batch_size=32, - num_proc=num_proc, + num_proc=1, desc="Encoding labels", fn_kwargs={"mesh_terms_column": "mesh_terms", "label2id": label2id}, ) From ad9816bae7f11643fe60daea5595086cfa00bfda Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 14:02:59 +0300 Subject: [PATCH 022/300] add list with mesh tree letters --- data/grants_comparison/mesh_tree_letters_list.txt | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 data/grants_comparison/mesh_tree_letters_list.txt diff --git a/data/grants_comparison/mesh_tree_letters_list.txt b/data/grants_comparison/mesh_tree_letters_list.txt new file mode 100644 index 00000000..270dd6a8 --- /dev/null +++ b/data/grants_comparison/mesh_tree_letters_list.txt @@ -0,0 +1,4 @@ +Information Sources: L +Phenomena and Processes: G +Geographicals: Z +Diseases: C From 5d50ce64b91fb2c18006afd58123846a63c07a0f Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 14:31:19 +0300 Subject: [PATCH 023/300] postprocess wandb tags and set to None if not specified --- grants_tagger_light/training/cli_args/wandb_args.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/cli_args/wandb_args.py b/grants_tagger_light/training/cli_args/wandb_args.py index 21529af7..b004353f 100644 --- a/grants_tagger_light/training/cli_args/wandb_args.py +++ b/grants_tagger_light/training/cli_args/wandb_args.py @@ -36,7 +36,9 @@ class WandbArguments: ) def __post_init__(self): - # Postprocess the + if len(self.wandb_tags) == 0: + self.wandb_tags = None + # Check if env variables are set, and if not set them to this class' values for field_ in fields(self): env_var_name = field_.name.upper() From f29db6c462a49e1721138da779ec3ad556a8faa7 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 14:56:39 +0300 Subject: [PATCH 024/300] make train test pass --- tests/test_train.py | 56 ++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/test_train.py b/tests/test_train.py index d34b0f77..2d548b70 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,32 +1,23 @@ from grants_tagger_light.training.train import train_bertmesh +from transformers import TrainingArguments import tempfile import pytest -import json +import numpy as np # Note dummy data is not necessarily annotated correctly -dummy_data = [ - { - "text": "This grant is about malaria", - "tags": ["Humans", "Malaria"], - }, - { - "text": "This grant is about HIV", - "tags": ["HIV Infections", "Humans"], - }, - { - "text": "This grant is about diabetes", - "tags": ["Diabetes Mellitus", "Humans"], - }, -] +dummy_data = """{"articles":[ +{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus."}, +{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about COVID-19."}, +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria"}, +""" # noqa @pytest.fixture def data_path(): with tempfile.TemporaryDirectory() as tmpdirname: - data_path = tmpdirname + "/data.jsonl" + data_path = tmpdirname + "/data.json" with open(data_path, "w") as f: - for sample in dummy_data: - f.write(json.dumps(sample) + "\n") + f.write(dummy_data) yield data_path @@ -40,13 +31,22 @@ def test_train_bertmesh(data_path, save_path): model_key = "Wellcome/WellcomeBertMesh" # 1 train step, 1 eval step, save after training - user_args = { - "output_dir": save_path, - "max_steps": 1, - "evaluation_strategy": "steps", - "eval_steps": 1, - "save_strategy": "steps", - "save_steps": 1, - } - - train_bertmesh(model_key, data_path, **user_args) + training_args = TrainingArguments( + output_dir=save_path, + max_steps=1, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + evaluation_strategy="steps", + eval_steps=1, + save_strategy="steps", + save_steps=1, + report_to="none", + no_cuda=True, + ) + + train_bertmesh( + model_key=model_key, + data_path=data_path, + max_samples=np.inf, + training_args=training_args, + ) From a34d56364aa0934edd997ee396de5590f6c9557b Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 14:59:11 +0300 Subject: [PATCH 025/300] update gitignore --- models/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/models/.gitignore b/models/.gitignore index ff69c5ca..7afd2f42 100644 --- a/models/.gitignore +++ b/models/.gitignore @@ -1 +1,2 @@ /xlinear-0.2.5 +bertmesh/model/epoch-3 From dffc32ce37df540d5d0d5a1466a47f4c22447779 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 16:18:03 +0300 Subject: [PATCH 026/300] set seed to enable reprodubility --- grants_tagger_light/training/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 32baf852..a9ca0e42 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -19,6 +19,10 @@ import typer import numpy as np import os +import transformers + + +transformers.set_seed(42) def train_bertmesh( From 7210b00bd7f62cc95071efb56ee74f7c0a67e0d1 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 16:25:59 +0300 Subject: [PATCH 027/300] remove unused script --- scripts/mesh_to_jsonl.py | 30 ------------------------------ 1 file changed, 30 deletions(-) delete mode 100644 scripts/mesh_to_jsonl.py diff --git a/scripts/mesh_to_jsonl.py b/scripts/mesh_to_jsonl.py deleted file mode 100644 index 870dbf89..00000000 --- a/scripts/mesh_to_jsonl.py +++ /dev/null @@ -1,30 +0,0 @@ -import json -import argparse -from tqdm import tqdm - - -def convert_json_to_jsonl(json_path, jsonl_path): - # First count the lines - with open(json_path, "r", encoding="latin1") as input_file: - num_lines = sum(1 for line in input_file) - - # Skip 1st line - with open(json_path, "r", encoding="latin1") as input_file, open( - jsonl_path, "w" - ) as output_file: - pbar = tqdm(total=num_lines) - for idx, line in enumerate(input_file): - if idx == 0: - continue - sample = json.loads(line[:-2]) - output_file.write(json.dumps(sample) + "\n") - pbar.update(1) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--json-path", type=str, default="data/raw/allMeSH_2021.json") - parser.add_argument("--jsonl-path", type=str, default="data/raw/allMeSH_2021.jsonl") - args = parser.parse_args() - - convert_json_to_jsonl(args.json_path, args.jsonl_path) From bfb6aa82f68a2cf68668a884c4fa1852ad6481e0 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 16:51:39 +0300 Subject: [PATCH 028/300] refactor dataloader --- .../training/dataloaders/__init__.py | 4 +- .../{dataloaders.py => mesh_json_loader.py} | 72 ------------------- 2 files changed, 2 insertions(+), 74 deletions(-) rename grants_tagger_light/training/dataloaders/{dataloaders.py => mesh_json_loader.py} (58%) diff --git a/grants_tagger_light/training/dataloaders/__init__.py b/grants_tagger_light/training/dataloaders/__init__.py index 63f7aadd..7665d628 100644 --- a/grants_tagger_light/training/dataloaders/__init__.py +++ b/grants_tagger_light/training/dataloaders/__init__.py @@ -1,3 +1,3 @@ -from .dataloaders import load_grants_sample, load_mesh_json +from .mesh_json_loader import load_mesh_json -__all__ = ["load_grants_sample", "load_mesh_json"] +__all__ = ["load_mesh_json"] diff --git a/grants_tagger_light/training/dataloaders/dataloaders.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py similarity index 58% rename from grants_tagger_light/training/dataloaders/dataloaders.py rename to grants_tagger_light/training/dataloaders/mesh_json_loader.py index a993c2f6..f1a34b2a 100644 --- a/grants_tagger_light/training/dataloaders/dataloaders.py +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -37,78 +37,6 @@ def _one_hot(batch, label2id: dict): return batch -def load_grants_sample( - data_path: str, - tokenizer: AutoTokenizer, - label2id: dict, - test_size: float = 0.1, - num_proc: int = 8, - max_samples: int = np.inf, -): - """ - Code that loads a grants sample. - The data should be a jsonl file where each line contains an abstract - and mesh_terms field. - The dvc pipeline in pipelines/generate_grants can be used for this. - It will populate the mesh_terms field with predictions - from Wellcome/WellcomeBertMesh. This can be used to generate a - dummy dataset (i.e. train the model on its own predictions - for development / sanity check purposes). - """ - - def _datagen(data_path: str, max_samples: int = np.inf): - """ - Loads the data from the given path. The data should be in jsonl format, - with each line containing a text and tags field. - The tags field should be a list of strings. - """ - with open(data_path, "r") as f: - for idx, line in f: - sample = json.loads(line) - - if idx > max_samples: - break - - yield sample - - dset = Dataset.from_generator( - _datagen, - gen_kwargs={"data_path": data_path, "max_samples": max_samples}, - ) - - dset = dset.map( - _tokenize, - batched=True, - batch_size=32, - num_proc=num_proc, - desc="Tokenizing", - fn_kwargs={"tokenizer": tokenizer, "x_col": "abstract"}, - ) - - dset = dset.map( - _label_encode, - batched=True, - batch_size=32, - num_proc=1, - desc="Encoding labels", - fn_kwargs={"mesh_terms_column": "mesh_terms", "label2id": label2id}, - ) - - dset = dset.map( - _one_hot, - batched=True, - batch_size=32, - num_proc=num_proc, - desc="One-hot labels", - fn_kwargs={"label2id": label2id}, - ) - - # Split into train and test - dset = dset.train_test_split(test_size=test_size) - - return dset["train"], dset["test"] - - def load_mesh_json( data_path: str, tokenizer: AutoTokenizer, From 301595b9627ec3c11cf78bb403f5b5aa255d5b27 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 17:54:21 +0300 Subject: [PATCH 029/300] add training pipeline with proper params & dependencies --- pipelines/bertmesh/dvc.yaml | 77 +++++++------------------------- pipelines/bertmesh/params.yaml | 16 ------- scripts/train_bertmesh_script.py | 25 +++++++++++ 3 files changed, 41 insertions(+), 77 deletions(-) delete mode 100644 pipelines/bertmesh/params.yaml create mode 100644 scripts/train_bertmesh_script.py diff --git a/pipelines/bertmesh/dvc.yaml b/pipelines/bertmesh/dvc.yaml index ded61706..684755e6 100644 --- a/pipelines/bertmesh/dvc.yaml +++ b/pipelines/bertmesh/dvc.yaml @@ -1,68 +1,23 @@ vars: - - data: "../../data/" - - models: "../../models/" - - results: "../../results" - - experiment_name: "transformers-bertmesh" - - scripts_folder: "../../grants_tagger/bertmesh" + - model_key: "Wellcome/WellcomeBertMesh" + - data_path: "../../data/raw/allMeSH_2021.json" + - scripts_folder: "../../scripts" + - output_dir: "../../bertmesh_outs/pipeline" stages: - prepare_data: - cmd: | - python ${scripts_folder}/prepare_data.py ${data}/processed/train_mesh2021.jsonl ${data}/processed/bertmesh/X.npy ${data}/processed/bertmesh/Y.npz ${models}/bertmesh/label_binarizer.pkl --years ${prepare_data.years} --pretrained-model ${train.pretrained_model} - python ${scripts_folder}/prepare_data.py ${data}/processed/train_mesh2021.jsonl ${data}/processed/bertmesh/X_test.npy ${data}/processed/bertmesh/Y_test.npz ${models}/bertmesh/label_binarizer.pkl --years ${prepare_data.test_years} --pretrained-model ${train.pretrained_model} - deps: - - ${scripts_folder}/prepare_data.py - params: - - prepare_data.years - - train.pretrained_model - outs: - - ${data}/processed/bertmesh/X.npy - - ${data}/processed/bertmesh/Y.npz - - ${data}/processed/bertmesh/X_test.npy - - ${data}/processed/bertmesh/Y_test.npz - - ${models}/bertmesh/label_binarizer.pkl train: - cmd: python ${scripts_folder}/train_torch.py ${data}/processed/bertmesh/X.npy ${data}/processed/bertmesh/Y.npz ${models}/bertmesh/model/ ${models}/bertmesh/label_binarizer.pkl - --train-info ${results}/bertmesh_train_info.json - --learning-rate ${train.learning_rate} --batch-size ${train.batch_size} --epochs ${train.epochs} - --pretrained-model ${train.pretrained_model} --multilabel-attention --hidden-size ${train.hidden_size} - --clip-norm ${train.clip_norm} --dropout ${train.dropout} --train-metrics-path train_metrics.json - --warmup-steps ${train.warmup_steps} --val-x-path ${data}/processed/bertmesh/X_test.npy - --val-y-path ${data}/processed/bertmesh/Y_test.npz --experiment-name ${experiment_name} + cmd: >- + python ${scripts_folder}/train_bertmesh_script.py + --model_key ${model_key} + --data_path ${data_path} + --output_dir ${output_dir} + --max_samples 10 + --per_device_train_batch_size 4 + --per_device_eval_batch_size 4 deps: - - ${scripts_folder}/train_torch.py - - ${scripts_folder}/model.py - - ${data}/processed/bertmesh/X.npy - - ${data}/processed/bertmesh/Y.npz - - ${models}/bertmesh/label_binarizer.pkl + - ${data_path} params: - - train.learning_rate - - train.epochs - - train.batch_size - - train.pretrained_model - - train.hidden_size - - train.clip_norm - - train.dropout - - train.warmup_steps + - ../../grants_tagger_light/training/cli_args/train_args.py: + - BertMeshTrainingArguments outs: - - ${models}/bertmesh/model/pytorch_model.bin - - ${models}/bertmesh/model/config.json - plots: - - train_metrics.json: - cache: false - evaluate: - cmd: python ${scripts_folder}/evaluate.py ${data}/processed/bertmesh/X_test.npy ${data}/processed/bertmesh/Y_test.npz ${models}/bertmesh/model/ --batch-size ${evaluate.batch_size} --results-path results.json --pr-curve-path pr_curve.json --experiment-name ${experiment_name} - deps: - - ${scripts_folder}/evaluate.py - - ${data}/processed/bertmesh/X_test.npy - - ${data}/processed/bertmesh/Y_test.npz - - ${models}/bertmesh/model/pytorch_model.bin - - ${models}/bertmesh/model/config.json - params: - - train.batch_size - metrics: - - results.json: - cache: false - plots: - - pr_curve.json: - cache: false + - ${output_dir}/best diff --git a/pipelines/bertmesh/params.yaml b/pipelines/bertmesh/params.yaml deleted file mode 100644 index a63dc724..00000000 --- a/pipelines/bertmesh/params.yaml +++ /dev/null @@ -1,16 +0,0 @@ -prepare_data: - years: 2016,2019 - test_years: 2020,2021 - -train: - learning_rate: 5e-5 - batch_size: 64 # with nn.DataParallel divide by 8 - epochs: 5 - pretrained_model: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract - hidden_size: 1024 - clip_norm: 5 - dropout: 0.1 - warmup_steps: 1000 - -evaluate: - batch_size: 8 diff --git a/scripts/train_bertmesh_script.py b/scripts/train_bertmesh_script.py new file mode 100644 index 00000000..66e2243e --- /dev/null +++ b/scripts/train_bertmesh_script.py @@ -0,0 +1,25 @@ +import numpy as np +from grants_tagger_light.training.train import train_bertmesh +from grants_tagger_light.training.cli_args import BertMeshTrainingArguments +from transformers import HfArgumentParser +from dataclasses import dataclass + + +@dataclass +class TrainFuncArgs: + model_key: str + data_path: str + max_samples: int = np.inf + + +if __name__ == "__main__": + func_args, training_args = HfArgumentParser( + (TrainFuncArgs, BertMeshTrainingArguments) + ).parse_args_into_dataclasses() + + train_bertmesh( + func_args.model_key, + func_args.data_path, + func_args.max_samples, + training_args, + ) From c395c7a547e4f10e34f062f59022b0381e4ac829 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 18:08:06 +0300 Subject: [PATCH 030/300] update output path --- .gitignore | 2 +- grants_tagger_light/training/cli_args/train_args.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 4fa1d780..22124d82 100644 --- a/.gitignore +++ b/.gitignore @@ -161,5 +161,5 @@ cython_debug/ # Folder where training outputs are stored -bertmesh-outs/ +bertmesh_outs/ wandb/ diff --git a/grants_tagger_light/training/cli_args/train_args.py b/grants_tagger_light/training/cli_args/train_args.py index 2bb1dbd0..39ebf02c 100644 --- a/grants_tagger_light/training/cli_args/train_args.py +++ b/grants_tagger_light/training/cli_args/train_args.py @@ -9,7 +9,7 @@ class BertMeshTrainingArguments(TrainingArguments): and implements some better defaults for convenience. """ - output_dir: str = field(default="bertmesh-outs/default") + output_dir: str = field(default="bertmesh_outs/default") overwrite_output_dir: bool = field(default=True) evaluation_strategy: str = field(default="epoch") # no | epoch | steps From 38730aa619c616d412b1e7201c9ab7188f07d1e9 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 10 Jul 2023 18:09:42 +0300 Subject: [PATCH 031/300] do train pipeline smoke test --- pipelines/bertmesh/dvc.lock | 19 +++++++++++++++++++ pipelines/bertmesh/dvc.yaml | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 pipelines/bertmesh/dvc.lock diff --git a/pipelines/bertmesh/dvc.lock b/pipelines/bertmesh/dvc.lock new file mode 100644 index 00000000..7d4c6990 --- /dev/null +++ b/pipelines/bertmesh/dvc.lock @@ -0,0 +1,19 @@ +schema: '2.0' +stages: + train: + cmd: python ../../scripts/train_bertmesh_script.py --model_key Wellcome/WellcomeBertMesh + --data_path ../../data/raw/allMeSH_2021.json --output_dir ../../bertmesh_outs/pipeline_test + --max_samples 10 --per_device_train_batch_size 4 --per_device_eval_batch_size + 4 + deps: + - path: ../../data/raw/allMeSH_2021.json + md5: e827a6b8062d1312664dcf075c12d89f + size: 27547042745 + params: + ../../grants_tagger_light/training/cli_args/train_args.py: + BertMeshTrainingArguments: {} + outs: + - path: ../../bertmesh_outs/pipeline_test/best + md5: 734a4e2eba3208449473dfa6cbee4683.dir + size: 649161367 + nfiles: 3 diff --git a/pipelines/bertmesh/dvc.yaml b/pipelines/bertmesh/dvc.yaml index 684755e6..699455d3 100644 --- a/pipelines/bertmesh/dvc.yaml +++ b/pipelines/bertmesh/dvc.yaml @@ -2,7 +2,7 @@ vars: - model_key: "Wellcome/WellcomeBertMesh" - data_path: "../../data/raw/allMeSH_2021.json" - scripts_folder: "../../scripts" - - output_dir: "../../bertmesh_outs/pipeline" + - output_dir: "../../bertmesh_outs/pipeline_test" stages: train: From d670ee5937c805e3571492fbe0e89721fc43f23b Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 11 Jul 2023 12:25:59 +0300 Subject: [PATCH 032/300] change script that is being run in dvc pipeline --- grants_tagger_light/training/train.py | 19 ++++++++++++++++++- pipelines/bertmesh/dvc.lock | 4 ++-- pipelines/bertmesh/dvc.yaml | 4 ++-- scripts/train_bertmesh_script.py | 25 ------------------------- 4 files changed, 22 insertions(+), 30 deletions(-) delete mode 100644 scripts/train_bertmesh_script.py diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index a9ca0e42..2c69ed2f 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -108,4 +108,21 @@ def train_bertmesh_cli( if __name__ == "__main__": - train_app() + from dataclasses import dataclass + + @dataclass + class TrainFuncArgs: + model_key: str + data_path: str + max_samples: int = np.inf + + func_args, training_args = HfArgumentParser( + (TrainFuncArgs, BertMeshTrainingArguments) + ).parse_args_into_dataclasses() + + train_bertmesh( + func_args.model_key, + func_args.data_path, + func_args.max_samples, + training_args, + ) diff --git a/pipelines/bertmesh/dvc.lock b/pipelines/bertmesh/dvc.lock index 7d4c6990..06feb65a 100644 --- a/pipelines/bertmesh/dvc.lock +++ b/pipelines/bertmesh/dvc.lock @@ -1,7 +1,7 @@ schema: '2.0' stages: train: - cmd: python ../../scripts/train_bertmesh_script.py --model_key Wellcome/WellcomeBertMesh + cmd: python ../../grants_tagger_light/training/train.py --model_key Wellcome/WellcomeBertMesh --data_path ../../data/raw/allMeSH_2021.json --output_dir ../../bertmesh_outs/pipeline_test --max_samples 10 --per_device_train_batch_size 4 --per_device_eval_batch_size 4 @@ -14,6 +14,6 @@ stages: BertMeshTrainingArguments: {} outs: - path: ../../bertmesh_outs/pipeline_test/best - md5: 734a4e2eba3208449473dfa6cbee4683.dir + md5: dfc280687bcbda1eebfcf1ea68662d89.dir size: 649161367 nfiles: 3 diff --git a/pipelines/bertmesh/dvc.yaml b/pipelines/bertmesh/dvc.yaml index 699455d3..1721d4c4 100644 --- a/pipelines/bertmesh/dvc.yaml +++ b/pipelines/bertmesh/dvc.yaml @@ -1,13 +1,13 @@ vars: - model_key: "Wellcome/WellcomeBertMesh" - data_path: "../../data/raw/allMeSH_2021.json" - - scripts_folder: "../../scripts" + - script_loc: "../../grants_tagger_light/training" - output_dir: "../../bertmesh_outs/pipeline_test" stages: train: cmd: >- - python ${scripts_folder}/train_bertmesh_script.py + python ${script_loc}/train.py --model_key ${model_key} --data_path ${data_path} --output_dir ${output_dir} diff --git a/scripts/train_bertmesh_script.py b/scripts/train_bertmesh_script.py deleted file mode 100644 index 66e2243e..00000000 --- a/scripts/train_bertmesh_script.py +++ /dev/null @@ -1,25 +0,0 @@ -import numpy as np -from grants_tagger_light.training.train import train_bertmesh -from grants_tagger_light.training.cli_args import BertMeshTrainingArguments -from transformers import HfArgumentParser -from dataclasses import dataclass - - -@dataclass -class TrainFuncArgs: - model_key: str - data_path: str - max_samples: int = np.inf - - -if __name__ == "__main__": - func_args, training_args = HfArgumentParser( - (TrainFuncArgs, BertMeshTrainingArguments) - ).parse_args_into_dataclasses() - - train_bertmesh( - func_args.model_key, - func_args.data_path, - func_args.max_samples, - training_args, - ) From b28e868de44442e342b8baa7065d65e29f69c164 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 11 Jul 2023 13:23:16 +0300 Subject: [PATCH 033/300] change optimizer to adamw_torch when no gpus are detected --- grants_tagger_light/training/cli_args/train_args.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/cli_args/train_args.py b/grants_tagger_light/training/cli_args/train_args.py index 39ebf02c..61820322 100644 --- a/grants_tagger_light/training/cli_args/train_args.py +++ b/grants_tagger_light/training/cli_args/train_args.py @@ -1,5 +1,6 @@ from transformers import TrainingArguments from dataclasses import dataclass, field +import torch @dataclass @@ -44,7 +45,7 @@ class BertMeshTrainingArguments(TrainingArguments): optim: str = field( default="adamw_torch_fused" - ) # TODO add support for adamw_apex_fused + ) # TODO add support for adamw_apex_fused; use adamw_anyprecision if using bf16 fp16: bool = field(default=False) # TODO test if micro-f1 is maintained @@ -60,3 +61,8 @@ class BertMeshTrainingArguments(TrainingArguments): # torch_compile_mode: str = field( # default="default" # ) # default | reduce-overhead | max-autotune + + def __post_init__(self): + super().__post_init__() + if "fused" in self.optim and not torch.cuda.is_available(): + self.optim = "adamw_torch" From 2530928af8dd4a3cfa648c4671fb39cf14bc1044 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 11 Jul 2023 14:11:35 +0300 Subject: [PATCH 034/300] rename sample tags to example tags --- grants_tagger_light/training/dataloaders/mesh_json_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py index f1a34b2a..9986ea07 100644 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -17,9 +17,9 @@ def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): def _label_encode(batch, mesh_terms_column: str, label2id: dict): batch_labels = [] - for sample_tags in batch[mesh_terms_column]: + for example_tags in batch[mesh_terms_column]: sample_labels = [] - for tag in sample_tags: + for tag in example_tags: if tag in label2id: sample_labels.append(label2id[tag]) batch_labels.append(sample_labels) From 4fdad8469dadf46c019b651473d5d3fb814908dc Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 11 Jul 2023 14:40:57 +0300 Subject: [PATCH 035/300] load model from hub locally to be able to run without hub access --- grants_tagger_light/models/bert_mesh/model.py | 4 +++ models/.gitignore | 1 + pipelines/save_local_from_hub/dvc.lock | 10 ++++++ pipelines/save_local_from_hub/dvc.yaml | 13 ++++++++ scripts/save_local_from_hub.py | 33 +++++++++++++++++++ 5 files changed, 61 insertions(+) create mode 100644 pipelines/save_local_from_hub/dvc.lock create mode 100644 pipelines/save_local_from_hub/dvc.yaml create mode 100644 scripts/save_local_from_hub.py diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index 48613074..249e1066 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -25,6 +25,10 @@ def __init__( config, ): super().__init__(config=config) + import pdb + + pdb.set_trace() + self.config.auto_map = {"AutoModel": "model.BertMesh"} self.pretrained_model = self.config.pretrained_model self.num_labels = self.config.num_labels diff --git a/models/.gitignore b/models/.gitignore index 7afd2f42..78b3fe46 100644 --- a/models/.gitignore +++ b/models/.gitignore @@ -1,2 +1,3 @@ /xlinear-0.2.5 bertmesh/model/epoch-3 +/WellcomeBertMesh-fromhub-11-07-2023 diff --git a/pipelines/save_local_from_hub/dvc.lock b/pipelines/save_local_from_hub/dvc.lock new file mode 100644 index 00000000..3c68ed40 --- /dev/null +++ b/pipelines/save_local_from_hub/dvc.lock @@ -0,0 +1,10 @@ +schema: '2.0' +stages: + save_from_hub: + cmd: python ../../scripts/save_local_from_hub.py --key Wellcome/WellcomeBertMesh + --out_path ../../models/WellcomeBertMesh-fromhub-11-07-2023 + outs: + - path: ../../models/WellcomeBertMesh-fromhub-11-07-2023 + md5: 918d408f3f0123641a0c989b78a0efde.dir + size: 650097867 + nfiles: 6 diff --git a/pipelines/save_local_from_hub/dvc.yaml b/pipelines/save_local_from_hub/dvc.yaml new file mode 100644 index 00000000..3b2c2588 --- /dev/null +++ b/pipelines/save_local_from_hub/dvc.yaml @@ -0,0 +1,13 @@ +vars: + - model_key: "Wellcome/WellcomeBertMesh" + - out_path: "../../models/WellcomeBertMesh-fromhub-11-07-2023" + - script_loc: "../../scripts" + +stages: + save_from_hub: + cmd: >- + python ${script_loc}/save_local_from_hub.py + --key ${model_key} + --out_path ${out_path} + outs: + - ${out_path} diff --git a/scripts/save_local_from_hub.py b/scripts/save_local_from_hub.py new file mode 100644 index 00000000..421db1c5 --- /dev/null +++ b/scripts/save_local_from_hub.py @@ -0,0 +1,33 @@ +from transformers import AutoModel +from transformers import AutoTokenizer +import argparse + + +def save_to_local_from_hub(key: str, out_path: str): + model = AutoModel.from_pretrained(key, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(key) + + # Save model and tokenizer to file + model.save_pretrained(out_path) + tokenizer.save_pretrained(out_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--key", + type=str, + help="The key of the model to save locally", + default="Wellcome/WellcomeBertMesh", + ) + parser.add_argument( + "--out_path", + type=str, + help="The path to save the model to", + default="models/WellcomeBertMesh-fromhub", + ) + args = parser.parse_args() + + print(args) + + save_to_local_from_hub(args.key, args.out_path) From d174f006c4909cb8b2e009d86a04b524ded45a3f Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 11 Jul 2023 15:29:41 +0300 Subject: [PATCH 036/300] remove redundant stuff --- pipelines/save_local_from_hub/dvc.lock | 10 -------- pipelines/save_local_from_hub/dvc.yaml | 13 ---------- scripts/save_local_from_hub.py | 33 -------------------------- 3 files changed, 56 deletions(-) delete mode 100644 pipelines/save_local_from_hub/dvc.lock delete mode 100644 pipelines/save_local_from_hub/dvc.yaml delete mode 100644 scripts/save_local_from_hub.py diff --git a/pipelines/save_local_from_hub/dvc.lock b/pipelines/save_local_from_hub/dvc.lock deleted file mode 100644 index 3c68ed40..00000000 --- a/pipelines/save_local_from_hub/dvc.lock +++ /dev/null @@ -1,10 +0,0 @@ -schema: '2.0' -stages: - save_from_hub: - cmd: python ../../scripts/save_local_from_hub.py --key Wellcome/WellcomeBertMesh - --out_path ../../models/WellcomeBertMesh-fromhub-11-07-2023 - outs: - - path: ../../models/WellcomeBertMesh-fromhub-11-07-2023 - md5: 918d408f3f0123641a0c989b78a0efde.dir - size: 650097867 - nfiles: 6 diff --git a/pipelines/save_local_from_hub/dvc.yaml b/pipelines/save_local_from_hub/dvc.yaml deleted file mode 100644 index 3b2c2588..00000000 --- a/pipelines/save_local_from_hub/dvc.yaml +++ /dev/null @@ -1,13 +0,0 @@ -vars: - - model_key: "Wellcome/WellcomeBertMesh" - - out_path: "../../models/WellcomeBertMesh-fromhub-11-07-2023" - - script_loc: "../../scripts" - -stages: - save_from_hub: - cmd: >- - python ${script_loc}/save_local_from_hub.py - --key ${model_key} - --out_path ${out_path} - outs: - - ${out_path} diff --git a/scripts/save_local_from_hub.py b/scripts/save_local_from_hub.py deleted file mode 100644 index 421db1c5..00000000 --- a/scripts/save_local_from_hub.py +++ /dev/null @@ -1,33 +0,0 @@ -from transformers import AutoModel -from transformers import AutoTokenizer -import argparse - - -def save_to_local_from_hub(key: str, out_path: str): - model = AutoModel.from_pretrained(key, trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(key) - - # Save model and tokenizer to file - model.save_pretrained(out_path) - tokenizer.save_pretrained(out_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--key", - type=str, - help="The key of the model to save locally", - default="Wellcome/WellcomeBertMesh", - ) - parser.add_argument( - "--out_path", - type=str, - help="The path to save the model to", - default="models/WellcomeBertMesh-fromhub", - ) - args = parser.parse_args() - - print(args) - - save_to_local_from_hub(args.key, args.out_path) From 29920932ace9bdb7d9833211a417100dbacfe199 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 11 Jul 2023 15:35:25 +0300 Subject: [PATCH 037/300] add possibility to instantiate from scratch by passing empty key --- grants_tagger_light/training/train.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 2c69ed2f..8fb4a9d7 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -4,6 +4,7 @@ TrainingArguments, EvalPrediction, HfArgumentParser, + AutoConfig, ) from grants_tagger_light.models.bert_mesh import BertMesh from grants_tagger_light.training.cli_args import ( @@ -31,7 +32,13 @@ def train_bertmesh( max_samples: int, training_args: TrainingArguments, ): - model = BertMesh.from_pretrained(model_key, trust_remote_code=True) + if not model_key: + pretrained_model_key = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" + config = AutoConfig.from_pretrained(pretrained_model_key) + model = BertMesh(config) + else: + model = BertMesh.from_pretrained(model_key, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_key) label2id = {v: k for k, v in model.id2label.items()} From b786eb4b5b97afac983cc33dbd86cb63929a084a Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 11 Jul 2023 15:37:58 +0300 Subject: [PATCH 038/300] remove redundant pdb call --- grants_tagger_light/models/bert_mesh/model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index 249e1066..47dc6d6f 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -25,9 +25,6 @@ def __init__( config, ): super().__init__(config=config) - import pdb - - pdb.set_trace() self.config.auto_map = {"AutoModel": "model.BertMesh"} self.pretrained_model = self.config.pretrained_model From 167088f61e6087997d1504b5133b08b85cdc41a9 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 11 Jul 2023 17:57:49 +0300 Subject: [PATCH 039/300] reproduce model from scratch in pipeline --- .../training/cli_args/__init__.py | 7 +- .../training/cli_args/bertmesh_args.py | 11 +++ .../training/dataloaders/mesh_json_loader.py | 15 +++- grants_tagger_light/training/train.py | 69 ++++++++++++++----- pipelines/bertmesh/dvc.lock | 8 +-- pipelines/bertmesh/dvc.yaml | 3 +- tests/test_train.py | 4 ++ 7 files changed, 93 insertions(+), 24 deletions(-) create mode 100644 grants_tagger_light/training/cli_args/bertmesh_args.py diff --git a/grants_tagger_light/training/cli_args/__init__.py b/grants_tagger_light/training/cli_args/__init__.py index 31771867..30ffb4a7 100644 --- a/grants_tagger_light/training/cli_args/__init__.py +++ b/grants_tagger_light/training/cli_args/__init__.py @@ -1,4 +1,9 @@ from .train_args import BertMeshTrainingArguments from .wandb_args import WandbArguments +from .bertmesh_args import BertMeshModelArguments -__all__ = ["BertMeshTrainingArguments", "WandbArguments"] +__all__ = [ + "BertMeshTrainingArguments", + "WandbArguments", + "BertMeshModelArguments", +] diff --git a/grants_tagger_light/training/cli_args/bertmesh_args.py b/grants_tagger_light/training/cli_args/bertmesh_args.py new file mode 100644 index 00000000..32797f8d --- /dev/null +++ b/grants_tagger_light/training/cli_args/bertmesh_args.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass, field + + +@dataclass +class BertMeshModelArguments: + pretrained_model_key: str = field( + default="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" + ) + hidden_size: int = field(default=512) + dropout: float = field(default=0) + multilabel_attention: bool = field(default=False) diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py index 9986ea07..cc0aff52 100644 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -37,6 +37,15 @@ def _one_hot(batch, label2id: dict): return batch +def _get_label2id(dset): + label_set = set() + for sample in dset: + for label in sample["meshMajor"]: + label_set.add(label) + label2id = {label: idx for idx, label in enumerate(label_set)} + return label2id + + def load_mesh_json( data_path: str, tokenizer: AutoTokenizer, @@ -72,6 +81,10 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, ) + # Generate label2id if None + if label2id is None: + label2id = _get_label2id(dset) + dset = dset.map( _label_encode, batched=True, @@ -93,7 +106,7 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): # Split into train and test dset = dset.train_test_split(test_size=test_size) - return dset["train"], dset["test"] + return dset["train"], dset["test"], label2id if __name__ == "__main__": diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 8fb4a9d7..fe9db2c9 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -10,6 +10,7 @@ from grants_tagger_light.training.cli_args import ( BertMeshTrainingArguments, WandbArguments, + BertMeshModelArguments, ) from grants_tagger_light.training.dataloaders import ( load_mesh_json, @@ -31,24 +32,53 @@ def train_bertmesh( data_path: str, max_samples: int, training_args: TrainingArguments, + model_args: BertMeshModelArguments = None, ): if not model_key: - pretrained_model_key = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" - config = AutoConfig.from_pretrained(pretrained_model_key) + assert isinstance( + model_args, BertMeshModelArguments + ), "If model_key is not provided, must provide model_args of type BertMeshModelArguments" # noqa + + logger.info("No model key provided. Training model from scratch") + + # Instantiate model from scratch + config = AutoConfig.from_pretrained(model_args.pretrained_model_key) + tokenizer = AutoTokenizer.from_pretrained(model_args.pretrained_model_key) + + train_dset, val_dset, label2id = load_mesh_json( + data_path, + tokenizer=tokenizer, + label2id=None, + max_samples=max_samples, + ) + + config.update( + { + "pretrained_model": model_args.pretrained_model_key, + "num_labels": len(label2id), + "hidden_size": model_args.hidden_size, + "dropout": model_args.dropout, + "multilabel_attention": model_args.multilabel_attention, + "id2label": {v: k for k, v in label2id.items()}, + } + ) model = BertMesh(config) + else: - model = BertMesh.from_pretrained(model_key, trust_remote_code=True) + logger.info(f"Training model from pretrained key {model_key}") - tokenizer = AutoTokenizer.from_pretrained(model_key) + # Instantiate from pretrained + model = BertMesh.from_pretrained(model_key, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_key) - label2id = {v: k for k, v in model.id2label.items()} + label2id = {v: k for k, v in model.id2label.items()} - train_dset, val_dset = load_mesh_json( - data_path, - tokenizer=tokenizer, - label2id=label2id, - max_samples=max_samples, - ) + train_dset, val_dset, _ = load_mesh_json( + data_path, + tokenizer=tokenizer, + label2id=label2id, + max_samples=max_samples, + ) def sklearn_metrics(prediction: EvalPrediction): y_pred = prediction.predictions @@ -105,13 +135,19 @@ def train_bertmesh_cli( help="Maximum number of samples to use for training. Useful for dev/debugging", ), ): - parser = HfArgumentParser((BertMeshTrainingArguments, WandbArguments)) - (training_args, wandb_args) = parser.parse_args_into_dataclasses(ctx.args) + parser = HfArgumentParser( + (BertMeshTrainingArguments, WandbArguments, BertMeshModelArguments) + ) + ( + training_args, + wandb_args, + model_args, + ) = parser.parse_args_into_dataclasses(ctx.args) logger.info("Training args: {}".format(pformat(training_args))) logger.info("Wandb args: {}".format(pformat(wandb_args))) - train_bertmesh(model_key, data_path, max_samples, training_args) + train_bertmesh(model_key, data_path, max_samples, training_args, model_args) if __name__ == "__main__": @@ -123,8 +159,8 @@ class TrainFuncArgs: data_path: str max_samples: int = np.inf - func_args, training_args = HfArgumentParser( - (TrainFuncArgs, BertMeshTrainingArguments) + func_args, training_args, model_args = HfArgumentParser( + (TrainFuncArgs, BertMeshTrainingArguments, BertMeshModelArguments) ).parse_args_into_dataclasses() train_bertmesh( @@ -132,4 +168,5 @@ class TrainFuncArgs: func_args.data_path, func_args.max_samples, training_args, + model_args, ) diff --git a/pipelines/bertmesh/dvc.lock b/pipelines/bertmesh/dvc.lock index 06feb65a..85a33960 100644 --- a/pipelines/bertmesh/dvc.lock +++ b/pipelines/bertmesh/dvc.lock @@ -1,8 +1,8 @@ schema: '2.0' stages: train: - cmd: python ../../grants_tagger_light/training/train.py --model_key Wellcome/WellcomeBertMesh - --data_path ../../data/raw/allMeSH_2021.json --output_dir ../../bertmesh_outs/pipeline_test + cmd: python ../../grants_tagger_light/training/train.py --model_key "" --data_path + ../../data/raw/allMeSH_2021.json --output_dir ../../bertmesh_outs/pipeline_test --max_samples 10 --per_device_train_batch_size 4 --per_device_eval_batch_size 4 deps: @@ -14,6 +14,6 @@ stages: BertMeshTrainingArguments: {} outs: - path: ../../bertmesh_outs/pipeline_test/best - md5: dfc280687bcbda1eebfcf1ea68662d89.dir - size: 649161367 + md5: b96ff54ecd600460dbb18b3e82d8b517.dir + size: 439905869 nfiles: 3 diff --git a/pipelines/bertmesh/dvc.yaml b/pipelines/bertmesh/dvc.yaml index 1721d4c4..f302c669 100644 --- a/pipelines/bertmesh/dvc.yaml +++ b/pipelines/bertmesh/dvc.yaml @@ -1,5 +1,4 @@ vars: - - model_key: "Wellcome/WellcomeBertMesh" - data_path: "../../data/raw/allMeSH_2021.json" - script_loc: "../../grants_tagger_light/training" - output_dir: "../../bertmesh_outs/pipeline_test" @@ -8,7 +7,7 @@ stages: train: cmd: >- python ${script_loc}/train.py - --model_key ${model_key} + --model_key "" --data_path ${data_path} --output_dir ${output_dir} --max_samples 10 diff --git a/tests/test_train.py b/tests/test_train.py index 2d548b70..53ea9c0c 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,4 +1,5 @@ from grants_tagger_light.training.train import train_bertmesh +from grants_tagger_light.training.cli_args import BertMeshModelArguments from transformers import TrainingArguments import tempfile import pytest @@ -44,9 +45,12 @@ def test_train_bertmesh(data_path, save_path): no_cuda=True, ) + model_args = BertMeshModelArguments() + train_bertmesh( model_key=model_key, data_path=data_path, max_samples=np.inf, training_args=training_args, + model_args=model_args, ) From a1982b6290ca1b43f5ca1b84069a0062c1a331a5 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 11 Jul 2023 19:34:05 +0300 Subject: [PATCH 040/300] add wandb args to main script invocation --- grants_tagger_light/training/train.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index fe9db2c9..22f74050 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -159,8 +159,13 @@ class TrainFuncArgs: data_path: str max_samples: int = np.inf - func_args, training_args, model_args = HfArgumentParser( - (TrainFuncArgs, BertMeshTrainingArguments, BertMeshModelArguments) + func_args, training_args, wandb_args, model_args = HfArgumentParser( + ( + TrainFuncArgs, + BertMeshTrainingArguments, + WandbArguments, + BertMeshModelArguments, + ) ).parse_args_into_dataclasses() train_bertmesh( From 93ab7d856dddc08613cb571c55d68f9adc59e658 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 11 Jul 2023 19:58:46 +0300 Subject: [PATCH 041/300] use bce with logits in mdoel for numerical stability --- grants_tagger_light/models/bert_mesh/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index 47dc6d6f..ef5e5b6e 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -53,16 +53,16 @@ def forward(self, input_ids, labels=None, **kwargs): attention_outs = self.multilabel_attention_layer(hidden_states) outs = torch.nn.functional.relu(self.linear_1(attention_outs)) outs = self.dropout_layer(outs) - outs = torch.sigmoid(self.linear_2(outs)) + outs = self.linear_2(outs) outs = torch.flatten(outs, start_dim=1) else: cls = self.bert(input_ids=input_ids)[1] outs = torch.nn.functional.relu(self.linear_1(cls)) outs = self.dropout_layer(outs) - outs = torch.sigmoid(self.linear_out(outs)) + outs = self.linear_out(outs) if labels is not None: - loss = F.binary_cross_entropy(outs, labels.float()) + loss = F.binary_cross_entropy_with_logits(outs, labels.float()) else: loss = -1 From 38bfc4d74f331bfd25abca93ad515f65c466a52c Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 12 Jul 2023 13:20:15 +0300 Subject: [PATCH 042/300] removed unused main script invocation --- .../training/dataloaders/mesh_json_loader.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py index cc0aff52..57c5c918 100644 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -107,18 +107,3 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): dset = dset.train_test_split(test_size=test_size) return dset["train"], dset["test"], label2id - - -if __name__ == "__main__": - from transformers import AutoModel - - model = AutoModel.from_pretrained( - "Wellcome/WellcomeBertMesh", trust_remote_code=True - ) - tokenizer = AutoTokenizer.from_pretrained("Wellcome/WellcomeBertMesh") - - dset_train, dset_val = load_mesh_json( - data_path="data/raw/allMeSH_2021.json", - tokenizer=tokenizer, - label2id=model.config.label2id, - ) From a4db56965857ecff92923222723bd2394c70d03b Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 12 Jul 2023 14:30:51 +0300 Subject: [PATCH 043/300] avoid overflow infinity error --- grants_tagger_light/training/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 22f74050..59c5886d 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -131,10 +131,13 @@ def train_bertmesh_cli( help="Path to data in jsonl format. Must contain text and tags field", ), max_samples: int = typer.Option( - np.inf, + -1, help="Maximum number of samples to use for training. Useful for dev/debugging", ), ): + if max_samples == -1: + max_samples = np.inf + parser = HfArgumentParser( (BertMeshTrainingArguments, WandbArguments, BertMeshModelArguments) ) From 6f45bee5fc9cdf564b59b609bdb637afc7e58c40 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 12 Jul 2023 16:18:44 +0300 Subject: [PATCH 044/300] add ability to freeze model backbone from args --- grants_tagger_light/models/bert_mesh/model.py | 8 ++++++++ grants_tagger_light/training/cli_args/bertmesh_args.py | 1 + grants_tagger_light/training/train.py | 5 +++++ 3 files changed, 14 insertions(+) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index ef5e5b6e..4acc82d6 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -43,6 +43,14 @@ def __init__( self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels) self.dropout_layer = torch.nn.Dropout(self.dropout) + def freeze_backbone(self): + for param in self.bert.parameters(): + param.requires_grad = False + + def unfreeze_backbone(self): + for param in self.bert.parameters(): + param.requires_grad = True + def forward(self, input_ids, labels=None, **kwargs): if type(input_ids) is list: # coming from tokenizer diff --git a/grants_tagger_light/training/cli_args/bertmesh_args.py b/grants_tagger_light/training/cli_args/bertmesh_args.py index 32797f8d..d1d50331 100644 --- a/grants_tagger_light/training/cli_args/bertmesh_args.py +++ b/grants_tagger_light/training/cli_args/bertmesh_args.py @@ -9,3 +9,4 @@ class BertMeshModelArguments: hidden_size: int = field(default=512) dropout: float = field(default=0) multilabel_attention: bool = field(default=False) + freeze_backbone: bool = field(default=False) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 59c5886d..0baf8d00 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -60,6 +60,7 @@ def train_bertmesh( "dropout": model_args.dropout, "multilabel_attention": model_args.multilabel_attention, "id2label": {v: k for k, v in label2id.items()}, + "freeze_backbone": model_args.freeze_backbone, } ) model = BertMesh(config) @@ -80,6 +81,10 @@ def train_bertmesh( max_samples=max_samples, ) + if model_args.freeze_backbone: + logger.info("Freezing backbone") + model.freeze_backbone() + def sklearn_metrics(prediction: EvalPrediction): y_pred = prediction.predictions y_true = prediction.label_ids From 393944e097c3550ffe6c6d6da25ea494c8f53300 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 12 Jul 2023 17:20:39 +0300 Subject: [PATCH 045/300] remove unnecessary columns --- grants_tagger_light/training/dataloaders/mesh_json_loader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py index 57c5c918..16ac0a1f 100644 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -72,6 +72,9 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, ) + # Remove unused columns to save space & time + dset.remove_columns(["journal", "year", "pmid", "title"]) + dset = dset.map( _tokenize, batched=True, From a5f636845916e0ade17dc959bf577bb381d43764 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 12 Jul 2023 17:32:24 +0300 Subject: [PATCH 046/300] update test --- tests/test_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_train.py b/tests/test_train.py index 53ea9c0c..c0adc3ea 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -7,9 +7,9 @@ # Note dummy data is not necessarily annotated correctly dummy_data = """{"articles":[ -{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus."}, -{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about COVID-19."}, -{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria"}, +{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.", "title": "article1", "pmid": "pmid1"}, +{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about COVID-19.", "title": "article2", "pmid": "pmid2"}}, +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}}, """ # noqa From ef5699ba3342b5b3a19a285245aa8ccea4419e5f Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 12 Jul 2023 17:34:27 +0300 Subject: [PATCH 047/300] remove abstract text and mesh major cols as they are being processed --- grants_tagger_light/training/dataloaders/mesh_json_loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py index 16ac0a1f..cb5de1df 100644 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -82,6 +82,7 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): num_proc=num_proc, desc="Tokenizing", fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, + remove_columns=["abstractText"], ) # Generate label2id if None @@ -95,6 +96,7 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): num_proc=num_proc, desc="Encoding labels", fn_kwargs={"mesh_terms_column": "meshMajor", "label2id": label2id}, + remove_columns=["meshMajor"], ) dset = dset.map( @@ -102,7 +104,7 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): batched=True, batch_size=32, num_proc=num_proc, - desc="One-hot labels", + desc="One-hot encoding labels", fn_kwargs={"label2id": label2id}, ) From 4e51aef1df1fbebfba2a9895bf27dc41b91ddb61 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Wed, 12 Jul 2023 17:54:12 +0300 Subject: [PATCH 048/300] fix typo --- tests/test_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_train.py b/tests/test_train.py index c0adc3ea..9742eac9 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -7,9 +7,9 @@ # Note dummy data is not necessarily annotated correctly dummy_data = """{"articles":[ -{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.", "title": "article1", "pmid": "pmid1"}, -{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about COVID-19.", "title": "article2", "pmid": "pmid2"}}, -{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}}, +{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"}, +{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about COVID-19.","title":"article2","pmid":"pmid2"}, +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}, """ # noqa From 95bd6c3342aa30d736e3f569489ffb99765319c4 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Thu, 13 Jul 2023 13:31:15 +0300 Subject: [PATCH 049/300] convert option to argument to avoid clash --- grants_tagger_light/training/train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 0baf8d00..d80329f9 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -135,7 +135,7 @@ def train_bertmesh_cli( ..., help="Path to data in jsonl format. Must contain text and tags field", ), - max_samples: int = typer.Option( + max_samples: int = typer.Argument( -1, help="Maximum number of samples to use for training. Useful for dev/debugging", ), @@ -144,7 +144,11 @@ def train_bertmesh_cli( max_samples = np.inf parser = HfArgumentParser( - (BertMeshTrainingArguments, WandbArguments, BertMeshModelArguments) + ( + BertMeshTrainingArguments, + WandbArguments, + BertMeshModelArguments, + ) ) ( training_args, From 62ecfb468523e8035fa79b83af0cd89b3573e2ef Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Thu, 13 Jul 2023 13:48:49 +0300 Subject: [PATCH 050/300] move to more efficient sklearn processor --- .../training/dataloaders/mesh_json_loader.py | 40 +++++-------------- 1 file changed, 9 insertions(+), 31 deletions(-) diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py index cb5de1df..4e6d4ef9 100644 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -2,6 +2,7 @@ import numpy as np from transformers import AutoTokenizer from datasets import Dataset +from sklearn.preprocessing import MultiLabelBinarizer # TODO refactor the two load funcs into a class @@ -15,25 +16,8 @@ def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): ) -def _label_encode(batch, mesh_terms_column: str, label2id: dict): - batch_labels = [] - for example_tags in batch[mesh_terms_column]: - sample_labels = [] - for tag in example_tags: - if tag in label2id: - sample_labels.append(label2id[tag]) - batch_labels.append(sample_labels) - - batch["labels"] = batch_labels - - return batch - - -def _one_hot(batch, label2id: dict): - batch["labels"] = [ - [1 if i in labels else 0 for i in range(len(label2id))] - for labels in batch["labels"] - ] +def _binarize_labels(batch, mlb: MultiLabelBinarizer): + batch["mlb_labels"] = mlb.transform(batch["meshMajor"]) return batch @@ -73,7 +57,7 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): ) # Remove unused columns to save space & time - dset.remove_columns(["journal", "year", "pmid", "title"]) + dset = dset.remove_columns(["journal", "year", "pmid", "title"]) dset = dset.map( _tokenize, @@ -89,23 +73,17 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): if label2id is None: label2id = _get_label2id(dset) - dset = dset.map( - _label_encode, - batched=True, - batch_size=32, - num_proc=num_proc, - desc="Encoding labels", - fn_kwargs={"mesh_terms_column": "meshMajor", "label2id": label2id}, - remove_columns=["meshMajor"], - ) + mlb = MultiLabelBinarizer(classes=list(label2id.keys())) + mlb.fit([list(label2id.keys())]) dset = dset.map( - _one_hot, + _binarize_labels, batched=True, batch_size=32, num_proc=num_proc, desc="One-hot encoding labels", - fn_kwargs={"label2id": label2id}, + fn_kwargs={"mlb": mlb}, + remove_columns=["meshMajor"], ) # Split into train and test From b13cb5f5ede95b7855349e339e31629040d0943d Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Thu, 13 Jul 2023 14:08:26 +0300 Subject: [PATCH 051/300] rename column for compatibility --- grants_tagger_light/training/dataloaders/mesh_json_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py index 4e6d4ef9..2f2a24de 100644 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -17,7 +17,7 @@ def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): def _binarize_labels(batch, mlb: MultiLabelBinarizer): - batch["mlb_labels"] = mlb.transform(batch["meshMajor"]) + batch["labels"] = mlb.transform(batch["meshMajor"]) return batch From 8cf6b1f5eabc270d2c42fa6cd1f4ba0ab8458bd5 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Thu, 13 Jul 2023 18:50:21 +0300 Subject: [PATCH 052/300] add multilabel data collator for storage savings --- .../training/dataloaders/__init__.py | 3 +- .../training/dataloaders/mesh_json_loader.py | 26 +++++----- .../dataloaders/multilabel_collator.py | 49 +++++++++++++++++++ grants_tagger_light/training/train.py | 4 ++ 4 files changed, 69 insertions(+), 13 deletions(-) create mode 100644 grants_tagger_light/training/dataloaders/multilabel_collator.py diff --git a/grants_tagger_light/training/dataloaders/__init__.py b/grants_tagger_light/training/dataloaders/__init__.py index 7665d628..14b74fba 100644 --- a/grants_tagger_light/training/dataloaders/__init__.py +++ b/grants_tagger_light/training/dataloaders/__init__.py @@ -1,3 +1,4 @@ from .mesh_json_loader import load_mesh_json +from .multilabel_collator import MultilabelDataCollator -__all__ = ["load_mesh_json"] +__all__ = ["load_mesh_json", "MultilabelDataCollator"] diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py index 2f2a24de..f975754a 100644 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -2,7 +2,7 @@ import numpy as np from transformers import AutoTokenizer from datasets import Dataset -from sklearn.preprocessing import MultiLabelBinarizer +from loguru import logger # TODO refactor the two load funcs into a class @@ -16,9 +16,15 @@ def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): ) -def _binarize_labels(batch, mlb: MultiLabelBinarizer): - batch["labels"] = mlb.transform(batch["meshMajor"]) - return batch +def _encode_labels(sample, label2id: dict): + sample["label_ids"] = [] + for label in sample["meshMajor"]: + try: + sample["label_ids"].append(label2id[label]) + except KeyError: + logger.warning(f"Label {label} not found in label2id") + + return sample def _get_label2id(dset): @@ -73,16 +79,12 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): if label2id is None: label2id = _get_label2id(dset) - mlb = MultiLabelBinarizer(classes=list(label2id.keys())) - mlb.fit([list(label2id.keys())]) - dset = dset.map( - _binarize_labels, - batched=True, - batch_size=32, + _encode_labels, + batched=False, num_proc=num_proc, - desc="One-hot encoding labels", - fn_kwargs={"mlb": mlb}, + desc="Encoding labels", + fn_kwargs={"label2id": label2id}, remove_columns=["meshMajor"], ) diff --git a/grants_tagger_light/training/dataloaders/multilabel_collator.py b/grants_tagger_light/training/dataloaders/multilabel_collator.py new file mode 100644 index 00000000..132d38e6 --- /dev/null +++ b/grants_tagger_light/training/dataloaders/multilabel_collator.py @@ -0,0 +1,49 @@ +from typing import List, Any, Mapping +from sklearn.preprocessing import MultiLabelBinarizer +import numpy as np +import torch + + +class MultilabelDataCollator: + def __init__(self, label2id: dict): + self.mlb = MultiLabelBinarizer(classes=list(label2id.values())) + self.mlb.fit([list(label2id.values())]) + + def __call__(self, features: List[Any]): + """ + Andrei: Inspired from implementation in + https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/data/data_collator.py#L105 + """ + + if not isinstance(features[0], Mapping): + features = [vars(f) for f in features] + first = features[0] + batch = {} + + # Special handling for labels. + # Ensure that tensor is created with the correct type + # (it should be automatically the case, but let's make sure of it.) + + labels_as_np = np.array( + [self.mlb.transform([f["label_ids"]]) for f in features] + ) + + batch["labels"] = torch.tensor(labels_as_np).squeeze(1) + + # Handling of all other possible keys. + # Again, we will use the first element to figure out which + # key/values are not None for this model. + for k, v in first.items(): + if ( + k not in ("label", "label_ids") + and v is not None + and not isinstance(v, str) + ): + if isinstance(v, torch.Tensor): + batch[k] = torch.stack([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.tensor(np.stack([f[k] for f in features])) + else: + batch[k] = torch.tensor([f[k] for f in features]) + + return batch diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index d80329f9..71f87497 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -14,6 +14,7 @@ ) from grants_tagger_light.training.dataloaders import ( load_mesh_json, + MultilabelDataCollator, ) from sklearn.metrics import classification_report from loguru import logger @@ -105,11 +106,14 @@ def sklearn_metrics(prediction: EvalPrediction): return metric_dict + collator = MultilabelDataCollator(label2id=label2id) + trainer = Trainer( model=model, args=training_args, train_dataset=train_dset, eval_dataset=val_dset, + data_collator=collator, compute_metrics=sklearn_metrics, ) From edba653da50805c646eb46c25e6593b369fdd7b0 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 17 Jul 2023 12:31:23 +0300 Subject: [PATCH 053/300] change test size to 0.05 and add train script --- .../training/dataloaders/mesh_json_loader.py | 2 +- scripts/train.sh | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 scripts/train.sh diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py index f975754a..6f870d0e 100644 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -40,7 +40,7 @@ def load_mesh_json( data_path: str, tokenizer: AutoTokenizer, label2id: dict, - test_size: float = 0.1, + test_size: float = 0.05, num_proc: int = 8, max_samples: int = np.inf, ): diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100644 index 00000000..f87c28d3 --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1,15 @@ +# Run on p2.8xlarge instance +grants-tagger train bertmesh \ + "" \ + data/raw/allMeSH_2021.json \ + -1 \ + --output_dir bertmesh_outs/pipeline_test/ \ + --wandb_name test-train-all \ + --wandb_api_key ${WANDB_API_KEY} \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 8 \ + --num_train_epochs 1 \ + --evaluation_strategy steps \ + --evaluation_steps 100000 \ + --fp16 \ + --torch_compile From d5545b871f6151ab4895f1b688d32fb18c361b14 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 17 Jul 2023 12:39:50 +0300 Subject: [PATCH 054/300] make save strategy match --- scripts/train.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/train.sh b/scripts/train.sh index f87c28d3..151e6dbe 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -11,5 +11,7 @@ grants-tagger train bertmesh \ --num_train_epochs 1 \ --evaluation_strategy steps \ --evaluation_steps 100000 \ + --save_strategy steps \ + --save_steps 100000 \ --fp16 \ --torch_compile From 20db0c3e22ae16cc3d1938d47bf95837158488cc Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Mon, 17 Jul 2023 12:43:31 +0300 Subject: [PATCH 055/300] fix typo --- scripts/train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train.sh b/scripts/train.sh index 151e6dbe..be797184 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -10,7 +10,7 @@ grants-tagger train bertmesh \ --per_device_eval_batch_size 8 \ --num_train_epochs 1 \ --evaluation_strategy steps \ - --evaluation_steps 100000 \ + --eval_steps 100000 \ --save_strategy steps \ --save_steps 100000 \ --fp16 \ From 28accc14ebb9918b79880ee15cc74e5337e838ee Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 18 Jul 2023 15:55:03 +0300 Subject: [PATCH 056/300] create separate preprocessing command that saves to disk (wip) --- grants_tagger_light/preprocessing/__init__.py | 8 +- .../preprocessing/preprocess_mesh.py | 221 +++++++++--------- .../training/dataloaders/__init__.py | 3 +- .../training/dataloaders/mesh_json_loader.py | 94 -------- grants_tagger_light/training/train.py | 28 +-- 5 files changed, 132 insertions(+), 222 deletions(-) delete mode 100644 grants_tagger_light/training/dataloaders/mesh_json_loader.py diff --git a/grants_tagger_light/preprocessing/__init__.py b/grants_tagger_light/preprocessing/__init__.py index 5b66021a..3a7024d1 100644 --- a/grants_tagger_light/preprocessing/__init__.py +++ b/grants_tagger_light/preprocessing/__init__.py @@ -1,8 +1,8 @@ import typer - from .preprocess_mesh import preprocess_mesh_cli preprocess_app = typer.Typer() -preprocess_app.command("bioasq-mesh")(preprocess_mesh_cli) - -__all__ = ["preprocess_app"] +preprocess_app.command( + "mesh", + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, +)(preprocess_mesh_cli) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 5957d235..634b1644 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -1,135 +1,146 @@ -""" -Preprocess JSON Mesh data from BioASQ to JSONL -""" import json -from pathlib import Path -from typing import Optional - -import pandas as pd +import numpy as np +import os import typer -from tqdm import tqdm - -from grants_tagger_light.utils import ( - write_jsonl, -) +from transformers import AutoTokenizer +from datasets import Dataset +from loguru import logger +from grants_tagger_light.models.bert_mesh import BertMesh +# TODO refactor the two load funcs into a class -def yield_raw_data(input_path): - with open(input_path, encoding="latin-1") as f_i: - f_i.readline() # skip first line ({"articles":[) which is not valid JSON - for i, line in enumerate(f_i): - item = json.loads(line[:-2]) - yield item +preprocess_app = typer.Typer() -def process_data(item, filter_tags=None, filter_years=None): - text = item["abstractText"] - tags = item["meshMajor"] - journal = item["journal"] - year = item["year"] - if filter_tags: - tags = list(set(tags).intersection(filter_tags)) - if not tags: - return - if filter_years: - min_year, max_year = filter_years.split(",") - if year > int(max_year): - return - if year < int(min_year): - return - data = {"text": text, "tags": tags, "meta": {"journal": journal, "year": year}} - return data +def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): + return tokenizer( + batch[x_col], + padding="max_length", + truncation=True, + max_length=512, + ) -def yield_data(input_path, filter_tags, filter_years, buffer_size=10_000): - data_batch = [] - for item in tqdm(yield_raw_data(input_path), total=15_000_000): # approx 15M docs - processed_item = process_data(item, filter_tags, filter_years) +def _encode_labels(sample, label2id: dict): + sample["label_ids"] = [] + for label in sample["meshMajor"]: + try: + sample["label_ids"].append(label2id[label]) + except KeyError: + logger.warning(f"Label {label} not found in label2id") - if processed_item: - data_batch.append(processed_item) + return sample - if len(data_batch) >= buffer_size: - yield data_batch - data_batch = [] - if data_batch: - yield data_batch +def _get_label2id(dset): + label_set = set() + for sample in dset: + for label in sample["meshMajor"]: + label_set.add(label) + label2id = {label: idx for idx, label in enumerate(label_set)} + return label2id def preprocess_mesh( - raw_data_path, - processed_data_path, - mesh_tags_path=None, - filter_years=None, - n_max=None, - buffer_size=10_000, + data_path: str, + save_loc: str, + model_key: str, + test_size: float = 0.05, + num_proc: int = 8, + max_samples: int = np.inf, ): - if mesh_tags_path: - filter_tags_data = pd.read_csv(mesh_tags_path) - filter_tags = filter_tags_data["DescriptorName"].tolist() - filter_tags = set(filter_tags) + if not model_key: + label2id = None + # Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh + tokenizer = AutoTokenizer.from_pretrained( + "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" + ) else: - filter_tags = None + # Load the model to get its label2id + tokenizer = AutoTokenizer.from_pretrained(model_key) + model = BertMesh.from_pretrained(model_key, trust_remote_code=True) + + label2id = {v: k for k, v in model.id2label.items()} + + def _datagen(mesh_json_path: str, max_samples: int = np.inf): + with open(mesh_json_path, "r", encoding="latin1") as f: + for idx, line in enumerate(f): + # Skip 1st line + if idx == 0: + continue + sample = json.loads(line[:-2]) + + if idx > max_samples: + break + + yield sample + + dset = Dataset.from_generator( + _datagen, + gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, + ) - if filter_years: - min_year, max_year = filter_years.split(",") - filter_years = [int(min_year), int(max_year)] + # Remove unused columns to save space & time + dset = dset.remove_columns(["journal", "year", "pmid", "title"]) + + dset = dset.map( + _tokenize, + batched=True, + batch_size=32, + num_proc=num_proc, + desc="Tokenizing", + fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, + remove_columns=["abstractText"], + ) - # If only using a tiny set of data, need to reduce buffer size - if n_max is not None and n_max < buffer_size: - buffer_size = n_max + # Generate label2id if None + if label2id is None: + label2id = _get_label2id(dset) + + dset = dset.map( + _encode_labels, + batched=False, + num_proc=num_proc, + desc="Encoding labels", + fn_kwargs={"label2id": label2id}, + remove_columns=["meshMajor"], + ) - with open(processed_data_path, "w") as f: - for i, data_batch in enumerate( - yield_data( - raw_data_path, filter_tags, filter_years, buffer_size=buffer_size - ) - ): - write_jsonl(f, data_batch) - if n_max and (i + 1) * buffer_size >= n_max: - break + # Split into train and test + dset = dset.train_test_split(test_size=test_size) + # Save to disk + dset.save_to_disk(os.path.join(save_loc, "dataset")) -preprocess_mesh_app = typer.Typer() + with open(os.path.join(save_loc, "label2id.json"), "w") as f: + json.dump(label2id, f, indent=4) -@preprocess_mesh_app.command() +@preprocess_app.command() def preprocess_mesh_cli( - input_path: Optional[str] = typer.Argument(None, help="path to BioASQ JSON data"), - train_output_path: Optional[str] = typer.Argument( - None, help="path to JSONL output file that will be generated for the train set" - ), - label_binarizer_path: Optional[Path] = typer.Argument( - None, help="path to pickle file that will contain the label binarizer" - ), - test_output_path: Optional[str] = typer.Option( - None, help="path to JSONL output file that will be generated for the test set" - ), - mesh_tags_path: Optional[str] = typer.Option( - None, help="path to mesh tags to filter" + data_path: str = typer.Argument(..., help="Path to mesh.json"), + save_loc: str = typer.Argument(..., help="Path to save processed data"), + model_key: str = typer.Argument( + ..., + help="Key to use when loading tokenizer and label2id. Leave blank if training from scratch", # noqa ), - test_split: Optional[float] = typer.Option( - 0.01, help="split percentage for test data. if None no split." + test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), + num_proc: int = typer.Option( + 8, help="Number of processes to use for preprocessing" ), - filter_years: Optional[str] = typer.Option( - None, help="years to keep in form min_year,max_year with both inclusive" - ), - n_max: Optional[int] = typer.Option( - None, - help=""" - Maximum limit on the number of datapoints in the set - (including training and test)""", + max_samples: int = typer.Argument( + -1, + help="Maximum number of samples to use for preprocessing", ), ): + if max_samples == -1: + max_samples = np.inf + preprocess_mesh( - input_path, - train_output_path, - mesh_tags_path=mesh_tags_path, - filter_years=filter_years, - n_max=n_max, + data_path=data_path, + save_loc=save_loc, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, ) - - -if __name__ == "__main__": - typer.run(preprocess_mesh) diff --git a/grants_tagger_light/training/dataloaders/__init__.py b/grants_tagger_light/training/dataloaders/__init__.py index 14b74fba..9fef3884 100644 --- a/grants_tagger_light/training/dataloaders/__init__.py +++ b/grants_tagger_light/training/dataloaders/__init__.py @@ -1,4 +1,3 @@ -from .mesh_json_loader import load_mesh_json from .multilabel_collator import MultilabelDataCollator -__all__ = ["load_mesh_json", "MultilabelDataCollator"] +__all__ = ["MultilabelDataCollator"] diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py deleted file mode 100644 index 6f870d0e..00000000 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ /dev/null @@ -1,94 +0,0 @@ -import json -import numpy as np -from transformers import AutoTokenizer -from datasets import Dataset -from loguru import logger - -# TODO refactor the two load funcs into a class - - -def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): - return tokenizer( - batch[x_col], - padding="max_length", - truncation=True, - max_length=512, - ) - - -def _encode_labels(sample, label2id: dict): - sample["label_ids"] = [] - for label in sample["meshMajor"]: - try: - sample["label_ids"].append(label2id[label]) - except KeyError: - logger.warning(f"Label {label} not found in label2id") - - return sample - - -def _get_label2id(dset): - label_set = set() - for sample in dset: - for label in sample["meshMajor"]: - label_set.add(label) - label2id = {label: idx for idx, label in enumerate(label_set)} - return label2id - - -def load_mesh_json( - data_path: str, - tokenizer: AutoTokenizer, - label2id: dict, - test_size: float = 0.05, - num_proc: int = 8, - max_samples: int = np.inf, -): - def _datagen(mesh_json_path: str, max_samples: int = np.inf): - with open(mesh_json_path, "r", encoding="latin1") as f: - for idx, line in enumerate(f): - # Skip 1st line - if idx == 0: - continue - sample = json.loads(line[:-2]) - - if idx > max_samples: - break - - yield sample - - dset = Dataset.from_generator( - _datagen, - gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, - ) - - # Remove unused columns to save space & time - dset = dset.remove_columns(["journal", "year", "pmid", "title"]) - - dset = dset.map( - _tokenize, - batched=True, - batch_size=32, - num_proc=num_proc, - desc="Tokenizing", - fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, - remove_columns=["abstractText"], - ) - - # Generate label2id if None - if label2id is None: - label2id = _get_label2id(dset) - - dset = dset.map( - _encode_labels, - batched=False, - num_proc=num_proc, - desc="Encoding labels", - fn_kwargs={"label2id": label2id}, - remove_columns=["meshMajor"], - ) - - # Split into train and test - dset = dset.train_test_split(test_size=test_size) - - return dset["train"], dset["test"], label2id diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 71f87497..2c4086c1 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -13,17 +13,17 @@ BertMeshModelArguments, ) from grants_tagger_light.training.dataloaders import ( - load_mesh_json, MultilabelDataCollator, ) from sklearn.metrics import classification_report from loguru import logger from pprint import pformat +from datasets import load_dataset import typer import numpy as np import os import transformers - +import json transformers.set_seed(42) @@ -31,7 +31,6 @@ def train_bertmesh( model_key: str, data_path: str, - max_samples: int, training_args: TrainingArguments, model_args: BertMeshModelArguments = None, ): @@ -44,14 +43,13 @@ def train_bertmesh( # Instantiate model from scratch config = AutoConfig.from_pretrained(model_args.pretrained_model_key) - tokenizer = AutoTokenizer.from_pretrained(model_args.pretrained_model_key) + AutoTokenizer.from_pretrained(model_args.pretrained_model_key) - train_dset, val_dset, label2id = load_mesh_json( - data_path, - tokenizer=tokenizer, - label2id=None, - max_samples=max_samples, - ) + dset = load_dataset(os.path.join(data_path, "dataset")) + train_dset, val_dset = dset["train"], dset["test"] + + with open(os.path.join(data_path, "label2id.json"), "r") as f: + label2id = json.load(f) config.update( { @@ -71,16 +69,12 @@ def train_bertmesh( # Instantiate from pretrained model = BertMesh.from_pretrained(model_key, trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(model_key) + AutoTokenizer.from_pretrained(model_key) label2id = {v: k for k, v in model.id2label.items()} - train_dset, val_dset, _ = load_mesh_json( - data_path, - tokenizer=tokenizer, - label2id=label2id, - max_samples=max_samples, - ) + dset = load_dataset(os.path.join(data_path, "dataset")) + train_dset, val_dset = dset["train"], dset["test"] if model_args.freeze_backbone: logger.info("Freezing backbone") From 1a6fdaeaeb71435a61aeb14d14c40e398740d197 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 18 Jul 2023 15:57:35 +0300 Subject: [PATCH 057/300] disable caching --- grants_tagger_light/preprocessing/preprocess_mesh.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 634b1644..e7d6485e 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -3,12 +3,13 @@ import os import typer from transformers import AutoTokenizer -from datasets import Dataset +from datasets import Dataset, disable_caching from loguru import logger from grants_tagger_light.models.bert_mesh import BertMesh # TODO refactor the two load funcs into a class +disable_caching() preprocess_app = typer.Typer() From c4bd7439a44d788707c85b9c8ccde6bb9c3095b2 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 18 Jul 2023 15:59:34 +0300 Subject: [PATCH 058/300] turn argument to option --- grants_tagger_light/preprocessing/preprocess_mesh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index e7d6485e..15ebb49a 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -129,7 +129,7 @@ def preprocess_mesh_cli( num_proc: int = typer.Option( 8, help="Number of processes to use for preprocessing" ), - max_samples: int = typer.Argument( + max_samples: int = typer.Option( -1, help="Maximum number of samples to use for preprocessing", ), From 1660ef092399c20e188d6716285cba853d16eccf Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 18 Jul 2023 19:44:25 +0300 Subject: [PATCH 059/300] revert separate preprocess command --- grants_tagger_light/preprocessing/__init__.py | 8 +- .../preprocessing/preprocess_mesh.py | 222 +++++++++--------- .../training/dataloaders/__init__.py | 3 +- .../training/dataloaders/mesh_json_loader.py | 94 ++++++++ grants_tagger_light/training/train.py | 28 ++- 5 files changed, 222 insertions(+), 133 deletions(-) create mode 100644 grants_tagger_light/training/dataloaders/mesh_json_loader.py diff --git a/grants_tagger_light/preprocessing/__init__.py b/grants_tagger_light/preprocessing/__init__.py index 3a7024d1..5b66021a 100644 --- a/grants_tagger_light/preprocessing/__init__.py +++ b/grants_tagger_light/preprocessing/__init__.py @@ -1,8 +1,8 @@ import typer + from .preprocess_mesh import preprocess_mesh_cli preprocess_app = typer.Typer() -preprocess_app.command( - "mesh", - context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, -)(preprocess_mesh_cli) +preprocess_app.command("bioasq-mesh")(preprocess_mesh_cli) + +__all__ = ["preprocess_app"] diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 15ebb49a..5957d235 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -1,147 +1,135 @@ +""" +Preprocess JSON Mesh data from BioASQ to JSONL +""" import json -import numpy as np -import os +from pathlib import Path +from typing import Optional + +import pandas as pd import typer -from transformers import AutoTokenizer -from datasets import Dataset, disable_caching -from loguru import logger -from grants_tagger_light.models.bert_mesh import BertMesh +from tqdm import tqdm -# TODO refactor the two load funcs into a class +from grants_tagger_light.utils import ( + write_jsonl, +) -disable_caching() -preprocess_app = typer.Typer() +def yield_raw_data(input_path): + with open(input_path, encoding="latin-1") as f_i: + f_i.readline() # skip first line ({"articles":[) which is not valid JSON + for i, line in enumerate(f_i): + item = json.loads(line[:-2]) + yield item -def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): - return tokenizer( - batch[x_col], - padding="max_length", - truncation=True, - max_length=512, - ) + +def process_data(item, filter_tags=None, filter_years=None): + text = item["abstractText"] + tags = item["meshMajor"] + journal = item["journal"] + year = item["year"] + if filter_tags: + tags = list(set(tags).intersection(filter_tags)) + if not tags: + return + if filter_years: + min_year, max_year = filter_years.split(",") + if year > int(max_year): + return + if year < int(min_year): + return + data = {"text": text, "tags": tags, "meta": {"journal": journal, "year": year}} + return data -def _encode_labels(sample, label2id: dict): - sample["label_ids"] = [] - for label in sample["meshMajor"]: - try: - sample["label_ids"].append(label2id[label]) - except KeyError: - logger.warning(f"Label {label} not found in label2id") +def yield_data(input_path, filter_tags, filter_years, buffer_size=10_000): + data_batch = [] + for item in tqdm(yield_raw_data(input_path), total=15_000_000): # approx 15M docs + processed_item = process_data(item, filter_tags, filter_years) - return sample + if processed_item: + data_batch.append(processed_item) + if len(data_batch) >= buffer_size: + yield data_batch + data_batch = [] -def _get_label2id(dset): - label_set = set() - for sample in dset: - for label in sample["meshMajor"]: - label_set.add(label) - label2id = {label: idx for idx, label in enumerate(label_set)} - return label2id + if data_batch: + yield data_batch def preprocess_mesh( - data_path: str, - save_loc: str, - model_key: str, - test_size: float = 0.05, - num_proc: int = 8, - max_samples: int = np.inf, + raw_data_path, + processed_data_path, + mesh_tags_path=None, + filter_years=None, + n_max=None, + buffer_size=10_000, ): - if not model_key: - label2id = None - # Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh - tokenizer = AutoTokenizer.from_pretrained( - "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" - ) + if mesh_tags_path: + filter_tags_data = pd.read_csv(mesh_tags_path) + filter_tags = filter_tags_data["DescriptorName"].tolist() + filter_tags = set(filter_tags) else: - # Load the model to get its label2id - tokenizer = AutoTokenizer.from_pretrained(model_key) - model = BertMesh.from_pretrained(model_key, trust_remote_code=True) - - label2id = {v: k for k, v in model.id2label.items()} - - def _datagen(mesh_json_path: str, max_samples: int = np.inf): - with open(mesh_json_path, "r", encoding="latin1") as f: - for idx, line in enumerate(f): - # Skip 1st line - if idx == 0: - continue - sample = json.loads(line[:-2]) - - if idx > max_samples: - break - - yield sample - - dset = Dataset.from_generator( - _datagen, - gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, - ) + filter_tags = None - # Remove unused columns to save space & time - dset = dset.remove_columns(["journal", "year", "pmid", "title"]) - - dset = dset.map( - _tokenize, - batched=True, - batch_size=32, - num_proc=num_proc, - desc="Tokenizing", - fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, - remove_columns=["abstractText"], - ) + if filter_years: + min_year, max_year = filter_years.split(",") + filter_years = [int(min_year), int(max_year)] - # Generate label2id if None - if label2id is None: - label2id = _get_label2id(dset) - - dset = dset.map( - _encode_labels, - batched=False, - num_proc=num_proc, - desc="Encoding labels", - fn_kwargs={"label2id": label2id}, - remove_columns=["meshMajor"], - ) + # If only using a tiny set of data, need to reduce buffer size + if n_max is not None and n_max < buffer_size: + buffer_size = n_max - # Split into train and test - dset = dset.train_test_split(test_size=test_size) + with open(processed_data_path, "w") as f: + for i, data_batch in enumerate( + yield_data( + raw_data_path, filter_tags, filter_years, buffer_size=buffer_size + ) + ): + write_jsonl(f, data_batch) + if n_max and (i + 1) * buffer_size >= n_max: + break - # Save to disk - dset.save_to_disk(os.path.join(save_loc, "dataset")) - with open(os.path.join(save_loc, "label2id.json"), "w") as f: - json.dump(label2id, f, indent=4) +preprocess_mesh_app = typer.Typer() -@preprocess_app.command() +@preprocess_mesh_app.command() def preprocess_mesh_cli( - data_path: str = typer.Argument(..., help="Path to mesh.json"), - save_loc: str = typer.Argument(..., help="Path to save processed data"), - model_key: str = typer.Argument( - ..., - help="Key to use when loading tokenizer and label2id. Leave blank if training from scratch", # noqa + input_path: Optional[str] = typer.Argument(None, help="path to BioASQ JSON data"), + train_output_path: Optional[str] = typer.Argument( + None, help="path to JSONL output file that will be generated for the train set" + ), + label_binarizer_path: Optional[Path] = typer.Argument( + None, help="path to pickle file that will contain the label binarizer" + ), + test_output_path: Optional[str] = typer.Option( + None, help="path to JSONL output file that will be generated for the test set" + ), + mesh_tags_path: Optional[str] = typer.Option( + None, help="path to mesh tags to filter" ), - test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), - num_proc: int = typer.Option( - 8, help="Number of processes to use for preprocessing" + test_split: Optional[float] = typer.Option( + 0.01, help="split percentage for test data. if None no split." ), - max_samples: int = typer.Option( - -1, - help="Maximum number of samples to use for preprocessing", + filter_years: Optional[str] = typer.Option( + None, help="years to keep in form min_year,max_year with both inclusive" + ), + n_max: Optional[int] = typer.Option( + None, + help=""" + Maximum limit on the number of datapoints in the set + (including training and test)""", ), ): - if max_samples == -1: - max_samples = np.inf - preprocess_mesh( - data_path=data_path, - save_loc=save_loc, - model_key=model_key, - test_size=test_size, - num_proc=num_proc, - max_samples=max_samples, + input_path, + train_output_path, + mesh_tags_path=mesh_tags_path, + filter_years=filter_years, + n_max=n_max, ) + + +if __name__ == "__main__": + typer.run(preprocess_mesh) diff --git a/grants_tagger_light/training/dataloaders/__init__.py b/grants_tagger_light/training/dataloaders/__init__.py index 9fef3884..14b74fba 100644 --- a/grants_tagger_light/training/dataloaders/__init__.py +++ b/grants_tagger_light/training/dataloaders/__init__.py @@ -1,3 +1,4 @@ +from .mesh_json_loader import load_mesh_json from .multilabel_collator import MultilabelDataCollator -__all__ = ["MultilabelDataCollator"] +__all__ = ["load_mesh_json", "MultilabelDataCollator"] diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py new file mode 100644 index 00000000..6f870d0e --- /dev/null +++ b/grants_tagger_light/training/dataloaders/mesh_json_loader.py @@ -0,0 +1,94 @@ +import json +import numpy as np +from transformers import AutoTokenizer +from datasets import Dataset +from loguru import logger + +# TODO refactor the two load funcs into a class + + +def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): + return tokenizer( + batch[x_col], + padding="max_length", + truncation=True, + max_length=512, + ) + + +def _encode_labels(sample, label2id: dict): + sample["label_ids"] = [] + for label in sample["meshMajor"]: + try: + sample["label_ids"].append(label2id[label]) + except KeyError: + logger.warning(f"Label {label} not found in label2id") + + return sample + + +def _get_label2id(dset): + label_set = set() + for sample in dset: + for label in sample["meshMajor"]: + label_set.add(label) + label2id = {label: idx for idx, label in enumerate(label_set)} + return label2id + + +def load_mesh_json( + data_path: str, + tokenizer: AutoTokenizer, + label2id: dict, + test_size: float = 0.05, + num_proc: int = 8, + max_samples: int = np.inf, +): + def _datagen(mesh_json_path: str, max_samples: int = np.inf): + with open(mesh_json_path, "r", encoding="latin1") as f: + for idx, line in enumerate(f): + # Skip 1st line + if idx == 0: + continue + sample = json.loads(line[:-2]) + + if idx > max_samples: + break + + yield sample + + dset = Dataset.from_generator( + _datagen, + gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, + ) + + # Remove unused columns to save space & time + dset = dset.remove_columns(["journal", "year", "pmid", "title"]) + + dset = dset.map( + _tokenize, + batched=True, + batch_size=32, + num_proc=num_proc, + desc="Tokenizing", + fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, + remove_columns=["abstractText"], + ) + + # Generate label2id if None + if label2id is None: + label2id = _get_label2id(dset) + + dset = dset.map( + _encode_labels, + batched=False, + num_proc=num_proc, + desc="Encoding labels", + fn_kwargs={"label2id": label2id}, + remove_columns=["meshMajor"], + ) + + # Split into train and test + dset = dset.train_test_split(test_size=test_size) + + return dset["train"], dset["test"], label2id diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 2c4086c1..71f87497 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -13,17 +13,17 @@ BertMeshModelArguments, ) from grants_tagger_light.training.dataloaders import ( + load_mesh_json, MultilabelDataCollator, ) from sklearn.metrics import classification_report from loguru import logger from pprint import pformat -from datasets import load_dataset import typer import numpy as np import os import transformers -import json + transformers.set_seed(42) @@ -31,6 +31,7 @@ def train_bertmesh( model_key: str, data_path: str, + max_samples: int, training_args: TrainingArguments, model_args: BertMeshModelArguments = None, ): @@ -43,13 +44,14 @@ def train_bertmesh( # Instantiate model from scratch config = AutoConfig.from_pretrained(model_args.pretrained_model_key) - AutoTokenizer.from_pretrained(model_args.pretrained_model_key) - - dset = load_dataset(os.path.join(data_path, "dataset")) - train_dset, val_dset = dset["train"], dset["test"] + tokenizer = AutoTokenizer.from_pretrained(model_args.pretrained_model_key) - with open(os.path.join(data_path, "label2id.json"), "r") as f: - label2id = json.load(f) + train_dset, val_dset, label2id = load_mesh_json( + data_path, + tokenizer=tokenizer, + label2id=None, + max_samples=max_samples, + ) config.update( { @@ -69,12 +71,16 @@ def train_bertmesh( # Instantiate from pretrained model = BertMesh.from_pretrained(model_key, trust_remote_code=True) - AutoTokenizer.from_pretrained(model_key) + tokenizer = AutoTokenizer.from_pretrained(model_key) label2id = {v: k for k, v in model.id2label.items()} - dset = load_dataset(os.path.join(data_path, "dataset")) - train_dset, val_dset = dset["train"], dset["test"] + train_dset, val_dset, _ = load_mesh_json( + data_path, + tokenizer=tokenizer, + label2id=label2id, + max_samples=max_samples, + ) if model_args.freeze_backbone: logger.info("Freezing backbone") From 4c6152c87925e2f2839791c2448071e592c3a82e Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 18 Jul 2023 19:52:55 +0300 Subject: [PATCH 060/300] Revert "revert separate preprocess command" This reverts commit 1660ef092399c20e188d6716285cba853d16eccf. --- grants_tagger_light/preprocessing/__init__.py | 8 +- .../preprocessing/preprocess_mesh.py | 222 +++++++++--------- .../training/dataloaders/__init__.py | 3 +- .../training/dataloaders/mesh_json_loader.py | 94 -------- grants_tagger_light/training/train.py | 28 +-- 5 files changed, 133 insertions(+), 222 deletions(-) delete mode 100644 grants_tagger_light/training/dataloaders/mesh_json_loader.py diff --git a/grants_tagger_light/preprocessing/__init__.py b/grants_tagger_light/preprocessing/__init__.py index 5b66021a..3a7024d1 100644 --- a/grants_tagger_light/preprocessing/__init__.py +++ b/grants_tagger_light/preprocessing/__init__.py @@ -1,8 +1,8 @@ import typer - from .preprocess_mesh import preprocess_mesh_cli preprocess_app = typer.Typer() -preprocess_app.command("bioasq-mesh")(preprocess_mesh_cli) - -__all__ = ["preprocess_app"] +preprocess_app.command( + "mesh", + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, +)(preprocess_mesh_cli) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 5957d235..15ebb49a 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -1,135 +1,147 @@ -""" -Preprocess JSON Mesh data from BioASQ to JSONL -""" import json -from pathlib import Path -from typing import Optional - -import pandas as pd +import numpy as np +import os import typer -from tqdm import tqdm - -from grants_tagger_light.utils import ( - write_jsonl, -) +from transformers import AutoTokenizer +from datasets import Dataset, disable_caching +from loguru import logger +from grants_tagger_light.models.bert_mesh import BertMesh +# TODO refactor the two load funcs into a class -def yield_raw_data(input_path): - with open(input_path, encoding="latin-1") as f_i: - f_i.readline() # skip first line ({"articles":[) which is not valid JSON - for i, line in enumerate(f_i): - item = json.loads(line[:-2]) - yield item +disable_caching() +preprocess_app = typer.Typer() -def process_data(item, filter_tags=None, filter_years=None): - text = item["abstractText"] - tags = item["meshMajor"] - journal = item["journal"] - year = item["year"] - if filter_tags: - tags = list(set(tags).intersection(filter_tags)) - if not tags: - return - if filter_years: - min_year, max_year = filter_years.split(",") - if year > int(max_year): - return - if year < int(min_year): - return - data = {"text": text, "tags": tags, "meta": {"journal": journal, "year": year}} - return data +def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): + return tokenizer( + batch[x_col], + padding="max_length", + truncation=True, + max_length=512, + ) -def yield_data(input_path, filter_tags, filter_years, buffer_size=10_000): - data_batch = [] - for item in tqdm(yield_raw_data(input_path), total=15_000_000): # approx 15M docs - processed_item = process_data(item, filter_tags, filter_years) +def _encode_labels(sample, label2id: dict): + sample["label_ids"] = [] + for label in sample["meshMajor"]: + try: + sample["label_ids"].append(label2id[label]) + except KeyError: + logger.warning(f"Label {label} not found in label2id") - if processed_item: - data_batch.append(processed_item) + return sample - if len(data_batch) >= buffer_size: - yield data_batch - data_batch = [] - if data_batch: - yield data_batch +def _get_label2id(dset): + label_set = set() + for sample in dset: + for label in sample["meshMajor"]: + label_set.add(label) + label2id = {label: idx for idx, label in enumerate(label_set)} + return label2id def preprocess_mesh( - raw_data_path, - processed_data_path, - mesh_tags_path=None, - filter_years=None, - n_max=None, - buffer_size=10_000, + data_path: str, + save_loc: str, + model_key: str, + test_size: float = 0.05, + num_proc: int = 8, + max_samples: int = np.inf, ): - if mesh_tags_path: - filter_tags_data = pd.read_csv(mesh_tags_path) - filter_tags = filter_tags_data["DescriptorName"].tolist() - filter_tags = set(filter_tags) + if not model_key: + label2id = None + # Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh + tokenizer = AutoTokenizer.from_pretrained( + "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" + ) else: - filter_tags = None + # Load the model to get its label2id + tokenizer = AutoTokenizer.from_pretrained(model_key) + model = BertMesh.from_pretrained(model_key, trust_remote_code=True) + + label2id = {v: k for k, v in model.id2label.items()} + + def _datagen(mesh_json_path: str, max_samples: int = np.inf): + with open(mesh_json_path, "r", encoding="latin1") as f: + for idx, line in enumerate(f): + # Skip 1st line + if idx == 0: + continue + sample = json.loads(line[:-2]) + + if idx > max_samples: + break + + yield sample + + dset = Dataset.from_generator( + _datagen, + gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, + ) - if filter_years: - min_year, max_year = filter_years.split(",") - filter_years = [int(min_year), int(max_year)] + # Remove unused columns to save space & time + dset = dset.remove_columns(["journal", "year", "pmid", "title"]) + + dset = dset.map( + _tokenize, + batched=True, + batch_size=32, + num_proc=num_proc, + desc="Tokenizing", + fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, + remove_columns=["abstractText"], + ) - # If only using a tiny set of data, need to reduce buffer size - if n_max is not None and n_max < buffer_size: - buffer_size = n_max + # Generate label2id if None + if label2id is None: + label2id = _get_label2id(dset) + + dset = dset.map( + _encode_labels, + batched=False, + num_proc=num_proc, + desc="Encoding labels", + fn_kwargs={"label2id": label2id}, + remove_columns=["meshMajor"], + ) - with open(processed_data_path, "w") as f: - for i, data_batch in enumerate( - yield_data( - raw_data_path, filter_tags, filter_years, buffer_size=buffer_size - ) - ): - write_jsonl(f, data_batch) - if n_max and (i + 1) * buffer_size >= n_max: - break + # Split into train and test + dset = dset.train_test_split(test_size=test_size) + # Save to disk + dset.save_to_disk(os.path.join(save_loc, "dataset")) -preprocess_mesh_app = typer.Typer() + with open(os.path.join(save_loc, "label2id.json"), "w") as f: + json.dump(label2id, f, indent=4) -@preprocess_mesh_app.command() +@preprocess_app.command() def preprocess_mesh_cli( - input_path: Optional[str] = typer.Argument(None, help="path to BioASQ JSON data"), - train_output_path: Optional[str] = typer.Argument( - None, help="path to JSONL output file that will be generated for the train set" - ), - label_binarizer_path: Optional[Path] = typer.Argument( - None, help="path to pickle file that will contain the label binarizer" - ), - test_output_path: Optional[str] = typer.Option( - None, help="path to JSONL output file that will be generated for the test set" - ), - mesh_tags_path: Optional[str] = typer.Option( - None, help="path to mesh tags to filter" + data_path: str = typer.Argument(..., help="Path to mesh.json"), + save_loc: str = typer.Argument(..., help="Path to save processed data"), + model_key: str = typer.Argument( + ..., + help="Key to use when loading tokenizer and label2id. Leave blank if training from scratch", # noqa ), - test_split: Optional[float] = typer.Option( - 0.01, help="split percentage for test data. if None no split." + test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), + num_proc: int = typer.Option( + 8, help="Number of processes to use for preprocessing" ), - filter_years: Optional[str] = typer.Option( - None, help="years to keep in form min_year,max_year with both inclusive" - ), - n_max: Optional[int] = typer.Option( - None, - help=""" - Maximum limit on the number of datapoints in the set - (including training and test)""", + max_samples: int = typer.Option( + -1, + help="Maximum number of samples to use for preprocessing", ), ): + if max_samples == -1: + max_samples = np.inf + preprocess_mesh( - input_path, - train_output_path, - mesh_tags_path=mesh_tags_path, - filter_years=filter_years, - n_max=n_max, + data_path=data_path, + save_loc=save_loc, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, ) - - -if __name__ == "__main__": - typer.run(preprocess_mesh) diff --git a/grants_tagger_light/training/dataloaders/__init__.py b/grants_tagger_light/training/dataloaders/__init__.py index 14b74fba..9fef3884 100644 --- a/grants_tagger_light/training/dataloaders/__init__.py +++ b/grants_tagger_light/training/dataloaders/__init__.py @@ -1,4 +1,3 @@ -from .mesh_json_loader import load_mesh_json from .multilabel_collator import MultilabelDataCollator -__all__ = ["load_mesh_json", "MultilabelDataCollator"] +__all__ = ["MultilabelDataCollator"] diff --git a/grants_tagger_light/training/dataloaders/mesh_json_loader.py b/grants_tagger_light/training/dataloaders/mesh_json_loader.py deleted file mode 100644 index 6f870d0e..00000000 --- a/grants_tagger_light/training/dataloaders/mesh_json_loader.py +++ /dev/null @@ -1,94 +0,0 @@ -import json -import numpy as np -from transformers import AutoTokenizer -from datasets import Dataset -from loguru import logger - -# TODO refactor the two load funcs into a class - - -def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): - return tokenizer( - batch[x_col], - padding="max_length", - truncation=True, - max_length=512, - ) - - -def _encode_labels(sample, label2id: dict): - sample["label_ids"] = [] - for label in sample["meshMajor"]: - try: - sample["label_ids"].append(label2id[label]) - except KeyError: - logger.warning(f"Label {label} not found in label2id") - - return sample - - -def _get_label2id(dset): - label_set = set() - for sample in dset: - for label in sample["meshMajor"]: - label_set.add(label) - label2id = {label: idx for idx, label in enumerate(label_set)} - return label2id - - -def load_mesh_json( - data_path: str, - tokenizer: AutoTokenizer, - label2id: dict, - test_size: float = 0.05, - num_proc: int = 8, - max_samples: int = np.inf, -): - def _datagen(mesh_json_path: str, max_samples: int = np.inf): - with open(mesh_json_path, "r", encoding="latin1") as f: - for idx, line in enumerate(f): - # Skip 1st line - if idx == 0: - continue - sample = json.loads(line[:-2]) - - if idx > max_samples: - break - - yield sample - - dset = Dataset.from_generator( - _datagen, - gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, - ) - - # Remove unused columns to save space & time - dset = dset.remove_columns(["journal", "year", "pmid", "title"]) - - dset = dset.map( - _tokenize, - batched=True, - batch_size=32, - num_proc=num_proc, - desc="Tokenizing", - fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, - remove_columns=["abstractText"], - ) - - # Generate label2id if None - if label2id is None: - label2id = _get_label2id(dset) - - dset = dset.map( - _encode_labels, - batched=False, - num_proc=num_proc, - desc="Encoding labels", - fn_kwargs={"label2id": label2id}, - remove_columns=["meshMajor"], - ) - - # Split into train and test - dset = dset.train_test_split(test_size=test_size) - - return dset["train"], dset["test"], label2id diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 71f87497..2c4086c1 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -13,17 +13,17 @@ BertMeshModelArguments, ) from grants_tagger_light.training.dataloaders import ( - load_mesh_json, MultilabelDataCollator, ) from sklearn.metrics import classification_report from loguru import logger from pprint import pformat +from datasets import load_dataset import typer import numpy as np import os import transformers - +import json transformers.set_seed(42) @@ -31,7 +31,6 @@ def train_bertmesh( model_key: str, data_path: str, - max_samples: int, training_args: TrainingArguments, model_args: BertMeshModelArguments = None, ): @@ -44,14 +43,13 @@ def train_bertmesh( # Instantiate model from scratch config = AutoConfig.from_pretrained(model_args.pretrained_model_key) - tokenizer = AutoTokenizer.from_pretrained(model_args.pretrained_model_key) + AutoTokenizer.from_pretrained(model_args.pretrained_model_key) - train_dset, val_dset, label2id = load_mesh_json( - data_path, - tokenizer=tokenizer, - label2id=None, - max_samples=max_samples, - ) + dset = load_dataset(os.path.join(data_path, "dataset")) + train_dset, val_dset = dset["train"], dset["test"] + + with open(os.path.join(data_path, "label2id.json"), "r") as f: + label2id = json.load(f) config.update( { @@ -71,16 +69,12 @@ def train_bertmesh( # Instantiate from pretrained model = BertMesh.from_pretrained(model_key, trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(model_key) + AutoTokenizer.from_pretrained(model_key) label2id = {v: k for k, v in model.id2label.items()} - train_dset, val_dset, _ = load_mesh_json( - data_path, - tokenizer=tokenizer, - label2id=label2id, - max_samples=max_samples, - ) + dset = load_dataset(os.path.join(data_path, "dataset")) + train_dset, val_dset = dset["train"], dset["test"] if model_args.freeze_backbone: logger.info("Freezing backbone") From 8a8d52fb9c7d281ef3586cdf6e87df09a0642a07 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 18 Jul 2023 20:00:11 +0300 Subject: [PATCH 061/300] move cache disable inside preproc --- grants_tagger_light/preprocessing/preprocess_mesh.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 15ebb49a..86cb4efc 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -9,7 +9,6 @@ # TODO refactor the two load funcs into a class -disable_caching() preprocess_app = typer.Typer() @@ -50,6 +49,8 @@ def preprocess_mesh( num_proc: int = 8, max_samples: int = np.inf, ): + disable_caching() + if not model_key: label2id = None # Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh @@ -92,6 +93,7 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): desc="Tokenizing", fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, remove_columns=["abstractText"], + load_from_cache_file=False, ) # Generate label2id if None @@ -105,6 +107,7 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): desc="Encoding labels", fn_kwargs={"label2id": label2id}, remove_columns=["meshMajor"], + load_from_cache_file=False, ) # Split into train and test From 7e3c2d3c7d063a6bf5c9e5ad35522496fd129e64 Mon Sep 17 00:00:00 2001 From: Andrei Apostol Date: Tue, 18 Jul 2023 20:04:05 +0300 Subject: [PATCH 062/300] add timings --- grants_tagger_light/preprocessing/preprocess_mesh.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 86cb4efc..ff870769 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -2,6 +2,7 @@ import numpy as np import os import typer +import time from transformers import AutoTokenizer from datasets import Dataset, disable_caching from loguru import logger @@ -77,14 +78,16 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): yield sample + t1 = time.time() dset = Dataset.from_generator( _datagen, gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, ) - + logger.info("Time taken to load dataset: {}".format(time.time() - t1)) # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "year", "pmid", "title"]) + t1 = time.time() dset = dset.map( _tokenize, batched=True, @@ -95,11 +98,13 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): remove_columns=["abstractText"], load_from_cache_file=False, ) + logger.info("Time taken to tokenize: {}".format(time.time() - t1)) # Generate label2id if None if label2id is None: label2id = _get_label2id(dset) + t1 = time.time() dset = dset.map( _encode_labels, batched=False, @@ -109,15 +114,20 @@ def _datagen(mesh_json_path: str, max_samples: int = np.inf): remove_columns=["meshMajor"], load_from_cache_file=False, ) + logger.info("Time taken to encode labels: {}".format(time.time() - t1)) # Split into train and test + t1 = time.time() dset = dset.train_test_split(test_size=test_size) + logger.info("Time taken to split into train and test: {}".format(time.time() - t1)) # Save to disk + t1 = time.time() dset.save_to_disk(os.path.join(save_loc, "dataset")) with open(os.path.join(save_loc, "label2id.json"), "w") as f: json.dump(label2id, f, indent=4) + logger.info("Time taken to save to disk: {}".format(time.time() - t1)) @preprocess_app.command() From 5ba41fc64302a1924af243e93834f85c5fcc1379 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 19 Jul 2023 20:06:34 +0100 Subject: [PATCH 063/300] Speed improvements on preprocessing --- README.md | 38 +- .../preprocessing/preprocess_mesh.py | 115 +- grants_tagger_light/training/train.py | 17 +- grants_tagger_light/utils/sharding.py | 23 + grants_tagger_light/utils/utils.py | 9 + poetry.lock | 1020 +++++++++-------- 6 files changed, 695 insertions(+), 527 deletions(-) create mode 100644 grants_tagger_light/utils/sharding.py diff --git a/README.md b/README.md index 3c624990..e5c80332 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ # Grants Tagger Light 🔖 -Light weight repository for grant tagger model deployment and inference. +Lightweight repository for grant tagger model deployment and inference. Adapted from [the original repository](https://github.com/wellcometrust/grants_tagger) Grants tagger is a machine learning powered tool that -assigns biomedically related tags to grants proposals. +assigns biomedical related tags to grant proposals. Those tags can be custom to the organisation or based upon a preexisting ontology like MeSH. @@ -12,16 +12,16 @@ Wellcome Trust for internal use but both the models and the code will be made available in a reusable manner. This work started as a means to automate the tags of one -funding division within Wellcome but currently it has expanded +funding division within Wellcome, but currently it has expanded into the development and automation of a complete set of tags that can cover past and future directions for the organisation. Science tags refer to the custom tags for the Science funding -division. These tags are higly specific to the research Wellcome -funds so it is not advisable to use them. +division. These tags are highly specific to the research Wellcome +funds, so it is not advisable to use them. MeSH tags are subset of tags from the MeSH ontology that aim to -tags grants according to: +tag grants according to: - diseases - themes of research Those tags are generic enough to be used by other biomedical funders @@ -41,24 +41,30 @@ For GPU-support: `poetry install --with gpu` For training the model, we recommend installing the version of this package with GPU support. -For infenrece, CPU-support should suffice. +For inference, CPU-support should suffice. ## 2. Activate the environment `poetry shell` You now have access to the `grants-tagger` command line interface! +## OPTIONAL: 3. Install MantisNLP `remote` to connect to a remote AWS instances +`pip install git+https://github.com/ivyleavedtoadflax/remote.py.git` +Then add your instance +`remote config add [instance_name]` +And then connect and attach to your machine with a tunnel +`remote connect -p 1234:localhost:1234 -v` # ⌨️ Commands -| Commands | | needs dev | -| --------------- | ------------------------------------------------------------ | --------- | -| ⚙️ preprocess | preprocess data to use for training | False | -| 🔥 train | trains a new model | True | -| 📈 evaluate | evaluate performance of pretrained model | True | -| 🔖 predict | predict tags given a grant abstract using a pretrained model | False | -| 🎛 tune | tune params and threshold | True | -| ⬇️ download | download data from EPMC | False | +| Commands | Description | Needs dev | +| --------------- |--------------------------------------------------------------|-----------| +| ⚙️ preprocess | preprocess data to use for training | False | +| 🔥 train | trains a new model | True | +| 📈 evaluate | evaluate performance of pretrained model | True | +| 🔖 predict | predict tags given a grant abstract using a pretrained model | False | +| 🎛 tune | tune params and threshold | True | +| ⬇️ download | download data from EPMC | False | in square brackets the commands that are not implemented yet @@ -97,7 +103,7 @@ your own data under development. ## 🔥 Train -Train acts as the entry point command for training all models. Currently we only support +Train acts as the entry point command for training all models. Currently, we only support the BertMesh model. The command will train a model and save it to the specified path. ### bertmesh diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 15ebb49a..bb90f5ce 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -1,11 +1,13 @@ import json +import tempfile + import numpy as np -import os import typer from transformers import AutoTokenizer -from datasets import Dataset, disable_caching +from datasets import Dataset, disable_caching, load_dataset from loguru import logger from grants_tagger_light.models.bert_mesh import BertMesh +import os # TODO refactor the two load funcs into a class @@ -22,24 +24,38 @@ def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): ) -def _encode_labels(sample, label2id: dict): - sample["label_ids"] = [] - for label in sample["meshMajor"]: - try: - sample["label_ids"].append(label2id[label]) - except KeyError: - logger.warning(f"Label {label} not found in label2id") +def _map_label_to_ids(labels, label2id): + return [label2id[label] for label in labels] + + +def _encode_labels(sample, label2id): + return {'label_ids': [_map_label_to_ids(x, label2id) for x in sample['meshMajor']]} + + +def create_tmp_file(jsonl_file, lines): + with open(jsonl_file, 'r') as input_file: + with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + for _ in range(lines): + line = input_file.readline() + if not line: + break + tmp_file.write(line) + + return tmp_file.name - return sample +def _datagen(mesh_json_path: str, max_samples: int = np.inf): + with open(mesh_json_path, "r", encoding="latin1") as f: + for idx, line in enumerate(f): + # Skip 1st line + if idx == 0: + continue + sample = json.loads(line[:-2]) -def _get_label2id(dset): - label_set = set() - for sample in dset: - for label in sample["meshMajor"]: - label_set.add(label) - label2id = {label: idx for idx, label in enumerate(label_set)} - return label2id + if idx > max_samples: + break + + yield sample def preprocess_mesh( @@ -49,7 +65,9 @@ def preprocess_mesh( test_size: float = 0.05, num_proc: int = 8, max_samples: int = np.inf, + batch_size: int = 32 ): + print("Downloading tokenizer and model...") if not model_key: label2id = None # Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh @@ -63,55 +81,72 @@ def preprocess_mesh( label2id = {v: k for k, v in model.id2label.items()} - def _datagen(mesh_json_path: str, max_samples: int = np.inf): - with open(mesh_json_path, "r", encoding="latin1") as f: - for idx, line in enumerate(f): - # Skip 1st line - if idx == 0: - continue - sample = json.loads(line[:-2]) - - if idx > max_samples: - break - - yield sample - + """print("Creating generator...") dset = Dataset.from_generator( _datagen, - gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, - ) + gen_kwargs={"files": files, "max_samples": max_samples}, + num_proc=num_proc + )""" + if max_samples != np.inf: + data_path = create_tmp_file(data_path, max_samples) + dset = load_dataset("json", data_files=data_path, num_proc=num_proc) + if 'train' in dset: + dset = dset['train'] + + print("Removing columns...") # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "year", "pmid", "title"]) + print("Tokenizing with map...") dset = dset.map( _tokenize, batched=True, - batch_size=32, + batch_size=batch_size, num_proc=num_proc, desc="Tokenizing", fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, remove_columns=["abstractText"], ) + print("Getting label2id...") # Generate label2id if None if label2id is None: - label2id = _get_label2id(dset) + dset = dset.map( + lambda x: {'labels': x["meshMajor"]}, + batched=True, + batch_size=batch_size, + num_proc=num_proc, + desc="Getting labels" + ) + + # Step 1: Get the 'labels' column from the dataset + labels_column = dset['labels'] + # Step 2: Flatten the list column and compute unique values + unique_labels_set = set(label for sublist in labels_column for label in sublist) + + # Step 3: Dictionary creation + label2id = {label: idx for idx, label in enumerate(unique_labels_set)} + + print("Encoding with map...") dset = dset.map( _encode_labels, - batched=False, - num_proc=num_proc, + batched=True, + batch_size=batch_size, desc="Encoding labels", + num_proc=num_proc, fn_kwargs={"label2id": label2id}, - remove_columns=["meshMajor"], + remove_columns=["meshMajor", "labels"], ) + print("Splitting training and test...") # Split into train and test dset = dset.train_test_split(test_size=test_size) + print("Saving to disk...") # Save to disk - dset.save_to_disk(os.path.join(save_loc, "dataset")) + dset.save_to_disk(os.path.join(save_loc, "dataset"), num_proc=num_proc) with open(os.path.join(save_loc, "label2id.json"), "w") as f: json.dump(label2id, f, indent=4) @@ -133,6 +168,9 @@ def preprocess_mesh_cli( -1, help="Maximum number of samples to use for preprocessing", ), + batch_size: int = typer.Option( + 32, + help="Size of the preprocessing batch") ): if max_samples == -1: max_samples = np.inf @@ -144,4 +182,5 @@ def preprocess_mesh_cli( test_size=test_size, num_proc=num_proc, max_samples=max_samples, + batch_size=batch_size ) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 2c4086c1..563c1a84 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -25,12 +25,16 @@ import transformers import json +from grants_tagger_light.utils.sharding import Sharding +from grants_tagger_light.utils.utils import calculate_max_steps + transformers.set_seed(42) def train_bertmesh( model_key: str, data_path: str, + max_samples: int, training_args: TrainingArguments, model_args: BertMeshModelArguments = None, ): @@ -46,7 +50,7 @@ def train_bertmesh( AutoTokenizer.from_pretrained(model_args.pretrained_model_key) dset = load_dataset(os.path.join(data_path, "dataset")) - train_dset, val_dset = dset["train"], dset["test"] + train_dset, val_dset = Sharding(num_shards=100).shard(dset["train"]), dset["test"] with open(os.path.join(data_path, "label2id.json"), "r") as f: label2id = json.load(f) @@ -74,7 +78,8 @@ def train_bertmesh( label2id = {v: k for k, v in model.id2label.items()} dset = load_dataset(os.path.join(data_path, "dataset")) - train_dset, val_dset = dset["train"], dset["test"] + + train_dset, val_dset = Sharding(num_shards=100).shard(dset["train"]), dset["test"] if model_args.freeze_backbone: logger.info("Freezing backbone") @@ -102,13 +107,19 @@ def sklearn_metrics(prediction: EvalPrediction): collator = MultilabelDataCollator(label2id=label2id) + max_steps = calculate_max_steps(training_args, dset) + print("Before:\n" + str(training_args.max_steps)) + training_args.max_steps = max_steps + print("After:\n" + str(training_args.max_steps)) + print(max_steps) + trainer = Trainer( model=model, args=training_args, train_dataset=train_dset, eval_dataset=val_dset, data_collator=collator, - compute_metrics=sklearn_metrics, + compute_metrics=sklearn_metrics ) trainer.train() diff --git a/grants_tagger_light/utils/sharding.py b/grants_tagger_light/utils/sharding.py new file mode 100644 index 00000000..f6e20df3 --- /dev/null +++ b/grants_tagger_light/utils/sharding.py @@ -0,0 +1,23 @@ +from datasets import load_dataset, IterableDataset + + +class Sharding: + def __init__(self, num_shards=100): + """ + Sharding to prevent processing time issues, suggested here https://github.com/huggingface/datasets/issues/2252 + Args: + num_shards: num of shards to split the train dataset into. + """ + self.num_shards = num_shards + + @classmethod + def gen_from_shards(cls, _shards): + for shard in _shards: + for example in shard: + yield example + + def shard(self, dataset): + shards = [dataset.shard(num_shards=self.num_shards, index=index, contiguous=True) + for index in range(self.num_shards)] + + return IterableDataset.from_generator(self.gen_from_shards, gen_kwargs={"_shards": shards}) diff --git a/grants_tagger_light/utils/utils.py b/grants_tagger_light/utils/utils.py index 05aa36ae..fc8f5cfd 100644 --- a/grants_tagger_light/utils/utils.py +++ b/grants_tagger_light/utils/utils.py @@ -177,3 +177,12 @@ def create_label_binarizer(model_path: str, label_binarizer_path: str): f.write(pickle.dumps(label_binarizer)) return label_binarizer + + +def calculate_max_steps(training_args, dset): + """ This is needed when using IterableDatasets, as there is no __len__ in advance since the dataset is a + generator with yield, so it does not know when to end. + Source: https://discuss.huggingface.co/t/streaming-dataset-into-trainer-does-not-implement-len-max-steps-has-to-be-specified/32893/6""" + train_batch_size = training_args.per_device_train_batch_size + accumulation_steps = training_args.gradient_accumulation_steps + return (len(dset["train"]) / train_batch_size) / accumulation_steps diff --git a/poetry.lock b/poetry.lock index 873c01b6..91b4bc19 100644 --- a/poetry.lock +++ b/poetry.lock @@ -30,25 +30,25 @@ testing = ["datasets", "deepspeed", "evaluate", "parameterized", "pytest", "pyte [[package]] name = "aiobotocore" -version = "2.5.0" +version = "2.5.2" description = "Async client for aws services using botocore and aiohttp" optional = false python-versions = ">=3.7" files = [ - {file = "aiobotocore-2.5.0-py3-none-any.whl", hash = "sha256:9a2a022d7b78ec9a2af0de589916d2721cddbf96264401b78d7a73c1a1435f3b"}, - {file = "aiobotocore-2.5.0.tar.gz", hash = "sha256:6a5b397cddd4f81026aa91a14c7dd2650727425740a5af8ba75127ff663faf67"}, + {file = "aiobotocore-2.5.2-py3-none-any.whl", hash = "sha256:337429ffd3cc367532572d40be809a84c7b5335f3f8eca2f23e09dfaa9a9ef90"}, + {file = "aiobotocore-2.5.2.tar.gz", hash = "sha256:e7399f21570db1c287f1c0c814dd3475dfe1c8166722e2c77ce67f172cbcfa89"}, ] [package.dependencies] -aiohttp = ">=3.3.1" -aioitertools = ">=0.5.1" -boto3 = {version = ">=1.26.76,<1.26.77", optional = true, markers = "extra == \"boto3\""} -botocore = ">=1.29.76,<1.29.77" -wrapt = ">=1.10.10" +aiohttp = ">=3.3.1,<4.0.0" +aioitertools = ">=0.5.1,<1.0.0" +boto3 = {version = ">=1.26.161,<1.26.162", optional = true, markers = "extra == \"boto3\""} +botocore = ">=1.29.161,<1.29.162" +wrapt = ">=1.10.10,<2.0.0" [package.extras] -awscli = ["awscli (>=1.27.76,<1.27.77)"] -boto3 = ["boto3 (>=1.26.76,<1.26.77)"] +awscli = ["awscli (>=1.27.161,<1.27.162)"] +boto3 = ["boto3 (>=1.26.161,<1.26.162)"] [[package]] name = "aiohttp" @@ -211,6 +211,17 @@ files = [ [package.dependencies] vine = ">=5.0.0" +[[package]] +name = "annotated-types" +version = "0.5.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.7" +files = [ + {file = "annotated_types-0.5.0-py3-none-any.whl", hash = "sha256:58da39888f92c276ad970249761ebea80ba544b77acddaa1a4d6cf78287d45fd"}, + {file = "annotated_types-0.5.0.tar.gz", hash = "sha256:47cdc3490d9ac1506ce92c7aaa76c579dc3509ff11e098fc867e5130ab7be802"}, +] + [[package]] name = "antlr4-python3-runtime" version = "4.9.3" @@ -223,13 +234,13 @@ files = [ [[package]] name = "anyio" -version = "3.7.0" +version = "3.7.1" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.7" files = [ - {file = "anyio-3.7.0-py3-none-any.whl", hash = "sha256:eddca883c4175f14df8aedce21054bfca3adb70ffe76a9f607aef9d7fa2ea7f0"}, - {file = "anyio-3.7.0.tar.gz", hash = "sha256:275d9973793619a5374e1c89a4f4ad3f4b0a5510a2b5b939444bee8f4c4d37ce"}, + {file = "anyio-3.7.1-py3-none-any.whl", hash = "sha256:91dee416e570e92c64041bd18b900d1d6fa78dff7048769ce5ac5ddad004fbb5"}, + {file = "anyio-3.7.1.tar.gz", hash = "sha256:44a3c9aba0f5defa43261a8b3efb97891f2bd7d804e0e1f56419befa1adfc780"}, ] [package.dependencies] @@ -238,7 +249,7 @@ idna = ">=2.8" sniffio = ">=1.1" [package.extras] -doc = ["Sphinx (>=6.1.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme", "sphinxcontrib-jquery"] +doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-jquery"] test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] trio = ["trio (<0.22)"] @@ -297,13 +308,13 @@ files = [ [[package]] name = "asyncssh" -version = "2.13.1" +version = "2.13.2" description = "AsyncSSH: Asynchronous SSHv2 client and server library" optional = false python-versions = ">= 3.6" files = [ - {file = "asyncssh-2.13.1-py3-none-any.whl", hash = "sha256:c90eb5e2b4f9a7cc6e6af01fd844563d722c0d667f8c5f51fe5b3c2a79fa0575"}, - {file = "asyncssh-2.13.1.tar.gz", hash = "sha256:ebbb83c05c0b45cf230de1ef2f06059e360f9afa5c3ddf60fc92faf7b94ff887"}, + {file = "asyncssh-2.13.2-py3-none-any.whl", hash = "sha256:c7dfe9085c0659acb2ef0d177fb12421e92a20d52b98ab83eed4a5916a1d60cc"}, + {file = "asyncssh-2.13.2.tar.gz", hash = "sha256:991e531c4bb7dbec62b754878d96a3246338aac11a28ce3c3e99018fb2f5828c"}, ] [package.dependencies] @@ -408,36 +419,33 @@ files = [ [[package]] name = "black" -version = "23.3.0" +version = "23.7.0" description = "The uncompromising code formatter." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"}, - {file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"}, - {file = "black-23.3.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:7c3eb7cea23904399866c55826b31c1f55bbcd3890ce22ff70466b907b6775c2"}, - {file = "black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32daa9783106c28815d05b724238e30718f34155653d4d6e125dc7daec8e260c"}, - {file = "black-23.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:35d1381d7a22cc5b2be2f72c7dfdae4072a3336060635718cc7e1ede24221d6c"}, - {file = "black-23.3.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:a8a968125d0a6a404842fa1bf0b349a568634f856aa08ffaff40ae0dfa52e7c6"}, - {file = "black-23.3.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c7ab5790333c448903c4b721b59c0d80b11fe5e9803d8703e84dcb8da56fec1b"}, - {file = "black-23.3.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:a6f6886c9869d4daae2d1715ce34a19bbc4b95006d20ed785ca00fa03cba312d"}, - {file = "black-23.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3c333ea1dd6771b2d3777482429864f8e258899f6ff05826c3a4fcc5ce3f70"}, - {file = "black-23.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:11c410f71b876f961d1de77b9699ad19f939094c3a677323f43d7a29855fe326"}, - {file = "black-23.3.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:1d06691f1eb8de91cd1b322f21e3bfc9efe0c7ca1f0e1eb1db44ea367dff656b"}, - {file = "black-23.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50cb33cac881766a5cd9913e10ff75b1e8eb71babf4c7104f2e9c52da1fb7de2"}, - {file = "black-23.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e114420bf26b90d4b9daa597351337762b63039752bdf72bf361364c1aa05925"}, - {file = "black-23.3.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:48f9d345675bb7fbc3dd85821b12487e1b9a75242028adad0333ce36ed2a6d27"}, - {file = "black-23.3.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:714290490c18fb0126baa0fca0a54ee795f7502b44177e1ce7624ba1c00f2331"}, - {file = "black-23.3.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:064101748afa12ad2291c2b91c960be28b817c0c7eaa35bec09cc63aa56493c5"}, - {file = "black-23.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:562bd3a70495facf56814293149e51aa1be9931567474993c7942ff7d3533961"}, - {file = "black-23.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:e198cf27888ad6f4ff331ca1c48ffc038848ea9f031a3b40ba36aced7e22f2c8"}, - {file = "black-23.3.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:3238f2aacf827d18d26db07524e44741233ae09a584273aa059066d644ca7b30"}, - {file = "black-23.3.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:f0bd2f4a58d6666500542b26354978218a9babcdc972722f4bf90779524515f3"}, - {file = "black-23.3.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:92c543f6854c28a3c7f39f4d9b7694f9a6eb9d3c5e2ece488c327b6e7ea9b266"}, - {file = "black-23.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a150542a204124ed00683f0db1f5cf1c2aaaa9cc3495b7a3b5976fb136090ab"}, - {file = "black-23.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6b39abdfb402002b8a7d030ccc85cf5afff64ee90fa4c5aebc531e3ad0175ddb"}, - {file = "black-23.3.0-py3-none-any.whl", hash = "sha256:ec751418022185b0c1bb7d7736e6933d40bbb14c14a0abcf9123d1b159f98dd4"}, - {file = "black-23.3.0.tar.gz", hash = "sha256:1c7b8d606e728a41ea1ccbd7264677e494e87cf630e399262ced92d4a8dac940"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"}, + {file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"}, + {file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"}, + {file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"}, + {file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"}, + {file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"}, + {file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"}, + {file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"}, + {file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"}, + {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, + {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, ] [package.dependencies] @@ -456,17 +464,17 @@ uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "boto3" -version = "1.26.76" +version = "1.26.161" description = "The AWS SDK for Python" optional = false python-versions = ">= 3.7" files = [ - {file = "boto3-1.26.76-py3-none-any.whl", hash = "sha256:b4c2969b7677762914394b8273cc1905dfe5b71f250741c1a575487ae357e729"}, - {file = "boto3-1.26.76.tar.gz", hash = "sha256:30c7d967ed1c6b5a05643e42cae9d4d36c3f1cb6782637ddc7007a104cfd9027"}, + {file = "boto3-1.26.161-py3-none-any.whl", hash = "sha256:f66e5c9dbe7f34383bcf64fa6070771355c11a44dd75c7f1279f2f37e1c89183"}, + {file = "boto3-1.26.161.tar.gz", hash = "sha256:662731e464d14af1035f44fc6a46b0e3112ee011ac0a5ed416d205daa3e15f25"}, ] [package.dependencies] -botocore = ">=1.29.76,<1.30.0" +botocore = ">=1.29.161,<1.30.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.6.0,<0.7.0" @@ -475,13 +483,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.29.76" +version = "1.29.161" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">= 3.7" files = [ - {file = "botocore-1.29.76-py3-none-any.whl", hash = "sha256:70735b00cd529f152992231ca6757e458e5ec25db43767b3526e9a35b2f143b7"}, - {file = "botocore-1.29.76.tar.gz", hash = "sha256:c2f67b6b3f8acf2968eafca06526f07b9fb0d27bac4c68a635d51abb675134a7"}, + {file = "botocore-1.29.161-py3-none-any.whl", hash = "sha256:b906999dd53dda2ef0ef6f7f55fcc81a4b06b9f1c8a9f65c546e0b981f959f5f"}, + {file = "botocore-1.29.161.tar.gz", hash = "sha256:a50edd715eb510343e27849f36483804aae4b871590db4d4996aa53368dcac40"}, ] [package.dependencies] @@ -647,97 +655,97 @@ files = [ [[package]] name = "charset-normalizer" -version = "3.1.0" +version = "3.2.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" files = [ - {file = "charset-normalizer-3.1.0.tar.gz", hash = "sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e0ac8959c929593fee38da1c2b64ee9778733cdf03c482c9ff1d508b6b593b2b"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d7fc3fca01da18fbabe4625d64bb612b533533ed10045a2ac3dd194bfa656b60"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:04eefcee095f58eaabe6dc3cc2262f3bcd776d2c67005880894f447b3f2cb9c1"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20064ead0717cf9a73a6d1e779b23d149b53daf971169289ed2ed43a71e8d3b0"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1435ae15108b1cb6fffbcea2af3d468683b7afed0169ad718451f8db5d1aff6f"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c84132a54c750fda57729d1e2599bb598f5fa0344085dbde5003ba429a4798c0"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75f2568b4189dda1c567339b48cba4ac7384accb9c2a7ed655cd86b04055c795"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11d3bcb7be35e7b1bba2c23beedac81ee893ac9871d0ba79effc7fc01167db6c"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:891cf9b48776b5c61c700b55a598621fdb7b1e301a550365571e9624f270c203"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5f008525e02908b20e04707a4f704cd286d94718f48bb33edddc7d7b584dddc1"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:b06f0d3bf045158d2fb8837c5785fe9ff9b8c93358be64461a1089f5da983137"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:49919f8400b5e49e961f320c735388ee686a62327e773fa5b3ce6721f7e785ce"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:22908891a380d50738e1f978667536f6c6b526a2064156203d418f4856d6e86a"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-win32.whl", hash = "sha256:12d1a39aa6b8c6f6248bb54550efcc1c38ce0d8096a146638fd4738e42284448"}, - {file = "charset_normalizer-3.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:65ed923f84a6844de5fd29726b888e58c62820e0769b76565480e1fdc3d062f8"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9a3267620866c9d17b959a84dd0bd2d45719b817245e49371ead79ed4f710d19"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6734e606355834f13445b6adc38b53c0fd45f1a56a9ba06c2058f86893ae8017"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f8303414c7b03f794347ad062c0516cee0e15f7a612abd0ce1e25caf6ceb47df"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf53a6cebad0eae578f062c7d462155eada9c172bd8c4d250b8c1d8eb7f916a"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dc5b6a8ecfdc5748a7e429782598e4f17ef378e3e272eeb1340ea57c9109f41"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e1b25e3ad6c909f398df8921780d6a3d120d8c09466720226fc621605b6f92b1"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca564606d2caafb0abe6d1b5311c2649e8071eb241b2d64e75a0d0065107e62"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b82fab78e0b1329e183a65260581de4375f619167478dddab510c6c6fb04d9b6"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bd7163182133c0c7701b25e604cf1611c0d87712e56e88e7ee5d72deab3e76b5"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:11d117e6c63e8f495412d37e7dc2e2fff09c34b2d09dbe2bee3c6229577818be"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:cf6511efa4801b9b38dc5546d7547d5b5c6ef4b081c60b23e4d941d0eba9cbeb"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:abc1185d79f47c0a7aaf7e2412a0eb2c03b724581139193d2d82b3ad8cbb00ac"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cb7b2ab0188829593b9de646545175547a70d9a6e2b63bf2cd87a0a391599324"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-win32.whl", hash = "sha256:c36bcbc0d5174a80d6cccf43a0ecaca44e81d25be4b7f90f0ed7bcfbb5a00909"}, - {file = "charset_normalizer-3.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:cca4def576f47a09a943666b8f829606bcb17e2bc2d5911a46c8f8da45f56755"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:0c95f12b74681e9ae127728f7e5409cbbef9cd914d5896ef238cc779b8152373"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fca62a8301b605b954ad2e9c3666f9d97f63872aa4efcae5492baca2056b74ab"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac0aa6cd53ab9a31d397f8303f92c42f534693528fafbdb997c82bae6e477ad9"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3af8e0f07399d3176b179f2e2634c3ce9c1301379a6b8c9c9aeecd481da494f"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a5fc78f9e3f501a1614a98f7c54d3969f3ad9bba8ba3d9b438c3bc5d047dd28"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:628c985afb2c7d27a4800bfb609e03985aaecb42f955049957814e0491d4006d"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:74db0052d985cf37fa111828d0dd230776ac99c740e1a758ad99094be4f1803d"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1e8fcdd8f672a1c4fc8d0bd3a2b576b152d2a349782d1eb0f6b8e52e9954731d"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:04afa6387e2b282cf78ff3dbce20f0cc071c12dc8f685bd40960cc68644cfea6"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:dd5653e67b149503c68c4018bf07e42eeed6b4e956b24c00ccdf93ac79cdff84"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d2686f91611f9e17f4548dbf050e75b079bbc2a82be565832bc8ea9047b61c8c"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-win32.whl", hash = "sha256:4155b51ae05ed47199dc5b2a4e62abccb274cee6b01da5b895099b61b1982974"}, - {file = "charset_normalizer-3.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:322102cdf1ab682ecc7d9b1c5eed4ec59657a65e1c146a0da342b78f4112db23"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e633940f28c1e913615fd624fcdd72fdba807bf53ea6925d6a588e84e1151531"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3a06f32c9634a8705f4ca9946d667609f52cf130d5548881401f1eb2c39b1e2c"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7381c66e0561c5757ffe616af869b916c8b4e42b367ab29fedc98481d1e74e14"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3573d376454d956553c356df45bb824262c397c6e26ce43e8203c4c540ee0acb"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e89df2958e5159b811af9ff0f92614dabf4ff617c03a4c1c6ff53bf1c399e0e1"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:78cacd03e79d009d95635e7d6ff12c21eb89b894c354bd2b2ed0b4763373693b"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de5695a6f1d8340b12a5d6d4484290ee74d61e467c39ff03b39e30df62cf83a0"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c60b9c202d00052183c9be85e5eaf18a4ada0a47d188a83c8f5c5b23252f649"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:f645caaf0008bacf349875a974220f1f1da349c5dbe7c4ec93048cdc785a3326"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ea9f9c6034ea2d93d9147818f17c2a0860d41b71c38b9ce4d55f21b6f9165a11"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:80d1543d58bd3d6c271b66abf454d437a438dff01c3e62fdbcd68f2a11310d4b"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:73dc03a6a7e30b7edc5b01b601e53e7fc924b04e1835e8e407c12c037e81adbd"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6f5c2e7bc8a4bf7c426599765b1bd33217ec84023033672c1e9a8b35eaeaaaf8"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-win32.whl", hash = "sha256:12a2b561af122e3d94cdb97fe6fb2bb2b82cef0cdca131646fdb940a1eda04f0"}, - {file = "charset_normalizer-3.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3160a0fd9754aab7d47f95a6b63ab355388d890163eb03b2d2b87ab0a30cfa59"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38e812a197bf8e71a59fe55b757a84c1f946d0ac114acafaafaf21667a7e169e"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6baf0baf0d5d265fa7944feb9f7451cc316bfe30e8df1a61b1bb08577c554f31"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8f25e17ab3039b05f762b0a55ae0b3632b2e073d9c8fc88e89aca31a6198e88f"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3747443b6a904001473370d7810aa19c3a180ccd52a7157aacc264a5ac79265e"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b116502087ce8a6b7a5f1814568ccbd0e9f6cfd99948aa59b0e241dc57cf739f"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d16fd5252f883eb074ca55cb622bc0bee49b979ae4e8639fff6ca3ff44f9f854"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fa558996782fc226b529fdd2ed7866c2c6ec91cee82735c98a197fae39f706"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f6c7a8a57e9405cad7485f4c9d3172ae486cfef1344b5ddd8e5239582d7355e"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ac3775e3311661d4adace3697a52ac0bab17edd166087d493b52d4f4f553f9f0"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:10c93628d7497c81686e8e5e557aafa78f230cd9e77dd0c40032ef90c18f2230"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:6f4f4668e1831850ebcc2fd0b1cd11721947b6dc7c00bf1c6bd3c929ae14f2c7"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:0be65ccf618c1e7ac9b849c315cc2e8a8751d9cfdaa43027d4f6624bd587ab7e"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:53d0a3fa5f8af98a1e261de6a3943ca631c526635eb5817a87a59d9a57ebf48f"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-win32.whl", hash = "sha256:a04f86f41a8916fe45ac5024ec477f41f886b3c435da2d4e3d2709b22ab02af1"}, - {file = "charset_normalizer-3.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:830d2948a5ec37c386d3170c483063798d7879037492540f10a475e3fd6f244b"}, - {file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"}, + {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"}, + {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"}, + {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"}, + {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"}, + {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"}, + {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"}, + {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, ] [[package]] name = "click" -version = "8.1.3" +version = "8.1.5" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" files = [ - {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, - {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, + {file = "click-8.1.5-py3-none-any.whl", hash = "sha256:e576aa487d679441d7d30abb87e1b43d24fc53bffb8758443b1a9e1cee504548"}, + {file = "click-8.1.5.tar.gz", hash = "sha256:4be4b1af8d665c6d942909916d31a213a106800c47d0eeba73d34da3cbc11367"}, ] [package.dependencies] @@ -862,30 +870,34 @@ six = "*" [[package]] name = "cryptography" -version = "41.0.1" +version = "41.0.2" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." optional = false python-versions = ">=3.7" files = [ - {file = "cryptography-41.0.1-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:f73bff05db2a3e5974a6fd248af2566134d8981fd7ab012e5dd4ddb1d9a70699"}, - {file = "cryptography-41.0.1-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1a5472d40c8f8e91ff7a3d8ac6dfa363d8e3138b961529c996f3e2df0c7a411a"}, - {file = "cryptography-41.0.1-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fa01527046ca5facdf973eef2535a27fec4cb651e4daec4d043ef63f6ecd4ca"}, - {file = "cryptography-41.0.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b46e37db3cc267b4dea1f56da7346c9727e1209aa98487179ee8ebed09d21e43"}, - {file = "cryptography-41.0.1-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d198820aba55660b4d74f7b5fd1f17db3aa5eb3e6893b0a41b75e84e4f9e0e4b"}, - {file = "cryptography-41.0.1-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:948224d76c4b6457349d47c0c98657557f429b4e93057cf5a2f71d603e2fc3a3"}, - {file = "cryptography-41.0.1-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:059e348f9a3c1950937e1b5d7ba1f8e968508ab181e75fc32b879452f08356db"}, - {file = "cryptography-41.0.1-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:b4ceb5324b998ce2003bc17d519080b4ec8d5b7b70794cbd2836101406a9be31"}, - {file = "cryptography-41.0.1-cp37-abi3-win32.whl", hash = "sha256:8f4ab7021127a9b4323537300a2acfb450124b2def3756f64dc3a3d2160ee4b5"}, - {file = "cryptography-41.0.1-cp37-abi3-win_amd64.whl", hash = "sha256:1fee5aacc7367487b4e22484d3c7e547992ed726d14864ee33c0176ae43b0d7c"}, - {file = "cryptography-41.0.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:9a6c7a3c87d595608a39980ebaa04d5a37f94024c9f24eb7d10262b92f739ddb"}, - {file = "cryptography-41.0.1-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:5d092fdfedaec4cbbffbf98cddc915ba145313a6fdaab83c6e67f4e6c218e6f3"}, - {file = "cryptography-41.0.1-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a8e6c2de6fbbcc5e14fd27fb24414507cb3333198ea9ab1258d916f00bc3039"}, - {file = "cryptography-41.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:cb33ccf15e89f7ed89b235cff9d49e2e62c6c981a6061c9c8bb47ed7951190bc"}, - {file = "cryptography-41.0.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5f0ff6e18d13a3de56f609dd1fd11470918f770c6bd5d00d632076c727d35485"}, - {file = "cryptography-41.0.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:7bfc55a5eae8b86a287747053140ba221afc65eb06207bedf6e019b8934b477c"}, - {file = "cryptography-41.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:eb8163f5e549a22888c18b0d53d6bb62a20510060a22fd5a995ec8a05268df8a"}, - {file = "cryptography-41.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8dde71c4169ec5ccc1087bb7521d54251c016f126f922ab2dfe6649170a3b8c5"}, - {file = "cryptography-41.0.1.tar.gz", hash = "sha256:d34579085401d3f49762d2f7d6634d6b6c2ae1242202e860f4d26b046e3a1006"}, + {file = "cryptography-41.0.2-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:01f1d9e537f9a15b037d5d9ee442b8c22e3ae11ce65ea1f3316a41c78756b711"}, + {file = "cryptography-41.0.2-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:079347de771f9282fbfe0e0236c716686950c19dee1b76240ab09ce1624d76d7"}, + {file = "cryptography-41.0.2-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:439c3cc4c0d42fa999b83ded80a9a1fb54d53c58d6e59234cfe97f241e6c781d"}, + {file = "cryptography-41.0.2-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f14ad275364c8b4e525d018f6716537ae7b6d369c094805cae45300847e0894f"}, + {file = "cryptography-41.0.2-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:84609ade00a6ec59a89729e87a503c6e36af98ddcd566d5f3be52e29ba993182"}, + {file = "cryptography-41.0.2-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:49c3222bb8f8e800aead2e376cbef687bc9e3cb9b58b29a261210456a7783d83"}, + {file = "cryptography-41.0.2-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:d73f419a56d74fef257955f51b18d046f3506270a5fd2ac5febbfa259d6c0fa5"}, + {file = "cryptography-41.0.2-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:2a034bf7d9ca894720f2ec1d8b7b5832d7e363571828037f9e0c4f18c1b58a58"}, + {file = "cryptography-41.0.2-cp37-abi3-win32.whl", hash = "sha256:d124682c7a23c9764e54ca9ab5b308b14b18eba02722b8659fb238546de83a76"}, + {file = "cryptography-41.0.2-cp37-abi3-win_amd64.whl", hash = "sha256:9c3fe6534d59d071ee82081ca3d71eed3210f76ebd0361798c74abc2bcf347d4"}, + {file = "cryptography-41.0.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a719399b99377b218dac6cf547b6ec54e6ef20207b6165126a280b0ce97e0d2a"}, + {file = "cryptography-41.0.2-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:182be4171f9332b6741ee818ec27daff9fb00349f706629f5cbf417bd50e66fd"}, + {file = "cryptography-41.0.2-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:7a9a3bced53b7f09da251685224d6a260c3cb291768f54954e28f03ef14e3766"}, + {file = "cryptography-41.0.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:f0dc40e6f7aa37af01aba07277d3d64d5a03dc66d682097541ec4da03cc140ee"}, + {file = "cryptography-41.0.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:674b669d5daa64206c38e507808aae49904c988fa0a71c935e7006a3e1e83831"}, + {file = "cryptography-41.0.2-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:7af244b012711a26196450d34f483357e42aeddb04128885d95a69bd8b14b69b"}, + {file = "cryptography-41.0.2-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:9b6d717393dbae53d4e52684ef4f022444fc1cce3c48c38cb74fca29e1f08eaa"}, + {file = "cryptography-41.0.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:192255f539d7a89f2102d07d7375b1e0a81f7478925b3bc2e0549ebf739dae0e"}, + {file = "cryptography-41.0.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f772610fe364372de33d76edcd313636a25684edb94cee53fd790195f5989d14"}, + {file = "cryptography-41.0.2-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:b332cba64d99a70c1e0836902720887fb4529ea49ea7f5462cf6640e095e11d2"}, + {file = "cryptography-41.0.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:9a6673c1828db6270b76b22cc696f40cde9043eb90373da5c2f8f2158957f42f"}, + {file = "cryptography-41.0.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:342f3767e25876751e14f8459ad85e77e660537ca0a066e10e75df9c9e9099f0"}, + {file = "cryptography-41.0.2.tar.gz", hash = "sha256:7d230bf856164de164ecb615ccc14c7fc6de6906ddd5b491f3af90d3514c925c"}, ] [package.dependencies] @@ -1004,13 +1016,13 @@ files = [ [[package]] name = "distlib" -version = "0.3.6" +version = "0.3.7" description = "Distribution utilities" optional = false python-versions = "*" files = [ - {file = "distlib-0.3.6-py2.py3-none-any.whl", hash = "sha256:f35c4b692542ca110de7ef0bea44d73981caeb34ca0b9b6b2e6d7790dda8f80e"}, - {file = "distlib-0.3.6.tar.gz", hash = "sha256:14bad2d9b04d3a36127ac97f30b12a19268f211063d8f8ee4f47108896e11b46"}, + {file = "distlib-0.3.7-py2.py3-none-any.whl", hash = "sha256:2e24928bc811348f0feb63014e97aaae3037f2cf48712d51ae61df7fd6075057"}, + {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, ] [[package]] @@ -1238,13 +1250,13 @@ tests = ["dvc[testing]", "flaky (==3.7.0)", "mypy (==0.910)", "pylint (==2.15.9) [[package]] name = "dvc-objects" -version = "0.23.0" +version = "0.23.1" description = "dvc objects" optional = false python-versions = ">=3.8" files = [ - {file = "dvc-objects-0.23.0.tar.gz", hash = "sha256:9e6a470eb910f3e2bd3156961e4d8617c51be9a464ffce5d8938cb2bc5bfddb3"}, - {file = "dvc_objects-0.23.0-py3-none-any.whl", hash = "sha256:6d4583a56d562cdc4b534790d3963a4401ffa4dde4f50685f3cc4d34c75d2e9b"}, + {file = "dvc-objects-0.23.1.tar.gz", hash = "sha256:159ec9bede7443fbcbc64d33e53071ae51bea86fc82a764ab652655c35b58776"}, + {file = "dvc_objects-0.23.1-py3-none-any.whl", hash = "sha256:118640f4cf83415cd2bf104be39712660e470a9ac061e55ac688f7ab703677c4"}, ] [package.dependencies] @@ -1299,13 +1311,13 @@ tests = ["Pygments (==2.10.0)", "collective.checkdocs (==0.2)", "dvc[testing]", [[package]] name = "dvc-studio-client" -version = "0.10.0" +version = "0.11.0" description = "Small library to post data from DVC/DVCLive to Iterative Studio" optional = false python-versions = ">=3.8" files = [ - {file = "dvc-studio-client-0.10.0.tar.gz", hash = "sha256:1d6c70d7c802e150a61e71c0205555a972898c340f25ab1f759564bd88c0f248"}, - {file = "dvc_studio_client-0.10.0-py3-none-any.whl", hash = "sha256:3b39eb242914b034c3e3634d74ded17cfa2de1d55473c833a511324bae870782"}, + {file = "dvc-studio-client-0.11.0.tar.gz", hash = "sha256:9179acc39bb9acfb54a5369142c835dc2428bd285e41281b005739ce63d9d55b"}, + {file = "dvc_studio_client-0.11.0-py3-none-any.whl", hash = "sha256:2832fe0bdf723dbe51320abcde238bd0a1e1a3befa15c1f05cc9ed2ca25fb39f"}, ] [package.dependencies] @@ -1343,13 +1355,13 @@ tests = ["celery-types (==0.15.0)", "flaky (==3.7.0)", "mypy (==0.971)", "pylint [[package]] name = "exceptiongroup" -version = "1.1.1" +version = "1.1.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" files = [ - {file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"}, - {file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"}, + {file = "exceptiongroup-1.1.2-py3-none-any.whl", hash = "sha256:e346e69d186172ca7cf029c8c1d16235aa0e04035e5750b4b95039e65204328f"}, + {file = "exceptiongroup-1.1.2.tar.gz", hash = "sha256:12c3e887d6485d16943a309616de20ae5582633e0a2eda17f4e10fd61c1e8af5"}, ] [package.extras] @@ -1386,100 +1398,87 @@ six = ">=1.12,<2.0" [[package]] name = "flufl-lock" -version = "7.1.1" +version = "8.0.1" description = "NFS-safe file locking with timeouts for POSIX and Windows" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "flufl.lock-7.1.1-py3-none-any.whl", hash = "sha256:96d2c0448ba9fd8fc65d5d681ed7217c8e1625149c1c880bba50559bb680a615"}, - {file = "flufl.lock-7.1.1.tar.gz", hash = "sha256:af14172b35bbc58687bd06b70d1693fd8d48cbf0ffde7e51a618c148ae24042d"}, + {file = "flufl_lock-8.0.1-py3-none-any.whl", hash = "sha256:a3df854d76173d59813fdcba91671234b59e2a14db3390793745c77a7bb92d9d"}, + {file = "flufl_lock-8.0.1.tar.gz", hash = "sha256:edb7f1f3f8b4805ef6a6a23b9a3975bfc9b7c15eb33e10b0b086d0caa2a97e04"}, ] [package.dependencies] -atpublic = ">=2.3" -psutil = ">=5.9.0" +atpublic = "*" +psutil = "*" [[package]] name = "frozenlist" -version = "1.3.3" +version = "1.4.0" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff8bf625fe85e119553b5383ba0fb6aa3d0ec2ae980295aaefa552374926b3f4"}, - {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dfbac4c2dfcc082fcf8d942d1e49b6aa0766c19d3358bd86e2000bf0fa4a9cf0"}, - {file = "frozenlist-1.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b1c63e8d377d039ac769cd0926558bb7068a1f7abb0f003e3717ee003ad85530"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fdfc24dcfce5b48109867c13b4cb15e4660e7bd7661741a391f821f23dfdca7"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c926450857408e42f0bbc295e84395722ce74bae69a3b2aa2a65fe22cb14b99"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1841e200fdafc3d51f974d9d377c079a0694a8f06de2e67b48150328d66d5483"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f470c92737afa7d4c3aacc001e335062d582053d4dbe73cda126f2d7031068dd"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:783263a4eaad7c49983fe4b2e7b53fa9770c136c270d2d4bbb6d2192bf4d9caf"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:924620eef691990dfb56dc4709f280f40baee568c794b5c1885800c3ecc69816"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ae4dc05c465a08a866b7a1baf360747078b362e6a6dbeb0c57f234db0ef88ae0"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:bed331fe18f58d844d39ceb398b77d6ac0b010d571cba8267c2e7165806b00ce"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:02c9ac843e3390826a265e331105efeab489ffaf4dd86384595ee8ce6d35ae7f"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9545a33965d0d377b0bc823dcabf26980e77f1b6a7caa368a365a9497fb09420"}, - {file = "frozenlist-1.3.3-cp310-cp310-win32.whl", hash = "sha256:d5cd3ab21acbdb414bb6c31958d7b06b85eeb40f66463c264a9b343a4e238642"}, - {file = "frozenlist-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:b756072364347cb6aa5b60f9bc18e94b2f79632de3b0190253ad770c5df17db1"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b4395e2f8d83fbe0c627b2b696acce67868793d7d9750e90e39592b3626691b7"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14143ae966a6229350021384870458e4777d1eae4c28d1a7aa47f24d030e6678"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5d8860749e813a6f65bad8285a0520607c9500caa23fea6ee407e63debcdbef6"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23d16d9f477bb55b6154654e0e74557040575d9d19fe78a161bd33d7d76808e8"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb82dbba47a8318e75f679690190c10a5e1f447fbf9df41cbc4c3afd726d88cb"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9309869032abb23d196cb4e4db574232abe8b8be1339026f489eeb34a4acfd91"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a97b4fe50b5890d36300820abd305694cb865ddb7885049587a5678215782a6b"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c188512b43542b1e91cadc3c6c915a82a5eb95929134faf7fd109f14f9892ce4"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:303e04d422e9b911a09ad499b0368dc551e8c3cd15293c99160c7f1f07b59a48"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0771aed7f596c7d73444c847a1c16288937ef988dc04fb9f7be4b2aa91db609d"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:66080ec69883597e4d026f2f71a231a1ee9887835902dbe6b6467d5a89216cf6"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:41fe21dc74ad3a779c3d73a2786bdf622ea81234bdd4faf90b8b03cad0c2c0b4"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f20380df709d91525e4bee04746ba612a4df0972c1b8f8e1e8af997e678c7b81"}, - {file = "frozenlist-1.3.3-cp311-cp311-win32.whl", hash = "sha256:f30f1928162e189091cf4d9da2eac617bfe78ef907a761614ff577ef4edfb3c8"}, - {file = "frozenlist-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a6394d7dadd3cfe3f4b3b186e54d5d8504d44f2d58dcc89d693698e8b7132b32"}, - {file = "frozenlist-1.3.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8df3de3a9ab8325f94f646609a66cbeeede263910c5c0de0101079ad541af332"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0693c609e9742c66ba4870bcee1ad5ff35462d5ffec18710b4ac89337ff16e27"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd4210baef299717db0a600d7a3cac81d46ef0e007f88c9335db79f8979c0d3d"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:394c9c242113bfb4b9aa36e2b80a05ffa163a30691c7b5a29eba82e937895d5e"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6327eb8e419f7d9c38f333cde41b9ae348bec26d840927332f17e887a8dcb70d"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e24900aa13212e75e5b366cb9065e78bbf3893d4baab6052d1aca10d46d944c"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3843f84a6c465a36559161e6c59dce2f2ac10943040c2fd021cfb70d58c4ad56"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:84610c1502b2461255b4c9b7d5e9c48052601a8957cd0aea6ec7a7a1e1fb9420"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:c21b9aa40e08e4f63a2f92ff3748e6b6c84d717d033c7b3438dd3123ee18f70e"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:efce6ae830831ab6a22b9b4091d411698145cb9b8fc869e1397ccf4b4b6455cb"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:40de71985e9042ca00b7953c4f41eabc3dc514a2d1ff534027f091bc74416401"}, - {file = "frozenlist-1.3.3-cp37-cp37m-win32.whl", hash = "sha256:180c00c66bde6146a860cbb81b54ee0df350d2daf13ca85b275123bbf85de18a"}, - {file = "frozenlist-1.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9bbbcedd75acdfecf2159663b87f1bb5cfc80e7cd99f7ddd9d66eb98b14a8411"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:034a5c08d36649591be1cbb10e09da9f531034acfe29275fc5454a3b101ce41a"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ba64dc2b3b7b158c6660d49cdb1d872d1d0bf4e42043ad8d5006099479a194e5"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:47df36a9fe24054b950bbc2db630d508cca3aa27ed0566c0baf661225e52c18e"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:008a054b75d77c995ea26629ab3a0c0d7281341f2fa7e1e85fa6153ae29ae99c"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:841ea19b43d438a80b4de62ac6ab21cfe6827bb8a9dc62b896acc88eaf9cecba"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e235688f42b36be2b6b06fc37ac2126a73b75fb8d6bc66dd632aa35286238703"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca713d4af15bae6e5d79b15c10c8522859a9a89d3b361a50b817c98c2fb402a2"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ac5995f2b408017b0be26d4a1d7c61bce106ff3d9e3324374d66b5964325448"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a4ae8135b11652b08a8baf07631d3ebfe65a4c87909dbef5fa0cdde440444ee4"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4ea42116ceb6bb16dbb7d526e242cb6747b08b7710d9782aa3d6732bd8d27649"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:810860bb4bdce7557bc0febb84bbd88198b9dbc2022d8eebe5b3590b2ad6c842"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:ee78feb9d293c323b59a6f2dd441b63339a30edf35abcb51187d2fc26e696d13"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0af2e7c87d35b38732e810befb9d797a99279cbb85374d42ea61c1e9d23094b3"}, - {file = "frozenlist-1.3.3-cp38-cp38-win32.whl", hash = "sha256:899c5e1928eec13fd6f6d8dc51be23f0d09c5281e40d9cf4273d188d9feeaf9b"}, - {file = "frozenlist-1.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:7f44e24fa70f6fbc74aeec3e971f60a14dde85da364aa87f15d1be94ae75aeef"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:2b07ae0c1edaa0a36339ec6cce700f51b14a3fc6545fdd32930d2c83917332cf"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ebb86518203e12e96af765ee89034a1dbb0c3c65052d1b0c19bbbd6af8a145e1"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5cf820485f1b4c91e0417ea0afd41ce5cf5965011b3c22c400f6d144296ccbc0"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c11e43016b9024240212d2a65043b70ed8dfd3b52678a1271972702d990ac6d"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8fa3c6e3305aa1146b59a09b32b2e04074945ffcfb2f0931836d103a2c38f936"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:352bd4c8c72d508778cf05ab491f6ef36149f4d0cb3c56b1b4302852255d05d5"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:65a5e4d3aa679610ac6e3569e865425b23b372277f89b5ef06cf2cdaf1ebf22b"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e2c1185858d7e10ff045c496bbf90ae752c28b365fef2c09cf0fa309291669"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f163d2fd041c630fed01bc48d28c3ed4a3b003c00acd396900e11ee5316b56bb"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:05cdb16d09a0832eedf770cb7bd1fe57d8cf4eaf5aced29c4e41e3f20b30a784"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:8bae29d60768bfa8fb92244b74502b18fae55a80eac13c88eb0b496d4268fd2d"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eedab4c310c0299961ac285591acd53dc6723a1ebd90a57207c71f6e0c2153ab"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3bbdf44855ed8f0fbcd102ef05ec3012d6a4fd7c7562403f76ce6a52aeffb2b1"}, - {file = "frozenlist-1.3.3-cp39-cp39-win32.whl", hash = "sha256:efa568b885bca461f7c7b9e032655c0c143d305bf01c30caf6db2854a4532b38"}, - {file = "frozenlist-1.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:cfe33efc9cb900a4c46f91a5ceba26d6df370ffddd9ca386eb1d4f0ad97b9ea9"}, - {file = "frozenlist-1.3.3.tar.gz", hash = "sha256:58bcc55721e8a90b88332d6cd441261ebb22342e238296bb330968952fbb3a6a"}, + {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"}, + {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6484756b12f40003c6128bfcc3fa9f0d49a687e171186c2d85ec82e3758c559"}, + {file = "frozenlist-1.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9ac08e601308e41eb533f232dbf6b7e4cea762f9f84f6357136eed926c15d12c"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d081f13b095d74b67d550de04df1c756831f3b83dc9881c38985834387487f1b"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:71932b597f9895f011f47f17d6428252fc728ba2ae6024e13c3398a087c2cdea"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:981b9ab5a0a3178ff413bca62526bb784249421c24ad7381e39d67981be2c326"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e41f3de4df3e80de75845d3e743b3f1c4c8613c3997a912dbf0229fc61a8b963"}, + {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6918d49b1f90821e93069682c06ffde41829c346c66b721e65a5c62b4bab0300"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0e5c8764c7829343d919cc2dfc587a8db01c4f70a4ebbc49abde5d4b158b007b"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8d0edd6b1c7fb94922bf569c9b092ee187a83f03fb1a63076e7774b60f9481a8"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e29cda763f752553fa14c68fb2195150bfab22b352572cb36c43c47bedba70eb"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:0c7c1b47859ee2cac3846fde1c1dc0f15da6cec5a0e5c72d101e0f83dcb67ff9"}, + {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:901289d524fdd571be1c7be054f48b1f88ce8dddcbdf1ec698b27d4b8b9e5d62"}, + {file = "frozenlist-1.4.0-cp310-cp310-win32.whl", hash = "sha256:1a0848b52815006ea6596c395f87449f693dc419061cc21e970f139d466dc0a0"}, + {file = "frozenlist-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:b206646d176a007466358aa21d85cd8600a415c67c9bd15403336c331a10d956"}, + {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de343e75f40e972bae1ef6090267f8260c1446a1695e77096db6cfa25e759a95"}, + {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad2a9eb6d9839ae241701d0918f54c51365a51407fd80f6b8289e2dfca977cc3"}, + {file = "frozenlist-1.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bd7bd3b3830247580de99c99ea2a01416dfc3c34471ca1298bccabf86d0ff4dc"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bdf1847068c362f16b353163391210269e4f0569a3c166bc6a9f74ccbfc7e839"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38461d02d66de17455072c9ba981d35f1d2a73024bee7790ac2f9e361ef1cd0c"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5a32087d720c608f42caed0ef36d2b3ea61a9d09ee59a5142d6070da9041b8f"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dd65632acaf0d47608190a71bfe46b209719bf2beb59507db08ccdbe712f969b"}, + {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261b9f5d17cac914531331ff1b1d452125bf5daa05faf73b71d935485b0c510b"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b89ac9768b82205936771f8d2eb3ce88503b1556324c9f903e7156669f521472"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:008eb8b31b3ea6896da16c38c1b136cb9fec9e249e77f6211d479db79a4eaf01"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e74b0506fa5aa5598ac6a975a12aa8928cbb58e1f5ac8360792ef15de1aa848f"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:490132667476f6781b4c9458298b0c1cddf237488abd228b0b3650e5ecba7467"}, + {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:76d4711f6f6d08551a7e9ef28c722f4a50dd0fc204c56b4bcd95c6cc05ce6fbb"}, + {file = "frozenlist-1.4.0-cp311-cp311-win32.whl", hash = "sha256:a02eb8ab2b8f200179b5f62b59757685ae9987996ae549ccf30f983f40602431"}, + {file = "frozenlist-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:515e1abc578dd3b275d6a5114030b1330ba044ffba03f94091842852f806f1c1"}, + {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0ed05f5079c708fe74bf9027e95125334b6978bf07fd5ab923e9e55e5fbb9d3"}, + {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ca265542ca427bf97aed183c1676e2a9c66942e822b14dc6e5f42e038f92a503"}, + {file = "frozenlist-1.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:491e014f5c43656da08958808588cc6c016847b4360e327a62cb308c791bd2d9"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ae5cd0f333f94f2e03aaf140bb762c64783935cc764ff9c82dff626089bebf"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e78fb68cf9c1a6aa4a9a12e960a5c9dfbdb89b3695197aa7064705662515de2"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5655a942f5f5d2c9ed93d72148226d75369b4f6952680211972a33e59b1dfdc"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c11b0746f5d946fecf750428a95f3e9ebe792c1ee3b1e96eeba145dc631a9672"}, + {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e66d2a64d44d50d2543405fb183a21f76b3b5fd16f130f5c99187c3fb4e64919"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:88f7bc0fcca81f985f78dd0fa68d2c75abf8272b1f5c323ea4a01a4d7a614efc"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5833593c25ac59ede40ed4de6d67eb42928cca97f26feea219f21d0ed0959b79"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:fec520865f42e5c7f050c2a79038897b1c7d1595e907a9e08e3353293ffc948e"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:b826d97e4276750beca7c8f0f1a4938892697a6bcd8ec8217b3312dad6982781"}, + {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ceb6ec0a10c65540421e20ebd29083c50e6d1143278746a4ef6bcf6153171eb8"}, + {file = "frozenlist-1.4.0-cp38-cp38-win32.whl", hash = "sha256:2b8bcf994563466db019fab287ff390fffbfdb4f905fc77bc1c1d604b1c689cc"}, + {file = "frozenlist-1.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:a6c8097e01886188e5be3e6b14e94ab365f384736aa1fca6a0b9e35bd4a30bc7"}, + {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6c38721585f285203e4b4132a352eb3daa19121a035f3182e08e437cface44bf"}, + {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0c6da9aee33ff0b1a451e867da0c1f47408112b3391dd43133838339e410963"}, + {file = "frozenlist-1.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93ea75c050c5bb3d98016b4ba2497851eadf0ac154d88a67d7a6816206f6fa7f"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f61e2dc5ad442c52b4887f1fdc112f97caeff4d9e6ebe78879364ac59f1663e1"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa384489fefeb62321b238e64c07ef48398fe80f9e1e6afeff22e140e0850eef"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10ff5faaa22786315ef57097a279b833ecab1a0bfb07d604c9cbb1c4cdc2ed87"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:007df07a6e3eb3e33e9a1fe6a9db7af152bbd8a185f9aaa6ece10a3529e3e1c6"}, + {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f4f399d28478d1f604c2ff9119907af9726aed73680e5ed1ca634d377abb087"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c5374b80521d3d3f2ec5572e05adc94601985cc526fb276d0c8574a6d749f1b3"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ce31ae3e19f3c902de379cf1323d90c649425b86de7bbdf82871b8a2a0615f3d"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7211ef110a9194b6042449431e08c4d80c0481e5891e58d429df5899690511c2"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:556de4430ce324c836789fa4560ca62d1591d2538b8ceb0b4f68fb7b2384a27a"}, + {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7645a8e814a3ee34a89c4a372011dcd817964ce8cb273c8ed6119d706e9613e3"}, + {file = "frozenlist-1.4.0-cp39-cp39-win32.whl", hash = "sha256:19488c57c12d4e8095a922f328df3f179c820c212940a498623ed39160bc3c2f"}, + {file = "frozenlist-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:6221d84d463fb110bdd7619b69cb43878a11d51cbb9394ae3105d082d5199167"}, + {file = "frozenlist-1.4.0.tar.gz", hash = "sha256:09163bdf0b2907454042edb19f887c6d33806adc71fbd54afc14908bfdc22251"}, ] [[package]] @@ -1548,13 +1547,13 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.31" +version = "3.1.32" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" files = [ - {file = "GitPython-3.1.31-py3-none-any.whl", hash = "sha256:f04893614f6aa713a60cbbe1e6a97403ef633103cdd0ef5eb6efe0deb98dbe8d"}, - {file = "GitPython-3.1.31.tar.gz", hash = "sha256:8ce3bcf69adfdf7c7d503e78fd3b1c492af782d58893b650adb2ac8912ddd573"}, + {file = "GitPython-3.1.32-py3-none-any.whl", hash = "sha256:e3d59b1c2c6ebb9dfa7a184daf3b6dd4914237e7488a1730a6d8f6f5d0b4187f"}, + {file = "GitPython-3.1.32.tar.gz", hash = "sha256:8d9b8cb1e80b9735e8717c9362079d3ce4c6e5ddeebedd0361b228c3a67a62f6"}, ] [package.dependencies] @@ -1634,13 +1633,13 @@ socks = ["socksio (==1.*)"] [[package]] name = "huggingface-hub" -version = "0.15.1" +version = "0.16.4" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.7.0" files = [ - {file = "huggingface_hub-0.15.1-py3-none-any.whl", hash = "sha256:05b0fb0abbf1f625dfee864648ac3049fe225ac4371c7bafaca0c2d3a2f83445"}, - {file = "huggingface_hub-0.15.1.tar.gz", hash = "sha256:a61b7d1a7769fe10119e730277c72ab99d95c48d86a3d6da3e9f3d0f632a4081"}, + {file = "huggingface_hub-0.16.4-py3-none-any.whl", hash = "sha256:0d3df29932f334fead024afc7cb4cc5149d955238b8b5e42dcf9740d6995a349"}, + {file = "huggingface_hub-0.16.4.tar.gz", hash = "sha256:608c7d4f3d368b326d1747f91523dbd1f692871e8e2e7a4750314a2dd8b63e14"}, ] [package.dependencies] @@ -1653,15 +1652,16 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pytest", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pytest", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +inference = ["aiohttp", "pydantic"] quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "gradio", "jedi", "numpy", "pytest", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] torch = ["torch"] -typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] +typing = ["pydantic", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] [[package]] name = "hydra-core" @@ -1783,13 +1783,13 @@ files = [ [[package]] name = "joblib" -version = "1.2.0" +version = "1.3.1" description = "Lightweight pipelining with Python functions" optional = false python-versions = ">=3.7" files = [ - {file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"}, - {file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"}, + {file = "joblib-1.3.1-py3-none-any.whl", hash = "sha256:89cf0529520e01b3de7ac7b74a8102c90d16d54c64b5dd98cafcd14307fdf915"}, + {file = "joblib-1.3.1.tar.gz", hash = "sha256:1f937906df65329ba98013dc9692fe22a4c5e4a648112de500508b18a21b41e3"}, ] [[package]] @@ -2180,57 +2180,57 @@ PyYAML = ">=5.1.0" [[package]] name = "orjson" -version = "3.9.1" +version = "3.9.2" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" optional = false python-versions = ">=3.7" files = [ - {file = "orjson-3.9.1-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:c4434b7b786fdc394b95d029fb99949d7c2b05bbd4bf5cb5e3906be96ffeee3b"}, - {file = "orjson-3.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09faf14f74ed47e773fa56833be118e04aa534956f661eb491522970b7478e3b"}, - {file = "orjson-3.9.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:503eb86a8d53a187fe66aa80c69295a3ca35475804da89a9547e4fce5f803822"}, - {file = "orjson-3.9.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:20f2804b5a1dbd3609c086041bd243519224d47716efd7429db6c03ed28b7cc3"}, - {file = "orjson-3.9.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fd828e0656615a711c4cc4da70f3cac142e66a6703ba876c20156a14e28e3fa"}, - {file = "orjson-3.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec53d648176f873203b9c700a0abacab33ca1ab595066e9d616f98cdc56f4434"}, - {file = "orjson-3.9.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e186ae76b0d97c505500664193ddf508c13c1e675d9b25f1f4414a7606100da6"}, - {file = "orjson-3.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d4edee78503016f4df30aeede0d999b3cb11fb56f47e9db0e487bce0aaca9285"}, - {file = "orjson-3.9.1-cp310-none-win_amd64.whl", hash = "sha256:a4cc5d21e68af982d9a2528ac61e604f092c60eed27aef3324969c68f182ec7e"}, - {file = "orjson-3.9.1-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:761b6efd33c49de20dd73ce64cc59da62c0dab10aa6015f582680e0663cc792c"}, - {file = "orjson-3.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31229f9d0b8dc2ef7ee7e4393f2e4433a28e16582d4b25afbfccc9d68dc768f8"}, - {file = "orjson-3.9.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0b7ab18d55ecb1de543d452f0a5f8094b52282b916aa4097ac11a4c79f317b86"}, - {file = "orjson-3.9.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db774344c39041f4801c7dfe03483df9203cbd6c84e601a65908e5552228dd25"}, - {file = "orjson-3.9.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ae47ef8c0fe89c4677db7e9e1fb2093ca6e66c3acbee5442d84d74e727edad5e"}, - {file = "orjson-3.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:103952c21575b9805803c98add2eaecd005580a1e746292ed2ec0d76dd3b9746"}, - {file = "orjson-3.9.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2cb0121e6f2c9da3eddf049b99b95fef0adf8480ea7cb544ce858706cdf916eb"}, - {file = "orjson-3.9.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:24d4ddaa2876e657c0fd32902b5c451fd2afc35159d66a58da7837357044b8c2"}, - {file = "orjson-3.9.1-cp311-none-win_amd64.whl", hash = "sha256:0b53b5f72cf536dd8aa4fc4c95e7e09a7adb119f8ff8ee6cc60f735d7740ad6a"}, - {file = "orjson-3.9.1-cp37-cp37m-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d4b68d01a506242316a07f1d2f29fb0a8b36cee30a7c35076f1ef59dce0890c1"}, - {file = "orjson-3.9.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9dd4abe6c6fd352f00f4246d85228f6a9847d0cc14f4d54ee553718c225388f"}, - {file = "orjson-3.9.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9e20bca5e13041e31ceba7a09bf142e6d63c8a7467f5a9c974f8c13377c75af2"}, - {file = "orjson-3.9.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8ae0467d01eb1e4bcffef4486d964bfd1c2e608103e75f7074ed34be5df48cc"}, - {file = "orjson-3.9.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:06f6ab4697fab090517f295915318763a97a12ee8186054adf21c1e6f6abbd3d"}, - {file = "orjson-3.9.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8515867713301fa065c58ec4c9053ba1a22c35113ab4acad555317b8fd802e50"}, - {file = "orjson-3.9.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:393d0697d1dfa18d27d193e980c04fdfb672c87f7765b87952f550521e21b627"}, - {file = "orjson-3.9.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d96747662d3666f79119e5d28c124e7d356c7dc195cd4b09faea4031c9079dc9"}, - {file = "orjson-3.9.1-cp37-none-win_amd64.whl", hash = "sha256:6d173d3921dd58a068c88ec22baea7dbc87a137411501618b1292a9d6252318e"}, - {file = "orjson-3.9.1-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d1c2b0b4246c992ce2529fc610a446b945f1429445ece1c1f826a234c829a918"}, - {file = "orjson-3.9.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19f70ba1f441e1c4bb1a581f0baa092e8b3e3ce5b2aac2e1e090f0ac097966da"}, - {file = "orjson-3.9.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:375d65f002e686212aac42680aed044872c45ee4bc656cf63d4a215137a6124a"}, - {file = "orjson-3.9.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4751cee4a7b1daeacb90a7f5adf2170ccab893c3ab7c5cea58b45a13f89b30b3"}, - {file = "orjson-3.9.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:78d9a2a4b2302d5ebc3695498ebc305c3568e5ad4f3501eb30a6405a32d8af22"}, - {file = "orjson-3.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46b4facc32643b2689dfc292c0c463985dac4b6ab504799cf51fc3c6959ed668"}, - {file = "orjson-3.9.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ec7c8a0f1bf35da0d5fd14f8956f3b82a9a6918a3c6963d718dfd414d6d3b604"}, - {file = "orjson-3.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d3a40b0fbe06ccd4d6a99e523d20b47985655bcada8d1eba485b1b32a43e4904"}, - {file = "orjson-3.9.1-cp38-none-win_amd64.whl", hash = "sha256:402f9d3edfec4560a98880224ec10eba4c5f7b4791e4bc0d4f4d8df5faf2a006"}, - {file = "orjson-3.9.1-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:49c0d78dcd34626e2e934f1192d7c052b94e0ecadc5f386fd2bda6d2e03dadf5"}, - {file = "orjson-3.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:125f63e56d38393daa0a1a6dc6fedefca16c538614b66ea5997c3bd3af35ef26"}, - {file = "orjson-3.9.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:08927970365d2e1f3ce4894f9ff928a7b865d53f26768f1bbdd85dd4fee3e966"}, - {file = "orjson-3.9.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9a744e212d4780ecd67f4b6b128b2e727bee1df03e7059cddb2dfe1083e7dc4"}, - {file = "orjson-3.9.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d1dbf36db7240c61eec98c8d21545d671bce70be0730deb2c0d772e06b71af3"}, - {file = "orjson-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80a1e384626f76b66df615f7bb622a79a25c166d08c5d2151ffd41f24c4cc104"}, - {file = "orjson-3.9.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:15d28872fb055bf17ffca913826e618af61b2f689d2b170f72ecae1a86f80d52"}, - {file = "orjson-3.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1e4d905338f9ef32c67566929dfbfbb23cc80287af8a2c38930fb0eda3d40b76"}, - {file = "orjson-3.9.1-cp39-none-win_amd64.whl", hash = "sha256:48a27da6c7306965846565cc385611d03382bbd84120008653aa2f6741e2105d"}, - {file = "orjson-3.9.1.tar.gz", hash = "sha256:db373a25ec4a4fccf8186f9a72a1b3442837e40807a736a815ab42481e83b7d0"}, + {file = "orjson-3.9.2-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:7323e4ca8322b1ecb87562f1ec2491831c086d9faa9a6c6503f489dadbed37d7"}, + {file = "orjson-3.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1272688ea1865f711b01ba479dea2d53e037ea00892fd04196b5875f7021d9d3"}, + {file = "orjson-3.9.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0b9a26f1d1427a9101a1e8910f2e2df1f44d3d18ad5480ba031b15d5c1cb282e"}, + {file = "orjson-3.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6a5ca55b0d8f25f18b471e34abaee4b175924b6cd62f59992945b25963443141"}, + {file = "orjson-3.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:877872db2c0f41fbe21f852ff642ca842a43bc34895b70f71c9d575df31fffb4"}, + {file = "orjson-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a39c2529d75373b7167bf84c814ef9b8f3737a339c225ed6c0df40736df8748"}, + {file = "orjson-3.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:84ebd6fdf138eb0eb4280045442331ee71c0aab5e16397ba6645f32f911bfb37"}, + {file = "orjson-3.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5a60a1cfcfe310547a1946506dd4f1ed0a7d5bd5b02c8697d9d5dcd8d2e9245e"}, + {file = "orjson-3.9.2-cp310-none-win_amd64.whl", hash = "sha256:c290c4f81e8fd0c1683638802c11610b2f722b540f8e5e858b6914b495cf90c8"}, + {file = "orjson-3.9.2-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:02ef014f9a605e84b675060785e37ec9c0d2347a04f1307a9d6840ab8ecd6f55"}, + {file = "orjson-3.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:992af54265ada1c1579500d6594ed73fe333e726de70d64919cf37f93defdd06"}, + {file = "orjson-3.9.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a40958f7af7c6d992ee67b2da4098dca8b770fc3b4b3834d540477788bfa76d3"}, + {file = "orjson-3.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:93864dec3e3dd058a2dbe488d11ac0345214a6a12697f53a63e34de7d28d4257"}, + {file = "orjson-3.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16fdf5a82df80c544c3c91516ab3882cd1ac4f1f84eefeafa642e05cef5f6699"}, + {file = "orjson-3.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:275b5a18fd9ed60b2720543d3ddac170051c43d680e47d04ff5203d2c6d8ebf1"}, + {file = "orjson-3.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b9aea6dcb99fcbc9f6d1dd84fca92322fda261da7fb014514bb4689c7c2097a8"}, + {file = "orjson-3.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7d74ae0e101d17c22ef67b741ba356ab896fc0fa64b301c2bf2bb0a4d874b190"}, + {file = "orjson-3.9.2-cp311-none-win_amd64.whl", hash = "sha256:6320b28e7bdb58c3a3a5efffe04b9edad3318d82409e84670a9b24e8035a249d"}, + {file = "orjson-3.9.2-cp37-cp37m-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:368e9cc91ecb7ac21f2aa475e1901204110cf3e714e98649c2502227d248f947"}, + {file = "orjson-3.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58e9e70f0dcd6a802c35887f306b555ff7a214840aad7de24901fc8bd9cf5dde"}, + {file = "orjson-3.9.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:00c983896c2e01c94c0ef72fd7373b2aa06d0c0eed0342c4884559f812a6835b"}, + {file = "orjson-3.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ee743e8890b16c87a2f89733f983370672272b61ee77429c0a5899b2c98c1a7"}, + {file = "orjson-3.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7b065942d362aad4818ff599d2f104c35a565c2cbcbab8c09ec49edba91da75"}, + {file = "orjson-3.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e46e9c5b404bb9e41d5555762fd410d5466b7eb1ec170ad1b1609cbebe71df21"}, + {file = "orjson-3.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:8170157288714678ffd64f5de33039e1164a73fd8b6be40a8a273f80093f5c4f"}, + {file = "orjson-3.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e3e2f087161947dafe8319ea2cfcb9cea4bb9d2172ecc60ac3c9738f72ef2909"}, + {file = "orjson-3.9.2-cp37-none-win_amd64.whl", hash = "sha256:d7de3dbbe74109ae598692113cec327fd30c5a30ebca819b21dfa4052f7b08ef"}, + {file = "orjson-3.9.2-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8cd4385c59bbc1433cad4a80aca65d2d9039646a9c57f8084897549b55913b17"}, + {file = "orjson-3.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a74036aab1a80c361039290cdbc51aa7adc7ea13f56e5ef94e9be536abd227bd"}, + {file = "orjson-3.9.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1aaa46d7d4ae55335f635eadc9be0bd9bcf742e6757209fc6dc697e390010adc"}, + {file = "orjson-3.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e52c67ed6bb368083aa2078ea3ccbd9721920b93d4b06c43eb4e20c4c860046"}, + {file = "orjson-3.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a6cdfcf9c7dd4026b2b01fdff56986251dc0cc1e980c690c79eec3ae07b36e7"}, + {file = "orjson-3.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1882a70bb69595b9ec5aac0040a819e94d2833fe54901e2b32f5e734bc259a8b"}, + {file = "orjson-3.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fc05e060d452145ab3c0b5420769e7356050ea311fc03cb9d79c481982917cca"}, + {file = "orjson-3.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f8bc2c40d9bb26efefb10949d261a47ca196772c308babc538dd9f4b73e8d386"}, + {file = "orjson-3.9.2-cp38-none-win_amd64.whl", hash = "sha256:3164fc20a585ec30a9aff33ad5de3b20ce85702b2b2a456852c413e3f0d7ab09"}, + {file = "orjson-3.9.2-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:7a6ccadf788531595ed4728aa746bc271955448d2460ff0ef8e21eb3f2a281ba"}, + {file = "orjson-3.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3245d230370f571c945f69aab823c279a868dc877352817e22e551de155cb06c"}, + {file = "orjson-3.9.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:205925b179550a4ee39b8418dd4c94ad6b777d165d7d22614771c771d44f57bd"}, + {file = "orjson-3.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0325fe2d69512187761f7368c8cda1959bcb75fc56b8e7a884e9569112320e57"}, + {file = "orjson-3.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:806704cd58708acc66a064a9a58e3be25cf1c3f9f159e8757bd3f515bfabdfa1"}, + {file = "orjson-3.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03fb36f187a0c19ff38f6289418863df8b9b7880cdbe279e920bef3a09d8dab1"}, + {file = "orjson-3.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:20925d07a97c49c6305bff1635318d9fc1804aa4ccacb5fb0deb8a910e57d97a"}, + {file = "orjson-3.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:eebfed53bec5674e981ebe8ed2cf00b3f7bcda62d634733ff779c264307ea505"}, + {file = "orjson-3.9.2-cp39-none-win_amd64.whl", hash = "sha256:869b961df5fcedf6c79f4096119b35679b63272362e9b745e668f0391a892d39"}, + {file = "orjson-3.9.2.tar.gz", hash = "sha256:24257c8f641979bf25ecd3e27251b5cc194cdd3a6e96004aac8446f5e63d9664"}, ] [[package]] @@ -2314,13 +2314,13 @@ files = [ [[package]] name = "platformdirs" -version = "3.6.0" +version = "3.9.1" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false python-versions = ">=3.7" files = [ - {file = "platformdirs-3.6.0-py3-none-any.whl", hash = "sha256:ffa199e3fbab8365778c4a10e1fbf1b9cd50707de826eb304b50e57ec0cc8d38"}, - {file = "platformdirs-3.6.0.tar.gz", hash = "sha256:57e28820ca8094678b807ff529196506d7a21e17156cb1cddb3e74cebce54640"}, + {file = "platformdirs-3.9.1-py3-none-any.whl", hash = "sha256:ad8291ae0ae5072f66c16945166cb11c63394c7a3ad1b1bc9828ca3162da8c2f"}, + {file = "platformdirs-3.9.1.tar.gz", hash = "sha256:1b42b450ad933e981d56e59f1b97495428c9bd60698baab9f3eb3d00d5822421"}, ] [package.extras] @@ -2329,13 +2329,13 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest- [[package]] name = "pluggy" -version = "1.1.0" +version = "1.2.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.7" files = [ - {file = "pluggy-1.1.0-py3-none-any.whl", hash = "sha256:d81d19a3a88d82ed06998353ce5d5c02587ef07ee2d808ae63904ab0ccef0087"}, - {file = "pluggy-1.1.0.tar.gz", hash = "sha256:c500b592c5512df35622e4faf2135aa0b7e989c7d31344194b4afb9d5e47b1bf"}, + {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"}, + {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, ] [package.extras] @@ -2362,13 +2362,13 @@ virtualenv = ">=20.10.0" [[package]] name = "prompt-toolkit" -version = "3.0.38" +version = "3.0.39" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.7.0" files = [ - {file = "prompt_toolkit-3.0.38-py3-none-any.whl", hash = "sha256:45ea77a2f7c60418850331366c81cf6b5b9cf4c7fd34616f733c5427e6abbb1f"}, - {file = "prompt_toolkit-3.0.38.tar.gz", hash = "sha256:23ac5d50538a9a38c8bde05fecb47d0b403ecd0662857a86f886f798563d5b9b"}, + {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"}, + {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"}, ] [package.dependencies] @@ -2376,24 +2376,24 @@ wcwidth = "*" [[package]] name = "protobuf" -version = "4.23.3" +version = "4.23.4" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "protobuf-4.23.3-cp310-abi3-win32.whl", hash = "sha256:514b6bbd54a41ca50c86dd5ad6488afe9505901b3557c5e0f7823a0cf67106fb"}, - {file = "protobuf-4.23.3-cp310-abi3-win_amd64.whl", hash = "sha256:cc14358a8742c4e06b1bfe4be1afbdf5c9f6bd094dff3e14edb78a1513893ff5"}, - {file = "protobuf-4.23.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:2991f5e7690dab569f8f81702e6700e7364cc3b5e572725098215d3da5ccc6ac"}, - {file = "protobuf-4.23.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:08fe19d267608d438aa37019236db02b306e33f6b9902c3163838b8e75970223"}, - {file = "protobuf-4.23.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:3b01a5274ac920feb75d0b372d901524f7e3ad39c63b1a2d55043f3887afe0c1"}, - {file = "protobuf-4.23.3-cp37-cp37m-win32.whl", hash = "sha256:aca6e86a08c5c5962f55eac9b5bd6fce6ed98645d77e8bfc2b952ecd4a8e4f6a"}, - {file = "protobuf-4.23.3-cp37-cp37m-win_amd64.whl", hash = "sha256:0149053336a466e3e0b040e54d0b615fc71de86da66791c592cc3c8d18150bf8"}, - {file = "protobuf-4.23.3-cp38-cp38-win32.whl", hash = "sha256:84ea0bd90c2fdd70ddd9f3d3fc0197cc24ecec1345856c2b5ba70e4d99815359"}, - {file = "protobuf-4.23.3-cp38-cp38-win_amd64.whl", hash = "sha256:3bcbeb2bf4bb61fe960dd6e005801a23a43578200ea8ceb726d1f6bd0e562ba1"}, - {file = "protobuf-4.23.3-cp39-cp39-win32.whl", hash = "sha256:5cb9e41188737f321f4fce9a4337bf40a5414b8d03227e1d9fbc59bc3a216e35"}, - {file = "protobuf-4.23.3-cp39-cp39-win_amd64.whl", hash = "sha256:29660574cd769f2324a57fb78127cda59327eb6664381ecfe1c69731b83e8288"}, - {file = "protobuf-4.23.3-py3-none-any.whl", hash = "sha256:447b9786ac8e50ae72cae7a2eec5c5df6a9dbf9aa6f908f1b8bda6032644ea62"}, - {file = "protobuf-4.23.3.tar.gz", hash = "sha256:7a92beb30600332a52cdadbedb40d33fd7c8a0d7f549c440347bc606fb3fe34b"}, + {file = "protobuf-4.23.4-cp310-abi3-win32.whl", hash = "sha256:5fea3c64d41ea5ecf5697b83e41d09b9589e6f20b677ab3c48e5f242d9b7897b"}, + {file = "protobuf-4.23.4-cp310-abi3-win_amd64.whl", hash = "sha256:7b19b6266d92ca6a2a87effa88ecc4af73ebc5cfde194dc737cf8ef23a9a3b12"}, + {file = "protobuf-4.23.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:8547bf44fe8cec3c69e3042f5c4fb3e36eb2a7a013bb0a44c018fc1e427aafbd"}, + {file = "protobuf-4.23.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a"}, + {file = "protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597"}, + {file = "protobuf-4.23.4-cp37-cp37m-win32.whl", hash = "sha256:c3e0939433c40796ca4cfc0fac08af50b00eb66a40bbbc5dee711998fb0bbc1e"}, + {file = "protobuf-4.23.4-cp37-cp37m-win_amd64.whl", hash = "sha256:9053df6df8e5a76c84339ee4a9f5a2661ceee4a0dab019e8663c50ba324208b0"}, + {file = "protobuf-4.23.4-cp38-cp38-win32.whl", hash = "sha256:e1c915778d8ced71e26fcf43c0866d7499891bca14c4368448a82edc61fdbc70"}, + {file = "protobuf-4.23.4-cp38-cp38-win_amd64.whl", hash = "sha256:351cc90f7d10839c480aeb9b870a211e322bf05f6ab3f55fcb2f51331f80a7d2"}, + {file = "protobuf-4.23.4-cp39-cp39-win32.whl", hash = "sha256:6dd9b9940e3f17077e820b75851126615ee38643c2c5332aa7a359988820c720"}, + {file = "protobuf-4.23.4-cp39-cp39-win_amd64.whl", hash = "sha256:0a5759f5696895de8cc913f084e27fd4125e8fb0914bb729a17816a33819f474"}, + {file = "protobuf-4.23.4-py3-none-any.whl", hash = "sha256:e9d0be5bf34b275b9f87ba7407796556abeeba635455d036c7351f7c183ef8ff"}, + {file = "protobuf-4.23.4.tar.gz", hash = "sha256:ccd9430c0719dce806b93f89c91de7977304729e55377f872a92465d548329a9"}, ] [[package]] @@ -2472,55 +2472,135 @@ files = [ [[package]] name = "pydantic" -version = "1.10.9" -description = "Data validation and settings management using python type hints" +version = "2.0.3" +description = "Data validation using Python type hints" optional = false python-versions = ">=3.7" files = [ - {file = "pydantic-1.10.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e692dec4a40bfb40ca530e07805b1208c1de071a18d26af4a2a0d79015b352ca"}, - {file = "pydantic-1.10.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3c52eb595db83e189419bf337b59154bdcca642ee4b2a09e5d7797e41ace783f"}, - {file = "pydantic-1.10.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:939328fd539b8d0edf244327398a667b6b140afd3bf7e347cf9813c736211896"}, - {file = "pydantic-1.10.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b48d3d634bca23b172f47f2335c617d3fcb4b3ba18481c96b7943a4c634f5c8d"}, - {file = "pydantic-1.10.9-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f0b7628fb8efe60fe66fd4adadd7ad2304014770cdc1f4934db41fe46cc8825f"}, - {file = "pydantic-1.10.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e1aa5c2410769ca28aa9a7841b80d9d9a1c5f223928ca8bec7e7c9a34d26b1d4"}, - {file = "pydantic-1.10.9-cp310-cp310-win_amd64.whl", hash = "sha256:eec39224b2b2e861259d6f3c8b6290d4e0fbdce147adb797484a42278a1a486f"}, - {file = "pydantic-1.10.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d111a21bbbfd85c17248130deac02bbd9b5e20b303338e0dbe0faa78330e37e0"}, - {file = "pydantic-1.10.9-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e9aec8627a1a6823fc62fb96480abe3eb10168fd0d859ee3d3b395105ae19a7"}, - {file = "pydantic-1.10.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07293ab08e7b4d3c9d7de4949a0ea571f11e4557d19ea24dd3ae0c524c0c334d"}, - {file = "pydantic-1.10.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ee829b86ce984261d99ff2fd6e88f2230068d96c2a582f29583ed602ef3fc2c"}, - {file = "pydantic-1.10.9-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4b466a23009ff5cdd7076eb56aca537c745ca491293cc38e72bf1e0e00de5b91"}, - {file = "pydantic-1.10.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7847ca62e581e6088d9000f3c497267868ca2fa89432714e21a4fb33a04d52e8"}, - {file = "pydantic-1.10.9-cp311-cp311-win_amd64.whl", hash = "sha256:7845b31959468bc5b78d7b95ec52fe5be32b55d0d09983a877cca6aedc51068f"}, - {file = "pydantic-1.10.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:517a681919bf880ce1dac7e5bc0c3af1e58ba118fd774da2ffcd93c5f96eaece"}, - {file = "pydantic-1.10.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67195274fd27780f15c4c372f4ba9a5c02dad6d50647b917b6a92bf00b3d301a"}, - {file = "pydantic-1.10.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2196c06484da2b3fded1ab6dbe182bdabeb09f6318b7fdc412609ee2b564c49a"}, - {file = "pydantic-1.10.9-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:6257bb45ad78abacda13f15bde5886efd6bf549dd71085e64b8dcf9919c38b60"}, - {file = "pydantic-1.10.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3283b574b01e8dbc982080d8287c968489d25329a463b29a90d4157de4f2baaf"}, - {file = "pydantic-1.10.9-cp37-cp37m-win_amd64.whl", hash = "sha256:5f8bbaf4013b9a50e8100333cc4e3fa2f81214033e05ac5aa44fa24a98670a29"}, - {file = "pydantic-1.10.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9cd67fb763248cbe38f0593cd8611bfe4b8ad82acb3bdf2b0898c23415a1f82"}, - {file = "pydantic-1.10.9-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f50e1764ce9353be67267e7fd0da08349397c7db17a562ad036aa7c8f4adfdb6"}, - {file = "pydantic-1.10.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73ef93e5e1d3c8e83f1ff2e7fdd026d9e063c7e089394869a6e2985696693766"}, - {file = "pydantic-1.10.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:128d9453d92e6e81e881dd7e2484e08d8b164da5507f62d06ceecf84bf2e21d3"}, - {file = "pydantic-1.10.9-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ad428e92ab68798d9326bb3e5515bc927444a3d71a93b4a2ca02a8a5d795c572"}, - {file = "pydantic-1.10.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fab81a92f42d6d525dd47ced310b0c3e10c416bbfae5d59523e63ea22f82b31e"}, - {file = "pydantic-1.10.9-cp38-cp38-win_amd64.whl", hash = "sha256:963671eda0b6ba6926d8fc759e3e10335e1dc1b71ff2a43ed2efd6996634dafb"}, - {file = "pydantic-1.10.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:970b1bdc6243ef663ba5c7e36ac9ab1f2bfecb8ad297c9824b542d41a750b298"}, - {file = "pydantic-1.10.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7e1d5290044f620f80cf1c969c542a5468f3656de47b41aa78100c5baa2b8276"}, - {file = "pydantic-1.10.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83fcff3c7df7adff880622a98022626f4f6dbce6639a88a15a3ce0f96466cb60"}, - {file = "pydantic-1.10.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0da48717dc9495d3a8f215e0d012599db6b8092db02acac5e0d58a65248ec5bc"}, - {file = "pydantic-1.10.9-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:0a2aabdc73c2a5960e87c3ffebca6ccde88665616d1fd6d3db3178ef427b267a"}, - {file = "pydantic-1.10.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9863b9420d99dfa9c064042304868e8ba08e89081428a1c471858aa2af6f57c4"}, - {file = "pydantic-1.10.9-cp39-cp39-win_amd64.whl", hash = "sha256:e7c9900b43ac14110efa977be3da28931ffc74c27e96ee89fbcaaf0b0fe338e1"}, - {file = "pydantic-1.10.9-py3-none-any.whl", hash = "sha256:6cafde02f6699ce4ff643417d1a9223716ec25e228ddc3b436fe7e2d25a1f305"}, - {file = "pydantic-1.10.9.tar.gz", hash = "sha256:95c70da2cd3b6ddf3b9645ecaa8d98f3d80c606624b6d245558d202cd23ea3be"}, + {file = "pydantic-2.0.3-py3-none-any.whl", hash = "sha256:614eb3321eb600c81899a88fa9858b008e3c79e0d4f1b49ab1f516b4b0c27cfb"}, + {file = "pydantic-2.0.3.tar.gz", hash = "sha256:94f13e0dcf139a5125e88283fc999788d894e14ed90cf478bcc2ee50bd4fc630"}, ] [package.dependencies] -typing-extensions = ">=4.2.0" +annotated-types = ">=0.4.0" +pydantic-core = "2.3.0" +typing-extensions = ">=4.6.1" [package.extras] -dotenv = ["python-dotenv (>=0.10.4)"] -email = ["email-validator (>=1.0.3)"] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.3.0" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic_core-2.3.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:4542c98b8364b976593703a2dda97377433b102f380b61bc3a2cbc2fbdae1d1f"}, + {file = "pydantic_core-2.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9342de50824b40f55d2600f66c6f9a91a3a24851eca39145a749a3dc804ee599"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:539432f911686cb80284c30b33eaf9f4fd9a11e1111fe0dc98fdbdce69b49821"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38a0e7ee65c8999394d92d9c724434cb629279d19844f2b69d9bbc46dc8b8b61"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_armv7l.whl", hash = "sha256:e3ed6834cc005798187a56c248a2240207cb8ffdda1c89e9afda4c3d526c2ea0"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_ppc64le.whl", hash = "sha256:e72ac299a6bf732a60852d052acf3999d234686755a02ba111e85e7ebf8155b1"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_s390x.whl", hash = "sha256:616b3451b05ca63b8f433c627f68046b39543faeaa4e50d8c6699a2a1e4b85a5"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:adcb9c8848e15c613e483e0b99767ae325af27fe0dbd866df01fe5849d06e6e1"}, + {file = "pydantic_core-2.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:464bf799b422be662e5e562e62beeffc9eaa907d381a9d63a2556615bbda286d"}, + {file = "pydantic_core-2.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4638ebc17de08c2f3acba557efeb6f195c88b7299d8c55c0bb4e20638bbd4d03"}, + {file = "pydantic_core-2.3.0-cp310-none-win32.whl", hash = "sha256:9ff322c7e1030543d35d83bb521b69114d3d150750528d7757544f639def9ad6"}, + {file = "pydantic_core-2.3.0-cp310-none-win_amd64.whl", hash = "sha256:4824eb018f0a4680b1e434697a9bf3f41c7799b80076d06530cbbd212e040ccc"}, + {file = "pydantic_core-2.3.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:0aa429578e23885b3984c49d687cd05ab06f0b908ea1711a8bf7e503b7f97160"}, + {file = "pydantic_core-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20d710c1f79af930b8891bcebd84096798e4387ab64023ef41521d58f21277d3"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:309f45d4d7481d6f09cb9e35c72caa0e50add4a30bb08c04c5fe5956a0158633"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bcfb7be905aa849bd882262e1df3f75b564e2f708b4b4c7ad2d3deaf5410562"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_armv7l.whl", hash = "sha256:85cd9c0af34e371390e3cb2f3a470b0b40cc07568c1e966c638c49062be6352d"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_ppc64le.whl", hash = "sha256:37c5028cebdf731298724070838fb3a71ef1fbd201d193d311ac2cbdbca25a23"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_s390x.whl", hash = "sha256:e4208f23f12d0ad206a07a489ef4cb15722c10b62774c4460ee4123250be938e"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c24465dd11b65c8510f251b095fc788c7c91481c81840112fe3f76c30793a455"}, + {file = "pydantic_core-2.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3cd7ee8bbfab277ab56e272221886fd33a1b5943fbf45ae9195aa6a48715a8a0"}, + {file = "pydantic_core-2.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0fc7e0b056b66cc536e97ef60f48b3b289f6b3b62ac225afd4b22a42434617bf"}, + {file = "pydantic_core-2.3.0-cp311-none-win32.whl", hash = "sha256:4788135db4bd83a5edc3522b11544b013be7d25b74b155e08dd3b20cd6663bbb"}, + {file = "pydantic_core-2.3.0-cp311-none-win_amd64.whl", hash = "sha256:f93c867e5e85584a28c6a6feb6f2086d717266eb5d1210d096dd717b7f4dec04"}, + {file = "pydantic_core-2.3.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:73f62bb7fd862d9bcd886e10612bade6fe042eda8b47e8c129892bcfb7b45e84"}, + {file = "pydantic_core-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4d889d498fce64bfcd8adf1a78579a7f626f825cbeb2956a24a29b35f9a1df32"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d55e38a89ec2ae17b2fa7ffeda6b70f63afab1888bd0d57aaa7b7879760acb4"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1aefebb506bc1fe355d91d25f12bcdea7f4d7c2d9f0f6716dd025543777c99a5"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_armv7l.whl", hash = "sha256:6441a29f42585f085db0c04cd0557d4cbbb46fa68a0972409b1cfe9f430280c1"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_ppc64le.whl", hash = "sha256:47e8f034be31390a8f525431eb5e803a78ce7e2e11b32abf5361a972e14e6b61"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_s390x.whl", hash = "sha256:ad814864aba263be9c83ada44a95f72d10caabbf91589321f95c29c902bdcff0"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9eff3837d447fccf2ac38c259b14ab9cbde700df355a45a1f3ff244d5e78f8b6"}, + {file = "pydantic_core-2.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:534f3f63c000f08050c6f7f4378bf2b52d7ba9214e9d35e3f60f7ad24a4d6425"}, + {file = "pydantic_core-2.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ef6a222d54f742c24f6b143aab088702db3a827b224e75b9dd28b38597c595fe"}, + {file = "pydantic_core-2.3.0-cp312-none-win32.whl", hash = "sha256:4e26944e64ecc1d7b19db954c0f7b471f3b141ec8e1a9f57cfe27671525cd248"}, + {file = "pydantic_core-2.3.0-cp312-none-win_amd64.whl", hash = "sha256:019c5c41941438570dfc7d3f0ae389b2425add1775a357ce1e83ed1434f943d6"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:27c1bbfb9d84a75cf33b7f19b53c29eb7ead99b235fce52aced5507174ab8f98"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:7cb496e934b71f1ade844ab91d6ccac78a3520e5df02fdb2357f85a71e541e69"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5af2d43b1978958d91351afbcc9b4d0cfe144c46c61740e82aaac8bb39ab1a4d"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3097c39d7d4e8dba2ef86de171dcccad876c36d8379415ba18a5a4d0533510"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_armv7l.whl", hash = "sha256:dd3b023f3317dbbbc775e43651ce1a31a9cea46216ad0b5be37afc18a2007699"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_ppc64le.whl", hash = "sha256:27babb9879bf2c45ed655d02639f4c30e2b9ef1b71ce59c2305bbf7287910a18"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_s390x.whl", hash = "sha256:2183a9e18cdc0de53bdaa1675f237259162abeb62d6ac9e527c359c1074dc55d"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c089d8e7f1b4db08b2f8e4107304eec338df046275dad432635a9be9531e2fc8"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2f10aa5452b865818dd0137f568d443f5e93b60a27080a01aa4b7512c7ba13a3"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:f642313d559f9d9a00c4de6820124059cc3342a0d0127b18301de2c680d5ea40"}, + {file = "pydantic_core-2.3.0-cp37-none-win32.whl", hash = "sha256:45327fc57afbe3f2c3d7f54a335d5cecee8a9fdb3906a2fbed8af4092f4926df"}, + {file = "pydantic_core-2.3.0-cp37-none-win_amd64.whl", hash = "sha256:e427b66596a6441a5607dfc0085b47d36073f88da7ac48afd284263b9b99e6ce"}, + {file = "pydantic_core-2.3.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:0b3d781c71b8bfb621ef23b9c874933e2cd33237c1a65cc20eeb37437f8e7e18"}, + {file = "pydantic_core-2.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ad46027dbd5c1db87dc0b49becbe23093b143a20302028d387dae37ee5ef95f5"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39aa09ed7ce2a648c904f79032d16dda29e6913112af8465a7bf710eef23c7ca"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05b4bf8c58409586a7a04c858a86ab10f28c6c1a7c33da65e0326c59d5b0ab16"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_armv7l.whl", hash = "sha256:ba2b807d2b62c446120906b8580cddae1d76d3de4efbb95ccc87f5e35c75b4b2"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_ppc64le.whl", hash = "sha256:ea955e4ed21f4bbb9b83fea09fc6af0bed82e69ecf6b35ec89237a0a49633033"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_s390x.whl", hash = "sha256:06884c07956526ac9ebfef40fe21a11605569b8fc0e2054a375fb39c978bf48f"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f868e731a18b403b88aa434d960489ceeed0ddeb44ebc02389540731a67705e0"}, + {file = "pydantic_core-2.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cb08fab0fc1db15c277b72e33ac74ad9c0c789413da8984a3eacb22a94b42ef4"}, + {file = "pydantic_core-2.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6ca34c29fbd6592de5fd39e80c1993634d704c4e7e14ba54c87b2c7c53da68fe"}, + {file = "pydantic_core-2.3.0-cp38-none-win32.whl", hash = "sha256:cd782807d35c8a41aaa7d30b5107784420eefd9fdc1c760d86007d43ae00b15d"}, + {file = "pydantic_core-2.3.0-cp38-none-win_amd64.whl", hash = "sha256:01f56d5ee70b1d39c0fd08372cc5142274070ab7181d17c86035f130eebc05b8"}, + {file = "pydantic_core-2.3.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:78b1ac0151271ce62bc2b33755f1043eda6a310373143a2f27e2bcd3d5fc8633"}, + {file = "pydantic_core-2.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:64bfd2c35a2c350f73ac52dc134d8775f93359c4c969280a6fe5301b5b6e7431"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:937c0fe9538f1212b62df6a68f8d78df3572fe3682d9a0dd8851eac8a4e46063"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d965c7c4b40d1cedec9188782e98bd576f9a04868835604200c3a6e817b824f"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_armv7l.whl", hash = "sha256:ad442b8585ed4a3c2d22e4bf7b465d9b7d281e055b09719a8aeb5b576422dc9b"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_ppc64le.whl", hash = "sha256:4bf20c9722821fce766e685718e739deeccc60d6bc7be5029281db41f999ee0c"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_s390x.whl", hash = "sha256:f3dd5333049b5b3faa739e0f40b77cc8b7a1aded2f2da0e28794c81586d7b08a"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dc5f516b24d24bc9e8dd9305460899f38302b3c4f9752663b396ef9848557bf"}, + {file = "pydantic_core-2.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:055f7ea6b1fbb37880d66d70eefd22dd319b09c79d2cb99b1dbfeb34b653b0b2"}, + {file = "pydantic_core-2.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:af693a89db6d6ac97dd84dd7769b3f2bd9007b578127d0e7dda03053f4d3b34b"}, + {file = "pydantic_core-2.3.0-cp39-none-win32.whl", hash = "sha256:f60e31e3e15e8c294bf70c60f8ae4d0c3caf3af8f26466e9aa8ea4c01302749b"}, + {file = "pydantic_core-2.3.0-cp39-none-win_amd64.whl", hash = "sha256:2b79f3681481f4424d7845cc7a261d5a4baa810d656b631fa844dc9967b36a7b"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:a666134b41712e30a71afaa26deeb4da374179f769fa49784cdf0e7698880fab"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c119e9227487ad3d7c3c737d896afe548a6be554091f9745da1f4b489c40561"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73929a2fb600a2333fce2efd92596cff5e6bf8946e20e93c067b220760064862"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:41bbc2678a5b6a19371b2cb51f30ccea71f0c14b26477d2d884fed761cea42c7"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dcbff997f47d45bf028bda4c3036bb3101e89a3df271281d392b6175f71c71d1"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:afa8808159169368b66e4fbeafac6c6fd8f26246dc4d0dcc2caf94bd9cf1b828"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:12be3b5f54f8111ca38e6b7277f26c23ba5cb3344fae06f879a0a93dfc8b479e"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ed5babdcd3d052ba5cf8832561f18df20778c7ccf12587b2d82f7bf3bf259a0e"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d642e5c029e2acfacf6aa0a7a3e822086b3b777c70d364742561f9ca64c1ffc"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ba3073eb38a1294e8c7902989fb80a7a147a69db2396818722bd078476586a0"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5146a6749b1905e04e62e0ad4622f079e5582f8b3abef5fb64516c623127908"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:deeb64335f489c3c11949cbd1d1668b3f1fb2d1c6a5bf40e126ef7bf95f9fa40"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:31acc37288b8e69e4849f618c3d5cf13b58077c1a1ff9ade0b3065ba974cd385"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:e09d9f6d722de9d4c1c5f122ea9bc6b25a05f975457805af4dcab7b0128aacbf"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ba6a8cf089222a171b8f84e6ec2d10f7a9d14f26be3a347b14775a8741810676"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef1fd1b24e9bcddcb168437686677104e205c8e25b066e73ffdf331d3bb8792b"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eda1a89c4526826c0a87d33596a4cd15b8f58e9250f503e39af1699ba9c878e8"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3e9a18401a28db4358da2e191508702dbf065f2664c710708cdf9552b9fa50c"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:a439fd0d45d51245bbde799726adda5bd18aed3fa2b01ab2e6a64d6d13776fa3"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:bf6a1d2c920cc9528e884850a4b2ee7629e3d362d5c44c66526d4097bbb07a1a"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e33fcbea3b63a339dd94de0fc442fefacfe681cc7027ce63f67af9f7ceec7422"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:bf3ed993bdf4754909f175ff348cf8f78d4451215b8aa338633f149ca3b1f37a"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7584171eb3115acd4aba699bc836634783f5bd5aab131e88d8eeb8a3328a4a72"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1624baa76d1740711b2048f302ae9a6d73d277c55a8c3e88b53b773ebf73a971"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:06f33f695527f5a86e090f208978f9fd252c9cfc7e869d3b679bd71f7cb2c1fa"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7ecf0a67b212900e92f328181fed02840d74ed39553cdb38d27314e2b9c89dfa"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:45fa1e8ad6f4367ad73674ca560da8e827cc890eaf371f3ee063d6d7366a207b"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8d0dbcc57839831ae79fd24b1b83d42bc9448d79feaf3ed3fb5cbf94ffbf3eb7"}, + {file = "pydantic_core-2.3.0.tar.gz", hash = "sha256:5cfb5ac4e82c47d5dc25b209dd4c3989e284b80109f9e08b33c895080c424b4f"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pydot" @@ -2538,42 +2618,42 @@ pyparsing = ">=2.1.4" [[package]] name = "pygit2" -version = "1.12.1" +version = "1.12.2" description = "Python bindings for libgit2." optional = false python-versions = ">=3.8" files = [ - {file = "pygit2-1.12.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:50a155528aa611e4a217be31a9d2d8da283cfd978dbba07494cd04ea3d7c8768"}, - {file = "pygit2-1.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:248e22ccb1ea31f569373a3da3fa73d110ba2585c6326ff74b03c9579fb7b913"}, - {file = "pygit2-1.12.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e575e672c5a6cb39234b0076423a560e016d6b88cd50947c2df3bf59c5ccdf3d"}, - {file = "pygit2-1.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad9b46b52997d131b31ff46f699b074e9745c8fea8d0efb6b72ace43ab25828c"}, - {file = "pygit2-1.12.1-cp310-cp310-win32.whl", hash = "sha256:a8f495df877da04c572ecec4d532ae195680b4781dbf229bab4e801fa9ef20e9"}, - {file = "pygit2-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f1e1355c7fe2938a2bca0d6204a00c02950d13008722879e54a335b3e874006"}, - {file = "pygit2-1.12.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8a5c56b0b5dc8a317561070ef7557e180d4937d8b115c5a762d85e0109a216f3"}, - {file = "pygit2-1.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b7c9ca8bc8a722863fc873234748fef3422007d5a6ea90ba3ae338d2907d3d6e"}, - {file = "pygit2-1.12.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71c02a11f10bc4e329ab941f0c70874d39053c8f78544aefeb506f04cedb621a"}, - {file = "pygit2-1.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b3af334adf325b7c973417efa220fd5a9ce946b936262eceabc8ad8d46e0310"}, - {file = "pygit2-1.12.1-cp311-cp311-win32.whl", hash = "sha256:86c393962d1341893bbfa91829b3b8545e8ac7622f8b53b9a0b835b9cc1b5198"}, - {file = "pygit2-1.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:86c7e75ddc76f4e5593b47f9c2074fff242322ed9f4126116749f7c86021520a"}, - {file = "pygit2-1.12.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:939d11677f434024ea25a9137d8a525ef9f9ac474fb8b86399bc9526e6a7bff5"}, - {file = "pygit2-1.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:946f9215c0442995042ea512f764f7a6638d3a09f9d0484d3aeedbf8833f89e6"}, - {file = "pygit2-1.12.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd574620d3cc80df0b23bf2b7b08d8726e75a338d0fa1b67e4d6738d3ee56635"}, - {file = "pygit2-1.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24d0adeff5c43229913f3bdae71c36e77ed19f36bd8dd6b5c141820964b1f5b3"}, - {file = "pygit2-1.12.1-cp38-cp38-win32.whl", hash = "sha256:ed8e2ef97171e994bf4d46c6c6534a3c12dd2dbbc47741e5995eaf8c2c92f71c"}, - {file = "pygit2-1.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:5318817055a3ca3906bf88344b9a6dc70c640f9b6bc236ac9e767d12bad54361"}, - {file = "pygit2-1.12.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cb9c803151ffeb0b8de52a93381108a2c6a9a446c55d659a135f52645e1650eb"}, - {file = "pygit2-1.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:47bf1e196dc23fe38018ad49b021d425edc319328169c597df45d73cf46b62ef"}, - {file = "pygit2-1.12.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:948479df72223bbcd16b2a88904dc2a3886c15a0107a7cf3b5373c8e34f52f31"}, - {file = "pygit2-1.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4bebe8b310edc2662cbffb94ef1a758252fe2e4c92bc83fac0eaf2bedf8b871"}, - {file = "pygit2-1.12.1-cp39-cp39-win32.whl", hash = "sha256:77bc0ab778ab6fe631f5f9eb831b426376a7b71426c5a913aaa9088382ef1dc9"}, - {file = "pygit2-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:e87b2306a266f6abca94ab37dda807033a6f40faad05c4d1e089f9e8354130a8"}, - {file = "pygit2-1.12.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5d5e8a3b67f5d4ba8e3838c492254688997747989b184b5f1a3af4fef7f9f53e"}, - {file = "pygit2-1.12.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2500b749759f2efdfa5096c0aafeb2d92152766708f5700284427bd658e5c407"}, - {file = "pygit2-1.12.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c21759ca9cc755faa2d17180cd49af004486ca84f3166cac089a2083dcb09114"}, - {file = "pygit2-1.12.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d73074ab64b383e3a1ab03e8070f6b195ef89b9d379ca5682c38dd9c289cc6e2"}, - {file = "pygit2-1.12.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:865c0d1925c52426455317f29c1db718187ec69ed5474faaf3e1c68ff2135767"}, - {file = "pygit2-1.12.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ebebbe9125b068337b5415565ec94c9e092c708e430851b2d02e51217bdce4a"}, - {file = "pygit2-1.12.1.tar.gz", hash = "sha256:8218922abedc88a65d5092308d533ca4c4ed634aec86a3493d3bdf1a25aeeff3"}, + {file = "pygit2-1.12.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:79fbd99d3e08ca7478150eeba28ca4d4103f564148eab8d00aba8f1e6fc60654"}, + {file = "pygit2-1.12.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:be3bb0139f464947523022a5af343a2e862c4ff250a57ec9f631449e7c0ba7c0"}, + {file = "pygit2-1.12.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4df3e5745fdf3111a6ccc905eae99f22f1a180728f714795138ca540cc2a50a"}, + {file = "pygit2-1.12.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:214bd214784fcbef7a8494d1d59e0cd3a731c0d24ce0f230dcc843322ee33b08"}, + {file = "pygit2-1.12.2-cp310-cp310-win32.whl", hash = "sha256:336c864ac961e7be8ba06e9ed8c999e4f624a8ccd90121cc4e40956d8b57acac"}, + {file = "pygit2-1.12.2-cp310-cp310-win_amd64.whl", hash = "sha256:fb9eb57b75ce586928053692a25aae2a50fef3ad36661c57c07d4902899b1df3"}, + {file = "pygit2-1.12.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f8f813d35d836c5b0d1962c387754786bcc7f1c3c8e11207b9eeb30238ac4cc7"}, + {file = "pygit2-1.12.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:25a6548930328c5247bfb7c67d29104e63b036cb5390f032d9f91f63efb70434"}, + {file = "pygit2-1.12.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a365ffca23d910381749fdbcc367db52fe808f9aa4852914dd9ef8b711384a32"}, + {file = "pygit2-1.12.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec04c27be5d5af1ceecdcc0464e07081222f91f285f156dc53b23751d146569a"}, + {file = "pygit2-1.12.2-cp311-cp311-win32.whl", hash = "sha256:546091316c9a8c37b9867ddcc6c9f7402ca4d0b9db3f349212a7b5e71988e359"}, + {file = "pygit2-1.12.2-cp311-cp311-win_amd64.whl", hash = "sha256:8bf14196cbfffbcd286f459a1d4fc660c5d5dfa8fb422e21216961df575410d6"}, + {file = "pygit2-1.12.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:7bb30ab1fdaa4c30821fed33892958b6d92d50dbd03c76f7775b4e5d62f53a2e"}, + {file = "pygit2-1.12.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e7e705aaecad85b883022e81e054fbd27d26023fc031618ee61c51516580517e"}, + {file = "pygit2-1.12.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac2b5f408eb882e79645ebb43039ac37739c3edd25d857cc97d7482a684b613f"}, + {file = "pygit2-1.12.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22e7f3ad2b7b0c80be991bb47d8a2f2535cc9bf090746eb8679231ee565fde81"}, + {file = "pygit2-1.12.2-cp38-cp38-win32.whl", hash = "sha256:5b3ab4d6302990f7adb2b015bcbda1f0715277008d0c66440497e6f8313bf9cb"}, + {file = "pygit2-1.12.2-cp38-cp38-win_amd64.whl", hash = "sha256:c74e7601cb8b8dc3d02fd32274e200a7761cffd20ee531442bf1fa115c8f99a5"}, + {file = "pygit2-1.12.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6a4083ba093c69142e0400114a4ef75e87834637d2bbfd77b964614bf70f624f"}, + {file = "pygit2-1.12.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:926f2e48c4eaa179249d417b8382290b86b0f01dbf41d289f763576209276b9f"}, + {file = "pygit2-1.12.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14ae27491347a0ac4bbe8347b09d752cfe7fea1121c14525415e0cca6db4a836"}, + {file = "pygit2-1.12.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f65483ab5e3563c58f60debe2acc0979fdf6fd633432fcfbddf727a9a265ba4"}, + {file = "pygit2-1.12.2-cp39-cp39-win32.whl", hash = "sha256:8da8517809635ea3da950d9cf99c6d1851352d92b6db309382db88a01c3b0bfd"}, + {file = "pygit2-1.12.2-cp39-cp39-win_amd64.whl", hash = "sha256:b9c2359b99eed8e7fac30c06e6b4ae277a6a0537d6b4b88a190828c3d7eb9ef2"}, + {file = "pygit2-1.12.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:685378852ef8eb081333bc80dbdfc4f1333cf4a8f3baf614c4135e02ad1ee38a"}, + {file = "pygit2-1.12.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdf655e5f801990f5cad721b6ccbe7610962f0a4f1c20373dbf9c0be39374a81"}, + {file = "pygit2-1.12.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:857c5cde635d470f58803d67bfb281dc4f6336065a0253bfbed001f18e2d0767"}, + {file = "pygit2-1.12.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fe35a72af61961dbb7fb4abcdaa36d5f1c85b2cd3daae94137eeb9c07215cdd3"}, + {file = "pygit2-1.12.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f443d3641762b2bb9c76400bb18beb4ba27dd35bc098a8bfae82e6a190c52ab"}, + {file = "pygit2-1.12.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1e26649e1540b6a774f812e2fc9890320ff4d33f16db1bb02626318b5ceae2"}, + {file = "pygit2-1.12.2.tar.gz", hash = "sha256:56e85d0e66de957d599d1efb2409d39afeefd8f01009bfda0796b42a4b678358"}, ] [package.dependencies] @@ -2620,13 +2700,13 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pytest" -version = "7.3.2" +version = "7.4.0" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" files = [ - {file = "pytest-7.3.2-py3-none-any.whl", hash = "sha256:cdcbd012c9312258922f8cd3f1b62a6580fdced17db6014896053d47cddf9295"}, - {file = "pytest-7.3.2.tar.gz", hash = "sha256:ee990a3cc55ba808b80795a79944756f315c67c12b56abd3ac993a7b8c17030b"}, + {file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"}, + {file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"}, ] [package.dependencies] @@ -2690,51 +2770,51 @@ files = [ [[package]] name = "pyyaml" -version = "6.0" +version = "6.0.1" description = "YAML parser and emitter for Python" optional = false python-versions = ">=3.6" files = [ - {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, - {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, - {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc"}, - {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a80a78046a72361de73f8f395f1f1e49f956c6be882eed58505a15f3e430962b"}, - {file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"}, - {file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"}, - {file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"}, - {file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"}, - {file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"}, - {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"}, - {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"}, - {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"}, - {file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"}, - {file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"}, - {file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"}, - {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"}, - {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"}, - {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:98c4d36e99714e55cfbaaee6dd5badbc9a1ec339ebfc3b1f52e293aee6bb71a4"}, - {file = "PyYAML-6.0-cp36-cp36m-win32.whl", hash = "sha256:0283c35a6a9fbf047493e3a0ce8d79ef5030852c51e9d911a27badfde0605293"}, - {file = "PyYAML-6.0-cp36-cp36m-win_amd64.whl", hash = "sha256:07751360502caac1c067a8132d150cf3d61339af5691fe9e87803040dbc5db57"}, - {file = "PyYAML-6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:819b3830a1543db06c4d4b865e70ded25be52a2e0631ccd2f6a47a2822f2fd7c"}, - {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:473f9edb243cb1935ab5a084eb238d842fb8f404ed2193a915d1784b5a6b5fc0"}, - {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ce82d761c532fe4ec3f87fc45688bdd3a4c1dc5e0b4a19814b9009a29baefd4"}, - {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:231710d57adfd809ef5d34183b8ed1eeae3f76459c18fb4a0b373ad56bedcdd9"}, - {file = "PyYAML-6.0-cp37-cp37m-win32.whl", hash = "sha256:c5687b8d43cf58545ade1fe3e055f70eac7a5a1a0bf42824308d868289a95737"}, - {file = "PyYAML-6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:d15a181d1ecd0d4270dc32edb46f7cb7733c7c508857278d3d378d14d606db2d"}, - {file = "PyYAML-6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0b4624f379dab24d3725ffde76559cff63d9ec94e1736b556dacdfebe5ab6d4b"}, - {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213c60cd50106436cc818accf5baa1aba61c0189ff610f64f4a3e8c6726218ba"}, - {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fa600030013c4de8165339db93d182b9431076eb98eb40ee068700c9c813e34"}, - {file = "PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287"}, - {file = "PyYAML-6.0-cp38-cp38-win32.whl", hash = "sha256:d4eccecf9adf6fbcc6861a38015c2a64f38b9d94838ac1810a9023a0609e1b78"}, - {file = "PyYAML-6.0-cp38-cp38-win_amd64.whl", hash = "sha256:1e4747bc279b4f613a09eb64bba2ba602d8a6664c6ce6396a4d0cd413a50ce07"}, - {file = "PyYAML-6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:055d937d65826939cb044fc8c9b08889e8c743fdc6a32b33e2390f66013e449b"}, - {file = "PyYAML-6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e61ceaab6f49fb8bdfaa0f92c4b57bcfbea54c09277b1b4f7ac376bfb7a7c174"}, - {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d67d839ede4ed1b28a4e8909735fc992a923cdb84e618544973d7dfc71540803"}, - {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cba8c411ef271aa037d7357a2bc8f9ee8b58b9965831d9e51baf703280dc73d3"}, - {file = "PyYAML-6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:40527857252b61eacd1d9af500c3337ba8deb8fc298940291486c465c8b46ec0"}, - {file = "PyYAML-6.0-cp39-cp39-win32.whl", hash = "sha256:b5b9eccad747aabaaffbc6064800670f0c297e52c12754eb1d976c57e4f74dcb"}, - {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, - {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] [[package]] @@ -3099,13 +3179,13 @@ test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "sciki [[package]] name = "scmrepo" -version = "1.0.3" +version = "1.0.4" description = "SCM wrapper and fsspec filesystem for Git for use in DVC" optional = false python-versions = ">=3.8" files = [ - {file = "scmrepo-1.0.3-py3-none-any.whl", hash = "sha256:95a9f4da2f9bc4a1604a01b5d73c4d76432ceff17440ff5e70cfd5d53ae6c47b"}, - {file = "scmrepo-1.0.3.tar.gz", hash = "sha256:70e65e3d614040e6698c1a2c876a82ea1b2b6b85cb4e0a65e5b8c19320f3597d"}, + {file = "scmrepo-1.0.4-py3-none-any.whl", hash = "sha256:048ffd98ab72afb6d3dfb177e5cefe14652ea28377b4258d9dac098cd4461036"}, + {file = "scmrepo-1.0.4.tar.gz", hash = "sha256:d03278d6a86caa5c7f1e85918bc28ef69040a5e5fd04c97cc0eea611ea8be13c"}, ] [package.dependencies] @@ -3120,8 +3200,8 @@ pygtrie = ">=2.3.2" shortuuid = ">=0.5.0" [package.extras] -dev = ["mock (==5.0.1)", "mypy (==0.971)", "paramiko (==3.1.0)", "pylint (==2.15.0)", "pytest (==7.2.0)", "pytest-asyncio (==0.18.3)", "pytest-cov (==3.0.0)", "pytest-docker (==0.12.0)", "pytest-mock (==3.8.2)", "pytest-sugar (==0.9.5)", "pytest-test-utils (==0.0.8)", "types-certifi (==2021.10.8.3)", "types-mock (==5.0.0.6)", "types-paramiko (==3.0.0.7)"] -tests = ["mock (==5.0.1)", "mypy (==0.971)", "paramiko (==3.1.0)", "pylint (==2.15.0)", "pytest (==7.2.0)", "pytest-asyncio (==0.18.3)", "pytest-cov (==3.0.0)", "pytest-docker (==0.12.0)", "pytest-mock (==3.8.2)", "pytest-sugar (==0.9.5)", "pytest-test-utils (==0.0.8)", "types-certifi (==2021.10.8.3)", "types-mock (==5.0.0.6)", "types-paramiko (==3.0.0.7)"] +dev = ["mock (==5.0.1)", "mypy (==0.971)", "paramiko (==3.1.0)", "pylint (==2.15.0)", "pytest (==7.2.0)", "pytest-asyncio (==0.18.3)", "pytest-cov (==3.0.0)", "pytest-docker (==0.12.0)", "pytest-mock (==3.8.2)", "pytest-sugar (==0.9.5)", "pytest-test-utils (==0.0.8)", "types-certifi (==2021.10.8.3)", "types-mock (==5.0.0.6)", "types-paramiko (==3.0.0.10)"] +tests = ["mock (==5.0.1)", "mypy (==0.971)", "paramiko (==3.1.0)", "pylint (==2.15.0)", "pytest (==7.2.0)", "pytest-asyncio (==0.18.3)", "pytest-cov (==3.0.0)", "pytest-docker (==0.12.0)", "pytest-mock (==3.8.2)", "pytest-sugar (==0.9.5)", "pytest-test-utils (==0.0.8)", "types-certifi (==2021.10.8.3)", "types-mock (==5.0.0.6)", "types-paramiko (==3.0.0.10)"] [[package]] name = "sentencepiece" @@ -3179,13 +3259,13 @@ files = [ [[package]] name = "sentry-sdk" -version = "1.26.0" +version = "1.28.1" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.26.0.tar.gz", hash = "sha256:760e4fb6d01c994110507133e08ecd4bdf4d75ee4be77f296a3579796cf73134"}, - {file = "sentry_sdk-1.26.0-py2.py3-none-any.whl", hash = "sha256:0c9f858337ec3781cf4851972ef42bba8c9828aea116b0dbed8f38c5f9a1896c"}, + {file = "sentry-sdk-1.28.1.tar.gz", hash = "sha256:dcd88c68aa64dae715311b5ede6502fd684f70d00a7cd4858118f0ba3153a3ae"}, + {file = "sentry_sdk-1.28.1-py2.py3-none-any.whl", hash = "sha256:6bdb25bd9092478d3a817cb0d01fa99e296aea34d404eac3ca0037faa5c2aa0a"}, ] [package.dependencies] @@ -3305,13 +3385,13 @@ test = ["pytest"] [[package]] name = "setuptools" -version = "67.8.0" +version = "68.0.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.7" files = [ - {file = "setuptools-67.8.0-py3-none-any.whl", hash = "sha256:5df61bf30bb10c6f756eb19e7c9f3b473051f48db77fddbe06ff2ca307df9a6f"}, - {file = "setuptools-67.8.0.tar.gz", hash = "sha256:62642358adc77ffa87233bc4d2354c4b2682d214048f500964dbe760ccedf102"}, + {file = "setuptools-68.0.0-py3-none-any.whl", hash = "sha256:11e52c67415a381d10d6b462ced9cfb97066179f0e871399e006c4ab101fc85f"}, + {file = "setuptools-68.0.0.tar.gz", hash = "sha256:baf1fdb41c6da4cd2eae722e135500da913332ab3f2f5c7d33af9b492acb5235"}, ] [package.extras] @@ -3427,13 +3507,13 @@ widechars = ["wcwidth"] [[package]] name = "threadpoolctl" -version = "3.1.0" +version = "3.2.0" description = "threadpoolctl" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"}, - {file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"}, + {file = "threadpoolctl-3.2.0-py3-none-any.whl", hash = "sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032"}, + {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"}, ] [[package]] @@ -3693,13 +3773,13 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6. [[package]] name = "typing-extensions" -version = "4.6.3" +version = "4.7.1" description = "Backported and Experimental Type Hints for Python 3.7+" optional = false python-versions = ">=3.7" files = [ - {file = "typing_extensions-4.6.3-py3-none-any.whl", hash = "sha256:88a4153d8505aabbb4e13aacb7c486c2b4a33ca3b3f807914a9b4c844c471c26"}, - {file = "typing_extensions-4.6.3.tar.gz", hash = "sha256:d91d5919357fe7f681a9f2b5b4cb2a5f1ef0a1e9f59c4d8ff0d3491e05c0ffd5"}, + {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, + {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, ] [[package]] @@ -3742,13 +3822,13 @@ files = [ [[package]] name = "virtualenv" -version = "20.23.1" +version = "20.24.0" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.23.1-py3-none-any.whl", hash = "sha256:34da10f14fea9be20e0fd7f04aba9732f84e593dac291b757ce42e3368a39419"}, - {file = "virtualenv-20.23.1.tar.gz", hash = "sha256:8ff19a38c1021c742148edc4f81cb43d7f8c6816d2ede2ab72af5b84c749ade1"}, + {file = "virtualenv-20.24.0-py3-none-any.whl", hash = "sha256:18d1b37fc75cc2670625702d76849a91ebd383768b4e91382a8d51be3246049e"}, + {file = "virtualenv-20.24.0.tar.gz", hash = "sha256:e2a7cef9da880d693b933db7654367754f14e20650dc60e8ee7385571f8593a3"}, ] [package.dependencies] @@ -3773,18 +3853,18 @@ files = [ [[package]] name = "wandb" -version = "0.15.4" +version = "0.15.5" description = "A CLI and library for interacting with the Weights and Biases API." optional = false python-versions = ">=3.6" files = [ - {file = "wandb-0.15.4-py3-none-any.whl", hash = "sha256:9018565177e1be14d7d0dd470c583206031c6027c32a98c57fa3bb83955143d7"}, - {file = "wandb-0.15.4.tar.gz", hash = "sha256:472daaaa1a4e29a46407a85fd77aadb724c91d87dfe2c37cd82ef77be2257011"}, + {file = "wandb-0.15.5-py3-none-any.whl", hash = "sha256:8cfb8fdaaf0a35b636d0ca2c2c1262b0e3d835ac37f70fc3094b618f55f63f01"}, + {file = "wandb-0.15.5.tar.gz", hash = "sha256:40c1d9ae501194bff408bc9c555865ffcccf08d2d65dd413547df0c17ed20cb5"}, ] [package.dependencies] appdirs = ">=1.4.3" -Click = ">=7.0,<8.0.0 || >8.0.0" +Click = ">=7.1,<8.0.0 || >8.0.0" docker-pycreds = ">=0.4.0" GitPython = ">=1.0.0,<3.1.29 || >3.1.29" pathtools = "*" @@ -3803,7 +3883,7 @@ azure = ["azure-identity", "azure-storage-blob"] gcp = ["google-cloud-storage"] grpc = ["grpcio (>=1.27.2)"] kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] -launch = ["awscli", "boto3", "botocore", "chardet", "google-auth", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "nbconvert", "nbformat", "optuna", "typing-extensions"] +launch = ["awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "nbconvert", "nbformat", "optuna", "typing-extensions"] media = ["bokeh", "moviepy", "numpy", "pillow", "plotly", "rdkit-pypi", "soundfile"] models = ["cloudpickle"] sweeps = ["sweeps (>=0.2.0)"] From 1628362d23dd50bdba90d0a5526cae3c0c2f0b40 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 19 Jul 2023 22:33:22 +0100 Subject: [PATCH 064/300] Adds benchmarking utility --- .../preprocessing/preprocess_mesh.py | 86 +++++++++++++------ grants_tagger_light/utils/benchmark.py | 29 +++++++ 2 files changed, 88 insertions(+), 27 deletions(-) create mode 100644 grants_tagger_light/utils/benchmark.py diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index bb90f5ce..7294b2bc 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -1,5 +1,6 @@ import json import tempfile +import time import numpy as np import typer @@ -9,6 +10,8 @@ from grants_tagger_light.models.bert_mesh import BertMesh import os +from grants_tagger_light.utils.benchmark import Benchmark + # TODO refactor the two load funcs into a class disable_caching() @@ -65,9 +68,11 @@ def preprocess_mesh( test_size: float = 0.05, num_proc: int = 8, max_samples: int = np.inf, - batch_size: int = 32 + batch_size: int = 32, + benchmark: Benchmark = None, ): - print("Downloading tokenizer and model...") + experiment_name = 'num_proc=' + str(num_proc) + '_max_samples=' + str(max_samples) + if not model_key: label2id = None # Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh @@ -81,24 +86,23 @@ def preprocess_mesh( label2id = {v: k for k, v in model.id2label.items()} - """print("Creating generator...") - dset = Dataset.from_generator( - _datagen, - gen_kwargs={"files": files, "max_samples": max_samples}, - num_proc=num_proc - )""" + start = time.time() if max_samples != np.inf: data_path = create_tmp_file(data_path, max_samples) dset = load_dataset("json", data_files=data_path, num_proc=num_proc) if 'train' in dset: dset = dset['train'] + if benchmark: + benchmark.register(experiment_name, "Loading dataset", str(time.time() - start)) - print("Removing columns...") + start = time.time() # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "year", "pmid", "title"]) + if benchmark: + benchmark.register(experiment_name, "Removing columns", str(time.time() - start)) - print("Tokenizing with map...") + start = time.time() dset = dset.map( _tokenize, batched=True, @@ -108,15 +112,17 @@ def preprocess_mesh( fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, remove_columns=["abstractText"], ) + if benchmark: + benchmark.register(experiment_name, "Tokenizing", str(time.time() - start)) - print("Getting label2id...") + start = time.time() # Generate label2id if None if label2id is None: dset = dset.map( lambda x: {'labels': x["meshMajor"]}, batched=True, batch_size=batch_size, - num_proc=num_proc, + num_proc=1, # Multithreading degrades times, as benchmarking showed desc="Getting labels" ) @@ -128,29 +134,38 @@ def preprocess_mesh( # Step 3: Dictionary creation label2id = {label: idx for idx, label in enumerate(unique_labels_set)} + if benchmark: + benchmark.register(experiment_name, "label2id", str(time.time() - start)) - print("Encoding with map...") + start = time.time() dset = dset.map( _encode_labels, batched=True, batch_size=batch_size, desc="Encoding labels", - num_proc=num_proc, + num_proc=1, # Multithreading degrades times, as benchmarking showed fn_kwargs={"label2id": label2id}, remove_columns=["meshMajor", "labels"], ) + if benchmark: + benchmark.register(experiment_name, "Encoding labels", str(time.time() - start)) - print("Splitting training and test...") # Split into train and test dset = dset.train_test_split(test_size=test_size) - print("Saving to disk...") + start = time.time() # Save to disk - dset.save_to_disk(os.path.join(save_loc, "dataset"), num_proc=num_proc) + dset.save_to_disk( + os.path.join(save_loc, "dataset"), + num_proc=1 # Multithreading degrades times, as benchmarking showed + ) with open(os.path.join(save_loc, "label2id.json"), "w") as f: json.dump(label2id, f, indent=4) + if benchmark: + benchmark.register(experiment_name, "Saving", str(time.time() - start)) + @preprocess_app.command() def preprocess_mesh_cli( @@ -170,17 +185,34 @@ def preprocess_mesh_cli( ), batch_size: int = typer.Option( 32, - help="Size of the preprocessing batch") + help="Size of the preprocessing batch"), + benchmark: bool = typer.Option( + False, + help="Benchmark and create a file with the times") ): if max_samples == -1: max_samples = np.inf - preprocess_mesh( - data_path=data_path, - save_loc=save_loc, - model_key=model_key, - test_size=test_size, - num_proc=num_proc, - max_samples=max_samples, - batch_size=batch_size - ) + if benchmark: + benchmark = Benchmark('preprocessing_mesh_benchmark.csv') + for num_proc in range(1, 9): + preprocess_mesh( + data_path=data_path, + save_loc=save_loc, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, + batch_size=batch_size, + benchmark=benchmark + ) + benchmark.to_csv() + else: + preprocess_mesh( + data_path=data_path, + save_loc=save_loc, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, + batch_size=batch_size) diff --git a/grants_tagger_light/utils/benchmark.py b/grants_tagger_light/utils/benchmark.py new file mode 100644 index 00000000..bd2bf97f --- /dev/null +++ b/grants_tagger_light/utils/benchmark.py @@ -0,0 +1,29 @@ +class Benchmark: + def __init__(self, filename): + self.filename = filename + self.metrics = {} + + def register(self, experiment_name, metric, time): + if experiment_name not in self.metrics: + self.metrics[experiment_name] = dict() + self.metrics[experiment_name][metric] = time + + def to_csv(self): + headers_written = False + with open(self.filename, 'w') as f: + for experiment_name, results in self.metrics.items(): + keys = results.keys() + values = results.values() + if not headers_written: + f.write(experiment_name) + headers_written = True + for k in keys: + f.write(";") + f.write(k) + f.write('\n') + f.write(experiment_name) + for v in values: + f.write(";") + f.write(v) + f.write('\n') + From 669d8e23d782bb43f54646a0722b4afc849f1142 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 20 Jul 2023 13:26:25 +0100 Subject: [PATCH 065/300] Cleaning. Training using preprocessing output --- .../preprocessing/preprocess_mesh.py | 85 ++++--------------- grants_tagger_light/training/train.py | 43 ++++++---- grants_tagger_light/utils/benchmark.py | 29 ------- scripts/mesh_json_to_jsonl.py | 41 +++++++++ 4 files changed, 85 insertions(+), 113 deletions(-) delete mode 100644 grants_tagger_light/utils/benchmark.py create mode 100644 scripts/mesh_json_to_jsonl.py diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 7294b2bc..23f0c337 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -1,23 +1,16 @@ import json import tempfile -import time import numpy as np import typer from transformers import AutoTokenizer from datasets import Dataset, disable_caching, load_dataset -from loguru import logger from grants_tagger_light.models.bert_mesh import BertMesh import os +from loguru import logger -from grants_tagger_light.utils.benchmark import Benchmark - -# TODO refactor the two load funcs into a class - -disable_caching() preprocess_app = typer.Typer() - def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): return tokenizer( batch[x_col], @@ -35,7 +28,7 @@ def _encode_labels(sample, label2id): return {'label_ids': [_map_label_to_ids(x, label2id) for x in sample['meshMajor']]} -def create_tmp_file(jsonl_file, lines): +def create_sample_file(jsonl_file, lines): with open(jsonl_file, 'r') as input_file: with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: for _ in range(lines): @@ -47,20 +40,6 @@ def create_tmp_file(jsonl_file, lines): return tmp_file.name -def _datagen(mesh_json_path: str, max_samples: int = np.inf): - with open(mesh_json_path, "r", encoding="latin1") as f: - for idx, line in enumerate(f): - # Skip 1st line - if idx == 0: - continue - sample = json.loads(line[:-2]) - - if idx > max_samples: - break - - yield sample - - def preprocess_mesh( data_path: str, save_loc: str, @@ -68,11 +47,8 @@ def preprocess_mesh( test_size: float = 0.05, num_proc: int = 8, max_samples: int = np.inf, - batch_size: int = 32, - benchmark: Benchmark = None, + batch_size: int = 32 ): - experiment_name = 'num_proc=' + str(num_proc) + '_max_samples=' + str(max_samples) - if not model_key: label2id = None # Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh @@ -86,23 +62,16 @@ def preprocess_mesh( label2id = {v: k for k, v in model.id2label.items()} - start = time.time() if max_samples != np.inf: - data_path = create_tmp_file(data_path, max_samples) + data_path = create_sample_file(data_path, max_samples) dset = load_dataset("json", data_files=data_path, num_proc=num_proc) if 'train' in dset: dset = dset['train'] - if benchmark: - benchmark.register(experiment_name, "Loading dataset", str(time.time() - start)) - start = time.time() # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "year", "pmid", "title"]) - if benchmark: - benchmark.register(experiment_name, "Removing columns", str(time.time() - start)) - start = time.time() dset = dset.map( _tokenize, batched=True, @@ -112,10 +81,7 @@ def preprocess_mesh( fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, remove_columns=["abstractText"], ) - if benchmark: - benchmark.register(experiment_name, "Tokenizing", str(time.time() - start)) - start = time.time() # Generate label2id if None if label2id is None: dset = dset.map( @@ -134,10 +100,7 @@ def preprocess_mesh( # Step 3: Dictionary creation label2id = {label: idx for idx, label in enumerate(unique_labels_set)} - if benchmark: - benchmark.register(experiment_name, "label2id", str(time.time() - start)) - start = time.time() dset = dset.map( _encode_labels, batched=True, @@ -147,13 +110,10 @@ def preprocess_mesh( fn_kwargs={"label2id": label2id}, remove_columns=["meshMajor", "labels"], ) - if benchmark: - benchmark.register(experiment_name, "Encoding labels", str(time.time() - start)) # Split into train and test dset = dset.train_test_split(test_size=test_size) - start = time.time() # Save to disk dset.save_to_disk( os.path.join(save_loc, "dataset"), @@ -163,13 +123,10 @@ def preprocess_mesh( with open(os.path.join(save_loc, "label2id.json"), "w") as f: json.dump(label2id, f, indent=4) - if benchmark: - benchmark.register(experiment_name, "Saving", str(time.time() - start)) - @preprocess_app.command() def preprocess_mesh_cli( - data_path: str = typer.Argument(..., help="Path to mesh.json"), + data_path: str = typer.Argument(..., help="Path to mesh.jsonl"), save_loc: str = typer.Argument(..., help="Path to save processed data"), model_key: str = typer.Argument( ..., @@ -186,33 +143,27 @@ def preprocess_mesh_cli( batch_size: int = typer.Option( 32, help="Size of the preprocessing batch"), - benchmark: bool = typer.Option( + disable_cache: bool = typer.Option( False, - help="Benchmark and create a file with the times") + help="Do you want to disable `Datasets` caching? (not recommended)") + ): + if disable_cache: + disable_caching() + if max_samples == -1: max_samples = np.inf - if benchmark: - benchmark = Benchmark('preprocessing_mesh_benchmark.csv') - for num_proc in range(1, 9): - preprocess_mesh( + if not data_path.endswith('jsonl'): + logger.error("It seems your input MeSH data is not in `jsonl` format. " + "Please, run first `scripts/mesh_json_to_jsonlpy.`") + exit(1) + + preprocess_mesh( data_path=data_path, save_loc=save_loc, model_key=model_key, test_size=test_size, num_proc=num_proc, max_samples=max_samples, - batch_size=batch_size, - benchmark=benchmark - ) - benchmark.to_csv() - else: - preprocess_mesh( - data_path=data_path, - save_loc=save_loc, - model_key=model_key, - test_size=test_size, - num_proc=num_proc, - max_samples=max_samples, - batch_size=batch_size) + batch_size=batch_size) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 563c1a84..112a60e0 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -24,6 +24,7 @@ import os import transformers import json +from datasets import load_from_disk from grants_tagger_light.utils.sharding import Sharding from grants_tagger_light.utils.utils import calculate_max_steps @@ -34,7 +35,6 @@ def train_bertmesh( model_key: str, data_path: str, - max_samples: int, training_args: TrainingArguments, model_args: BertMeshModelArguments = None, ): @@ -46,12 +46,16 @@ def train_bertmesh( logger.info("No model key provided. Training model from scratch") # Instantiate model from scratch + logger.info(f"Loading `{model_args.pretrained_model_key}` tokenizer...") config = AutoConfig.from_pretrained(model_args.pretrained_model_key) AutoTokenizer.from_pretrained(model_args.pretrained_model_key) - dset = load_dataset(os.path.join(data_path, "dataset")) + logger.info(f"Loading preprocessed dataset at {data_path}...") + dset = load_from_disk(os.path.join(data_path, 'dataset')) + logger.info(f"Sharding training dataset...") train_dset, val_dset = Sharding(num_shards=100).shard(dset["train"]), dset["test"] + logger.info(f"Loading labels and other model configurations...") with open(os.path.join(data_path, "label2id.json"), "r") as f: label2id = json.load(f) @@ -72,13 +76,17 @@ def train_bertmesh( logger.info(f"Training model from pretrained key {model_key}") # Instantiate from pretrained + logger.info(f"Loading `{model_key}` tokenizer...") model = BertMesh.from_pretrained(model_key, trust_remote_code=True) AutoTokenizer.from_pretrained(model_key) + logger.info(f"Loading labels from model {model_key}...") label2id = {v: k for k, v in model.id2label.items()} - dset = load_dataset(os.path.join(data_path, "dataset")) + logger.info(f"Loading preprocessed dataset at {data_path}...") + dset = load_from_disk(os.path.join(data_path, 'dataset')) + logger.info(f"Sharding training dataset...") train_dset, val_dset = Sharding(num_shards=100).shard(dset["train"]), dset["test"] if model_args.freeze_backbone: @@ -105,14 +113,14 @@ def sklearn_metrics(prediction: EvalPrediction): return metric_dict + logger.info(f"Collating labels...") collator = MultilabelDataCollator(label2id=label2id) + logger.info(f"Calculating max steps for IterableDatasets...") max_steps = calculate_max_steps(training_args, dset) - print("Before:\n" + str(training_args.max_steps)) training_args.max_steps = max_steps - print("After:\n" + str(training_args.max_steps)) - print(max_steps) + logger.info(f"Initializing Trainer...") trainer = Trainer( model=model, args=training_args, @@ -122,12 +130,15 @@ def sklearn_metrics(prediction: EvalPrediction): compute_metrics=sklearn_metrics ) + logger.info(f"Training...") trainer.train() + logger.info(f"Evaluating...") metrics = trainer.evaluate(eval_dataset=val_dset) logger.info(pformat(metrics)) + logger.info(f"Saving the model...") trainer.save_model(os.path.join(training_args.output_dir, "best")) @@ -142,15 +153,14 @@ def train_bertmesh_cli( ), data_path: str = typer.Argument( ..., - help="Path to data in jsonl format. Must contain text and tags field", - ), - max_samples: int = typer.Argument( - -1, - help="Maximum number of samples to use for training. Useful for dev/debugging", - ), + help="Path to PyArrow folder with preprocessed Mesh. If not available, run first " + "`grants-tagger preprocess mesh [input_jsonl] [output_pyarrrow_folder] [model_key|'']`", + ) ): - if max_samples == -1: - max_samples = np.inf + if not os.path.isdir(data_path): + logger.error("`data_path` should be a folder resulted as the output of calling to " + "`grants-tagger preprocess mesh [input_jsonl] [output_pyarrrow_folder] [model_key|'']`") + exit(1) parser = HfArgumentParser( ( @@ -168,7 +178,7 @@ def train_bertmesh_cli( logger.info("Training args: {}".format(pformat(training_args))) logger.info("Wandb args: {}".format(pformat(wandb_args))) - train_bertmesh(model_key, data_path, max_samples, training_args, model_args) + train_bertmesh(model_key, data_path, training_args, model_args) if __name__ == "__main__": @@ -192,7 +202,6 @@ class TrainFuncArgs: train_bertmesh( func_args.model_key, func_args.data_path, - func_args.max_samples, training_args, - model_args, + model_args ) diff --git a/grants_tagger_light/utils/benchmark.py b/grants_tagger_light/utils/benchmark.py deleted file mode 100644 index bd2bf97f..00000000 --- a/grants_tagger_light/utils/benchmark.py +++ /dev/null @@ -1,29 +0,0 @@ -class Benchmark: - def __init__(self, filename): - self.filename = filename - self.metrics = {} - - def register(self, experiment_name, metric, time): - if experiment_name not in self.metrics: - self.metrics[experiment_name] = dict() - self.metrics[experiment_name][metric] = time - - def to_csv(self): - headers_written = False - with open(self.filename, 'w') as f: - for experiment_name, results in self.metrics.items(): - keys = results.keys() - values = results.values() - if not headers_written: - f.write(experiment_name) - headers_written = True - for k in keys: - f.write(";") - f.write(k) - f.write('\n') - f.write(experiment_name) - for v in values: - f.write(";") - f.write(v) - f.write('\n') - diff --git a/scripts/mesh_json_to_jsonl.py b/scripts/mesh_json_to_jsonl.py new file mode 100644 index 00000000..432321e2 --- /dev/null +++ b/scripts/mesh_json_to_jsonl.py @@ -0,0 +1,41 @@ +import json +from argparse import ArgumentParser + + +def mesh_to_jsonl(input_path, output_path, input_encoding='latin1', output_encoding='latin1'): + """ + Mesh json is not optimized for parallel processing. This script aims to transform it into `jsonl` so that, + by having 1 json per line, you can evenly distribute it among all the cores. + + Args: + input_path: allMeSH_2021.json (or similar) path + output_path: path for the resulted jsonl file + input_encoding: encoding of the input json (default `latin1`) + output_encoding: encoding of the output jsonl (default `latin1`) + + Returns: + + """ + with open(output_path, 'w', encoding=output_encoding) as fw: + with open(input_path, 'r', encoding=input_encoding) as fr: + for idx, line in enumerate(fr): + print(idx, end='\r') + # Skip 1st line + if idx == 0: + continue + sample = json.loads(line[:-2]) + + fw.write(json.dumps(sample)) + fw.write('\n') + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--input_path", required=True, help="path to input allMeSH_2021.json (or equivalent)") + parser.add_argument("--output_path", required=True, help="path to input jsonl") + parser.add_argument("--input_encoding", required=False, default='latin1', help="encoding of the input json") + parser.add_argument("--output_encoding", required=False, default='latin1', help="encoding of the output jsonl") + + args = parser.parse_args() + + mesh_to_jsonl(args.input_path, args.output_path, args.input_encoding, args.output_encoding) From 9d41a3ffb7091acb16c87792dc873be4a4b0877b Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 10:33:45 +0100 Subject: [PATCH 066/300] Cleaning. Training using preprocessing output --- README.md | 3 +++ .../preprocessing/preprocess_mesh.py | 21 ++++++++++++++----- tests/__init__.py | 0 3 files changed, 19 insertions(+), 5 deletions(-) create mode 100644 tests/__init__.py diff --git a/README.md b/README.md index e5c80332..7b2c171b 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,9 @@ and you would be able to run `grants_tagger preprocess epmc_mesh ...` ## 🚦 Test +To run the test you need to have installed the `dev` dependencies first. +This is done by running `poetry install --with dev` after you are in the sell (`poetry shell`) + Run tests with `pytest`. If you want to write some additional tests, they should go in the subfolde `tests/` diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 23f0c337..de48cb46 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -11,6 +11,7 @@ preprocess_app = typer.Typer() + def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): return tokenizer( batch[x_col], @@ -65,7 +66,8 @@ def preprocess_mesh( if max_samples != np.inf: data_path = create_sample_file(data_path, max_samples) - dset = load_dataset("json", data_files=data_path, num_proc=num_proc) + dset = load_dataset("json", data_files=data_path, num_proc=1) # We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing + # By default, any dataset loaded is set to 'train' using the previous command if 'train' in dset: dset = dset['train'] @@ -84,21 +86,29 @@ def preprocess_mesh( # Generate label2id if None if label2id is None: + logger.info("labelGetting the labels. Please, wait...") dset = dset.map( lambda x: {'labels': x["meshMajor"]}, batched=True, batch_size=batch_size, - num_proc=1, # Multithreading degrades times, as benchmarking showed + num_proc=num_proc, # Multithreading degrades times in some cases: https://github.com/huggingface/datasets/issues/1992 desc="Getting labels" ) + logger.info("Obtaining the labels. Please, wait...") # Step 1: Get the 'labels' column from the dataset labels_column = dset['labels'] # Step 2: Flatten the list column and compute unique values - unique_labels_set = set(label for sublist in labels_column for label in sublist) + # unique_labels_set = set(label for sublist in labels_column for label in sublist) + unique_labels_set = set() + + # Iterate through the lists and add elements to the set + for arr in labels_column: + unique_labels_set.update(arr.to_pylist()) # Step 3: Dictionary creation + logger.info("Creating label2id. Please, wait...") label2id = {label: idx for idx, label in enumerate(unique_labels_set)} dset = dset.map( @@ -106,18 +116,19 @@ def preprocess_mesh( batched=True, batch_size=batch_size, desc="Encoding labels", - num_proc=1, # Multithreading degrades times, as benchmarking showed + num_proc=num_proc, # Multithreading degrades times in some cases: https://github.com/huggingface/datasets/issues/1992 fn_kwargs={"label2id": label2id}, remove_columns=["meshMajor", "labels"], ) + logger.info("Preparing train/test split. Please, wait...") # Split into train and test dset = dset.train_test_split(test_size=test_size) # Save to disk dset.save_to_disk( os.path.join(save_loc, "dataset"), - num_proc=1 # Multithreading degrades times, as benchmarking showed + num_proc=num_proc # Multithreading degrades times in some cases: https://github.com/huggingface/datasets/issues/1992 ) with open(os.path.join(save_loc, "label2id.json"), "w") as f: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b From 273c705abb72dad0320ae3ad52347dd01f409150 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 10:35:56 +0100 Subject: [PATCH 067/300] Cleaning. Training using preprocessing output --- grants_tagger_light/preprocessing/preprocess_mesh.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index de48cb46..1a04da57 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -86,7 +86,7 @@ def preprocess_mesh( # Generate label2id if None if label2id is None: - logger.info("labelGetting the labels. Please, wait...") + logger.info("Getting the labels...") dset = dset.map( lambda x: {'labels': x["meshMajor"]}, batched=True, @@ -95,12 +95,13 @@ def preprocess_mesh( desc="Getting labels" ) - logger.info("Obtaining the labels. Please, wait...") + logger.info("Obtaining the labels...") # Step 1: Get the 'labels' column from the dataset labels_column = dset['labels'] # Step 2: Flatten the list column and compute unique values # unique_labels_set = set(label for sublist in labels_column for label in sublist) + logger.info("Obtaining unique values from the labels...") unique_labels_set = set() # Iterate through the lists and add elements to the set @@ -108,7 +109,7 @@ def preprocess_mesh( unique_labels_set.update(arr.to_pylist()) # Step 3: Dictionary creation - logger.info("Creating label2id. Please, wait...") + logger.info("Creating label2id dictionary...") label2id = {label: idx for idx, label in enumerate(unique_labels_set)} dset = dset.map( @@ -121,10 +122,11 @@ def preprocess_mesh( remove_columns=["meshMajor", "labels"], ) - logger.info("Preparing train/test split. Please, wait...") + logger.info("Preparing train/test split....") # Split into train and test dset = dset.train_test_split(test_size=test_size) + logger.info("Saving to disk...") # Save to disk dset.save_to_disk( os.path.join(save_loc, "dataset"), From a774e726f02f6b01d819dd8481b3afd43cb4deaf Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 10:39:56 +0100 Subject: [PATCH 068/300] Improving some steps --- grants_tagger_light/preprocessing/preprocess_mesh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 1a04da57..e83bca4e 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -106,7 +106,7 @@ def preprocess_mesh( # Iterate through the lists and add elements to the set for arr in labels_column: - unique_labels_set.update(arr.to_pylist()) + unique_labels_set.update(arr) # Step 3: Dictionary creation logger.info("Creating label2id dictionary...") From f8a10df5ed9386342d22040a881e19b88aaed73f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 11:04:08 +0100 Subject: [PATCH 069/300] Improving some steps --- .../preprocessing/preprocess_mesh.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index e83bca4e..27d9abd4 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -48,7 +48,7 @@ def preprocess_mesh( test_size: float = 0.05, num_proc: int = 8, max_samples: int = np.inf, - batch_size: int = 32 + batch_size: int = 256 ): if not model_key: label2id = None @@ -95,18 +95,10 @@ def preprocess_mesh( desc="Getting labels" ) - logger.info("Obtaining the labels...") - # Step 1: Get the 'labels' column from the dataset - labels_column = dset['labels'] - - # Step 2: Flatten the list column and compute unique values - # unique_labels_set = set(label for sublist in labels_column for label in sublist) logger.info("Obtaining unique values from the labels...") unique_labels_set = set() - # Iterate through the lists and add elements to the set - for arr in labels_column: - unique_labels_set.update(arr) + unique_labels_set.update([arr for arr in dset['labels']]) # Step 3: Dictionary creation logger.info("Creating label2id dictionary...") @@ -154,7 +146,7 @@ def preprocess_mesh_cli( help="Maximum number of samples to use for preprocessing", ), batch_size: int = typer.Option( - 32, + 256, help="Size of the preprocessing batch"), disable_cache: bool = typer.Option( False, From 3c99d253f24e13b5f74be017053971c2c04c3569 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 11:22:38 +0100 Subject: [PATCH 070/300] Improving some steps --- grants_tagger_light/preprocessing/preprocess_mesh.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 27d9abd4..f64b9d05 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -95,14 +95,17 @@ def preprocess_mesh( desc="Getting labels" ) - logger.info("Obtaining unique values from the labels...") unique_labels_set = set() - unique_labels_set.update([arr for arr in dset['labels']]) + logger.info("Obtaining unique values from the labels...") + # Iterate through the lists and add elements to the set + for arr in tqdm(dset['labels']): + unique_labels_set.update(arr) - # Step 3: Dictionary creation logger.info("Creating label2id dictionary...") - label2id = {label: idx for idx, label in enumerate(unique_labels_set)} + label2id = dict() + for idx, label in enumerate(tqdm(unique_labels_set)): + label2id.update({label: idx}) dset = dset.map( _encode_labels, From f5af79ba79379f548d64af2a2249f65c2f9ee1bb Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 11:24:37 +0100 Subject: [PATCH 071/300] Improving some steps --- grants_tagger_light/preprocessing/preprocess_mesh.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index f64b9d05..c30f812b 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -8,6 +8,7 @@ from grants_tagger_light.models.bert_mesh import BertMesh import os from loguru import logger +from tqdm import tqdm preprocess_app = typer.Typer() From 256e07cb3d4ddab0828245585c8c1c7a1467ec81 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 13:36:32 +0100 Subject: [PATCH 072/300] Improving save_to_disk --- grants_tagger_light/preprocessing/preprocess_mesh.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index c30f812b..b6aa200e 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -67,7 +67,8 @@ def preprocess_mesh( if max_samples != np.inf: data_path = create_sample_file(data_path, max_samples) - dset = load_dataset("json", data_files=data_path, num_proc=1) # We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing + # We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing + dset = load_dataset("json", data_files=data_path, num_proc=1) # By default, any dataset loaded is set to 'train' using the previous command if 'train' in dset: dset = dset['train'] @@ -92,10 +93,11 @@ def preprocess_mesh( lambda x: {'labels': x["meshMajor"]}, batched=True, batch_size=batch_size, - num_proc=num_proc, # Multithreading degrades times in some cases: https://github.com/huggingface/datasets/issues/1992 + num_proc=num_proc, desc="Getting labels" ) + # Most efficient way to do dedup of labels unique_labels_set = set() logger.info("Obtaining unique values from the labels...") @@ -103,6 +105,7 @@ def preprocess_mesh( for arr in tqdm(dset['labels']): unique_labels_set.update(arr) + # Most efficient way to do dictionary creation logger.info("Creating label2id dictionary...") label2id = dict() for idx, label in enumerate(tqdm(unique_labels_set)): @@ -113,7 +116,7 @@ def preprocess_mesh( batched=True, batch_size=batch_size, desc="Encoding labels", - num_proc=num_proc, # Multithreading degrades times in some cases: https://github.com/huggingface/datasets/issues/1992 + num_proc=num_proc, fn_kwargs={"label2id": label2id}, remove_columns=["meshMajor", "labels"], ) @@ -126,7 +129,8 @@ def preprocess_mesh( # Save to disk dset.save_to_disk( os.path.join(save_loc, "dataset"), - num_proc=num_proc # Multithreading degrades times in some cases: https://github.com/huggingface/datasets/issues/1992 + num_proc=1, + num_shards=1 ) with open(os.path.join(save_loc, "label2id.json"), "w") as f: From ee18b7592faafea308eba26308f532cf1614897f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 13:39:02 +0100 Subject: [PATCH 073/300] Improving save_to_disk --- grants_tagger_light/preprocessing/preprocess_mesh.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index b6aa200e..e056b2fc 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -75,7 +75,7 @@ def preprocess_mesh( # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "year", "pmid", "title"]) - + """ dset = dset.map( _tokenize, batched=True, @@ -124,7 +124,7 @@ def preprocess_mesh( logger.info("Preparing train/test split....") # Split into train and test dset = dset.train_test_split(test_size=test_size) - + """ logger.info("Saving to disk...") # Save to disk dset.save_to_disk( From daf94fafd6b404acbb5eff6a6fa98c5b9caa4079 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 14:34:14 +0100 Subject: [PATCH 074/300] Removing save_to_disk --- grants_tagger_light/preprocessing/preprocess_mesh.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index e056b2fc..30eb1087 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -75,7 +75,7 @@ def preprocess_mesh( # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "year", "pmid", "title"]) - """ + dset = dset.map( _tokenize, batched=True, @@ -124,17 +124,14 @@ def preprocess_mesh( logger.info("Preparing train/test split....") # Split into train and test dset = dset.train_test_split(test_size=test_size) - """ + logger.info("Saving to disk...") # Save to disk dset.save_to_disk( - os.path.join(save_loc, "dataset"), - num_proc=1, - num_shards=1 + os.path.join(save_loc, "dataset") ) - with open(os.path.join(save_loc, "label2id.json"), "w") as f: - json.dump(label2id, f, indent=4) + return dset, label2id @preprocess_app.command() From 7e6a2b05e99c60f8d0645cd6206d4cea6307bbe8 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 15:20:11 +0100 Subject: [PATCH 075/300] Adapts training to call preprocess --- .../preprocessing/preprocess_mesh.py | 27 +++----- grants_tagger_light/training/train.py | 68 +++++++++++++------ 2 files changed, 57 insertions(+), 38 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 30eb1087..5e4781ae 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -44,10 +44,9 @@ def create_sample_file(jsonl_file, lines): def preprocess_mesh( data_path: str, - save_loc: str, model_key: str, test_size: float = 0.05, - num_proc: int = 8, + num_proc: int = os.cpu_count(), max_samples: int = np.inf, batch_size: int = 256 ): @@ -125,42 +124,39 @@ def preprocess_mesh( # Split into train and test dset = dset.train_test_split(test_size=test_size) + """ logger.info("Saving to disk...") + # Save to disk dset.save_to_disk( os.path.join(save_loc, "dataset") ) - + """ return dset, label2id @preprocess_app.command() def preprocess_mesh_cli( - data_path: str = typer.Argument(..., help="Path to mesh.jsonl"), - save_loc: str = typer.Argument(..., help="Path to save processed data"), + data_path: str = typer.Argument( + ..., + help="Path to mesh.jsonl" + ), model_key: str = typer.Argument( ..., help="Key to use when loading tokenizer and label2id. Leave blank if training from scratch", # noqa ), test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), num_proc: int = typer.Option( - 8, help="Number of processes to use for preprocessing" + os.cpu_count(), help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( - -1, + np.inf, help="Maximum number of samples to use for preprocessing", ), batch_size: int = typer.Option( 256, - help="Size of the preprocessing batch"), - disable_cache: bool = typer.Option( - False, - help="Do you want to disable `Datasets` caching? (not recommended)") - + help="Size of the preprocessing batch") ): - if disable_cache: - disable_caching() - if max_samples == -1: max_samples = np.inf @@ -171,7 +167,6 @@ def preprocess_mesh_cli( preprocess_mesh( data_path=data_path, - save_loc=save_loc, model_key=model_key, test_size=test_size, num_proc=num_proc, diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 112a60e0..aaecfb9a 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -7,6 +7,7 @@ AutoConfig, ) from grants_tagger_light.models.bert_mesh import BertMesh +from grants_tagger_light.preprocessing.preprocess_mesh import preprocess_mesh from grants_tagger_light.training.cli_args import ( BertMeshTrainingArguments, WandbArguments, @@ -37,6 +38,10 @@ def train_bertmesh( data_path: str, training_args: TrainingArguments, model_args: BertMeshModelArguments = None, + max_samples:int = np.inf, + test_size: float = 0.05, + num_proc: int = os.cpu_count(), + batch_size: int = 256 ): if not model_key: assert isinstance( @@ -48,17 +53,20 @@ def train_bertmesh( # Instantiate model from scratch logger.info(f"Loading `{model_args.pretrained_model_key}` tokenizer...") config = AutoConfig.from_pretrained(model_args.pretrained_model_key) - AutoTokenizer.from_pretrained(model_args.pretrained_model_key) - logger.info(f"Loading preprocessed dataset at {data_path}...") - dset = load_from_disk(os.path.join(data_path, 'dataset')) + logger.info(f"Preprocessing the dataset at {data_path}...") + # dset = load_from_disk(os.path.join(data_path, 'dataset')) + dset, label2id = preprocess_mesh( + data_path=data_path, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, + batch_size=batch_size) + logger.info(f"Sharding training dataset...") train_dset, val_dset = Sharding(num_shards=100).shard(dset["train"]), dset["test"] - logger.info(f"Loading labels and other model configurations...") - with open(os.path.join(data_path, "label2id.json"), "r") as f: - label2id = json.load(f) - config.update( { "pretrained_model": model_args.pretrained_model_key, @@ -78,13 +86,19 @@ def train_bertmesh( # Instantiate from pretrained logger.info(f"Loading `{model_key}` tokenizer...") model = BertMesh.from_pretrained(model_key, trust_remote_code=True) - AutoTokenizer.from_pretrained(model_key) - logger.info(f"Loading labels from model {model_key}...") - label2id = {v: k for k, v in model.id2label.items()} + # logger.info(f"Loading labels from model {model_key}...") + # label2id = {v: k for k, v in model.id2label.items()} - logger.info(f"Loading preprocessed dataset at {data_path}...") - dset = load_from_disk(os.path.join(data_path, 'dataset')) + logger.info(f"Preprocessing the dataset at {data_path}...") + # dset = load_from_disk(os.path.join(data_path, 'dataset')) + dset, label2id = preprocess_mesh( + data_path=data_path, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, + batch_size=batch_size) logger.info(f"Sharding training dataset...") train_dset, val_dset = Sharding(num_shards=100).shard(dset["train"]), dset["test"] @@ -153,14 +167,20 @@ def train_bertmesh_cli( ), data_path: str = typer.Argument( ..., - help="Path to PyArrow folder with preprocessed Mesh. If not available, run first " - "`grants-tagger preprocess mesh [input_jsonl] [output_pyarrrow_folder] [model_key|'']`", - ) + help="Path to mesh.jsonl" + ), + test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), + num_proc: int = typer.Option( + os.cpu_count(), help="Number of processes to use for preprocessing" + ), + max_samples: int = typer.Option( + -1, + help="Maximum number of samples to use for preprocessing", + ), + batch_size : int = typer.Option( + 256, + help="Size of the preprocessing batch") ): - if not os.path.isdir(data_path): - logger.error("`data_path` should be a folder resulted as the output of calling to " - "`grants-tagger preprocess mesh [input_jsonl] [output_pyarrrow_folder] [model_key|'']`") - exit(1) parser = HfArgumentParser( ( @@ -178,9 +198,9 @@ def train_bertmesh_cli( logger.info("Training args: {}".format(pformat(training_args))) logger.info("Wandb args: {}".format(pformat(wandb_args))) - train_bertmesh(model_key, data_path, training_args, model_args) - + train_bertmesh(model_key, data_path, training_args, model_args, max_samples, test_size, num_proc, batch_size) +""" if __name__ == "__main__": from dataclasses import dataclass @@ -203,5 +223,9 @@ class TrainFuncArgs: func_args.model_key, func_args.data_path, training_args, - model_args + model_args, + max_samples, + test_size, + batch_size ) +""" \ No newline at end of file From f804517ac21b2447e1ae8656eb474047c79f3563 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 15:30:40 +0100 Subject: [PATCH 076/300] Adapts training to call preprocess --- grants_tagger_light/preprocessing/preprocess_mesh.py | 3 --- grants_tagger_light/training/train.py | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 5e4781ae..ae4e452b 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -157,9 +157,6 @@ def preprocess_mesh_cli( 256, help="Size of the preprocessing batch") ): - if max_samples == -1: - max_samples = np.inf - if not data_path.endswith('jsonl'): logger.error("It seems your input MeSH data is not in `jsonl` format. " "Please, run first `scripts/mesh_json_to_jsonlpy.`") diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index aaecfb9a..611b546a 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -38,7 +38,7 @@ def train_bertmesh( data_path: str, training_args: TrainingArguments, model_args: BertMeshModelArguments = None, - max_samples:int = np.inf, + max_samples: int = np.inf, test_size: float = 0.05, num_proc: int = os.cpu_count(), batch_size: int = 256 @@ -174,10 +174,10 @@ def train_bertmesh_cli( os.cpu_count(), help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( - -1, + np.inf, help="Maximum number of samples to use for preprocessing", ), - batch_size : int = typer.Option( + batch_size: int = typer.Option( 256, help="Size of the preprocessing batch") ): From 18fb506a76f1135f0d76f980e9fe5d195480e1ce Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 15:33:49 +0100 Subject: [PATCH 077/300] Adapts training to call preprocess --- grants_tagger_light/preprocessing/preprocess_mesh.py | 12 +++++++----- grants_tagger_light/training/train.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index ae4e452b..f1c11e28 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -47,9 +47,14 @@ def preprocess_mesh( model_key: str, test_size: float = 0.05, num_proc: int = os.cpu_count(), - max_samples: int = np.inf, + max_samples: int = -1, batch_size: int = 256 ): + if max_samples == -1: + max_samples = np.inf + else: + data_path = create_sample_file(data_path, max_samples) + if not model_key: label2id = None # Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh @@ -63,9 +68,6 @@ def preprocess_mesh( label2id = {v: k for k, v in model.id2label.items()} - if max_samples != np.inf: - data_path = create_sample_file(data_path, max_samples) - # We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing dset = load_dataset("json", data_files=data_path, num_proc=1) # By default, any dataset loaded is set to 'train' using the previous command @@ -150,7 +152,7 @@ def preprocess_mesh_cli( os.cpu_count(), help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( - np.inf, + -1, help="Maximum number of samples to use for preprocessing", ), batch_size: int = typer.Option( diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 611b546a..840a4672 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -38,7 +38,7 @@ def train_bertmesh( data_path: str, training_args: TrainingArguments, model_args: BertMeshModelArguments = None, - max_samples: int = np.inf, + max_samples: int = -1, test_size: float = 0.05, num_proc: int = os.cpu_count(), batch_size: int = 256 @@ -174,7 +174,7 @@ def train_bertmesh_cli( os.cpu_count(), help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( - np.inf, + -1, help="Maximum number of samples to use for preprocessing", ), batch_size: int = typer.Option( From 60d05ccb30422c78289533915fddbb1a93e71b93 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 16:58:45 +0100 Subject: [PATCH 078/300] Adapts training to call preprocess --- grants_tagger_light/preprocessing/preprocess_mesh.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index f1c11e28..283db48d 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -125,15 +125,7 @@ def preprocess_mesh( logger.info("Preparing train/test split....") # Split into train and test dset = dset.train_test_split(test_size=test_size) - - """ - logger.info("Saving to disk...") - - # Save to disk - dset.save_to_disk( - os.path.join(save_loc, "dataset") - ) - """ + return dset, label2id From afe04be002b719f777d837e2ce2dcc64b292cfef Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 21 Jul 2023 21:49:08 +0100 Subject: [PATCH 079/300] Allows pretraining+saving or just pretraining before train --- .../preprocessing/preprocess_mesh.py | 30 +++++- grants_tagger_light/training/train.py | 93 ++++++++++++------- 2 files changed, 87 insertions(+), 36 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 283db48d..708c925f 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -48,7 +48,8 @@ def preprocess_mesh( test_size: float = 0.05, num_proc: int = os.cpu_count(), max_samples: int = -1, - batch_size: int = 256 + batch_size: int = 256, + save_to_path: str = None ): if max_samples == -1: max_samples = np.inf @@ -126,6 +127,12 @@ def preprocess_mesh( # Split into train and test dset = dset.train_test_split(test_size=test_size) + if save_to_path is not None: + logger.info("Saving to disk...") + dset.save_to_disk(os.path.join(save_to_path, 'dataset'), num_proc=num_proc) + with open(os.path.join(save_to_path, 'label2id'), 'w') as f: + json.dump(label2id, f) + return dset, label2id @@ -135,13 +142,20 @@ def preprocess_mesh_cli( ..., help="Path to mesh.jsonl" ), + save_to_path: str = typer.Argument( + ..., + help="Path to save the serialized PyArrow dataset after preprocessing" + ), model_key: str = typer.Argument( ..., help="Key to use when loading tokenizer and label2id. Leave blank if training from scratch", # noqa ), - test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), + test_size: float = typer.Option( + 0.05, + help="Fraction of data to use for testing"), num_proc: int = typer.Option( - os.cpu_count(), help="Number of processes to use for preprocessing" + os.cpu_count(), + help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( -1, @@ -151,6 +165,13 @@ def preprocess_mesh_cli( 256, help="Size of the preprocessing batch") ): + print("\033[96mRunning preprocessing will save the data as a PyArrow dataset which is a very time consuming " + "operation. If you don't need the data to be saved, you can save much time just by running:\n" + "\t`grants-tagger train bertmesh {model_key} {path_to_jsonl}`\033[0m") + + if input('Do You Want To Continue? [Y/n]') != 'Y': + exit(1) + if not data_path.endswith('jsonl'): logger.error("It seems your input MeSH data is not in `jsonl` format. " "Please, run first `scripts/mesh_json_to_jsonlpy.`") @@ -162,4 +183,5 @@ def preprocess_mesh_cli( test_size=test_size, num_proc=num_proc, max_samples=max_samples, - batch_size=batch_size) + batch_size=batch_size, + save_to_path=save_to_path) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 840a4672..2c0c58a0 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -41,7 +41,8 @@ def train_bertmesh( max_samples: int = -1, test_size: float = 0.05, num_proc: int = os.cpu_count(), - batch_size: int = 256 + batch_size: int = 256, + shards: int = -1 ): if not model_key: assert isinstance( @@ -55,17 +56,33 @@ def train_bertmesh( config = AutoConfig.from_pretrained(model_args.pretrained_model_key) logger.info(f"Preprocessing the dataset at {data_path}...") - # dset = load_from_disk(os.path.join(data_path, 'dataset')) - dset, label2id = preprocess_mesh( - data_path=data_path, - model_key=model_key, - test_size=test_size, - num_proc=num_proc, - max_samples=max_samples, - batch_size=batch_size) - - logger.info(f"Sharding training dataset...") - train_dset, val_dset = Sharding(num_shards=100).shard(dset["train"]), dset["test"] + if os.path.isdir(data_path): + logger.info(f"Loading from disk...") + dset = load_from_disk(os.path.join(data_path, 'dataset')) + with open(os.path.join(data_path, 'label2id'), 'r') as f: + label2id = json.load(f) + else: + logger.info(f"Preprocessing started!") + dset, label2id = preprocess_mesh( + data_path=data_path, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, + batch_size=batch_size) + + train_dset, val_dset = dset["train"], dset["test"] + train_dset_size = len(train_dset) + if max_samples > 0: + max_samples = min(max_samples, train_dset_size) + logger.info(f"Training max samples: {max_samples}.") + train_dset.filter(lambda example, idx: idx < max_samples, with_indices=True) + else: + logger.info(f"Training with all data...") + + if shards > 0: + logger.info(f"Sharding training dataset...") + train_dset = Sharding(num_shards=shards).shard(train_dset) config.update( { @@ -87,21 +104,29 @@ def train_bertmesh( logger.info(f"Loading `{model_key}` tokenizer...") model = BertMesh.from_pretrained(model_key, trust_remote_code=True) - # logger.info(f"Loading labels from model {model_key}...") - # label2id = {v: k for k, v in model.id2label.items()} - logger.info(f"Preprocessing the dataset at {data_path}...") - # dset = load_from_disk(os.path.join(data_path, 'dataset')) - dset, label2id = preprocess_mesh( - data_path=data_path, - model_key=model_key, - test_size=test_size, - num_proc=num_proc, - max_samples=max_samples, - batch_size=batch_size) - - logger.info(f"Sharding training dataset...") - train_dset, val_dset = Sharding(num_shards=100).shard(dset["train"]), dset["test"] + if os.path.isdir(data_path): + logger.info(f"Loading from disk...") + dset = load_from_disk(os.path.join(data_path, 'dataset')) + with open(os.path.join(data_path, 'label2id'), 'r') as f: + label2id = json.load(f) + else: + dset, label2id = preprocess_mesh( + data_path=data_path, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, + batch_size=batch_size) + + train_dset, val_dset = dset["train"], dset["test"] + train_dset_size = len(train_dset) + if max_samples > 0: + max_samples = min(max_samples, train_dset_size) + logger.info(f"Training max samples: {max_samples}.") + train_dset.filter(lambda example, idx: idx < max_samples, with_indices=True) + else: + logger.info(f"Training with all data...") if model_args.freeze_backbone: logger.info("Freezing backbone") @@ -130,9 +155,10 @@ def sklearn_metrics(prediction: EvalPrediction): logger.info(f"Collating labels...") collator = MultilabelDataCollator(label2id=label2id) - logger.info(f"Calculating max steps for IterableDatasets...") - max_steps = calculate_max_steps(training_args, dset) - training_args.max_steps = max_steps + if shards > 0: + logger.info(f"Calculating max steps for IterableDatasets shards...") + max_steps = calculate_max_steps(training_args, dset) + training_args.max_steps = max_steps logger.info(f"Initializing Trainer...") trainer = Trainer( @@ -167,7 +193,7 @@ def train_bertmesh_cli( ), data_path: str = typer.Argument( ..., - help="Path to mesh.jsonl" + help="Path to allMeSH_2021.jsonl (or similar) or to a folder after preprocessing and saving to disk" ), test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), num_proc: int = typer.Option( @@ -179,7 +205,10 @@ def train_bertmesh_cli( ), batch_size: int = typer.Option( 256, - help="Size of the preprocessing batch") + help="Size of the preprocessing batch"), + shards: int = typer.Option( + -1, + help="Number os shards to divide training IterativeDataset to (improves performance)") ): parser = HfArgumentParser( @@ -198,7 +227,7 @@ def train_bertmesh_cli( logger.info("Training args: {}".format(pformat(training_args))) logger.info("Wandb args: {}".format(pformat(wandb_args))) - train_bertmesh(model_key, data_path, training_args, model_args, max_samples, test_size, num_proc, batch_size) + train_bertmesh(model_key, data_path, training_args, model_args, max_samples, test_size, num_proc, batch_size, shards) """ if __name__ == "__main__": From 6837af7106a210b547f135169fcdef568d8ed904 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 09:45:25 +0100 Subject: [PATCH 080/300] Cleaning and rearranging --- .../preprocessing/preprocess_mesh.py | 15 +++-- grants_tagger_light/training/train.py | 66 +++++-------------- grants_tagger_light/utils/sharding.py | 18 +++++ grants_tagger_light/utils/utils.py | 8 --- scripts/train.sh | 4 +- 5 files changed, 45 insertions(+), 66 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 708c925f..6e2a69d5 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -152,7 +152,8 @@ def preprocess_mesh_cli( ), test_size: float = typer.Option( 0.05, - help="Fraction of data to use for testing"), + help="Fraction of data to use for testing" + ), num_proc: int = typer.Option( os.cpu_count(), help="Number of processes to use for preprocessing" @@ -163,13 +164,13 @@ def preprocess_mesh_cli( ), batch_size: int = typer.Option( 256, - help="Size of the preprocessing batch") + help="Size of the preprocessing batch" + ) ): - print("\033[96mRunning preprocessing will save the data as a PyArrow dataset which is a very time consuming " - "operation. If you don't need the data to be saved, you can save much time just by running:\n" - "\t`grants-tagger train bertmesh {model_key} {path_to_jsonl}`\033[0m") - - if input('Do You Want To Continue? [Y/n]') != 'Y': + if input("\033[96mRunning preprocessing will save the data as a PyArrow dataset which is a very time consuming " + "operation. If you don't need the data to be saved, you can save much time just by running:\n" + "\t`grants-tagger train bertmesh {model_key} {path_to_jsonl}`\033[0m\n\n" + "Do You Want To Continue? [Y/n]") != 'Y': exit(1) if not data_path.endswith('jsonl'): diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 2c0c58a0..b5675942 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -19,7 +19,6 @@ from sklearn.metrics import classification_report from loguru import logger from pprint import pformat -from datasets import load_dataset import typer import numpy as np import os @@ -28,7 +27,6 @@ from datasets import load_from_disk from grants_tagger_light.utils.sharding import Sharding -from grants_tagger_light.utils.utils import calculate_max_steps transformers.set_seed(42) @@ -41,7 +39,6 @@ def train_bertmesh( max_samples: int = -1, test_size: float = 0.05, num_proc: int = os.cpu_count(), - batch_size: int = 256, shards: int = -1 ): if not model_key: @@ -69,14 +66,14 @@ def train_bertmesh( test_size=test_size, num_proc=num_proc, max_samples=max_samples, - batch_size=batch_size) + batch_size=training_args.per_device_train_batch_size) train_dset, val_dset = dset["train"], dset["test"] train_dset_size = len(train_dset) if max_samples > 0: - max_samples = min(max_samples, train_dset_size) - logger.info(f"Training max samples: {max_samples}.") - train_dset.filter(lambda example, idx: idx < max_samples, with_indices=True) + train_dset_size = min(max_samples, train_dset_size) + logger.info(f"Training max samples: {train_dset_size}.") + train_dset.filter(lambda example, idx: idx < train_dset_size, with_indices=True) else: logger.info(f"Training with all data...") @@ -117,14 +114,14 @@ def train_bertmesh( test_size=test_size, num_proc=num_proc, max_samples=max_samples, - batch_size=batch_size) + batch_size=training_args.per_device_train_batch_size) train_dset, val_dset = dset["train"], dset["test"] train_dset_size = len(train_dset) if max_samples > 0: - max_samples = min(max_samples, train_dset_size) - logger.info(f"Training max samples: {max_samples}.") - train_dset.filter(lambda example, idx: idx < max_samples, with_indices=True) + train_dset_size = min(max_samples, train_dset_size) + logger.info(f"Training max samples: {train_dset_size}.") + train_dset.filter(lambda example, idx: idx < train_dset_size, with_indices=True) else: logger.info(f"Training with all data...") @@ -157,7 +154,7 @@ def sklearn_metrics(prediction: EvalPrediction): if shards > 0: logger.info(f"Calculating max steps for IterableDatasets shards...") - max_steps = calculate_max_steps(training_args, dset) + max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) training_args.max_steps = max_steps logger.info(f"Initializing Trainer...") @@ -195,17 +192,18 @@ def train_bertmesh_cli( ..., help="Path to allMeSH_2021.jsonl (or similar) or to a folder after preprocessing and saving to disk" ), - test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), + test_size: float = typer.Option( + 0.05, + help="Fraction of data to use for testing" + ), num_proc: int = typer.Option( - os.cpu_count(), help="Number of processes to use for preprocessing" + os.cpu_count(), + help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( -1, - help="Maximum number of samples to use for preprocessing", + help="Maximum number of samples to use from the json", ), - batch_size: int = typer.Option( - 256, - help="Size of the preprocessing batch"), shards: int = typer.Option( -1, help="Number os shards to divide training IterativeDataset to (improves performance)") @@ -227,34 +225,4 @@ def train_bertmesh_cli( logger.info("Training args: {}".format(pformat(training_args))) logger.info("Wandb args: {}".format(pformat(wandb_args))) - train_bertmesh(model_key, data_path, training_args, model_args, max_samples, test_size, num_proc, batch_size, shards) - -""" -if __name__ == "__main__": - from dataclasses import dataclass - - @dataclass - class TrainFuncArgs: - model_key: str - data_path: str - max_samples: int = np.inf - - func_args, training_args, wandb_args, model_args = HfArgumentParser( - ( - TrainFuncArgs, - BertMeshTrainingArguments, - WandbArguments, - BertMeshModelArguments, - ) - ).parse_args_into_dataclasses() - - train_bertmesh( - func_args.model_key, - func_args.data_path, - training_args, - model_args, - max_samples, - test_size, - batch_size - ) -""" \ No newline at end of file + train_bertmesh(model_key, data_path, training_args, model_args, max_samples, test_size, num_proc, shards) diff --git a/grants_tagger_light/utils/sharding.py b/grants_tagger_light/utils/sharding.py index f6e20df3..85a7ff22 100644 --- a/grants_tagger_light/utils/sharding.py +++ b/grants_tagger_light/utils/sharding.py @@ -21,3 +21,21 @@ def shard(self, dataset): for index in range(self.num_shards)] return IterableDataset.from_generator(self.gen_from_shards, gen_kwargs={"_shards": shards}) + + @staticmethod + def calculate_max_steps(training_args, train_dset_size): + """ This is needed when using IterableDatasets, as there is no __len__ in advance since the dataset is a + generator with yield, so it does not know when to end. + Source: https://discuss.huggingface.co/t/streaming-dataset-into-trainer-does-not-implement-len-max-steps-has-to-be-specified/32893/6 + + Example: allMeSH_2021.json has 15.559.157 rows, from which let's suppose 5% is test + > 15559157-0.05*15559157 = 14781199.15 training rows + let's suppose batch size is 8 + > 14781199.15 / 8 = 1847649.89375 + if accumulation_steps is 1, then + > 1847649.89375 / 1 = 1847649.89375 + """ + + train_batch_size = training_args.per_device_train_batch_size + accumulation_steps = training_args.gradient_accumulation_steps + return (train_dset_size / train_batch_size) / accumulation_steps diff --git a/grants_tagger_light/utils/utils.py b/grants_tagger_light/utils/utils.py index fc8f5cfd..0f57f57f 100644 --- a/grants_tagger_light/utils/utils.py +++ b/grants_tagger_light/utils/utils.py @@ -178,11 +178,3 @@ def create_label_binarizer(model_path: str, label_binarizer_path: str): return label_binarizer - -def calculate_max_steps(training_args, dset): - """ This is needed when using IterableDatasets, as there is no __len__ in advance since the dataset is a - generator with yield, so it does not know when to end. - Source: https://discuss.huggingface.co/t/streaming-dataset-into-trainer-does-not-implement-len-max-steps-has-to-be-specified/32893/6""" - train_batch_size = training_args.per_device_train_batch_size - accumulation_steps = training_args.gradient_accumulation_steps - return (len(dset["train"]) / train_batch_size) / accumulation_steps diff --git a/scripts/train.sh b/scripts/train.sh index be797184..def294b2 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -1,12 +1,12 @@ # Run on p2.8xlarge instance grants-tagger train bertmesh \ "" \ - data/raw/allMeSH_2021.json \ + data/raw/allMeSH_2021.jsonl \ -1 \ --output_dir bertmesh_outs/pipeline_test/ \ --wandb_name test-train-all \ --wandb_api_key ${WANDB_API_KEY} \ - --per_device_train_batch_size 8 \ + --per_device_train_batch_size 256 \ --per_device_eval_batch_size 8 \ --num_train_epochs 1 \ --evaluation_strategy steps \ From feb00376ada0bcc4173d026c4d77f813f2259d65 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 10:27:12 +0100 Subject: [PATCH 081/300] Updates train.sh --- grants_tagger_light/training/train.py | 7 ++++++- scripts/train.sh | 5 ++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index b5675942..f03406d5 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -157,7 +157,12 @@ def sklearn_metrics(prediction: EvalPrediction): max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) training_args.max_steps = max_steps - logger.info(f"Initializing Trainer...") + logger.info(f"Initializing Trainer:\n" + f"* per_device_train_batch_size={training_args.per_device_train_batch_size}\n" + f"* max_steps = {training_args.max_steps}\n" + f"* epochs = {training_args.num_train_epochs}\n" + ) + trainer = Trainer( model=model, args=training_args, diff --git a/scripts/train.sh b/scripts/train.sh index def294b2..cb7f2492 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -2,7 +2,9 @@ grants-tagger train bertmesh \ "" \ data/raw/allMeSH_2021.jsonl \ - -1 \ + --test-size 0.05 \ + --max-samples 1000 \ + --shards 100 \ --output_dir bertmesh_outs/pipeline_test/ \ --wandb_name test-train-all \ --wandb_api_key ${WANDB_API_KEY} \ @@ -15,3 +17,4 @@ grants-tagger train bertmesh \ --save_steps 100000 \ --fp16 \ --torch_compile + From bc3cd22b16ed1dd4433a2dab73e739836af8f9bb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 24 Jul 2023 10:21:46 +0000 Subject: [PATCH 082/300] Adds allMesh_2021 jsonl version --- data/raw/allMeSH_2021.jsonl.dvc | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 data/raw/allMeSH_2021.jsonl.dvc diff --git a/data/raw/allMeSH_2021.jsonl.dvc b/data/raw/allMeSH_2021.jsonl.dvc new file mode 100644 index 00000000..f2f43ce8 --- /dev/null +++ b/data/raw/allMeSH_2021.jsonl.dvc @@ -0,0 +1,4 @@ +outs: +- md5: 94f18c3918b180728a553123edb2ee32 + size: 27914288461 + path: allMeSH_2021.jsonl From c383ea10f79e172ddc78d92428e5ca0979f055aa Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 11:34:27 +0100 Subject: [PATCH 083/300] Ignores .jsonl --- data/raw/.gitignore | 1 + tests/test_preprocess_mesh.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/data/raw/.gitignore b/data/raw/.gitignore index e17ccf76..69e7cea9 100644 --- a/data/raw/.gitignore +++ b/data/raw/.gitignore @@ -1,3 +1,4 @@ /allMeSH_2021.json +/allMeSH_2021.jsonl /desc2021.xml /disease_tags_validation_grants.xlsx diff --git a/tests/test_preprocess_mesh.py b/tests/test_preprocess_mesh.py index 07007cdb..d474edfe 100644 --- a/tests/test_preprocess_mesh.py +++ b/tests/test_preprocess_mesh.py @@ -3,7 +3,6 @@ from grants_tagger_light.preprocessing.preprocess_mesh import ( preprocess_mesh, - process_data, ) From e49366b6fc22f25f39c7ab2aef33ec8a3793e4b4 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 15:29:05 +0100 Subject: [PATCH 084/300] Fixes preprocessing unit tests --- .../preprocessing/preprocess_mesh.py | 2 +- scripts/jsonl_preprocessing.py | 88 ++++++++++++++++++ scripts/mesh_json_to_jsonl.py | 41 -------- tests/fixtures/mesh_fixture_head_2.json | 3 + tests/fixtures/mesh_fixture_head_2.jsonl | 2 + tests/test_preprocess_mesh.py | 93 +++++++++---------- 6 files changed, 137 insertions(+), 92 deletions(-) create mode 100644 scripts/jsonl_preprocessing.py delete mode 100644 scripts/mesh_json_to_jsonl.py create mode 100644 tests/fixtures/mesh_fixture_head_2.json create mode 100644 tests/fixtures/mesh_fixture_head_2.jsonl diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 6e2a69d5..b4494161 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -45,11 +45,11 @@ def create_sample_file(jsonl_file, lines): def preprocess_mesh( data_path: str, model_key: str, + save_to_path: str = None, test_size: float = 0.05, num_proc: int = os.cpu_count(), max_samples: int = -1, batch_size: int = 256, - save_to_path: str = None ): if max_samples == -1: max_samples = np.inf diff --git a/scripts/jsonl_preprocessing.py b/scripts/jsonl_preprocessing.py new file mode 100644 index 00000000..cf59e0ed --- /dev/null +++ b/scripts/jsonl_preprocessing.py @@ -0,0 +1,88 @@ +import json +from argparse import ArgumentParser +import numpy as np +from loguru import logger + + +def process_data(item, filter_tags: list = None, filter_years: list = None): + check_tags = filter_tags is not None + check_years = filter_years is not None + + if check_tags and 'meshMajor' not in item: + logger.warning("`meshMajor` not found in the fields. Unable to filter tags.") + check_tags = False + + if check_years and 'year' not in item: + logger.warning("`year` not found in the fields. Unable to filter tags.") + check_years = False + if check_tags: + if filter_tags is None: + filter_tags = [] + if len(filter_tags) > 0 and not any(np.isin(filter_tags, item['meshMajor'])): + return False + + if check_years: + if filter_years is None: + filter_years = [] + + # Making sure it's str and not int + filter_years = [str(y) for y in filter_years] + if len(filter_years) > 0 and not any(np.isin(filter_years, [str(item['year'])])): + return False + + return True + + +def mesh_json_to_jsonl(input_path, output_path, input_encoding='latin1', output_encoding='latin1', + filter_tags: str = '', filter_years: str = '', show_progress: bool = True): + """ + Mesh json is not optimized for parallel processing. This script aims to transform it into `jsonl` so that, + by having 1 json per line, you can evenly distribute it among all the cores. + + Args: + input_path: allMeSH_2021.json (or similar) path + output_path: path for the resulted jsonl file + input_encoding: encoding of the input json (default `latin1`) + output_encoding: encoding of the output jsonl (default `latin1`) + filter_tags: tags separated by commas(T1,T2,T3) to only include the entries with those tags + filter_years: years separated by commas(2008,2009) to only include the entries of those years + show_progress: print the number of line you are processing + Returns: + + """ + filter_tags_list = list(filter(lambda x: x.strip() != '', filter_tags.split(','))) + filter_years_list = list(filter(lambda x: x.strip() != '', filter_years.split(','))) + with open(output_path, 'w', encoding=output_encoding) as fw: + with open(input_path, 'r', encoding=input_encoding) as fr: + for idx, line in enumerate(fr): + if show_progress: + print(idx, end='\r') + # Skip 1st line + if idx == 0: + logger.info(f"Skipping first line (articles): {line}") + continue + try: + sample = json.loads(line[:-2]) + except: + logger.warning(f"Skipping line in bad json format: {line}") + continue + if process_data(sample, filter_tags_list, filter_years_list): + fw.write(json.dumps(sample)) + fw.write('\n') + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--input_path", required=True, help="path to input allMeSH_2021.json (or equivalent)") + parser.add_argument("--output_path", required=True, help="path to ioutput jsonl") + parser.add_argument("--input_encoding", required=False, default='latin1', help="encoding of the input json") + parser.add_argument("--output_encoding", required=False, default='latin1', help="encoding of the output jsonl") + parser.add_argument("--filter_tags", required=False, default='', + help="comma-separated tags to include (the rest will be discarded)") + parser.add_argument("--filter_years", required=False, default='', + help="comma-separated years to include (the rest will be discarded)") + + args = parser.parse_args() + + mesh_json_to_jsonl(args.input_path, args.output_path, args.input_encoding, args.output_encoding, args.filter_tags, + args.filter_years) diff --git a/scripts/mesh_json_to_jsonl.py b/scripts/mesh_json_to_jsonl.py deleted file mode 100644 index 432321e2..00000000 --- a/scripts/mesh_json_to_jsonl.py +++ /dev/null @@ -1,41 +0,0 @@ -import json -from argparse import ArgumentParser - - -def mesh_to_jsonl(input_path, output_path, input_encoding='latin1', output_encoding='latin1'): - """ - Mesh json is not optimized for parallel processing. This script aims to transform it into `jsonl` so that, - by having 1 json per line, you can evenly distribute it among all the cores. - - Args: - input_path: allMeSH_2021.json (or similar) path - output_path: path for the resulted jsonl file - input_encoding: encoding of the input json (default `latin1`) - output_encoding: encoding of the output jsonl (default `latin1`) - - Returns: - - """ - with open(output_path, 'w', encoding=output_encoding) as fw: - with open(input_path, 'r', encoding=input_encoding) as fr: - for idx, line in enumerate(fr): - print(idx, end='\r') - # Skip 1st line - if idx == 0: - continue - sample = json.loads(line[:-2]) - - fw.write(json.dumps(sample)) - fw.write('\n') - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--input_path", required=True, help="path to input allMeSH_2021.json (or equivalent)") - parser.add_argument("--output_path", required=True, help="path to input jsonl") - parser.add_argument("--input_encoding", required=False, default='latin1', help="encoding of the input json") - parser.add_argument("--output_encoding", required=False, default='latin1', help="encoding of the output jsonl") - - args = parser.parse_args() - - mesh_to_jsonl(args.input_path, args.output_path, args.input_encoding, args.output_encoding) diff --git a/tests/fixtures/mesh_fixture_head_2.json b/tests/fixtures/mesh_fixture_head_2.json new file mode 100644 index 00000000..c12f1dff --- /dev/null +++ b/tests/fixtures/mesh_fixture_head_2.json @@ -0,0 +1,3 @@ +{"articles":[ +{"journal":"F1000Research","meshMajor":["Adaptive Clinical Trials as Topic","Betacoronavirus","COVID-19","Coronavirus Infections","Humans","Pandemics","Pneumonia, Viral","Research Design","SARS-CoV-2"],"year":"2020","abstractText":"Global health pandemics, such as coronavirus disease 2019 (COVID-19), require efficient and well-conducted trials to determine effective interventions, such as treatments and vaccinations. Early work focused on rapid sequencing of severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), subsequent in-vitro and in-silico work, along with greater understanding of the different clinical phases of the infection, have helped identify a catalogue of potential therapeutic agents requiring assessment. In a pandemic, there is a need to quickly identify efficacious treatments, and reject those that are non-beneficial or even harmful, using randomised clinical trials. Whilst each potential treatment could be investigated across multiple, separate, competing two-arm trials, this is a very inefficient process. Despite the very large numbers of interventional trials for COVID-19, the vast majority have not used efficient trial designs. Well conducted, adaptive platform trials utilising a multi-arm multi-stage (MAMS) approach provide a solution to overcome limitations of traditional designs. The multi-arm element allows multiple different treatments to be investigated simultaneously against a shared, standard-of-care control arm. The multi-stage element uses interim analyses to assess accumulating data from the trial and ensure that only treatments showing promise continue to recruitment during the next stage of the trial. The ability to test many treatments at once and drop insufficiently active interventions significantly speeds up the rate at which answers can be achieved. This article provides an overview of the benefits of MAMS designs and successes of trials, which have used this approach to COVID-19. We also discuss international collaboration between trial teams, including prospective agreement to synthesise trial results, and identify the most effective interventions. We believe that international collaboration will help provide faster answers for patients, clinicians, and health care systems around the world, including for future waves of COVID-19, and enable preparedness for future global health pandemics.","pmid":"33149899","title":"Adaptive platform trials using multi-arm, multi-stage protocols: getting fast answers in pandemic settings."}, +{"journal":"BMC public health","meshMajor":["Adult","COVID-19","Coronavirus Infections","Disease Outbreaks","Female","Humans","Infectious Disease Medicine","Male","Middle Aged","Pandemics","Physicians","Pneumonia, Viral","Psychological Distress","Republic of Korea"],"year":"2020","abstractText":"BACKGROUND: This study aimed to investigate psychological distress among infectious disease (ID) physicians during the coronavirus disease�2019 (COVID-19) outbreak in the Republic of Korea.METHODS: Using an online-based survey link sent via text message and email, we conducted a survey from April 21 to 25, 2020, targeting all ID physicians currently working in ID (n\u2009=\u2009265). The questionnaire was based on the Maslach Burnout Inventory-Human Services Survey and the Depression, Anxiety, and Stress Scales, and information was collected on factors protecting against psychological distress and difficulties in relation to COVID-19.RESULTS: Of 265 ID physicians, 115 (43.3%) responded, showing burnout (97, 90.4%), depression (20, 17.4%), anxiety (23, 20.0%), and stress (5, 4.3%). There were no differences in terms of distress between ID physicians who were directly involved in the care of patients with COVID-19 or not. Greater than 50% of physicians valued their work and felt recognized by others, whereas <\u200910% indicated that sufficient human and financial support and private time had been provided during the outbreak. The most challenging issues concerned a lack of attending physicians caring for COVID-19 patients or infection control practitioners, a shortage of personal protective equipment or airborne infection isolation rooms, pressure for research, and lack of guidelines for COVID-19 management.CONCLUSIONS: During the COVID-19 outbreak in the Republic of Korea, most respondents reported psychological distress. Preparing strategies to secure human resources are crucial to prepare effectively for future epidemics and pandemics.","pmid":"33246426","title":"Psychological distress among infectious disease physicians during the response to the COVID-19 outbreak in the Republic of Korea."}]} \ No newline at end of file diff --git a/tests/fixtures/mesh_fixture_head_2.jsonl b/tests/fixtures/mesh_fixture_head_2.jsonl new file mode 100644 index 00000000..8f613803 --- /dev/null +++ b/tests/fixtures/mesh_fixture_head_2.jsonl @@ -0,0 +1,2 @@ +{"journal":"F1000Research","meshMajor":["Adaptive Clinical Trials as Topic","Betacoronavirus","COVID-19","Coronavirus Infections","Humans","Pandemics","Pneumonia, Viral","Research Design","SARS-CoV-2"],"year":"2020","abstractText":"Global health pandemics, such as coronavirus disease 2019 (COVID-19), require efficient and well-conducted trials to determine effective interventions, such as treatments and vaccinations. Early work focused on rapid sequencing of severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), subsequent in-vitro and in-silico work, along with greater understanding of the different clinical phases of the infection, have helped identify a catalogue of potential therapeutic agents requiring assessment. In a pandemic, there is a need to quickly identify efficacious treatments, and reject those that are non-beneficial or even harmful, using randomised clinical trials. Whilst each potential treatment could be investigated across multiple, separate, competing two-arm trials, this is a very inefficient process. Despite the very large numbers of interventional trials for COVID-19, the vast majority have not used efficient trial designs. Well conducted, adaptive platform trials utilising a multi-arm multi-stage (MAMS) approach provide a solution to overcome limitations of traditional designs. The multi-arm element allows multiple different treatments to be investigated simultaneously against a shared, standard-of-care control arm. The multi-stage element uses interim analyses to assess accumulating data from the trial and ensure that only treatments showing promise continue to recruitment during the next stage of the trial. The ability to test many treatments at once and drop insufficiently active interventions significantly speeds up the rate at which answers can be achieved. This article provides an overview of the benefits of MAMS designs and successes of trials, which have used this approach to COVID-19. We also discuss international collaboration between trial teams, including prospective agreement to synthesise trial results, and identify the most effective interventions. We believe that international collaboration will help provide faster answers for patients, clinicians, and health care systems around the world, including for future waves of COVID-19, and enable preparedness for future global health pandemics.","pmid":"33149899","title":"Adaptive platform trials using multi-arm, multi-stage protocols: getting fast answers in pandemic settings."} +{"journal":"BMC public health","meshMajor":["Adult","COVID-19","Coronavirus Infections","Disease Outbreaks","Female","Humans","Infectious Disease Medicine","Male","Middle Aged","Pandemics","Physicians","Pneumonia, Viral","Psychological Distress","Republic of Korea"],"year":"2020","abstractText":"BACKGROUND: This study aimed to investigate psychological distress among infectious disease (ID) physicians during the coronavirus disease�2019 (COVID-19) outbreak in the Republic of Korea.METHODS: Using an online-based survey link sent via text message and email, we conducted a survey from April 21 to 25, 2020, targeting all ID physicians currently working in ID (n\u2009=\u2009265). The questionnaire was based on the Maslach Burnout Inventory-Human Services Survey and the Depression, Anxiety, and Stress Scales, and information was collected on factors protecting against psychological distress and difficulties in relation to COVID-19.RESULTS: Of 265 ID physicians, 115 (43.3%) responded, showing burnout (97, 90.4%), depression (20, 17.4%), anxiety (23, 20.0%), and stress (5, 4.3%). There were no differences in terms of distress between ID physicians who were directly involved in the care of patients with COVID-19 or not. Greater than 50% of physicians valued their work and felt recognized by others, whereas <\u200910% indicated that sufficient human and financial support and private time had been provided during the outbreak. The most challenging issues concerned a lack of attending physicians caring for COVID-19 patients or infection control practitioners, a shortage of personal protective equipment or airborne infection isolation rooms, pressure for research, and lack of guidelines for COVID-19 management.CONCLUSIONS: During the COVID-19 outbreak in the Republic of Korea, most respondents reported psychological distress. Preparing strategies to secure human resources are crucial to prepare effectively for future epidemics and pandemics.","pmid":"33246426","title":"Psychological distress among infectious disease physicians during the response to the COVID-19 outbreak in the Republic of Korea."} \ No newline at end of file diff --git a/tests/test_preprocess_mesh.py b/tests/test_preprocess_mesh.py index d474edfe..3588d955 100644 --- a/tests/test_preprocess_mesh.py +++ b/tests/test_preprocess_mesh.py @@ -1,89 +1,82 @@ import json import tempfile +from pathlib import Path from grants_tagger_light.preprocessing.preprocess_mesh import ( preprocess_mesh, ) +from scripts.jsonl_preprocessing import process_data, mesh_json_to_jsonl +from loguru import logger -def test_preprocess_mesh(): - item = { - "abstractText": "This is an abstract", - "meshMajor": ["T1", "T2"], - "journal": "Journal", - "year": 2018, - } - with tempfile.NamedTemporaryFile(mode="r+") as input_tmp: - input_tmp.write("\n" + json.dumps(item) + ", ") - input_tmp.seek(0) - output_tmp = tempfile.NamedTemporaryFile(mode="r+") - preprocess_mesh(input_tmp.name, output_tmp.name) - output_tmp.seek(0) - expected_processed_item = { - "text": "This is an abstract", - "tags": ["T1", "T2"], - "meta": {"journal": "Journal", "year": 2018}, - } - assert output_tmp.read() == json.dumps(expected_processed_item) + "\n" +FIXTURES_DIR = Path(__file__).resolve().parent / 'fixtures' -def test_process_data(): +def test_process_data_with_filter_tags(): item = { "abstractText": "This is an abstract", "meshMajor": ["T1", "T2"], "journal": "Journal", "year": 2018, } - expected_processed_item = { - "text": "This is an abstract", - "tags": ["T1", "T2"], - "meta": {"journal": "Journal", "year": 2018}, - } - processed_item = process_data(item) - assert processed_item == expected_processed_item + assert process_data(item, filter_tags=["T1"]) is True -def test_process_data_with_filter_tags(): +def test_process_data_with_missing_filter_tag(): item = { "abstractText": "This is an abstract", "meshMajor": ["T1", "T2"], "journal": "Journal", "year": 2018, } - expected_processed_item = { - "text": "This is an abstract", - "tags": ["T1"], - "meta": {"journal": "Journal", "year": 2018}, - } - processed_item = process_data(item, filter_tags=["T1"]) - assert processed_item == expected_processed_item + assert process_data(item, filter_tags=["T3"]) is False -def test_process_data_with_missing_filter_tag(): +def test_process_data_with_filter_years(): item = { "abstractText": "This is an abstract", "meshMajor": ["T1", "T2"], "journal": "Journal", "year": 2018, } - processed_item = process_data(item, filter_tags=["T3"]) - assert processed_item is None + assert process_data(item, filter_years=["2019", "2020"]) is False + item["year"] = 2020 + assert process_data(item, filter_years=["2019", "2020"]) is True -def test_process_data_with_filter_years(): +def test_process_data_with_filter_years_and_tags(): item = { "abstractText": "This is an abstract", "meshMajor": ["T1", "T2"], "journal": "Journal", - "year": 2018, + "year": 2020, } - processed_item = process_data(item, filter_years="2019,2020") - assert processed_item is None - item["year"] = 2020 - expected_processed_item = { - "text": "This is an abstract", - "tags": ["T1", "T2"], - "meta": {"journal": "Journal", "year": 2020}, - } - processed_item = process_data(item, filter_years="2019,2020") - assert processed_item == expected_processed_item + assert process_data(item, filter_years=["2019", "2020"], filter_tags=["T1"]) is True + assert process_data(item, filter_years=["2018"], filter_tags=["T1"]) is False + assert process_data(item, filter_years=["2020"], filter_tags=["T3"]) is False + + +def test_json_to_jsonl(): + output_tmp = tempfile.NamedTemporaryFile(mode="w") + mesh_json_to_jsonl(f'{FIXTURES_DIR}/mesh_fixture_head_2.json', output_tmp.name, show_progress=False) + + with open(output_tmp.name, 'r') as f: + result = [json.loads(jline) for jline in f.read().splitlines()] + assert len(result) == 2 + + output_tmp.close() + + +def test_preprocess_mesh(): + dset, label2id = preprocess_mesh(data_path=f"{FIXTURES_DIR}/mesh_fixture_head_2.jsonl", + model_key='', + num_proc=2, + batch_size=1, + test_size=0.5) + assert "train" in dset + assert "test" in dset + assert len(dset["train"]) == 1 + assert len(dset["test"]) == 1 + logger.info("label2id") + logger.info(label2id) + assert len(list(label2id.keys())) == 18 From c6c3b5bcb411ff897fa21e5b2f52dd5b2ef2c388 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 15:55:51 +0100 Subject: [PATCH 085/300] Adapts preprocessing tests to the style of other tests --- tests/fixtures/mesh_fixture_head_2.json | 3 -- tests/fixtures/mesh_fixture_head_2.jsonl | 2 -- tests/test_preprocess_mesh.py | 39 ++++++++++++++++++------ 3 files changed, 29 insertions(+), 15 deletions(-) delete mode 100644 tests/fixtures/mesh_fixture_head_2.json delete mode 100644 tests/fixtures/mesh_fixture_head_2.jsonl diff --git a/tests/fixtures/mesh_fixture_head_2.json b/tests/fixtures/mesh_fixture_head_2.json deleted file mode 100644 index c12f1dff..00000000 --- a/tests/fixtures/mesh_fixture_head_2.json +++ /dev/null @@ -1,3 +0,0 @@ -{"articles":[ -{"journal":"F1000Research","meshMajor":["Adaptive Clinical Trials as Topic","Betacoronavirus","COVID-19","Coronavirus Infections","Humans","Pandemics","Pneumonia, Viral","Research Design","SARS-CoV-2"],"year":"2020","abstractText":"Global health pandemics, such as coronavirus disease 2019 (COVID-19), require efficient and well-conducted trials to determine effective interventions, such as treatments and vaccinations. Early work focused on rapid sequencing of severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), subsequent in-vitro and in-silico work, along with greater understanding of the different clinical phases of the infection, have helped identify a catalogue of potential therapeutic agents requiring assessment. In a pandemic, there is a need to quickly identify efficacious treatments, and reject those that are non-beneficial or even harmful, using randomised clinical trials. Whilst each potential treatment could be investigated across multiple, separate, competing two-arm trials, this is a very inefficient process. Despite the very large numbers of interventional trials for COVID-19, the vast majority have not used efficient trial designs. Well conducted, adaptive platform trials utilising a multi-arm multi-stage (MAMS) approach provide a solution to overcome limitations of traditional designs. The multi-arm element allows multiple different treatments to be investigated simultaneously against a shared, standard-of-care control arm. The multi-stage element uses interim analyses to assess accumulating data from the trial and ensure that only treatments showing promise continue to recruitment during the next stage of the trial. The ability to test many treatments at once and drop insufficiently active interventions significantly speeds up the rate at which answers can be achieved. This article provides an overview of the benefits of MAMS designs and successes of trials, which have used this approach to COVID-19. We also discuss international collaboration between trial teams, including prospective agreement to synthesise trial results, and identify the most effective interventions. We believe that international collaboration will help provide faster answers for patients, clinicians, and health care systems around the world, including for future waves of COVID-19, and enable preparedness for future global health pandemics.","pmid":"33149899","title":"Adaptive platform trials using multi-arm, multi-stage protocols: getting fast answers in pandemic settings."}, -{"journal":"BMC public health","meshMajor":["Adult","COVID-19","Coronavirus Infections","Disease Outbreaks","Female","Humans","Infectious Disease Medicine","Male","Middle Aged","Pandemics","Physicians","Pneumonia, Viral","Psychological Distress","Republic of Korea"],"year":"2020","abstractText":"BACKGROUND: This study aimed to investigate psychological distress among infectious disease (ID) physicians during the coronavirus disease�2019 (COVID-19) outbreak in the Republic of Korea.METHODS: Using an online-based survey link sent via text message and email, we conducted a survey from April 21 to 25, 2020, targeting all ID physicians currently working in ID (n\u2009=\u2009265). The questionnaire was based on the Maslach Burnout Inventory-Human Services Survey and the Depression, Anxiety, and Stress Scales, and information was collected on factors protecting against psychological distress and difficulties in relation to COVID-19.RESULTS: Of 265 ID physicians, 115 (43.3%) responded, showing burnout (97, 90.4%), depression (20, 17.4%), anxiety (23, 20.0%), and stress (5, 4.3%). There were no differences in terms of distress between ID physicians who were directly involved in the care of patients with COVID-19 or not. Greater than 50% of physicians valued their work and felt recognized by others, whereas <\u200910% indicated that sufficient human and financial support and private time had been provided during the outbreak. The most challenging issues concerned a lack of attending physicians caring for COVID-19 patients or infection control practitioners, a shortage of personal protective equipment or airborne infection isolation rooms, pressure for research, and lack of guidelines for COVID-19 management.CONCLUSIONS: During the COVID-19 outbreak in the Republic of Korea, most respondents reported psychological distress. Preparing strategies to secure human resources are crucial to prepare effectively for future epidemics and pandemics.","pmid":"33246426","title":"Psychological distress among infectious disease physicians during the response to the COVID-19 outbreak in the Republic of Korea."}]} \ No newline at end of file diff --git a/tests/fixtures/mesh_fixture_head_2.jsonl b/tests/fixtures/mesh_fixture_head_2.jsonl deleted file mode 100644 index 8f613803..00000000 --- a/tests/fixtures/mesh_fixture_head_2.jsonl +++ /dev/null @@ -1,2 +0,0 @@ -{"journal":"F1000Research","meshMajor":["Adaptive Clinical Trials as Topic","Betacoronavirus","COVID-19","Coronavirus Infections","Humans","Pandemics","Pneumonia, Viral","Research Design","SARS-CoV-2"],"year":"2020","abstractText":"Global health pandemics, such as coronavirus disease 2019 (COVID-19), require efficient and well-conducted trials to determine effective interventions, such as treatments and vaccinations. Early work focused on rapid sequencing of severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), subsequent in-vitro and in-silico work, along with greater understanding of the different clinical phases of the infection, have helped identify a catalogue of potential therapeutic agents requiring assessment. In a pandemic, there is a need to quickly identify efficacious treatments, and reject those that are non-beneficial or even harmful, using randomised clinical trials. Whilst each potential treatment could be investigated across multiple, separate, competing two-arm trials, this is a very inefficient process. Despite the very large numbers of interventional trials for COVID-19, the vast majority have not used efficient trial designs. Well conducted, adaptive platform trials utilising a multi-arm multi-stage (MAMS) approach provide a solution to overcome limitations of traditional designs. The multi-arm element allows multiple different treatments to be investigated simultaneously against a shared, standard-of-care control arm. The multi-stage element uses interim analyses to assess accumulating data from the trial and ensure that only treatments showing promise continue to recruitment during the next stage of the trial. The ability to test many treatments at once and drop insufficiently active interventions significantly speeds up the rate at which answers can be achieved. This article provides an overview of the benefits of MAMS designs and successes of trials, which have used this approach to COVID-19. We also discuss international collaboration between trial teams, including prospective agreement to synthesise trial results, and identify the most effective interventions. We believe that international collaboration will help provide faster answers for patients, clinicians, and health care systems around the world, including for future waves of COVID-19, and enable preparedness for future global health pandemics.","pmid":"33149899","title":"Adaptive platform trials using multi-arm, multi-stage protocols: getting fast answers in pandemic settings."} -{"journal":"BMC public health","meshMajor":["Adult","COVID-19","Coronavirus Infections","Disease Outbreaks","Female","Humans","Infectious Disease Medicine","Male","Middle Aged","Pandemics","Physicians","Pneumonia, Viral","Psychological Distress","Republic of Korea"],"year":"2020","abstractText":"BACKGROUND: This study aimed to investigate psychological distress among infectious disease (ID) physicians during the coronavirus disease�2019 (COVID-19) outbreak in the Republic of Korea.METHODS: Using an online-based survey link sent via text message and email, we conducted a survey from April 21 to 25, 2020, targeting all ID physicians currently working in ID (n\u2009=\u2009265). The questionnaire was based on the Maslach Burnout Inventory-Human Services Survey and the Depression, Anxiety, and Stress Scales, and information was collected on factors protecting against psychological distress and difficulties in relation to COVID-19.RESULTS: Of 265 ID physicians, 115 (43.3%) responded, showing burnout (97, 90.4%), depression (20, 17.4%), anxiety (23, 20.0%), and stress (5, 4.3%). There were no differences in terms of distress between ID physicians who were directly involved in the care of patients with COVID-19 or not. Greater than 50% of physicians valued their work and felt recognized by others, whereas <\u200910% indicated that sufficient human and financial support and private time had been provided during the outbreak. The most challenging issues concerned a lack of attending physicians caring for COVID-19 patients or infection control practitioners, a shortage of personal protective equipment or airborne infection isolation rooms, pressure for research, and lack of guidelines for COVID-19 management.CONCLUSIONS: During the COVID-19 outbreak in the Republic of Korea, most respondents reported psychological distress. Preparing strategies to secure human resources are crucial to prepare effectively for future epidemics and pandemics.","pmid":"33246426","title":"Psychological distress among infectious disease physicians during the response to the COVID-19 outbreak in the Republic of Korea."} \ No newline at end of file diff --git a/tests/test_preprocess_mesh.py b/tests/test_preprocess_mesh.py index 3588d955..4b6c46c3 100644 --- a/tests/test_preprocess_mesh.py +++ b/tests/test_preprocess_mesh.py @@ -1,15 +1,36 @@ import json import tempfile -from pathlib import Path from grants_tagger_light.preprocessing.preprocess_mesh import ( preprocess_mesh, ) from scripts.jsonl_preprocessing import process_data, mesh_json_to_jsonl -from loguru import logger +import pytest -FIXTURES_DIR = Path(__file__).resolve().parent / 'fixtures' +jsonl_data = """{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"} +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" + +json_data = """{"articles":[ +{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"}, +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}, +""" + +@pytest.fixture +def json_data_path(): + with tempfile.TemporaryDirectory() as tmpdirname: + data_path = tmpdirname + "/data.json" + with open(data_path, "w") as f: + f.write(json_data) + yield data_path + +@pytest.fixture +def jsonl_data_path(): + with tempfile.TemporaryDirectory() as tmpdirname: + data_path = tmpdirname + "/data.jsonl" + with open(data_path, "w") as f: + f.write(jsonl_data) + yield data_path def test_process_data_with_filter_tags(): @@ -56,9 +77,9 @@ def test_process_data_with_filter_years_and_tags(): assert process_data(item, filter_years=["2020"], filter_tags=["T3"]) is False -def test_json_to_jsonl(): +def test_json_to_jsonl(json_data_path): output_tmp = tempfile.NamedTemporaryFile(mode="w") - mesh_json_to_jsonl(f'{FIXTURES_DIR}/mesh_fixture_head_2.json', output_tmp.name, show_progress=False) + mesh_json_to_jsonl(json_data_path, output_tmp.name, show_progress=False) with open(output_tmp.name, 'r') as f: result = [json.loads(jline) for jline in f.read().splitlines()] @@ -67,8 +88,8 @@ def test_json_to_jsonl(): output_tmp.close() -def test_preprocess_mesh(): - dset, label2id = preprocess_mesh(data_path=f"{FIXTURES_DIR}/mesh_fixture_head_2.jsonl", +def test_preprocess_mesh(jsonl_data_path): + dset, label2id = preprocess_mesh(data_path=jsonl_data_path, model_key='', num_proc=2, batch_size=1, @@ -77,6 +98,4 @@ def test_preprocess_mesh(): assert "test" in dset assert len(dset["train"]) == 1 assert len(dset["test"]) == 1 - logger.info("label2id") - logger.info(label2id) - assert len(list(label2id.keys())) == 18 + assert len(list(label2id.keys())) == 3 From c8dffc27c1e1e34b26a10adc93c8091b40f7731f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 16:11:49 +0100 Subject: [PATCH 086/300] Fixes tests in train. Adds more tests. --- .../preprocessing/preprocess_mesh.py | 9 +++---- tests/test_preprocess_mesh.py | 2 ++ tests/test_train.py | 24 ++++++++++++------- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index b4494161..24f1d503 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -51,9 +51,7 @@ def preprocess_mesh( max_samples: int = -1, batch_size: int = 256, ): - if max_samples == -1: - max_samples = np.inf - else: + if max_samples != -1: data_path = create_sample_file(data_path, max_samples) if not model_key: @@ -88,6 +86,7 @@ def preprocess_mesh( remove_columns=["abstractText"], ) + columns_to_remove = ["meshMajor"] # Generate label2id if None if label2id is None: logger.info("Getting the labels...") @@ -113,6 +112,8 @@ def preprocess_mesh( for idx, label in enumerate(tqdm(unique_labels_set)): label2id.update({label: idx}) + columns_to_remove.append("labels") + dset = dset.map( _encode_labels, batched=True, @@ -120,7 +121,7 @@ def preprocess_mesh( desc="Encoding labels", num_proc=num_proc, fn_kwargs={"label2id": label2id}, - remove_columns=["meshMajor", "labels"], + remove_columns=columns_to_remove, ) logger.info("Preparing train/test split....") diff --git a/tests/test_preprocess_mesh.py b/tests/test_preprocess_mesh.py index 4b6c46c3..0217da7a 100644 --- a/tests/test_preprocess_mesh.py +++ b/tests/test_preprocess_mesh.py @@ -16,6 +16,7 @@ {"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}, """ + @pytest.fixture def json_data_path(): with tempfile.TemporaryDirectory() as tmpdirname: @@ -24,6 +25,7 @@ def json_data_path(): f.write(json_data) yield data_path + @pytest.fixture def jsonl_data_path(): with tempfile.TemporaryDirectory() as tmpdirname: diff --git a/tests/test_train.py b/tests/test_train.py index 9742eac9..2c4ca6d8 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -6,11 +6,8 @@ import numpy as np # Note dummy data is not necessarily annotated correctly -dummy_data = """{"articles":[ -{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"}, -{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about COVID-19.","title":"article2","pmid":"pmid2"}, -{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}, -""" # noqa +dummy_data = """{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"} +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" @pytest.fixture @@ -28,9 +25,7 @@ def save_path(): yield tmpdirname + "/model" -def test_train_bertmesh(data_path, save_path): - model_key = "Wellcome/WellcomeBertMesh" - +def _train_bertmesh_from_model_key(data_path, save_path, model_key): # 1 train step, 1 eval step, save after training training_args = TrainingArguments( output_dir=save_path, @@ -50,7 +45,18 @@ def test_train_bertmesh(data_path, save_path): train_bertmesh( model_key=model_key, data_path=data_path, - max_samples=np.inf, + max_samples=-1, training_args=training_args, model_args=model_args, + num_proc=2, + test_size=0.5 ) + + +def test_train_bertmesh_from_model_key(data_path, save_path): + _train_bertmesh_from_model_key(data_path, save_path, "Wellcome/WellcomeBertMesh") + + +def test_train_bertmesh_from_scratch(data_path, save_path): + _train_bertmesh_from_model_key(data_path, save_path, "") + From 250108280c24e6564e99f2fcc2d211c3e9a295c1 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 17:21:39 +0100 Subject: [PATCH 087/300] Adds some comments --- grants_tagger_light/preprocessing/preprocess_mesh.py | 2 ++ grants_tagger_light/training/train.py | 7 ++++--- tests/test_train.py | 1 - 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 24f1d503..97885c84 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -128,6 +128,8 @@ def preprocess_mesh( # Split into train and test dset = dset.train_test_split(test_size=test_size) + # If running from Training, by default it will be None so that we don't spend time on serializing the data + # to disk if we are going to load it afterwards if save_to_path is not None: logger.info("Saving to disk...") dset.save_to_disk(os.path.join(save_to_path, 'dataset'), num_proc=num_proc) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index f03406d5..a76207d6 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -54,12 +54,12 @@ def train_bertmesh( logger.info(f"Preprocessing the dataset at {data_path}...") if os.path.isdir(data_path): - logger.info(f"Loading from disk...") + logger.info(f"Folder found, which means you preprocessed and save the data before. Loading from disk...") dset = load_from_disk(os.path.join(data_path, 'dataset')) with open(os.path.join(data_path, 'label2id'), 'r') as f: label2id = json.load(f) else: - logger.info(f"Preprocessing started!") + logger.info(f"Preprocessing the data on the fly...") dset, label2id = preprocess_mesh( data_path=data_path, model_key=model_key, @@ -103,11 +103,12 @@ def train_bertmesh( logger.info(f"Preprocessing the dataset at {data_path}...") if os.path.isdir(data_path): - logger.info(f"Loading from disk...") + logger.info(f"Folder found, which means you preprocessed and save the data before. Loading from disk...") dset = load_from_disk(os.path.join(data_path, 'dataset')) with open(os.path.join(data_path, 'label2id'), 'r') as f: label2id = json.load(f) else: + logger.info(f"Preprocessing the data on the fly...") dset, label2id = preprocess_mesh( data_path=data_path, model_key=model_key, diff --git a/tests/test_train.py b/tests/test_train.py index 2c4ca6d8..0c41b390 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -3,7 +3,6 @@ from transformers import TrainingArguments import tempfile import pytest -import numpy as np # Note dummy data is not necessarily annotated correctly dummy_data = """{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"} From c963136bfa062d6671c9b0cff39902aa65ad7e33 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 17:28:16 +0100 Subject: [PATCH 088/300] Reformatting --- .../preprocessing/preprocess_mesh.py | 20 +++++++++++-------- grants_tagger_light/training/train.py | 6 +++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 97885c84..bbdeddee 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -1,15 +1,17 @@ import json import tempfile -import numpy as np import typer from transformers import AutoTokenizer -from datasets import Dataset, disable_caching, load_dataset +from datasets import load_dataset from grants_tagger_light.models.bert_mesh import BertMesh import os from loguru import logger from tqdm import tqdm +# from datasets import disable_caching +# disable_caching() + preprocess_app = typer.Typer() @@ -128,8 +130,8 @@ def preprocess_mesh( # Split into train and test dset = dset.train_test_split(test_size=test_size) - # If running from Training, by default it will be None so that we don't spend time on serializing the data - # to disk if we are going to load it afterwards + # If running from Training, by default it will be None so that we don't spend time + # on serializing the data # to disk if we are going to load it afterwards if save_to_path is not None: logger.info("Saving to disk...") dset.save_to_disk(os.path.join(save_to_path, 'dataset'), num_proc=num_proc) @@ -151,7 +153,8 @@ def preprocess_mesh_cli( ), model_key: str = typer.Argument( ..., - help="Key to use when loading tokenizer and label2id. Leave blank if training from scratch", # noqa + help="Key to use when loading tokenizer and label2id. " + "Leave blank if training from scratch", # noqa ), test_size: float = typer.Option( 0.05, @@ -170,9 +173,10 @@ def preprocess_mesh_cli( help="Size of the preprocessing batch" ) ): - if input("\033[96mRunning preprocessing will save the data as a PyArrow dataset which is a very time consuming " - "operation. If you don't need the data to be saved, you can save much time just by running:\n" - "\t`grants-tagger train bertmesh {model_key} {path_to_jsonl}`\033[0m\n\n" + if input("\033[96mRunning preprocessing will save the data as a PyArrow dataset " + "which is a very time consuming operation. If you don't need the data " + "to be saved, you can save much time just by running:\n" + "> `grants-tagger train bertmesh {model_key} {path_to_jsonl}`\033[0m\n\n" "Do You Want To Continue? [Y/n]") != 'Y': exit(1) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index a76207d6..11c86dcb 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -1,5 +1,4 @@ from transformers import ( - AutoTokenizer, Trainer, TrainingArguments, EvalPrediction, @@ -54,12 +53,13 @@ def train_bertmesh( logger.info(f"Preprocessing the dataset at {data_path}...") if os.path.isdir(data_path): - logger.info(f"Folder found, which means you preprocessed and save the data before. Loading from disk...") + logger.info("Folder found, which means you preprocessed and save the data before. " + "Loading from disk...") dset = load_from_disk(os.path.join(data_path, 'dataset')) with open(os.path.join(data_path, 'label2id'), 'r') as f: label2id = json.load(f) else: - logger.info(f"Preprocessing the data on the fly...") + logger.info("Preprocessing the data on the fly...") dset, label2id = preprocess_mesh( data_path=data_path, model_key=model_key, From b015f49f75392c679422678291d290ce8efc731f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 17:32:55 +0100 Subject: [PATCH 089/300] Reformatting --- grants_tagger_light/training/train.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 11c86dcb..99343e32 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -43,7 +43,8 @@ def train_bertmesh( if not model_key: assert isinstance( model_args, BertMeshModelArguments - ), "If model_key is not provided, must provide model_args of type BertMeshModelArguments" # noqa + ), "If model_key is not provided, " \ + "must provide model_args of type BertMeshModelArguments" # noqa logger.info("No model key provided. Training model from scratch") @@ -53,8 +54,8 @@ def train_bertmesh( logger.info(f"Preprocessing the dataset at {data_path}...") if os.path.isdir(data_path): - logger.info("Folder found, which means you preprocessed and save the data before. " - "Loading from disk...") + logger.info("Folder found, which means you preprocessed and " + "save the data before. Loading from disk...") dset = load_from_disk(os.path.join(data_path, 'dataset')) with open(os.path.join(data_path, 'label2id'), 'r') as f: label2id = json.load(f) @@ -75,10 +76,10 @@ def train_bertmesh( logger.info(f"Training max samples: {train_dset_size}.") train_dset.filter(lambda example, idx: idx < train_dset_size, with_indices=True) else: - logger.info(f"Training with all data...") + logger.info("Training with all data...") if shards > 0: - logger.info(f"Sharding training dataset...") + logger.info("Sharding training dataset...") train_dset = Sharding(num_shards=shards).shard(train_dset) config.update( @@ -103,12 +104,13 @@ def train_bertmesh( logger.info(f"Preprocessing the dataset at {data_path}...") if os.path.isdir(data_path): - logger.info(f"Folder found, which means you preprocessed and save the data before. Loading from disk...") + logger.info("Folder found, which means you preprocessed and " + "save the data before. Loading from disk...") dset = load_from_disk(os.path.join(data_path, 'dataset')) with open(os.path.join(data_path, 'label2id'), 'r') as f: label2id = json.load(f) else: - logger.info(f"Preprocessing the data on the fly...") + logger.info("Preprocessing the data on the fly...") dset, label2id = preprocess_mesh( data_path=data_path, model_key=model_key, @@ -121,10 +123,10 @@ def train_bertmesh( train_dset_size = len(train_dset) if max_samples > 0: train_dset_size = min(max_samples, train_dset_size) - logger.info(f"Training max samples: {train_dset_size}.") + logger.info("Training max samples: {train_dset_size}.") train_dset.filter(lambda example, idx: idx < train_dset_size, with_indices=True) else: - logger.info(f"Training with all data...") + logger.info("Training with all data...") if model_args.freeze_backbone: logger.info("Freezing backbone") @@ -150,7 +152,7 @@ def sklearn_metrics(prediction: EvalPrediction): return metric_dict - logger.info(f"Collating labels...") + logger.info("Collating labels...") collator = MultilabelDataCollator(label2id=label2id) if shards > 0: From 13822a8c1332debbac44d79fdc84e197dcaf687b Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 17:36:32 +0100 Subject: [PATCH 090/300] Reformatting --- grants_tagger_light/training/train.py | 48 ++++++++++++++++++--------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 99343e32..3801105e 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -74,7 +74,8 @@ def train_bertmesh( if max_samples > 0: train_dset_size = min(max_samples, train_dset_size) logger.info(f"Training max samples: {train_dset_size}.") - train_dset.filter(lambda example, idx: idx < train_dset_size, with_indices=True) + train_dset.filter(lambda example, idx: idx < train_dset_size, + with_indices=True) else: logger.info("Training with all data...") @@ -123,8 +124,9 @@ def train_bertmesh( train_dset_size = len(train_dset) if max_samples > 0: train_dset_size = min(max_samples, train_dset_size) - logger.info("Training max samples: {train_dset_size}.") - train_dset.filter(lambda example, idx: idx < train_dset_size, with_indices=True) + logger.info(f"Training max samples: {train_dset_size}.") + train_dset.filter(lambda example, idx: idx < train_dset_size, + with_indices=True) else: logger.info("Training with all data...") @@ -136,12 +138,15 @@ def sklearn_metrics(prediction: EvalPrediction): y_pred = prediction.predictions y_true = prediction.label_ids - # TODO make thresh configurable or return metrics for multiple thresholds + # TODO make thresh configurable or return metrics + # for multiple thresholds # e.g. 0.5:0.95:0.05 y_pred = np.int64(y_pred > 0.5) - report = classification_report(y_true, y_pred, output_dict=True) + report = classification_report(y_true, + y_pred, + output_dict=True) metric_dict = { "micro_avg": report["micro avg"], @@ -156,12 +161,15 @@ def sklearn_metrics(prediction: EvalPrediction): collator = MultilabelDataCollator(label2id=label2id) if shards > 0: - logger.info(f"Calculating max steps for IterableDatasets shards...") - max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) + logger.info(f"Calculating max steps for " + f"IterableDatasets shards...") + max_steps = Sharding.calculate_max_steps(training_args, + train_dset_size) training_args.max_steps = max_steps logger.info(f"Initializing Trainer:\n" - f"* per_device_train_batch_size={training_args.per_device_train_batch_size}\n" + f"* per_device_train_batch_size=" + f"{training_args.per_device_train_batch_size}\n" f"* max_steps = {training_args.max_steps}\n" f"* epochs = {training_args.num_train_epochs}\n" ) @@ -175,15 +183,15 @@ def sklearn_metrics(prediction: EvalPrediction): compute_metrics=sklearn_metrics ) - logger.info(f"Training...") + logger.info("Training...") trainer.train() - logger.info(f"Evaluating...") + logger.info("Evaluating...") metrics = trainer.evaluate(eval_dataset=val_dset) logger.info(pformat(metrics)) - logger.info(f"Saving the model...") + logger.info("Saving the model...") trainer.save_model(os.path.join(training_args.output_dir, "best")) @@ -194,11 +202,13 @@ def sklearn_metrics(prediction: EvalPrediction): def train_bertmesh_cli( ctx: typer.Context, model_key: str = typer.Argument( - ..., help="Pretrained model key. Local path or HF location" + ..., help="Pretrained model key. " + "Local path or HF location" ), data_path: str = typer.Argument( ..., - help="Path to allMeSH_2021.jsonl (or similar) or to a folder after preprocessing and saving to disk" + help="Path to allMeSH_2021.jsonl (or similar) " + "or to a folder after preprocessing and saving to disk" ), test_size: float = typer.Option( 0.05, @@ -214,7 +224,8 @@ def train_bertmesh_cli( ), shards: int = typer.Option( -1, - help="Number os shards to divide training IterativeDataset to (improves performance)") + help="Number os shards to divide training " + "IterativeDataset to (improves performance)") ): parser = HfArgumentParser( @@ -233,4 +244,11 @@ def train_bertmesh_cli( logger.info("Training args: {}".format(pformat(training_args))) logger.info("Wandb args: {}".format(pformat(wandb_args))) - train_bertmesh(model_key, data_path, training_args, model_args, max_samples, test_size, num_proc, shards) + train_bertmesh(model_key, + data_path, + training_args, + model_args, + max_samples, + test_size, + num_proc, + shards) From 3ae614339bcf6788bd8e1b29acbaa7e3a91cd0e4 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 17:40:58 +0100 Subject: [PATCH 091/300] Reformatting --- .../preprocessing/preprocess_mesh.py | 81 ++++++------ grants_tagger_light/training/train.py | 116 +++++++++--------- scripts/jsonl_preprocessing.py | 79 ++++++++---- 3 files changed, 156 insertions(+), 120 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index bbdeddee..a739cb2d 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -29,12 +29,12 @@ def _map_label_to_ids(labels, label2id): def _encode_labels(sample, label2id): - return {'label_ids': [_map_label_to_ids(x, label2id) for x in sample['meshMajor']]} + return {"label_ids": [_map_label_to_ids(x, label2id) for x in sample["meshMajor"]]} def create_sample_file(jsonl_file, lines): - with open(jsonl_file, 'r') as input_file: - with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + with open(jsonl_file, "r") as input_file: + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file: for _ in range(lines): line = input_file.readline() if not line: @@ -72,8 +72,8 @@ def preprocess_mesh( # We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing dset = load_dataset("json", data_files=data_path, num_proc=1) # By default, any dataset loaded is set to 'train' using the previous command - if 'train' in dset: - dset = dset['train'] + if "train" in dset: + dset = dset["train"] # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "year", "pmid", "title"]) @@ -93,11 +93,11 @@ def preprocess_mesh( if label2id is None: logger.info("Getting the labels...") dset = dset.map( - lambda x: {'labels': x["meshMajor"]}, + lambda x: {"labels": x["meshMajor"]}, batched=True, batch_size=batch_size, num_proc=num_proc, - desc="Getting labels" + desc="Getting labels", ) # Most efficient way to do dedup of labels @@ -105,7 +105,7 @@ def preprocess_mesh( logger.info("Obtaining unique values from the labels...") # Iterate through the lists and add elements to the set - for arr in tqdm(dset['labels']): + for arr in tqdm(dset["labels"]): unique_labels_set.update(arr) # Most efficient way to do dictionary creation @@ -134,8 +134,8 @@ def preprocess_mesh( # on serializing the data # to disk if we are going to load it afterwards if save_to_path is not None: logger.info("Saving to disk...") - dset.save_to_disk(os.path.join(save_to_path, 'dataset'), num_proc=num_proc) - with open(os.path.join(save_to_path, 'label2id'), 'w') as f: + dset.save_to_disk(os.path.join(save_to_path, "dataset"), num_proc=num_proc) + with open(os.path.join(save_to_path, "label2id"), "w") as f: json.dump(label2id, f) return dset, label2id @@ -143,53 +143,50 @@ def preprocess_mesh( @preprocess_app.command() def preprocess_mesh_cli( - data_path: str = typer.Argument( - ..., - help="Path to mesh.jsonl" - ), + data_path: str = typer.Argument(..., help="Path to mesh.jsonl"), save_to_path: str = typer.Argument( - ..., - help="Path to save the serialized PyArrow dataset after preprocessing" + ..., help="Path to save the serialized PyArrow dataset after preprocessing" ), model_key: str = typer.Argument( ..., help="Key to use when loading tokenizer and label2id. " - "Leave blank if training from scratch", # noqa - ), - test_size: float = typer.Option( - 0.05, - help="Fraction of data to use for testing" + "Leave blank if training from scratch", # noqa ), + test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), num_proc: int = typer.Option( - os.cpu_count(), - help="Number of processes to use for preprocessing" + os.cpu_count(), help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( -1, help="Maximum number of samples to use for preprocessing", ), - batch_size: int = typer.Option( - 256, - help="Size of the preprocessing batch" - ) + batch_size: int = typer.Option(256, help="Size of the preprocessing batch"), ): - if input("\033[96mRunning preprocessing will save the data as a PyArrow dataset " - "which is a very time consuming operation. If you don't need the data " - "to be saved, you can save much time just by running:\n" - "> `grants-tagger train bertmesh {model_key} {path_to_jsonl}`\033[0m\n\n" - "Do You Want To Continue? [Y/n]") != 'Y': + if ( + input( + "\033[96mRunning preprocessing will save the data as a PyArrow dataset " + "which is a very time consuming operation. If you don't need the data " + "to be saved, you can save much time just by running:\n" + "> `grants-tagger train bertmesh {model_key} {path_to_jsonl}`\033[0m\n\n" + "Do You Want To Continue? [Y/n]" + ) + != "Y" + ): exit(1) - if not data_path.endswith('jsonl'): - logger.error("It seems your input MeSH data is not in `jsonl` format. " - "Please, run first `scripts/mesh_json_to_jsonlpy.`") + if not data_path.endswith("jsonl"): + logger.error( + "It seems your input MeSH data is not in `jsonl` format. " + "Please, run first `scripts/mesh_json_to_jsonlpy.`" + ) exit(1) preprocess_mesh( - data_path=data_path, - model_key=model_key, - test_size=test_size, - num_proc=num_proc, - max_samples=max_samples, - batch_size=batch_size, - save_to_path=save_to_path) + data_path=data_path, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, + batch_size=batch_size, + save_to_path=save_to_path, + ) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 3801105e..db8144dc 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -38,13 +38,13 @@ def train_bertmesh( max_samples: int = -1, test_size: float = 0.05, num_proc: int = os.cpu_count(), - shards: int = -1 + shards: int = -1, ): if not model_key: - assert isinstance( - model_args, BertMeshModelArguments - ), "If model_key is not provided, " \ - "must provide model_args of type BertMeshModelArguments" # noqa + assert isinstance(model_args, BertMeshModelArguments), ( + "If model_key is not provided, " + "must provide model_args of type BertMeshModelArguments" + ) # noqa logger.info("No model key provided. Training model from scratch") @@ -54,28 +54,32 @@ def train_bertmesh( logger.info(f"Preprocessing the dataset at {data_path}...") if os.path.isdir(data_path): - logger.info("Folder found, which means you preprocessed and " - "save the data before. Loading from disk...") - dset = load_from_disk(os.path.join(data_path, 'dataset')) - with open(os.path.join(data_path, 'label2id'), 'r') as f: + logger.info( + "Folder found, which means you preprocessed and " + "save the data before. Loading from disk..." + ) + dset = load_from_disk(os.path.join(data_path, "dataset")) + with open(os.path.join(data_path, "label2id"), "r") as f: label2id = json.load(f) else: logger.info("Preprocessing the data on the fly...") dset, label2id = preprocess_mesh( - data_path=data_path, - model_key=model_key, - test_size=test_size, - num_proc=num_proc, - max_samples=max_samples, - batch_size=training_args.per_device_train_batch_size) + data_path=data_path, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, + batch_size=training_args.per_device_train_batch_size, + ) train_dset, val_dset = dset["train"], dset["test"] train_dset_size = len(train_dset) if max_samples > 0: train_dset_size = min(max_samples, train_dset_size) logger.info(f"Training max samples: {train_dset_size}.") - train_dset.filter(lambda example, idx: idx < train_dset_size, - with_indices=True) + train_dset.filter( + lambda example, idx: idx < train_dset_size, with_indices=True + ) else: logger.info("Training with all data...") @@ -105,10 +109,12 @@ def train_bertmesh( logger.info(f"Preprocessing the dataset at {data_path}...") if os.path.isdir(data_path): - logger.info("Folder found, which means you preprocessed and " - "save the data before. Loading from disk...") - dset = load_from_disk(os.path.join(data_path, 'dataset')) - with open(os.path.join(data_path, 'label2id'), 'r') as f: + logger.info( + "Folder found, which means you preprocessed and " + "save the data before. Loading from disk..." + ) + dset = load_from_disk(os.path.join(data_path, "dataset")) + with open(os.path.join(data_path, "label2id"), "r") as f: label2id = json.load(f) else: logger.info("Preprocessing the data on the fly...") @@ -118,15 +124,17 @@ def train_bertmesh( test_size=test_size, num_proc=num_proc, max_samples=max_samples, - batch_size=training_args.per_device_train_batch_size) + batch_size=training_args.per_device_train_batch_size, + ) train_dset, val_dset = dset["train"], dset["test"] train_dset_size = len(train_dset) if max_samples > 0: train_dset_size = min(max_samples, train_dset_size) logger.info(f"Training max samples: {train_dset_size}.") - train_dset.filter(lambda example, idx: idx < train_dset_size, - with_indices=True) + train_dset.filter( + lambda example, idx: idx < train_dset_size, with_indices=True + ) else: logger.info("Training with all data...") @@ -144,9 +152,7 @@ def sklearn_metrics(prediction: EvalPrediction): y_pred = np.int64(y_pred > 0.5) - report = classification_report(y_true, - y_pred, - output_dict=True) + report = classification_report(y_true, y_pred, output_dict=True) metric_dict = { "micro_avg": report["micro avg"], @@ -161,18 +167,17 @@ def sklearn_metrics(prediction: EvalPrediction): collator = MultilabelDataCollator(label2id=label2id) if shards > 0: - logger.info(f"Calculating max steps for " - f"IterableDatasets shards...") - max_steps = Sharding.calculate_max_steps(training_args, - train_dset_size) + logger.info(f"Calculating max steps for " f"IterableDatasets shards...") + max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) training_args.max_steps = max_steps - logger.info(f"Initializing Trainer:\n" - f"* per_device_train_batch_size=" - f"{training_args.per_device_train_batch_size}\n" - f"* max_steps = {training_args.max_steps}\n" - f"* epochs = {training_args.num_train_epochs}\n" - ) + logger.info( + f"Initializing Trainer:\n" + f"* per_device_train_batch_size=" + f"{training_args.per_device_train_batch_size}\n" + f"* max_steps = {training_args.max_steps}\n" + f"* epochs = {training_args.num_train_epochs}\n" + ) trainer = Trainer( model=model, @@ -180,7 +185,7 @@ def sklearn_metrics(prediction: EvalPrediction): train_dataset=train_dset, eval_dataset=val_dset, data_collator=collator, - compute_metrics=sklearn_metrics + compute_metrics=sklearn_metrics, ) logger.info("Training...") @@ -202,21 +207,16 @@ def sklearn_metrics(prediction: EvalPrediction): def train_bertmesh_cli( ctx: typer.Context, model_key: str = typer.Argument( - ..., help="Pretrained model key. " - "Local path or HF location" + ..., help="Pretrained model key. " "Local path or HF location" ), data_path: str = typer.Argument( ..., help="Path to allMeSH_2021.jsonl (or similar) " - "or to a folder after preprocessing and saving to disk" - ), - test_size: float = typer.Option( - 0.05, - help="Fraction of data to use for testing" + "or to a folder after preprocessing and saving to disk", ), + test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), num_proc: int = typer.Option( - os.cpu_count(), - help="Number of processes to use for preprocessing" + os.cpu_count(), help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( -1, @@ -225,9 +225,9 @@ def train_bertmesh_cli( shards: int = typer.Option( -1, help="Number os shards to divide training " - "IterativeDataset to (improves performance)") + "IterativeDataset to (improves performance)", + ), ): - parser = HfArgumentParser( ( BertMeshTrainingArguments, @@ -244,11 +244,13 @@ def train_bertmesh_cli( logger.info("Training args: {}".format(pformat(training_args))) logger.info("Wandb args: {}".format(pformat(wandb_args))) - train_bertmesh(model_key, - data_path, - training_args, - model_args, - max_samples, - test_size, - num_proc, - shards) + train_bertmesh( + model_key, + data_path, + training_args, + model_args, + max_samples, + test_size, + num_proc, + shards, + ) diff --git a/scripts/jsonl_preprocessing.py b/scripts/jsonl_preprocessing.py index cf59e0ed..b2b17956 100644 --- a/scripts/jsonl_preprocessing.py +++ b/scripts/jsonl_preprocessing.py @@ -8,17 +8,17 @@ def process_data(item, filter_tags: list = None, filter_years: list = None): check_tags = filter_tags is not None check_years = filter_years is not None - if check_tags and 'meshMajor' not in item: + if check_tags and "meshMajor" not in item: logger.warning("`meshMajor` not found in the fields. Unable to filter tags.") check_tags = False - if check_years and 'year' not in item: + if check_years and "year" not in item: logger.warning("`year` not found in the fields. Unable to filter tags.") check_years = False if check_tags: if filter_tags is None: filter_tags = [] - if len(filter_tags) > 0 and not any(np.isin(filter_tags, item['meshMajor'])): + if len(filter_tags) > 0 and not any(np.isin(filter_tags, item["meshMajor"])): return False if check_years: @@ -27,14 +27,23 @@ def process_data(item, filter_tags: list = None, filter_years: list = None): # Making sure it's str and not int filter_years = [str(y) for y in filter_years] - if len(filter_years) > 0 and not any(np.isin(filter_years, [str(item['year'])])): + if len(filter_years) > 0 and not any( + np.isin(filter_years, [str(item["year"])]) + ): return False return True -def mesh_json_to_jsonl(input_path, output_path, input_encoding='latin1', output_encoding='latin1', - filter_tags: str = '', filter_years: str = '', show_progress: bool = True): +def mesh_json_to_jsonl( + input_path, + output_path, + input_encoding="latin1", + output_encoding="latin1", + filter_tags: str = "", + filter_years: str = "", + show_progress: bool = True, +): """ Mesh json is not optimized for parallel processing. This script aims to transform it into `jsonl` so that, by having 1 json per line, you can evenly distribute it among all the cores. @@ -50,13 +59,13 @@ def mesh_json_to_jsonl(input_path, output_path, input_encoding='latin1', output_ Returns: """ - filter_tags_list = list(filter(lambda x: x.strip() != '', filter_tags.split(','))) - filter_years_list = list(filter(lambda x: x.strip() != '', filter_years.split(','))) - with open(output_path, 'w', encoding=output_encoding) as fw: - with open(input_path, 'r', encoding=input_encoding) as fr: + filter_tags_list = list(filter(lambda x: x.strip() != "", filter_tags.split(","))) + filter_years_list = list(filter(lambda x: x.strip() != "", filter_years.split(","))) + with open(output_path, "w", encoding=output_encoding) as fw: + with open(input_path, "r", encoding=input_encoding) as fr: for idx, line in enumerate(fr): if show_progress: - print(idx, end='\r') + print(idx, end="\r") # Skip 1st line if idx == 0: logger.info(f"Skipping first line (articles): {line}") @@ -68,21 +77,49 @@ def mesh_json_to_jsonl(input_path, output_path, input_encoding='latin1', output_ continue if process_data(sample, filter_tags_list, filter_years_list): fw.write(json.dumps(sample)) - fw.write('\n') + fw.write("\n") if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--input_path", required=True, help="path to input allMeSH_2021.json (or equivalent)") + parser.add_argument( + "--input_path", + required=True, + help="path to input allMeSH_2021.json (or equivalent)", + ) parser.add_argument("--output_path", required=True, help="path to ioutput jsonl") - parser.add_argument("--input_encoding", required=False, default='latin1', help="encoding of the input json") - parser.add_argument("--output_encoding", required=False, default='latin1', help="encoding of the output jsonl") - parser.add_argument("--filter_tags", required=False, default='', - help="comma-separated tags to include (the rest will be discarded)") - parser.add_argument("--filter_years", required=False, default='', - help="comma-separated years to include (the rest will be discarded)") + parser.add_argument( + "--input_encoding", + required=False, + default="latin1", + help="encoding of the input json", + ) + parser.add_argument( + "--output_encoding", + required=False, + default="latin1", + help="encoding of the output jsonl", + ) + parser.add_argument( + "--filter_tags", + required=False, + default="", + help="comma-separated tags to include (the rest will be discarded)", + ) + parser.add_argument( + "--filter_years", + required=False, + default="", + help="comma-separated years to include (the rest will be discarded)", + ) args = parser.parse_args() - mesh_json_to_jsonl(args.input_path, args.output_path, args.input_encoding, args.output_encoding, args.filter_tags, - args.filter_years) + mesh_json_to_jsonl( + args.input_path, + args.output_path, + args.input_encoding, + args.output_encoding, + args.filter_tags, + args.filter_years, + ) From 689f323f51b1380a13aedb2c723afd73e25132d1 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 17:44:32 +0100 Subject: [PATCH 092/300] Reformatting --- grants_tagger_light/training/train.py | 2 +- grants_tagger_light/utils/sharding.py | 16 ++++++++++------ grants_tagger_light/utils/utils.py | 1 - tests/test_preprocess_mesh.py | 10 ++++------ tests/test_train.py | 3 +-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index db8144dc..2f45b040 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -167,7 +167,7 @@ def sklearn_metrics(prediction: EvalPrediction): collator = MultilabelDataCollator(label2id=label2id) if shards > 0: - logger.info(f"Calculating max steps for " f"IterableDatasets shards...") + logger.info("Calculating max steps for IterableDatasets shards...") max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) training_args.max_steps = max_steps diff --git a/grants_tagger_light/utils/sharding.py b/grants_tagger_light/utils/sharding.py index 85a7ff22..d8046b36 100644 --- a/grants_tagger_light/utils/sharding.py +++ b/grants_tagger_light/utils/sharding.py @@ -1,4 +1,4 @@ -from datasets import load_dataset, IterableDataset +from datasets import IterableDataset class Sharding: @@ -17,14 +17,18 @@ def gen_from_shards(cls, _shards): yield example def shard(self, dataset): - shards = [dataset.shard(num_shards=self.num_shards, index=index, contiguous=True) - for index in range(self.num_shards)] + shards = [ + dataset.shard(num_shards=self.num_shards, index=index, contiguous=True) + for index in range(self.num_shards) + ] - return IterableDataset.from_generator(self.gen_from_shards, gen_kwargs={"_shards": shards}) + return IterableDataset.from_generator( + self.gen_from_shards, gen_kwargs={"_shards": shards} + ) @staticmethod def calculate_max_steps(training_args, train_dset_size): - """ This is needed when using IterableDatasets, as there is no __len__ in advance since the dataset is a + """This is needed when using IterableDatasets, as there is no __len__ in advance since the dataset is a generator with yield, so it does not know when to end. Source: https://discuss.huggingface.co/t/streaming-dataset-into-trainer-does-not-implement-len-max-steps-has-to-be-specified/32893/6 @@ -34,7 +38,7 @@ def calculate_max_steps(training_args, train_dset_size): > 14781199.15 / 8 = 1847649.89375 if accumulation_steps is 1, then > 1847649.89375 / 1 = 1847649.89375 - """ + """ train_batch_size = training_args.per_device_train_batch_size accumulation_steps = training_args.gradient_accumulation_steps diff --git a/grants_tagger_light/utils/utils.py b/grants_tagger_light/utils/utils.py index 0f57f57f..05aa36ae 100644 --- a/grants_tagger_light/utils/utils.py +++ b/grants_tagger_light/utils/utils.py @@ -177,4 +177,3 @@ def create_label_binarizer(model_path: str, label_binarizer_path: str): f.write(pickle.dumps(label_binarizer)) return label_binarizer - diff --git a/tests/test_preprocess_mesh.py b/tests/test_preprocess_mesh.py index 0217da7a..f256787e 100644 --- a/tests/test_preprocess_mesh.py +++ b/tests/test_preprocess_mesh.py @@ -83,7 +83,7 @@ def test_json_to_jsonl(json_data_path): output_tmp = tempfile.NamedTemporaryFile(mode="w") mesh_json_to_jsonl(json_data_path, output_tmp.name, show_progress=False) - with open(output_tmp.name, 'r') as f: + with open(output_tmp.name, "r") as f: result = [json.loads(jline) for jline in f.read().splitlines()] assert len(result) == 2 @@ -91,11 +91,9 @@ def test_json_to_jsonl(json_data_path): def test_preprocess_mesh(jsonl_data_path): - dset, label2id = preprocess_mesh(data_path=jsonl_data_path, - model_key='', - num_proc=2, - batch_size=1, - test_size=0.5) + dset, label2id = preprocess_mesh( + data_path=jsonl_data_path, model_key="", num_proc=2, batch_size=1, test_size=0.5 + ) assert "train" in dset assert "test" in dset assert len(dset["train"]) == 1 diff --git a/tests/test_train.py b/tests/test_train.py index 0c41b390..8134e1a0 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -48,7 +48,7 @@ def _train_bertmesh_from_model_key(data_path, save_path, model_key): training_args=training_args, model_args=model_args, num_proc=2, - test_size=0.5 + test_size=0.5, ) @@ -58,4 +58,3 @@ def test_train_bertmesh_from_model_key(data_path, save_path): def test_train_bertmesh_from_scratch(data_path, save_path): _train_bertmesh_from_model_key(data_path, save_path, "") - From 42428847d8c6e8e417cca1cbdf8d7ec8d72696af Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 17:52:55 +0100 Subject: [PATCH 093/300] Reformatting --- grants_tagger_light/utils/sharding.py | 7 ++++--- scripts/jsonl_preprocessing.py | 9 +++++---- tests/test_preprocess_mesh.py | 4 ++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/grants_tagger_light/utils/sharding.py b/grants_tagger_light/utils/sharding.py index d8046b36..baec7423 100644 --- a/grants_tagger_light/utils/sharding.py +++ b/grants_tagger_light/utils/sharding.py @@ -28,17 +28,18 @@ def shard(self, dataset): @staticmethod def calculate_max_steps(training_args, train_dset_size): - """This is needed when using IterableDatasets, as there is no __len__ in advance since the dataset is a + """This is needed when using IterableDatasets, + as there is no __len__ in advance since the dataset is a generator with yield, so it does not know when to end. Source: https://discuss.huggingface.co/t/streaming-dataset-into-trainer-does-not-implement-len-max-steps-has-to-be-specified/32893/6 - Example: allMeSH_2021.json has 15.559.157 rows, from which let's suppose 5% is test + Example: allMeSH_2021.json has 15.559.157 rows, with 5% for test > 15559157-0.05*15559157 = 14781199.15 training rows let's suppose batch size is 8 > 14781199.15 / 8 = 1847649.89375 if accumulation_steps is 1, then > 1847649.89375 / 1 = 1847649.89375 - """ + """ # noqa train_batch_size = training_args.per_device_train_batch_size accumulation_steps = training_args.gradient_accumulation_steps diff --git a/scripts/jsonl_preprocessing.py b/scripts/jsonl_preprocessing.py index b2b17956..ce855706 100644 --- a/scripts/jsonl_preprocessing.py +++ b/scripts/jsonl_preprocessing.py @@ -45,7 +45,8 @@ def mesh_json_to_jsonl( show_progress: bool = True, ): """ - Mesh json is not optimized for parallel processing. This script aims to transform it into `jsonl` so that, + Mesh json is not optimized for parallel processing. + This script aims to transform it into `jsonl` so that, by having 1 json per line, you can evenly distribute it among all the cores. Args: @@ -53,8 +54,8 @@ def mesh_json_to_jsonl( output_path: path for the resulted jsonl file input_encoding: encoding of the input json (default `latin1`) output_encoding: encoding of the output jsonl (default `latin1`) - filter_tags: tags separated by commas(T1,T2,T3) to only include the entries with those tags - filter_years: years separated by commas(2008,2009) to only include the entries of those years + filter_tags: tags separated by commas(T1,T2,T3) to only include the entries with those + filter_years: years separated by commas(2008,2009) to only include the entries of those show_progress: print the number of line you are processing Returns: @@ -72,7 +73,7 @@ def mesh_json_to_jsonl( continue try: sample = json.loads(line[:-2]) - except: + except json.JSONDecodeError: logger.warning(f"Skipping line in bad json format: {line}") continue if process_data(sample, filter_tags_list, filter_years_list): diff --git a/tests/test_preprocess_mesh.py b/tests/test_preprocess_mesh.py index f256787e..65fc444f 100644 --- a/tests/test_preprocess_mesh.py +++ b/tests/test_preprocess_mesh.py @@ -9,12 +9,12 @@ jsonl_data = """{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"} -{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" # noqa json_data = """{"articles":[ {"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"}, {"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}, -""" +""" # noqa @pytest.fixture From 78aebe7e40c30be7f51ea81cdac90a351bff5c7a Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 17:54:00 +0100 Subject: [PATCH 094/300] Reformatting --- scripts/jsonl_preprocessing.py | 2 +- tests/test_train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/jsonl_preprocessing.py b/scripts/jsonl_preprocessing.py index ce855706..81022824 100644 --- a/scripts/jsonl_preprocessing.py +++ b/scripts/jsonl_preprocessing.py @@ -59,7 +59,7 @@ def mesh_json_to_jsonl( show_progress: print the number of line you are processing Returns: - """ + """ # noqa filter_tags_list = list(filter(lambda x: x.strip() != "", filter_tags.split(","))) filter_years_list = list(filter(lambda x: x.strip() != "", filter_years.split(","))) with open(output_path, "w", encoding=output_encoding) as fw: diff --git a/tests/test_train.py b/tests/test_train.py index 8134e1a0..009cf9ac 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -6,7 +6,7 @@ # Note dummy data is not necessarily annotated correctly dummy_data = """{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"} -{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" # noqa @pytest.fixture From 55e427b57be7a92120bb02d7e4d1445e22832233 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 17:56:28 +0100 Subject: [PATCH 095/300] Reformatting --- grants_tagger_light/utils/sharding.py | 2 +- scripts/jsonl_preprocessing.py | 2 +- tests/test_preprocess_mesh.py | 4 ++-- tests/test_train.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/utils/sharding.py b/grants_tagger_light/utils/sharding.py index baec7423..c4c8c3c3 100644 --- a/grants_tagger_light/utils/sharding.py +++ b/grants_tagger_light/utils/sharding.py @@ -39,7 +39,7 @@ def calculate_max_steps(training_args, train_dset_size): > 14781199.15 / 8 = 1847649.89375 if accumulation_steps is 1, then > 1847649.89375 / 1 = 1847649.89375 - """ # noqa + """ # noqa train_batch_size = training_args.per_device_train_batch_size accumulation_steps = training_args.gradient_accumulation_steps diff --git a/scripts/jsonl_preprocessing.py b/scripts/jsonl_preprocessing.py index 81022824..3ad5148a 100644 --- a/scripts/jsonl_preprocessing.py +++ b/scripts/jsonl_preprocessing.py @@ -59,7 +59,7 @@ def mesh_json_to_jsonl( show_progress: print the number of line you are processing Returns: - """ # noqa + """ # noqa filter_tags_list = list(filter(lambda x: x.strip() != "", filter_tags.split(","))) filter_years_list = list(filter(lambda x: x.strip() != "", filter_years.split(","))) with open(output_path, "w", encoding=output_encoding) as fw: diff --git a/tests/test_preprocess_mesh.py b/tests/test_preprocess_mesh.py index 65fc444f..f9d46295 100644 --- a/tests/test_preprocess_mesh.py +++ b/tests/test_preprocess_mesh.py @@ -9,12 +9,12 @@ jsonl_data = """{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"} -{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" # noqa +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" # noqa json_data = """{"articles":[ {"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"}, {"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}, -""" # noqa +""" # noqa @pytest.fixture diff --git a/tests/test_train.py b/tests/test_train.py index 009cf9ac..42e49159 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -6,7 +6,7 @@ # Note dummy data is not necessarily annotated correctly dummy_data = """{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"} -{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" # noqa +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" # noqa @pytest.fixture From d9fd22031df2411d418d92c14d9be787cb340297 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 18:18:20 +0100 Subject: [PATCH 096/300] Reformatting --- README.md | 146 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 91 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 7b2c171b..20f700c7 100644 --- a/README.md +++ b/README.md @@ -58,48 +58,16 @@ And then connect and attach to your machine with a tunnel # ⌨️ Commands | Commands | Description | Needs dev | -| --------------- |--------------------------------------------------------------|-----------| -| ⚙️ preprocess | preprocess data to use for training | False | -| 🔥 train | trains a new model | True | +|-----------------|--------------------------------------------------------------|-----------| +| 🔥 train | preprocesses the data and trains a new model | True | +| ⚙ preprocess | (Optional) preprocess and save the data outside training | False | | 📈 evaluate | evaluate performance of pretrained model | True | | 🔖 predict | predict tags given a grant abstract using a pretrained model | False | | 🎛 tune | tune params and threshold | True | -| ⬇️ download | download data from EPMC | False | - -in square brackets the commands that are not implemented yet - -## ⚙️ Preprocess - -Preprocess creates a JSONL datafile with `text`, `tags` and `meta` as keys. -Text and tags are used for training whereas meta can be useful during annotation -or to analyse predictions and performance. Each dataset needs its own -preprocessing so the current preprocess works with the bioasq-mesh one. -If you want to use a different dataset see section on bringing -your own data under development. +| ⬇ download | download data from EPMC | False | -#### bioasq-mesh -``` - - Usage: grants-tagger preprocess bioasq-mesh [OPTIONS] [INPUT_PATH] - [TRAIN_OUTPUT_PATH] - [LABEL_BINARIZER_PATH] - -╭─ Arguments ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ input_path [INPUT_PATH] path to BioASQ JSON data [default: None] │ -│ train_output_path [TRAIN_OUTPUT_PATH] path to JSONL output file that will be generated for the train set [default: None] │ -│ label_binarizer_path [LABEL_BINARIZER_PATH] path to pickle file that will contain the label binarizer [default: None] │ -╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Options ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ --test-output-path TEXT path to JSONL output file that will be generated for the test set [default: None] │ -│ --mesh-tags-path TEXT path to mesh tags to filter [default: None] │ -│ --test-split FLOAT split percentage for test data. if None no split. [default: 0.01] │ -│ --filter-years TEXT years to keep in form min_year,max_year with both inclusive [default: None] │ -│ --config PATH path to config files that defines arguments [default: None] │ -│ --n-max INTEGER Maximum limit on the number of datapoints in the set (including training and test) [default: None] │ -│ --help Show this message and exit. │ -╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -``` +in square brackets the commands that are not implemented yet ## 🔥 Train @@ -108,21 +76,86 @@ the BertMesh model. The command will train a model and save it to the specified ### bertmesh ``` - Usage: grants-tagger train bertmesh [OPTIONS] MODEL_KEY DATA_PATH - MODEL_SAVE_PATH -╭─ Arguments ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ * model_key TEXT Pretrained model key. Local path or HF location [default: None] [required] │ -│ * data_path TEXT Path to data in jsonl format. Must contain text and tags field [default: None] [required] │ -│ * model_save_path TEXT Path to save model to [default: None] [required] │ -╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Options ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ --help Show this message and exit. │ -╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ * model_key TEXT Pretrained model key. Local path or HF location [default: None] [required] │ +│ * data_path TEXT Path to allMeSH_2021.jsonl (or similar) or to a folder after preprocessing and saving │ +│ to disk │ +│ [default: None] │ +│ [required] │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ --test-size FLOAT Fraction of data to use for testing [default: 0.05] │ +│ --num-proc INTEGER Number of processes to use for preprocessing [default: 8] │ +│ --max-samples INTEGER Maximum number of samples to use from the json [default: -1] │ +│ --shards INTEGER Number os shards to divide training IterativeDataset to (improves performance) │ +│ [default: -1] │ +│ --help Show this message and exit. │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +``` +`model_key` possible values are: +- A HF location for a pretrained / finetuned model +- "" to load a model by default and train from scratch: `microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract` + +Besides those arguments, feel free to add any other TrainingArgument from Hugging Face or Wand DB. Examples: +```commandline +grants-tagger train bertmesh \ + "" \ + data/raw/allMeSH_2021.jsonl \ + --test-size 0.05 \ + --max-samples 1000 \ + --shards 100 \ + --output_dir bertmesh_outs/pipeline_test/ \ + --wandb_name test-train-all \ + --wandb_api_key ${WANDB_API_KEY} \ + --per_device_train_batch_size 256 \ + --per_device_eval_batch_size 8 \ + --num_train_epochs 1 \ + --evaluation_strategy steps \ + --eval_steps 100000 \ + --save_strategy steps \ + --save_steps 100000 \ + --fp16 \ + --torch_compile ``` +## ⚙️ Preprocess + +This process is optional to run, since it will be managed by the `grants-tagger train bertmesh` process. +If you run it manually, it will store the data in local. + +Preprocess creates a JSONL datafile with `text`, `tags` and `meta` as keys. +Text and tags are used for training whereas meta can be useful during annotation +or to analyse predictions and performance. Each dataset needs its own +preprocessing so the current preprocess works with the `allMeSH_2021` one. +If you want to use a different dataset see section on bringing +your own data under development. + + +### mesh +``` + Usage: grants-tagger preprocess mesh [OPTIONS] DATA_PATH SAVE_TO_PATH + MODEL_KEY + +╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ * data_path TEXT Path to mesh.jsonl [default: None] [required] │ +│ * save_to_path TEXT Path to save the serialized PyArrow dataset after preprocessing [default: None] │ +│ [required] │ +│ * model_key TEXT Key to use when loading tokenizer and label2id. Leave blank if training from │ +│ scratch │ +│ [default: None] │ +│ [required] │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ --test-size FLOAT Fraction of data to use for testing [default: 0.05] │ +│ --num-proc INTEGER Number of processes to use for preprocessing [default: 8] │ +│ --max-samples INTEGER Maximum number of samples to use for preprocessing [default: -1] │ +│ --batch-size INTEGER Size of the preprocessing batch [default: 256] │ +│ --help Show this message and exit. │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯``` +``` ## 📈 Evaluate @@ -143,7 +176,6 @@ not made into production this is the way to evaluate. The plan is to extend evaluate to all models when train starts training explicit model approaches. ``` - Usage: grants-tagger evaluate model [OPTIONS] MODEL_PATH DATA_PATH ╭─ Arguments ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ @@ -163,7 +195,6 @@ evaluate to all models when train starts training explicit model approaches. ### grants Evaluate an xlinear model on grants data. ``` - Usage: grants-tagger evaluate grants [OPTIONS] MODEL_PATH DATA_PATH LABEL_BINARIZER_PATH @@ -188,7 +219,6 @@ Predict assigns tags on a given abstract text that you can pass as argument. ``` - Usage: grants-tagger predict [OPTIONS] TEXT MODEL_PATH ╭─ Arguments ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ @@ -201,7 +231,6 @@ Predict assigns tags on a given abstract text that you can pass as argument. │ --threshold FLOAT [default: 0.5] │ │ --help Show this message and exit. │ ╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ - ``` ## 🎛 Tune @@ -209,7 +238,6 @@ Optimise the threshold used for tag decisions. ### threshold ``` - Usage: grants-tagger tune threshold [OPTIONS] DATA_PATH MODEL_PATH LABEL_BINARIZER_PATH THRESHOLDS_PATH @@ -230,12 +258,14 @@ Optimise the threshold used for tag decisions. ## ⬇️ Download -This commands enables you to download mesh data from EPMC +The project has references to `dvc` big files. You can just do `dvc pull` and retrieve those, +including `allMeSH_2021.json` and `allMeSH_2021.jsonl` to train `bertmesh`. + +Also, this commands enables you to download mesh data from EPMC ### epmc-mesh ``` - Usage: grants-tagger download epmc-mesh [OPTIONS] DOWNLOAD_PATH ╭─ Arguments ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ @@ -254,12 +284,18 @@ Install development dependencies via: ## 📋 Env variables +Variable | Required for | Description +--------------------- |--------------| ---------- +WANDB_API_KEY | train | key to dump the results to Weights&Biases +AWS_ACCESS_KEY_ID | train | access key to pull data from dvc on S3 +AWS_SECRET_ACCESS_KEY | train | secret key to pull data from dvc on S3 + If you want to participate to BIOASQ competition you need to set some variables. Variable | Required for | Description --------------------- | ------------------ | ---------- BIOASQ_USERNAME | bioasq | username with which registered in BioASQ -BIOASQ_PASSWORD | bioasq | password --//-- +BIOASQ_PASSWORD | bioasq | password If you use [direnv](https://direnv.net) then you can use it to populate your `.envrc` which will export the variables automatically, otherwise From 4db6fcf1e0a077bedad85f2fba53722043d52191 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 18:27:29 +0100 Subject: [PATCH 097/300] Updates the Readme --- README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 20f700c7..b5851cfa 100644 --- a/README.md +++ b/README.md @@ -90,15 +90,22 @@ the BertMesh model. The command will train a model and save it to the specified │ --num-proc INTEGER Number of processes to use for preprocessing [default: 8] │ │ --max-samples INTEGER Maximum number of samples to use from the json [default: -1] │ │ --shards INTEGER Number os shards to divide training IterativeDataset to (improves performance) │ -│ [default: -1] │ +│ [default: -1, meaning no shards]. Recommended: 100 │ │ --help Show this message and exit. │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` +#### About `model_key` `model_key` possible values are: - A HF location for a pretrained / finetuned model - "" to load a model by default and train from scratch: `microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract` +#### About `sharding` +`sharding` was proposed by [Hugging Face](https://github.com/huggingface/datasets/issues/2252#issuecomment-825596467) +to improve performance on big datasets. To enable it: +- set shards to something bigger than 1 (Recommended: 100) + +#### Other arguments Besides those arguments, feel free to add any other TrainingArgument from Hugging Face or Wand DB. Examples: ```commandline grants-tagger train bertmesh \ From 8ddb664834a8bb95d2c86f488bb77dcfae2c8ba4 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 24 Jul 2023 18:33:32 +0100 Subject: [PATCH 098/300] Reformatting --- grants_tagger_light/preprocessing/preprocess_mesh.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 31c37a71..09a1ab9e 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -54,7 +54,6 @@ def preprocess_mesh( max_samples: int = -1, batch_size: int = 256, ): - if max_samples != -1: data_path = create_sample_file(data_path, max_samples) @@ -77,7 +76,6 @@ def preprocess_mesh( if "train" in dset: dset = dset["train"] - # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "year", "pmid", "title"]) @@ -131,7 +129,6 @@ def preprocess_mesh( num_proc=num_proc, fn_kwargs={"label2id": label2id}, remove_columns=columns_to_remove, - ) logger.info("Time taken to encode labels: {}".format(time.time() - t1)) From 7c60c22149cf24bc7a8128c1f17484a3d8b45b10 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 28 Jul 2023 07:33:44 +0000 Subject: [PATCH 099/300] Fixes training multigpu --- grants_tagger_light/training/train.py | 4 +- grants_tagger_light/utils/sharding.py | 31 ++++- poetry.lock | 173 ++++++-------------------- scripts/train.sh | 13 +- 4 files changed, 80 insertions(+), 141 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 2f45b040..8b94379c 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -74,6 +74,7 @@ def train_bertmesh( train_dset, val_dset = dset["train"], dset["test"] train_dset_size = len(train_dset) + logger.info(f"Training dataset size: {train_dset_size}") if max_samples > 0: train_dset_size = min(max_samples, train_dset_size) logger.info(f"Training max samples: {train_dset_size}.") @@ -129,6 +130,7 @@ def train_bertmesh( train_dset, val_dset = dset["train"], dset["test"] train_dset_size = len(train_dset) + logger.info(f"Training dataset size: {train_dset_size}") if max_samples > 0: train_dset_size = min(max_samples, train_dset_size) logger.info(f"Training max samples: {train_dset_size}.") @@ -187,7 +189,7 @@ def sklearn_metrics(prediction: EvalPrediction): data_collator=collator, compute_metrics=sklearn_metrics, ) - + logger.info(training_args) logger.info("Training...") trainer.train() diff --git a/grants_tagger_light/utils/sharding.py b/grants_tagger_light/utils/sharding.py index c4c8c3c3..c7250a3a 100644 --- a/grants_tagger_light/utils/sharding.py +++ b/grants_tagger_light/utils/sharding.py @@ -1,4 +1,6 @@ from datasets import IterableDataset +from loguru import logger +import math class Sharding: @@ -41,6 +43,31 @@ def calculate_max_steps(training_args, train_dset_size): > 1847649.89375 / 1 = 1847649.89375 """ # noqa + # From https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_tf.html train_batch_size = training_args.per_device_train_batch_size - accumulation_steps = training_args.gradient_accumulation_steps - return (train_dset_size / train_batch_size) / accumulation_steps + logger.info(f"train_dset_size: {train_dset_size}") + logger.info(f"train_batch_size: {train_batch_size}") + + num_update_steps_per_epoch = train_dset_size / train_batch_size + num_update_steps_per_epoch = math.ceil(num_update_steps_per_epoch) + logger.info(f"num_update_steps_per_epoch: {num_update_steps_per_epoch}") + + accumulation_steps = max(training_args.gradient_accumulation_steps, 1) + logger.info(f"accumulation steps: {accumulation_steps}") + + num_update_steps_per_epoch /= accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + logger.info(f"num_update_steps_per_epoch: {num_update_steps_per_epoch}") + + epochs = max(training_args.num_train_epochs, 1) + logger.info(f"epochs: {epochs}") + + gpus = max(training_args._n_gpu, 1) + logger.info(f"gpus using HF: {gpus}") + + total_steps = num_update_steps_per_epoch * epochs + logger.info(f"total_steps: {total_steps}") + + total_steps_per_gpu = math.ceil(total_steps / gpus) + logger.info(f"total_steps_per_gpu: {total_steps_per_gpu}") + return total_steps_per_gpu diff --git a/poetry.lock b/poetry.lock index 91b4bc19..72efc107 100644 --- a/poetry.lock +++ b/poetry.lock @@ -211,17 +211,6 @@ files = [ [package.dependencies] vine = ">=5.0.0" -[[package]] -name = "annotated-types" -version = "0.5.0" -description = "Reusable constraint types to use with typing.Annotated" -optional = false -python-versions = ">=3.7" -files = [ - {file = "annotated_types-0.5.0-py3-none-any.whl", hash = "sha256:58da39888f92c276ad970249761ebea80ba544b77acddaa1a4d6cf78287d45fd"}, - {file = "annotated_types-0.5.0.tar.gz", hash = "sha256:47cdc3490d9ac1506ce92c7aaa76c579dc3509ff11e098fc867e5130ab7be802"}, -] - [[package]] name = "antlr4-python3-runtime" version = "4.9.3" @@ -2472,135 +2461,55 @@ files = [ [[package]] name = "pydantic" -version = "2.0.3" -description = "Data validation using Python type hints" +version = "1.10.12" +description = "Data validation and settings management using python type hints" optional = false python-versions = ">=3.7" files = [ - {file = "pydantic-2.0.3-py3-none-any.whl", hash = "sha256:614eb3321eb600c81899a88fa9858b008e3c79e0d4f1b49ab1f516b4b0c27cfb"}, - {file = "pydantic-2.0.3.tar.gz", hash = "sha256:94f13e0dcf139a5125e88283fc999788d894e14ed90cf478bcc2ee50bd4fc630"}, + {file = "pydantic-1.10.12-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a1fcb59f2f355ec350073af41d927bf83a63b50e640f4dbaa01053a28b7a7718"}, + {file = "pydantic-1.10.12-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b7ccf02d7eb340b216ec33e53a3a629856afe1c6e0ef91d84a4e6f2fb2ca70fe"}, + {file = "pydantic-1.10.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fb2aa3ab3728d950bcc885a2e9eff6c8fc40bc0b7bb434e555c215491bcf48b"}, + {file = "pydantic-1.10.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:771735dc43cf8383959dc9b90aa281f0b6092321ca98677c5fb6125a6f56d58d"}, + {file = "pydantic-1.10.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ca48477862372ac3770969b9d75f1bf66131d386dba79506c46d75e6b48c1e09"}, + {file = "pydantic-1.10.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a5e7add47a5b5a40c49b3036d464e3c7802f8ae0d1e66035ea16aa5b7a3923ed"}, + {file = "pydantic-1.10.12-cp310-cp310-win_amd64.whl", hash = "sha256:e4129b528c6baa99a429f97ce733fff478ec955513630e61b49804b6cf9b224a"}, + {file = "pydantic-1.10.12-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b0d191db0f92dfcb1dec210ca244fdae5cbe918c6050b342d619c09d31eea0cc"}, + {file = "pydantic-1.10.12-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:795e34e6cc065f8f498c89b894a3c6da294a936ee71e644e4bd44de048af1405"}, + {file = "pydantic-1.10.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69328e15cfda2c392da4e713443c7dbffa1505bc9d566e71e55abe14c97ddc62"}, + {file = "pydantic-1.10.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2031de0967c279df0d8a1c72b4ffc411ecd06bac607a212892757db7462fc494"}, + {file = "pydantic-1.10.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:ba5b2e6fe6ca2b7e013398bc7d7b170e21cce322d266ffcd57cca313e54fb246"}, + {file = "pydantic-1.10.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2a7bac939fa326db1ab741c9d7f44c565a1d1e80908b3797f7f81a4f86bc8d33"}, + {file = "pydantic-1.10.12-cp311-cp311-win_amd64.whl", hash = "sha256:87afda5539d5140cb8ba9e8b8c8865cb5b1463924d38490d73d3ccfd80896b3f"}, + {file = "pydantic-1.10.12-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:549a8e3d81df0a85226963611950b12d2d334f214436a19537b2efed61b7639a"}, + {file = "pydantic-1.10.12-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:598da88dfa127b666852bef6d0d796573a8cf5009ffd62104094a4fe39599565"}, + {file = "pydantic-1.10.12-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba5c4a8552bff16c61882db58544116d021d0b31ee7c66958d14cf386a5b5350"}, + {file = "pydantic-1.10.12-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c79e6a11a07da7374f46970410b41d5e266f7f38f6a17a9c4823db80dadf4303"}, + {file = "pydantic-1.10.12-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab26038b8375581dc832a63c948f261ae0aa21f1d34c1293469f135fa92972a5"}, + {file = "pydantic-1.10.12-cp37-cp37m-win_amd64.whl", hash = "sha256:e0a16d274b588767602b7646fa05af2782576a6cf1022f4ba74cbb4db66f6ca8"}, + {file = "pydantic-1.10.12-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6a9dfa722316f4acf4460afdf5d41d5246a80e249c7ff475c43a3a1e9d75cf62"}, + {file = "pydantic-1.10.12-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a73f489aebd0c2121ed974054cb2759af8a9f747de120acd2c3394cf84176ccb"}, + {file = "pydantic-1.10.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b30bcb8cbfccfcf02acb8f1a261143fab622831d9c0989707e0e659f77a18e0"}, + {file = "pydantic-1.10.12-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2fcfb5296d7877af406ba1547dfde9943b1256d8928732267e2653c26938cd9c"}, + {file = "pydantic-1.10.12-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2f9a6fab5f82ada41d56b0602606a5506aab165ca54e52bc4545028382ef1c5d"}, + {file = "pydantic-1.10.12-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dea7adcc33d5d105896401a1f37d56b47d443a2b2605ff8a969a0ed5543f7e33"}, + {file = "pydantic-1.10.12-cp38-cp38-win_amd64.whl", hash = "sha256:1eb2085c13bce1612da8537b2d90f549c8cbb05c67e8f22854e201bde5d98a47"}, + {file = "pydantic-1.10.12-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ef6c96b2baa2100ec91a4b428f80d8f28a3c9e53568219b6c298c1125572ebc6"}, + {file = "pydantic-1.10.12-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6c076be61cd0177a8433c0adcb03475baf4ee91edf5a4e550161ad57fc90f523"}, + {file = "pydantic-1.10.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d5a58feb9a39f481eda4d5ca220aa8b9d4f21a41274760b9bc66bfd72595b86"}, + {file = "pydantic-1.10.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5f805d2d5d0a41633651a73fa4ecdd0b3d7a49de4ec3fadf062fe16501ddbf1"}, + {file = "pydantic-1.10.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:1289c180abd4bd4555bb927c42ee42abc3aee02b0fb2d1223fb7c6e5bef87dbe"}, + {file = "pydantic-1.10.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5d1197e462e0364906cbc19681605cb7c036f2475c899b6f296104ad42b9f5fb"}, + {file = "pydantic-1.10.12-cp39-cp39-win_amd64.whl", hash = "sha256:fdbdd1d630195689f325c9ef1a12900524dceb503b00a987663ff4f58669b93d"}, + {file = "pydantic-1.10.12-py3-none-any.whl", hash = "sha256:b749a43aa51e32839c9d71dc67eb1e4221bb04af1033a32e3923d46f9effa942"}, + {file = "pydantic-1.10.12.tar.gz", hash = "sha256:0fe8a415cea8f340e7a9af9c54fc71a649b43e8ca3cc732986116b3cb135d303"}, ] [package.dependencies] -annotated-types = ">=0.4.0" -pydantic-core = "2.3.0" -typing-extensions = ">=4.6.1" +typing-extensions = ">=4.2.0" [package.extras] -email = ["email-validator (>=2.0.0)"] - -[[package]] -name = "pydantic-core" -version = "2.3.0" -description = "" -optional = false -python-versions = ">=3.7" -files = [ - {file = "pydantic_core-2.3.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:4542c98b8364b976593703a2dda97377433b102f380b61bc3a2cbc2fbdae1d1f"}, - {file = "pydantic_core-2.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9342de50824b40f55d2600f66c6f9a91a3a24851eca39145a749a3dc804ee599"}, - {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:539432f911686cb80284c30b33eaf9f4fd9a11e1111fe0dc98fdbdce69b49821"}, - {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38a0e7ee65c8999394d92d9c724434cb629279d19844f2b69d9bbc46dc8b8b61"}, - {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_armv7l.whl", hash = "sha256:e3ed6834cc005798187a56c248a2240207cb8ffdda1c89e9afda4c3d526c2ea0"}, - {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_ppc64le.whl", hash = "sha256:e72ac299a6bf732a60852d052acf3999d234686755a02ba111e85e7ebf8155b1"}, - {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_s390x.whl", hash = "sha256:616b3451b05ca63b8f433c627f68046b39543faeaa4e50d8c6699a2a1e4b85a5"}, - {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:adcb9c8848e15c613e483e0b99767ae325af27fe0dbd866df01fe5849d06e6e1"}, - {file = "pydantic_core-2.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:464bf799b422be662e5e562e62beeffc9eaa907d381a9d63a2556615bbda286d"}, - {file = "pydantic_core-2.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4638ebc17de08c2f3acba557efeb6f195c88b7299d8c55c0bb4e20638bbd4d03"}, - {file = "pydantic_core-2.3.0-cp310-none-win32.whl", hash = "sha256:9ff322c7e1030543d35d83bb521b69114d3d150750528d7757544f639def9ad6"}, - {file = "pydantic_core-2.3.0-cp310-none-win_amd64.whl", hash = "sha256:4824eb018f0a4680b1e434697a9bf3f41c7799b80076d06530cbbd212e040ccc"}, - {file = "pydantic_core-2.3.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:0aa429578e23885b3984c49d687cd05ab06f0b908ea1711a8bf7e503b7f97160"}, - {file = "pydantic_core-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20d710c1f79af930b8891bcebd84096798e4387ab64023ef41521d58f21277d3"}, - {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:309f45d4d7481d6f09cb9e35c72caa0e50add4a30bb08c04c5fe5956a0158633"}, - {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bcfb7be905aa849bd882262e1df3f75b564e2f708b4b4c7ad2d3deaf5410562"}, - {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_armv7l.whl", hash = "sha256:85cd9c0af34e371390e3cb2f3a470b0b40cc07568c1e966c638c49062be6352d"}, - {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_ppc64le.whl", hash = "sha256:37c5028cebdf731298724070838fb3a71ef1fbd201d193d311ac2cbdbca25a23"}, - {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_s390x.whl", hash = "sha256:e4208f23f12d0ad206a07a489ef4cb15722c10b62774c4460ee4123250be938e"}, - {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c24465dd11b65c8510f251b095fc788c7c91481c81840112fe3f76c30793a455"}, - {file = "pydantic_core-2.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3cd7ee8bbfab277ab56e272221886fd33a1b5943fbf45ae9195aa6a48715a8a0"}, - {file = "pydantic_core-2.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0fc7e0b056b66cc536e97ef60f48b3b289f6b3b62ac225afd4b22a42434617bf"}, - {file = "pydantic_core-2.3.0-cp311-none-win32.whl", hash = "sha256:4788135db4bd83a5edc3522b11544b013be7d25b74b155e08dd3b20cd6663bbb"}, - {file = "pydantic_core-2.3.0-cp311-none-win_amd64.whl", hash = "sha256:f93c867e5e85584a28c6a6feb6f2086d717266eb5d1210d096dd717b7f4dec04"}, - {file = "pydantic_core-2.3.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:73f62bb7fd862d9bcd886e10612bade6fe042eda8b47e8c129892bcfb7b45e84"}, - {file = "pydantic_core-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4d889d498fce64bfcd8adf1a78579a7f626f825cbeb2956a24a29b35f9a1df32"}, - {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d55e38a89ec2ae17b2fa7ffeda6b70f63afab1888bd0d57aaa7b7879760acb4"}, - {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1aefebb506bc1fe355d91d25f12bcdea7f4d7c2d9f0f6716dd025543777c99a5"}, - {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_armv7l.whl", hash = "sha256:6441a29f42585f085db0c04cd0557d4cbbb46fa68a0972409b1cfe9f430280c1"}, - {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_ppc64le.whl", hash = "sha256:47e8f034be31390a8f525431eb5e803a78ce7e2e11b32abf5361a972e14e6b61"}, - {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_s390x.whl", hash = "sha256:ad814864aba263be9c83ada44a95f72d10caabbf91589321f95c29c902bdcff0"}, - {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9eff3837d447fccf2ac38c259b14ab9cbde700df355a45a1f3ff244d5e78f8b6"}, - {file = "pydantic_core-2.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:534f3f63c000f08050c6f7f4378bf2b52d7ba9214e9d35e3f60f7ad24a4d6425"}, - {file = "pydantic_core-2.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ef6a222d54f742c24f6b143aab088702db3a827b224e75b9dd28b38597c595fe"}, - {file = "pydantic_core-2.3.0-cp312-none-win32.whl", hash = "sha256:4e26944e64ecc1d7b19db954c0f7b471f3b141ec8e1a9f57cfe27671525cd248"}, - {file = "pydantic_core-2.3.0-cp312-none-win_amd64.whl", hash = "sha256:019c5c41941438570dfc7d3f0ae389b2425add1775a357ce1e83ed1434f943d6"}, - {file = "pydantic_core-2.3.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:27c1bbfb9d84a75cf33b7f19b53c29eb7ead99b235fce52aced5507174ab8f98"}, - {file = "pydantic_core-2.3.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:7cb496e934b71f1ade844ab91d6ccac78a3520e5df02fdb2357f85a71e541e69"}, - {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5af2d43b1978958d91351afbcc9b4d0cfe144c46c61740e82aaac8bb39ab1a4d"}, - {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3097c39d7d4e8dba2ef86de171dcccad876c36d8379415ba18a5a4d0533510"}, - {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_armv7l.whl", hash = "sha256:dd3b023f3317dbbbc775e43651ce1a31a9cea46216ad0b5be37afc18a2007699"}, - {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_ppc64le.whl", hash = "sha256:27babb9879bf2c45ed655d02639f4c30e2b9ef1b71ce59c2305bbf7287910a18"}, - {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_s390x.whl", hash = "sha256:2183a9e18cdc0de53bdaa1675f237259162abeb62d6ac9e527c359c1074dc55d"}, - {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c089d8e7f1b4db08b2f8e4107304eec338df046275dad432635a9be9531e2fc8"}, - {file = "pydantic_core-2.3.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2f10aa5452b865818dd0137f568d443f5e93b60a27080a01aa4b7512c7ba13a3"}, - {file = "pydantic_core-2.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:f642313d559f9d9a00c4de6820124059cc3342a0d0127b18301de2c680d5ea40"}, - {file = "pydantic_core-2.3.0-cp37-none-win32.whl", hash = "sha256:45327fc57afbe3f2c3d7f54a335d5cecee8a9fdb3906a2fbed8af4092f4926df"}, - {file = "pydantic_core-2.3.0-cp37-none-win_amd64.whl", hash = "sha256:e427b66596a6441a5607dfc0085b47d36073f88da7ac48afd284263b9b99e6ce"}, - {file = "pydantic_core-2.3.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:0b3d781c71b8bfb621ef23b9c874933e2cd33237c1a65cc20eeb37437f8e7e18"}, - {file = "pydantic_core-2.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ad46027dbd5c1db87dc0b49becbe23093b143a20302028d387dae37ee5ef95f5"}, - {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39aa09ed7ce2a648c904f79032d16dda29e6913112af8465a7bf710eef23c7ca"}, - {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05b4bf8c58409586a7a04c858a86ab10f28c6c1a7c33da65e0326c59d5b0ab16"}, - {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_armv7l.whl", hash = "sha256:ba2b807d2b62c446120906b8580cddae1d76d3de4efbb95ccc87f5e35c75b4b2"}, - {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_ppc64le.whl", hash = "sha256:ea955e4ed21f4bbb9b83fea09fc6af0bed82e69ecf6b35ec89237a0a49633033"}, - {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_s390x.whl", hash = "sha256:06884c07956526ac9ebfef40fe21a11605569b8fc0e2054a375fb39c978bf48f"}, - {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f868e731a18b403b88aa434d960489ceeed0ddeb44ebc02389540731a67705e0"}, - {file = "pydantic_core-2.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cb08fab0fc1db15c277b72e33ac74ad9c0c789413da8984a3eacb22a94b42ef4"}, - {file = "pydantic_core-2.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6ca34c29fbd6592de5fd39e80c1993634d704c4e7e14ba54c87b2c7c53da68fe"}, - {file = "pydantic_core-2.3.0-cp38-none-win32.whl", hash = "sha256:cd782807d35c8a41aaa7d30b5107784420eefd9fdc1c760d86007d43ae00b15d"}, - {file = "pydantic_core-2.3.0-cp38-none-win_amd64.whl", hash = "sha256:01f56d5ee70b1d39c0fd08372cc5142274070ab7181d17c86035f130eebc05b8"}, - {file = "pydantic_core-2.3.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:78b1ac0151271ce62bc2b33755f1043eda6a310373143a2f27e2bcd3d5fc8633"}, - {file = "pydantic_core-2.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:64bfd2c35a2c350f73ac52dc134d8775f93359c4c969280a6fe5301b5b6e7431"}, - {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:937c0fe9538f1212b62df6a68f8d78df3572fe3682d9a0dd8851eac8a4e46063"}, - {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d965c7c4b40d1cedec9188782e98bd576f9a04868835604200c3a6e817b824f"}, - {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_armv7l.whl", hash = "sha256:ad442b8585ed4a3c2d22e4bf7b465d9b7d281e055b09719a8aeb5b576422dc9b"}, - {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_ppc64le.whl", hash = "sha256:4bf20c9722821fce766e685718e739deeccc60d6bc7be5029281db41f999ee0c"}, - {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_s390x.whl", hash = "sha256:f3dd5333049b5b3faa739e0f40b77cc8b7a1aded2f2da0e28794c81586d7b08a"}, - {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dc5f516b24d24bc9e8dd9305460899f38302b3c4f9752663b396ef9848557bf"}, - {file = "pydantic_core-2.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:055f7ea6b1fbb37880d66d70eefd22dd319b09c79d2cb99b1dbfeb34b653b0b2"}, - {file = "pydantic_core-2.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:af693a89db6d6ac97dd84dd7769b3f2bd9007b578127d0e7dda03053f4d3b34b"}, - {file = "pydantic_core-2.3.0-cp39-none-win32.whl", hash = "sha256:f60e31e3e15e8c294bf70c60f8ae4d0c3caf3af8f26466e9aa8ea4c01302749b"}, - {file = "pydantic_core-2.3.0-cp39-none-win_amd64.whl", hash = "sha256:2b79f3681481f4424d7845cc7a261d5a4baa810d656b631fa844dc9967b36a7b"}, - {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:a666134b41712e30a71afaa26deeb4da374179f769fa49784cdf0e7698880fab"}, - {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c119e9227487ad3d7c3c737d896afe548a6be554091f9745da1f4b489c40561"}, - {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73929a2fb600a2333fce2efd92596cff5e6bf8946e20e93c067b220760064862"}, - {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:41bbc2678a5b6a19371b2cb51f30ccea71f0c14b26477d2d884fed761cea42c7"}, - {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dcbff997f47d45bf028bda4c3036bb3101e89a3df271281d392b6175f71c71d1"}, - {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:afa8808159169368b66e4fbeafac6c6fd8f26246dc4d0dcc2caf94bd9cf1b828"}, - {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:12be3b5f54f8111ca38e6b7277f26c23ba5cb3344fae06f879a0a93dfc8b479e"}, - {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ed5babdcd3d052ba5cf8832561f18df20778c7ccf12587b2d82f7bf3bf259a0e"}, - {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d642e5c029e2acfacf6aa0a7a3e822086b3b777c70d364742561f9ca64c1ffc"}, - {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ba3073eb38a1294e8c7902989fb80a7a147a69db2396818722bd078476586a0"}, - {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5146a6749b1905e04e62e0ad4622f079e5582f8b3abef5fb64516c623127908"}, - {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:deeb64335f489c3c11949cbd1d1668b3f1fb2d1c6a5bf40e126ef7bf95f9fa40"}, - {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:31acc37288b8e69e4849f618c3d5cf13b58077c1a1ff9ade0b3065ba974cd385"}, - {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:e09d9f6d722de9d4c1c5f122ea9bc6b25a05f975457805af4dcab7b0128aacbf"}, - {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ba6a8cf089222a171b8f84e6ec2d10f7a9d14f26be3a347b14775a8741810676"}, - {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef1fd1b24e9bcddcb168437686677104e205c8e25b066e73ffdf331d3bb8792b"}, - {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eda1a89c4526826c0a87d33596a4cd15b8f58e9250f503e39af1699ba9c878e8"}, - {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3e9a18401a28db4358da2e191508702dbf065f2664c710708cdf9552b9fa50c"}, - {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:a439fd0d45d51245bbde799726adda5bd18aed3fa2b01ab2e6a64d6d13776fa3"}, - {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:bf6a1d2c920cc9528e884850a4b2ee7629e3d362d5c44c66526d4097bbb07a1a"}, - {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e33fcbea3b63a339dd94de0fc442fefacfe681cc7027ce63f67af9f7ceec7422"}, - {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:bf3ed993bdf4754909f175ff348cf8f78d4451215b8aa338633f149ca3b1f37a"}, - {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7584171eb3115acd4aba699bc836634783f5bd5aab131e88d8eeb8a3328a4a72"}, - {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1624baa76d1740711b2048f302ae9a6d73d277c55a8c3e88b53b773ebf73a971"}, - {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:06f33f695527f5a86e090f208978f9fd252c9cfc7e869d3b679bd71f7cb2c1fa"}, - {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7ecf0a67b212900e92f328181fed02840d74ed39553cdb38d27314e2b9c89dfa"}, - {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:45fa1e8ad6f4367ad73674ca560da8e827cc890eaf371f3ee063d6d7366a207b"}, - {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8d0dbcc57839831ae79fd24b1b83d42bc9448d79feaf3ed3fb5cbf94ffbf3eb7"}, - {file = "pydantic_core-2.3.0.tar.gz", hash = "sha256:5cfb5ac4e82c47d5dc25b209dd4c3989e284b80109f9e08b33c895080c424b4f"}, -] - -[package.dependencies] -typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +dotenv = ["python-dotenv (>=0.10.4)"] +email = ["email-validator (>=1.0.3)"] [[package]] name = "pydot" diff --git a/scripts/train.sh b/scripts/train.sh index cb7f2492..f092ad59 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -1,20 +1,21 @@ # Run on p2.8xlarge instance -grants-tagger train bertmesh \ +CUDA_VISIBLE_DEVICES=0 grants-tagger train bertmesh \ "" \ - data/raw/allMeSH_2021.jsonl \ + kk/1.json \ + --report_to None --test-size 0.05 \ - --max-samples 1000 \ - --shards 100 \ + --shards 250 \ --output_dir bertmesh_outs/pipeline_test/ \ --wandb_name test-train-all \ --wandb_api_key ${WANDB_API_KEY} \ - --per_device_train_batch_size 256 \ + --per_device_train_batch_size 12 \ --per_device_eval_batch_size 8 \ --num_train_epochs 1 \ --evaluation_strategy steps \ --eval_steps 100000 \ --save_strategy steps \ - --save_steps 100000 \ + --save_steps 100000 + --ddp_backend gloo --fp16 \ --torch_compile From f3254e9408b2ada50da72b99df5e0c82c281062c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 28 Jul 2023 07:47:31 +0000 Subject: [PATCH 100/300] Fixes train.sh --- scripts/train.sh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/train.sh b/scripts/train.sh index f092ad59..55bafa69 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -1,5 +1,5 @@ # Run on p2.8xlarge instance -CUDA_VISIBLE_DEVICES=0 grants-tagger train bertmesh \ +grants-tagger train bertmesh \ "" \ kk/1.json \ --report_to None @@ -8,14 +8,12 @@ CUDA_VISIBLE_DEVICES=0 grants-tagger train bertmesh \ --output_dir bertmesh_outs/pipeline_test/ \ --wandb_name test-train-all \ --wandb_api_key ${WANDB_API_KEY} \ - --per_device_train_batch_size 12 \ + --per_device_train_batch_size 32 \ --per_device_eval_batch_size 8 \ --num_train_epochs 1 \ --evaluation_strategy steps \ --eval_steps 100000 \ --save_strategy steps \ --save_steps 100000 - --ddp_backend gloo --fp16 \ --torch_compile - From c791f8d0cb645f2077c384f6ea5968e262de944a Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 28 Jul 2023 09:05:29 +0100 Subject: [PATCH 101/300] Merges duplicated code --- grants_tagger_light/training/train.py | 114 +++++++++----------------- 1 file changed, 38 insertions(+), 76 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 8b94379c..479a33c6 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -40,6 +40,44 @@ def train_bertmesh( num_proc: int = os.cpu_count(), shards: int = -1, ): + logger.info(f"Preprocessing the dataset at {data_path}...") + if os.path.isdir(data_path): + logger.info( + "Folder found, which means you preprocessed and " + "save the data before. Loading from disk..." + ) + dset = load_from_disk(os.path.join(data_path, "dataset")) + with open(os.path.join(data_path, "label2id"), "r") as f: + label2id = json.load(f) + else: + logger.info("Preprocessing the data on the fly...") + dset, label2id = preprocess_mesh( + data_path=data_path, + model_key=model_key, + test_size=test_size, + num_proc=num_proc, + max_samples=max_samples, + batch_size=training_args.per_device_train_batch_size, + ) + + id2label = + + train_dset, val_dset = dset["train"], dset["test"] + train_dset_size = len(train_dset) + logger.info(f"Training dataset size: {train_dset_size}") + if max_samples > 0: + train_dset_size = min(max_samples, train_dset_size) + logger.info(f"Training max samples: {train_dset_size}.") + train_dset.filter( + lambda example, idx: idx < train_dset_size, with_indices=True + ) + else: + logger.info("Training with all data...") + + if shards > 0: + logger.info("Sharding training dataset...") + train_dset = Sharding(num_shards=shards).shard(train_dset) + if not model_key: assert isinstance(model_args, BertMeshModelArguments), ( "If model_key is not provided, " @@ -52,42 +90,6 @@ def train_bertmesh( logger.info(f"Loading `{model_args.pretrained_model_key}` tokenizer...") config = AutoConfig.from_pretrained(model_args.pretrained_model_key) - logger.info(f"Preprocessing the dataset at {data_path}...") - if os.path.isdir(data_path): - logger.info( - "Folder found, which means you preprocessed and " - "save the data before. Loading from disk..." - ) - dset = load_from_disk(os.path.join(data_path, "dataset")) - with open(os.path.join(data_path, "label2id"), "r") as f: - label2id = json.load(f) - else: - logger.info("Preprocessing the data on the fly...") - dset, label2id = preprocess_mesh( - data_path=data_path, - model_key=model_key, - test_size=test_size, - num_proc=num_proc, - max_samples=max_samples, - batch_size=training_args.per_device_train_batch_size, - ) - - train_dset, val_dset = dset["train"], dset["test"] - train_dset_size = len(train_dset) - logger.info(f"Training dataset size: {train_dset_size}") - if max_samples > 0: - train_dset_size = min(max_samples, train_dset_size) - logger.info(f"Training max samples: {train_dset_size}.") - train_dset.filter( - lambda example, idx: idx < train_dset_size, with_indices=True - ) - else: - logger.info("Training with all data...") - - if shards > 0: - logger.info("Sharding training dataset...") - train_dset = Sharding(num_shards=shards).shard(train_dset) - config.update( { "pretrained_model": model_args.pretrained_model_key, @@ -108,38 +110,6 @@ def train_bertmesh( logger.info(f"Loading `{model_key}` tokenizer...") model = BertMesh.from_pretrained(model_key, trust_remote_code=True) - logger.info(f"Preprocessing the dataset at {data_path}...") - if os.path.isdir(data_path): - logger.info( - "Folder found, which means you preprocessed and " - "save the data before. Loading from disk..." - ) - dset = load_from_disk(os.path.join(data_path, "dataset")) - with open(os.path.join(data_path, "label2id"), "r") as f: - label2id = json.load(f) - else: - logger.info("Preprocessing the data on the fly...") - dset, label2id = preprocess_mesh( - data_path=data_path, - model_key=model_key, - test_size=test_size, - num_proc=num_proc, - max_samples=max_samples, - batch_size=training_args.per_device_train_batch_size, - ) - - train_dset, val_dset = dset["train"], dset["test"] - train_dset_size = len(train_dset) - logger.info(f"Training dataset size: {train_dset_size}") - if max_samples > 0: - train_dset_size = min(max_samples, train_dset_size) - logger.info(f"Training max samples: {train_dset_size}.") - train_dset.filter( - lambda example, idx: idx < train_dset_size, with_indices=True - ) - else: - logger.info("Training with all data...") - if model_args.freeze_backbone: logger.info("Freezing backbone") model.freeze_backbone() @@ -173,14 +143,6 @@ def sklearn_metrics(prediction: EvalPrediction): max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) training_args.max_steps = max_steps - logger.info( - f"Initializing Trainer:\n" - f"* per_device_train_batch_size=" - f"{training_args.per_device_train_batch_size}\n" - f"* max_steps = {training_args.max_steps}\n" - f"* epochs = {training_args.num_train_epochs}\n" - ) - trainer = Trainer( model=model, args=training_args, From f571f58bbc2a230aad8fb92b34002b2e1897eb99 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 28 Jul 2023 09:08:36 +0100 Subject: [PATCH 102/300] Adds label2id to config --- grants_tagger_light/training/train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 479a33c6..0045011f 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -40,6 +40,13 @@ def train_bertmesh( num_proc: int = os.cpu_count(), shards: int = -1, ): + if not model_key: + assert isinstance(model_args, BertMeshModelArguments), ( + "If model_key is not provided, " + "must provide model_args of type BertMeshModelArguments" + ) # noqa + exit(-1) + logger.info(f"Preprocessing the dataset at {data_path}...") if os.path.isdir(data_path): logger.info( @@ -60,8 +67,6 @@ def train_bertmesh( batch_size=training_args.per_device_train_batch_size, ) - id2label = - train_dset, val_dset = dset["train"], dset["test"] train_dset_size = len(train_dset) logger.info(f"Training dataset size: {train_dset_size}") @@ -79,11 +84,6 @@ def train_bertmesh( train_dset = Sharding(num_shards=shards).shard(train_dset) if not model_key: - assert isinstance(model_args, BertMeshModelArguments), ( - "If model_key is not provided, " - "must provide model_args of type BertMeshModelArguments" - ) # noqa - logger.info("No model key provided. Training model from scratch") # Instantiate model from scratch @@ -97,6 +97,7 @@ def train_bertmesh( "hidden_size": model_args.hidden_size, "dropout": model_args.dropout, "multilabel_attention": model_args.multilabel_attention, + "label2id": label2id, "id2label": {v: k for k, v in label2id.items()}, "freeze_backbone": model_args.freeze_backbone, } From c380c627c7c0ecdf8a72afaf9baac508ba097790 Mon Sep 17 00:00:00 2001 From: Juan Martinez Date: Fri, 28 Jul 2023 08:12:43 +0000 Subject: [PATCH 103/300] Fixes train.sh --- scripts/train.sh | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts/train.sh b/scripts/train.sh index 55bafa69..09139275 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -2,18 +2,19 @@ grants-tagger train bertmesh \ "" \ kk/1.json \ - --report_to None + --report_to none \ --test-size 0.05 \ --shards 250 \ --output_dir bertmesh_outs/pipeline_test/ \ - --wandb_name test-train-all \ - --wandb_api_key ${WANDB_API_KEY} \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 8 \ --num_train_epochs 1 \ --evaluation_strategy steps \ --eval_steps 100000 \ --save_strategy steps \ - --save_steps 100000 + --save_steps 100000 \ --fp16 \ - --torch_compile + --torch_compile \ + --report_to none + # --wandb_name test-train-all \ + # --wandb_api_key ${WANDB_API_KEY} \ From 8a8aa694453e816a230da077541d9c532c3d55c6 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 28 Jul 2023 09:13:18 +0100 Subject: [PATCH 104/300] Removes exit --- grants_tagger_light/training/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 0045011f..3651062f 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -45,7 +45,6 @@ def train_bertmesh( "If model_key is not provided, " "must provide model_args of type BertMeshModelArguments" ) # noqa - exit(-1) logger.info(f"Preprocessing the dataset at {data_path}...") if os.path.isdir(data_path): From 5e0e8a4f22e9c001a9859f45888a8f7c81e100b4 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 28 Jul 2023 09:24:30 +0100 Subject: [PATCH 105/300] Adds wandb param to train.sh --- scripts/train.sh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/train.sh b/scripts/train.sh index 09139275..0b61fe75 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -15,6 +15,7 @@ grants-tagger train bertmesh \ --save_steps 100000 \ --fp16 \ --torch_compile \ - --report_to none - # --wandb_name test-train-all \ - # --wandb_api_key ${WANDB_API_KEY} \ + --report_to none \ + --wandb_project wellcome-mesh \ + --wandb_name test-train-all \ + --wandb_api_key ${WANDB_API_KEY} From 2b1191d09a00a7d83174c02a3677455cef5ae552 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 28 Jul 2023 09:27:45 +0100 Subject: [PATCH 106/300] Adds report_to none for non-wandb --- scripts/train.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/train.sh b/scripts/train.sh index 0b61fe75..a39182c7 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -2,7 +2,6 @@ grants-tagger train bertmesh \ "" \ kk/1.json \ - --report_to none \ --test-size 0.05 \ --shards 250 \ --output_dir bertmesh_outs/pipeline_test/ \ @@ -15,7 +14,7 @@ grants-tagger train bertmesh \ --save_steps 100000 \ --fp16 \ --torch_compile \ - --report_to none \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ --wandb_api_key ${WANDB_API_KEY} + # --report_to none \ From 0f53dfc2f3950145d05e7e913d9ab3ac45c7e608 Mon Sep 17 00:00:00 2001 From: Juan Martinez Date: Fri, 28 Jul 2023 11:02:45 +0000 Subject: [PATCH 107/300] Updates train.sh --- scripts/train.sh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/train.sh b/scripts/train.sh index a39182c7..3e55bcfa 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -2,19 +2,19 @@ grants-tagger train bertmesh \ "" \ kk/1.json \ - --test-size 0.05 \ + --test-size 0.005 \ --shards 250 \ --output_dir bertmesh_outs/pipeline_test/ \ --per_device_train_batch_size 32 \ - --per_device_eval_batch_size 8 \ --num_train_epochs 1 \ - --evaluation_strategy steps \ - --eval_steps 100000 \ --save_strategy steps \ - --save_steps 100000 \ + --save_steps 50000 \ --fp16 \ --torch_compile \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ - --wandb_api_key ${WANDB_API_KEY} - # --report_to none \ + --wandb_api_key ${WANDB_API_KEY} \ + --per_device_eval_batch_size 8 \ + --eval_steps 50000 \ + --evaluation_strategy steps + #--report_to none From df49e1e77b325292bd7800fe7b9ad47203beebe0 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 28 Jul 2023 12:04:48 +0100 Subject: [PATCH 108/300] Updates README.md --- README.md | 27 ++++++++++++--------------- scripts/train.sh | 2 +- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index b5851cfa..b8234915 100644 --- a/README.md +++ b/README.md @@ -111,21 +111,21 @@ Besides those arguments, feel free to add any other TrainingArgument from Huggin grants-tagger train bertmesh \ "" \ data/raw/allMeSH_2021.jsonl \ - --test-size 0.05 \ - --max-samples 1000 \ - --shards 100 \ + --test-size 0.005 \ + --shards 250 \ --output_dir bertmesh_outs/pipeline_test/ \ - --wandb_name test-train-all \ - --wandb_api_key ${WANDB_API_KEY} \ - --per_device_train_batch_size 256 \ - --per_device_eval_batch_size 8 \ + --per_device_train_batch_size 32 \ --num_train_epochs 1 \ - --evaluation_strategy steps \ - --eval_steps 100000 \ --save_strategy steps \ - --save_steps 100000 \ + --save_steps 50000 \ --fp16 \ - --torch_compile + --torch_compile \ + --wandb_project wellcome-mesh \ + --wandb_name test-train-all \ + --wandb_api_key ${WANDB_API_KEY} \ + --per_device_eval_batch_size 8 \ + --eval_steps 50000 \ + --evaluation_strategy steps ``` ## ⚙️ Preprocess @@ -133,10 +133,7 @@ grants-tagger train bertmesh \ This process is optional to run, since it will be managed by the `grants-tagger train bertmesh` process. If you run it manually, it will store the data in local. -Preprocess creates a JSONL datafile with `text`, `tags` and `meta` as keys. -Text and tags are used for training whereas meta can be useful during annotation -or to analyse predictions and performance. Each dataset needs its own -preprocessing so the current preprocess works with the `allMeSH_2021` one. +Each dataset needs its own preprocessing so the current preprocess works with the `allMeSH_2021` one. If you want to use a different dataset see section on bringing your own data under development. diff --git a/scripts/train.sh b/scripts/train.sh index 3e55bcfa..5bb4f2c9 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -1,7 +1,7 @@ # Run on p2.8xlarge instance grants-tagger train bertmesh \ "" \ - kk/1.json \ + data/raw/allMeSH_2021.jsonl \ --test-size 0.005 \ --shards 250 \ --output_dir bertmesh_outs/pipeline_test/ \ From 6fae3bd04205b23e5b33b8edd3d0c6b22a15dd77 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 28 Jul 2023 12:06:48 +0100 Subject: [PATCH 109/300] Black reformatting --- grants_tagger_light/training/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 3651062f..a40d70ac 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -72,9 +72,7 @@ def train_bertmesh( if max_samples > 0: train_dset_size = min(max_samples, train_dset_size) logger.info(f"Training max samples: {train_dset_size}.") - train_dset.filter( - lambda example, idx: idx < train_dset_size, with_indices=True - ) + train_dset.filter(lambda example, idx: idx < train_dset_size, with_indices=True) else: logger.info("Training with all data...") From ecf213b42953be7b067ff6692becccf95a15d511 Mon Sep 17 00:00:00 2001 From: Juan Martinez Date: Mon, 31 Jul 2023 13:13:49 +0000 Subject: [PATCH 110/300] Modifies training --- grants_tagger_light/training/train.py | 2 ++ scripts/train.sh | 16 +++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 3651062f..26006f63 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -159,6 +159,8 @@ def sklearn_metrics(prediction: EvalPrediction): metrics = trainer.evaluate(eval_dataset=val_dset) logger.info(pformat(metrics)) + with open(os.path.join(training_args.output_dir, "metrics"), 'w') as f: + f.write(pformat(metrics)) logger.info("Saving the model...") trainer.save_model(os.path.join(training_args.output_dir, "best")) diff --git a/scripts/train.sh b/scripts/train.sh index 3e55bcfa..277d0174 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -2,19 +2,21 @@ grants-tagger train bertmesh \ "" \ kk/1.json \ - --test-size 0.005 \ + --test-size 0.05 \ --shards 250 \ --output_dir bertmesh_outs/pipeline_test/ \ --per_device_train_batch_size 32 \ --num_train_epochs 1 \ - --save_strategy steps \ - --save_steps 50000 \ --fp16 \ --torch_compile \ + --evaluation_strategy no \ + --save_strategy no \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ - --wandb_api_key ${WANDB_API_KEY} \ - --per_device_eval_batch_size 8 \ - --eval_steps 50000 \ - --evaluation_strategy steps + --wandb_api_key ${WANDB_API_KEY} + #--save_strategy steps + #--save_steps 50000 + #--per_device_eval_batch_size 8 \ + #--eval_steps 50000 \ + #--evaluation_strategy steps #--report_to none From 9fe49cd05a25dabffa8386d6bcc7eaa4ab713fe0 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 31 Jul 2023 14:48:08 +0100 Subject: [PATCH 111/300] Saving then evaluating --- grants_tagger_light/training/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 187372e2..71f5e250 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -153,15 +153,15 @@ def sklearn_metrics(prediction: EvalPrediction): logger.info("Training...") trainer.train() + logger.info("Saving the model...") + trainer.save_model(os.path.join(training_args.output_dir, "best")) + logger.info("Evaluating...") metrics = trainer.evaluate(eval_dataset=val_dset) logger.info(pformat(metrics)) with open(os.path.join(training_args.output_dir, "metrics"), 'w') as f: - f.write(pformat(metrics)) - - logger.info("Saving the model...") - trainer.save_model(os.path.join(training_args.output_dir, "best")) + f.write(pformat(metrics)) train_app = typer.Typer() From b18dd732cc1fa721a03ce3848c6d961d949414fc Mon Sep 17 00:00:00 2001 From: Juan Martinez Date: Mon, 31 Jul 2023 16:44:24 +0000 Subject: [PATCH 112/300] Updates train.sh --- scripts/train.sh | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts/train.sh b/scripts/train.sh index 7e545e43..4bb93513 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -1,19 +1,23 @@ # Run on p2.8xlarge instance grants-tagger train bertmesh \ "" \ - data/raw/allMeSH_2021.jsonl \ - --test-size 0.005 \ - --shards 250 \ + kk \ + --test-size 0.0025 \ + --shards 48 \ --output_dir bertmesh_outs/pipeline_test/ \ --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 1 \ --num_train_epochs 1 \ --fp16 \ --torch_compile \ - --evaluation_strategy no \ - --save_strategy no \ + --evaluation_strategy steps \ + --save_strategy steps \ + --eval_steps 50000 \ + --save_steps 50000 \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ - --wandb_api_key ${WANDB_API_KEY} + --wandb_api_key ${WANDB_API_KEY} \ + --eval_accumulation_steps 20 #--save_strategy steps #--save_steps 50000 #--per_device_eval_batch_size 8 \ From 336bb15c61e88932e344c6d932a61ba94310c568 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 31 Jul 2023 17:45:50 +0100 Subject: [PATCH 113/300] Updates train.sh --- scripts/train.sh | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/scripts/train.sh b/scripts/train.sh index 4bb93513..7ae6b1ef 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -1,7 +1,7 @@ # Run on p2.8xlarge instance grants-tagger train bertmesh \ "" \ - kk \ + data/raw/allMeSH_2021.jsonl \ --test-size 0.0025 \ --shards 48 \ --output_dir bertmesh_outs/pipeline_test/ \ @@ -11,16 +11,10 @@ grants-tagger train bertmesh \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ - --save_strategy steps \ --eval_steps 50000 \ + --eval_accumulation_steps 20 \ + --save_strategy steps \ --save_steps 50000 \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ - --wandb_api_key ${WANDB_API_KEY} \ - --eval_accumulation_steps 20 - #--save_strategy steps - #--save_steps 50000 - #--per_device_eval_batch_size 8 \ - #--eval_steps 50000 \ - #--evaluation_strategy steps - #--report_to none + --wandb_api_key ${WANDB_API_KEY} From 0da2c00c1f8223596751bddcf5990d0f89ce6658 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Tue, 1 Aug 2023 12:58:02 +0100 Subject: [PATCH 114/300] Black reformatting --- grants_tagger_light/training/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 71f5e250..5cab8f1b 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -160,7 +160,7 @@ def sklearn_metrics(prediction: EvalPrediction): metrics = trainer.evaluate(eval_dataset=val_dset) logger.info(pformat(metrics)) - with open(os.path.join(training_args.output_dir, "metrics"), 'w') as f: + with open(os.path.join(training_args.output_dir, "metrics"), "w") as f: f.write(pformat(metrics)) From 0ff389af3a03046b3a3d19fe6076affa879e5ed4 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 2 Aug 2023 11:28:14 +0100 Subject: [PATCH 115/300] Adding training resuming --- grants_tagger_light/training/train.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 5cab8f1b..fd229e5f 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -39,6 +39,7 @@ def train_bertmesh( test_size: float = 0.05, num_proc: int = os.cpu_count(), shards: int = -1, + from_checkpoint: str = None ): if not model_key: assert isinstance(model_args, BertMeshModelArguments), ( @@ -150,8 +151,13 @@ def sklearn_metrics(prediction: EvalPrediction): compute_metrics=sklearn_metrics, ) logger.info(training_args) - logger.info("Training...") - trainer.train() + + if from_checkpoint is None: + logger.info("Training...") + trainer.train() + else: + logger.info(f"Resuming training from checkpoint: {from_checkpoint}") + trainer.train(from_checkpoint) logger.info("Saving the model...") trainer.save_model(os.path.join(training_args.output_dir, "best")) @@ -191,6 +197,10 @@ def train_bertmesh_cli( help="Number os shards to divide training " "IterativeDataset to (improves performance)", ), + from_checkpoint: str = typer.Option( + None, + help="Name of the checkpoint to resume training" + ) ): parser = HfArgumentParser( ( @@ -217,4 +227,5 @@ def train_bertmesh_cli( test_size, num_proc, shards, + from_checkpoint ) From e791df7ee732ff14b7bb2d03c955d2e15db072fa Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 2 Aug 2023 14:54:39 +0100 Subject: [PATCH 116/300] Adding training resuming --- grants_tagger_light/training/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index fd229e5f..194d1b31 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -81,7 +81,10 @@ def train_bertmesh( logger.info("Sharding training dataset...") train_dset = Sharding(num_shards=shards).shard(train_dset) - if not model_key: + if from_checkpoint: + logger.info("Loading from checkpoint...") + model = BertMesh.from_pretrained(from_checkpoint).to("cuda") + elif not model_key: logger.info("No model key provided. Training model from scratch") # Instantiate model from scratch @@ -103,10 +106,7 @@ def train_bertmesh( model = BertMesh(config) else: - logger.info(f"Training model from pretrained key {model_key}") - - # Instantiate from pretrained - logger.info(f"Loading `{model_key}` tokenizer...") + logger.info(f"Training from pretrained key {model_key}") model = BertMesh.from_pretrained(model_key, trust_remote_code=True) if model_args.freeze_backbone: From fc90d250d9c6a6ddb1e85325af784d20e0844bd5 Mon Sep 17 00:00:00 2001 From: Juan Martinez Date: Wed, 2 Aug 2023 14:18:56 +0000 Subject: [PATCH 117/300] Adds resume_train.sh --- grants_tagger_light/training/train.py | 8 ++------ scripts/resume_train.sh | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) create mode 100644 scripts/resume_train.sh diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 194d1b31..46adf258 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -81,10 +81,7 @@ def train_bertmesh( logger.info("Sharding training dataset...") train_dset = Sharding(num_shards=shards).shard(train_dset) - if from_checkpoint: - logger.info("Loading from checkpoint...") - model = BertMesh.from_pretrained(from_checkpoint).to("cuda") - elif not model_key: + if not model_key: logger.info("No model key provided. Training model from scratch") # Instantiate model from scratch @@ -154,10 +151,9 @@ def sklearn_metrics(prediction: EvalPrediction): if from_checkpoint is None: logger.info("Training...") - trainer.train() else: logger.info(f"Resuming training from checkpoint: {from_checkpoint}") - trainer.train(from_checkpoint) + trainer.train() logger.info("Saving the model...") trainer.save_model(os.path.join(training_args.output_dir, "best")) diff --git a/scripts/resume_train.sh b/scripts/resume_train.sh new file mode 100644 index 00000000..57614b1a --- /dev/null +++ b/scripts/resume_train.sh @@ -0,0 +1,20 @@ +# Run on p2.8xlarge instance +grants-tagger train bertmesh \ + bertmesh_outs/pipeline_test/checkpoint-100000 \ + kk \ + --output_dir bertmesh_outs/pipeline_test_from_100000/ \ + --ignore_data_skip=True \ + --shards 48 \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 1 \ + --num_train_epochs 2 \ + --fp16 \ + --torch_compile \ + --evaluation_strategy steps \ + --eval_steps 50000 \ + --eval_accumulation_steps 20 \ + --save_strategy steps \ + --save_steps 50000 \ + --wandb_project wellcome-mesh \ + --wandb_name test-train-all \ + --wandb_api_key ${WANDB_API_KEY} From 3b43ce56891f789cb6c90f849e791aff400c7798 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 2 Aug 2023 15:25:32 +0100 Subject: [PATCH 118/300] Parametrization --- scripts/resume_train.sh | 18 ++++++++++++++---- scripts/train.sh | 11 +++++++++-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/scripts/resume_train.sh b/scripts/resume_train.sh index 57614b1a..4a39c264 100644 --- a/scripts/resume_train.sh +++ b/scripts/resume_train.sh @@ -1,8 +1,18 @@ -# Run on p2.8xlarge instance +# Run on g5.12xlarge instance + +# Without preprocessing (on-the-fly) +SOURCE="data/raw/allMeSH_2021.jsonl" + +# After preprocessing first +# SOURCE="output_folder_from_preprocessing" + +# Checkpoint +CHECKPOINT="checkpoint-100000" + grants-tagger train bertmesh \ - bertmesh_outs/pipeline_test/checkpoint-100000 \ - kk \ - --output_dir bertmesh_outs/pipeline_test_from_100000/ \ + bertmesh_outs/pipeline_test/$CHECKPOINT \ + $SOURCE \ + --output_dir bertmesh_outs/pipeline_test_from_$CHECKPOINT/ \ --ignore_data_skip=True \ --shards 48 \ --per_device_train_batch_size 32 \ diff --git a/scripts/train.sh b/scripts/train.sh index 7ae6b1ef..88546eff 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -1,7 +1,14 @@ -# Run on p2.8xlarge instance +# Run on g5.12xlargeinstance + +# Without preprocessing (on-the-fly) +SOURCE="data/raw/allMeSH_2021.jsonl" + +# After preprocessing first +# SOURCE="output_folder_from_preprocessing" + grants-tagger train bertmesh \ "" \ - data/raw/allMeSH_2021.jsonl \ + $SOURCE \ --test-size 0.0025 \ --shards 48 \ --output_dir bertmesh_outs/pipeline_test/ \ From 303a34a7963b987b52cc09496399a9fc206a24d5 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 4 Aug 2023 13:27:25 +0100 Subject: [PATCH 119/300] Updates documentation, script name and adds filter by years and tags to preprocessing_mesh.py --- README.md | 85 ++++++++++++------- .../preprocessing/preprocess_mesh.py | 35 +++++--- ...preprocessing.py => mesh_json_to_jsonl.py} | 1 + tests/test_preprocess_mesh.py | 2 +- 4 files changed, 76 insertions(+), 47 deletions(-) rename scripts/{jsonl_preprocessing.py => mesh_json_to_jsonl.py} (99%) diff --git a/README.md b/README.md index b8234915..51b3259f 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,56 @@ And then connect and attach to your machine with a tunnel in square brackets the commands that are not implemented yet +## ⚙️Preprocess + +This process is optional to run, since it will be managed by the `Train` process. +- If you run it manually, it will store the data in local first, which can help if you need finetune in the future, +rerun, etc. +- If not, the project will preprocess and then run, without any extra I/O operations on disk, +which may add latency depending on the infrastructure. + +It requires data in `jsonl` format for parallelization purposes. In `data/raw` you can find `allMesH_2021.jsonl` +already prepared for the preprocessing step. + +If your data is in `json` format, trasnform it to `jsonl` with tools as `jq` or using Python. +You can use an example of `allMeSH_2021.json` conversion to `jsonl` in `scripts/mesh_json_to_jsonl.py`: + +```bash +python scripts/mesh_json_to_jsonl.py --input_path data/raw/allMeSH_2021.json --output_path data/raw/test.jsonl --filter_years 2020,2021 +``` + +Each dataset needs its own preprocessing so the current preprocess works with the `allMeSH_2021.jsonl` one. + +If you want to use a different dataset see section on bringing +your own data under development. + + +### Preprocessing allMeSH + +``` + Usage: grants-tagger preprocess mesh [OPTIONS] DATA_PATH SAVE_TO_PATH + MODEL_KEY + +╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ * data_path TEXT Path to mesh.jsonl [default: None] [required] │ +│ * save_to_path TEXT Path to save the serialized PyArrow dataset after preprocessing [default: None] │ +│ [required] │ +│ * model_key TEXT Key to use when loading tokenizer and label2id. Leave blank if training from │ +│ scratch │ +│ [default: None] │ +│ [required] │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ --test-size FLOAT Fraction of data to use for testing [default: 0.05] │ +│ --num-proc INTEGER Number of processes to use for preprocessing [default: 8] │ +│ --max-samples INTEGER Maximum number of samples to use for preprocessing [default: -1] │ +│ --batch-size INTEGER Size of the preprocessing batch [default: 256] │ +│ --years TEXT Comma-separated years you want to included (e.g: 2020,2021) [default: None] │ +│ --tags TEXT Comma-separated tags you want to included (e.g: Pandemics,COVID19) [default: None] │ +│ --help Show this message and exit. │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯``` +``` + ## 🔥 Train Train acts as the entry point command for training all models. Currently, we only support @@ -98,12 +148,12 @@ the BertMesh model. The command will train a model and save it to the specified #### About `model_key` `model_key` possible values are: - A HF location for a pretrained / finetuned model -- "" to load a model by default and train from scratch: `microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract` +- "" to load a model by default and train from scratch (`microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract`) #### About `sharding` `sharding` was proposed by [Hugging Face](https://github.com/huggingface/datasets/issues/2252#issuecomment-825596467) to improve performance on big datasets. To enable it: -- set shards to something bigger than 1 (Recommended: 100) +- set shards to something bigger than 1 (Recommended: same number as cpu cores) #### Other arguments Besides those arguments, feel free to add any other TrainingArgument from Hugging Face or Wand DB. Examples: @@ -128,39 +178,8 @@ grants-tagger train bertmesh \ --evaluation_strategy steps ``` -## ⚙️ Preprocess - -This process is optional to run, since it will be managed by the `grants-tagger train bertmesh` process. -If you run it manually, it will store the data in local. - -Each dataset needs its own preprocessing so the current preprocess works with the `allMeSH_2021` one. -If you want to use a different dataset see section on bringing -your own data under development. -### mesh -``` - Usage: grants-tagger preprocess mesh [OPTIONS] DATA_PATH SAVE_TO_PATH - MODEL_KEY - -╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ * data_path TEXT Path to mesh.jsonl [default: None] [required] │ -│ * save_to_path TEXT Path to save the serialized PyArrow dataset after preprocessing [default: None] │ -│ [required] │ -│ * model_key TEXT Key to use when loading tokenizer and label2id. Leave blank if training from │ -│ scratch │ -│ [default: None] │ -│ [required] │ -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ --test-size FLOAT Fraction of data to use for testing [default: 0.05] │ -│ --num-proc INTEGER Number of processes to use for preprocessing [default: 8] │ -│ --max-samples INTEGER Maximum number of samples to use for preprocessing [default: -1] │ -│ --batch-size INTEGER Size of the preprocessing batch [default: 256] │ -│ --help Show this message and exit. │ -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯``` -``` - ## 📈 Evaluate Evaluate enables evaluation of the performance of various approaches including diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 09a1ab9e..e4ad04d9 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -9,9 +9,8 @@ import os from loguru import logger from tqdm import tqdm +import numpy as np -# from datasets import disable_caching -# disable_caching() preprocess_app = typer.Typer() @@ -53,8 +52,12 @@ def preprocess_mesh( num_proc: int = os.cpu_count(), max_samples: int = -1, batch_size: int = 256, + tags: str = None, + years: str = None ): + if max_samples != -1: + logger.info(f"Filtering examples to {max_samples}") data_path = create_sample_file(data_path, max_samples) if not model_key: @@ -76,6 +79,17 @@ def preprocess_mesh( if "train" in dset: dset = dset["train"] + if years is not None: + logger.info(f"Filtering years: {years}") + filter_years_list = list(filter(lambda x: x.strip() != "", years.split(","))) + filter_years_list = [str(y) for y in filter_years_list] + dset = dset.filter(lambda x: any(np.isin(filter_years_list, [str(x["year"])]))) + + if tags is not None: + logger.info(f"Filtering tags: {tags}") + filter_tags_list = list(filter(lambda x: x.strip() != "", tags.split(","))) + dset = dset.filter(lambda x: any(np.isin(filter_tags_list, x["meshMajor"]))) + # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "year", "pmid", "title"]) @@ -169,18 +183,11 @@ def preprocess_mesh_cli( help="Maximum number of samples to use for preprocessing", ), batch_size: int = typer.Option(256, help="Size of the preprocessing batch"), + tags: str = typer.Option(None, help="Comma-separated tags you want to include in the dataset " + "(the rest will be discarded)"), + years: str = typer.Option(None, help="Comma-separated yeasr you want to include in the dataset " + "(the rest will be discarded)"), ): - if ( - input( - "\033[96mRunning preprocessing will save the data as a PyArrow dataset " - "which is a very time consuming operation. If you don't need the data " - "to be saved, you can save much time just by running:\n" - "> `grants-tagger train bertmesh {model_key} {path_to_jsonl}`\033[0m\n\n" - "Do You Want To Continue? [Y/n]" - ) - != "Y" - ): - exit(1) if not data_path.endswith("jsonl"): logger.error( @@ -197,4 +204,6 @@ def preprocess_mesh_cli( max_samples=max_samples, batch_size=batch_size, save_to_path=save_to_path, + tags=tags, + years=years ) diff --git a/scripts/jsonl_preprocessing.py b/scripts/mesh_json_to_jsonl.py similarity index 99% rename from scripts/jsonl_preprocessing.py rename to scripts/mesh_json_to_jsonl.py index 3ad5148a..11d8efa3 100644 --- a/scripts/jsonl_preprocessing.py +++ b/scripts/mesh_json_to_jsonl.py @@ -15,6 +15,7 @@ def process_data(item, filter_tags: list = None, filter_years: list = None): if check_years and "year" not in item: logger.warning("`year` not found in the fields. Unable to filter tags.") check_years = False + if check_tags: if filter_tags is None: filter_tags = [] diff --git a/tests/test_preprocess_mesh.py b/tests/test_preprocess_mesh.py index f9d46295..a984b9a6 100644 --- a/tests/test_preprocess_mesh.py +++ b/tests/test_preprocess_mesh.py @@ -4,7 +4,7 @@ from grants_tagger_light.preprocessing.preprocess_mesh import ( preprocess_mesh, ) -from scripts.jsonl_preprocessing import process_data, mesh_json_to_jsonl +from scripts.mesh_json_to_jsonl import process_data, mesh_json_to_jsonl import pytest From 5c9876e12fd57ee5a1d86d7274fd48b705a77ebf Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 4 Aug 2023 14:29:41 +0100 Subject: [PATCH 120/300] Adding train/test split using years --- .../preprocessing/preprocess_mesh.py | 66 +++++++++++++++---- 1 file changed, 52 insertions(+), 14 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index e4ad04d9..8670795a 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -1,4 +1,5 @@ import json +import math import tempfile import typer @@ -10,7 +11,7 @@ from loguru import logger from tqdm import tqdm import numpy as np - +from datasets.dataset_dict import DatasetDict preprocess_app = typer.Typer() @@ -53,8 +54,12 @@ def preprocess_mesh( max_samples: int = -1, batch_size: int = 256, tags: str = None, - years: str = None + train_years: list = None, + test_years: list = None ): + if test_size > 1: + logger.info(f"Test size found not as a fraction, but as a number of rows. Transforming {test_size} to integer") + test_size = int(test_size) if max_samples != -1: logger.info(f"Filtering examples to {max_samples}") @@ -79,11 +84,15 @@ def preprocess_mesh( if "train" in dset: dset = dset["train"] - if years is not None: - logger.info(f"Filtering years: {years}") - filter_years_list = list(filter(lambda x: x.strip() != "", years.split(","))) - filter_years_list = [str(y) for y in filter_years_list] - dset = dset.filter(lambda x: any(np.isin(filter_years_list, [str(x["year"])]))) + years = list() + if train_years is not None and len(train_years) > 0: + years.extend(train_years) + if test_years is not None and len(test_years) > 0: + years.extend(test_years) + + if len(years) > 0: + logger.info(f"Removing all years which are not in {years}") + dset = dset.filter(lambda x: any(np.isin(years, [str(x["year"])]))) if tags is not None: logger.info(f"Filtering tags: {tags}") @@ -91,7 +100,7 @@ def preprocess_mesh( dset = dset.filter(lambda x: any(np.isin(filter_tags_list, x["meshMajor"]))) # Remove unused columns to save space & time - dset = dset.remove_columns(["journal", "year", "pmid", "title"]) + dset = dset.remove_columns(["journal", "pmid", "title"]) t1 = time.time() dset = dset.map( @@ -149,7 +158,16 @@ def preprocess_mesh( logger.info("Preparing train/test split....") # Split into train and test t1 = time.time() - dset = dset.train_test_split(test_size=test_size) + if len(years) > 0: + logger.info("Splitting the dataset by training and test years") + train_dset = dset.filter(lambda x: any(np.isin(train_years, [str(x["year"])]))) + test_dset = dset.filter(lambda x: any(np.isin(test_years, [str(x["year"])]))) + + dset = DatasetDict({'train': train_dset, 'test': test_dset.train_test_split(test_size)['test']}) + else: + logger.info(f"Splitting the dataset randomly by a fraction of {test_size}") + dset = dset.train_test_split(test_size=test_size) + logger.info("Time taken to split into train and test: {}".format(time.time() - t1)) # If running from Training, by default it will be None so that we don't spend time @@ -174,7 +192,7 @@ def preprocess_mesh_cli( help="Key to use when loading tokenizer and label2id. " "Leave blank if training from scratch", # noqa ), - test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), + test_size: float = typer.Option(0.05, help="Fraction of data to use for testing or number of rows"), num_proc: int = typer.Option( os.cpu_count(), help="Number of processes to use for preprocessing" ), @@ -185,8 +203,8 @@ def preprocess_mesh_cli( batch_size: int = typer.Option(256, help="Size of the preprocessing batch"), tags: str = typer.Option(None, help="Comma-separated tags you want to include in the dataset " "(the rest will be discarded)"), - years: str = typer.Option(None, help="Comma-separated yeasr you want to include in the dataset " - "(the rest will be discarded)"), + train_years: str = typer.Option(None, help="Comma-separated years you want to include in the training dataset"), + test_years: str = typer.Option(None, help="Comma-separated years you want to include in the test dataset"), ): if not data_path.endswith("jsonl"): @@ -194,7 +212,26 @@ def preprocess_mesh_cli( "It seems your input MeSH data is not in `jsonl` format. " "Please, run first `scripts/mesh_json_to_jsonlpy.`" ) - exit(1) + exit(-1) + + train_years_list = [] + test_years_list = [] + + if train_years is not None: + if test_years is None: + logger.error("--train-years require --test-years") + exit(-1) + filter_years_list = list(filter(lambda x: x.strip() != "", train_years.split(","))) + train_years_list = [str(y) for y in filter_years_list] + logger.info(f"Training years to be considered: {train_years_list}") + + if test_years is not None: + if train_years is None: + logger.error("--test-years require --train-years") + exit(-1) + filter_years_list = list(filter(lambda x: x.strip() != "", test_years.split(","))) + test_years_list = [str(y) for y in filter_years_list] + logger.info(f"Test years to be considered: {test_years_list}") preprocess_mesh( data_path=data_path, @@ -205,5 +242,6 @@ def preprocess_mesh_cli( batch_size=batch_size, save_to_path=save_to_path, tags=tags, - years=years + train_years=train_years_list, + test_years=test_years_list ) From 909cee00e9eae153c68aa67c6e08efd0457c2f09 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 4 Aug 2023 15:50:07 +0100 Subject: [PATCH 121/300] Adds multibatching to filtering train/test by years --- README.md | 2 +- .../preprocessing/preprocess_mesh.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 51b3259f..2e3ae992 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ If you want to use a different dataset see section on bringing your own data under development. -### Preprocessing allMeSH +### Preprocessing bertmesh ``` Usage: grants-tagger preprocess mesh [OPTIONS] DATA_PATH SAVE_TO_PATH diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 8670795a..7ce047d3 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -33,6 +33,10 @@ def _encode_labels(sample, label2id): return {"label_ids": [_map_label_to_ids(x, label2id) for x in sample["meshMajor"]]} +def _filter_rows_by_years(sample, years): + return [x in years for x in sample["year"]] + + def create_sample_file(jsonl_file, lines): with open(jsonl_file, "r") as input_file: with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file: @@ -160,8 +164,18 @@ def preprocess_mesh( t1 = time.time() if len(years) > 0: logger.info("Splitting the dataset by training and test years") - train_dset = dset.filter(lambda x: any(np.isin(train_years, [str(x["year"])]))) - test_dset = dset.filter(lambda x: any(np.isin(test_years, [str(x["year"])]))) + train_dset = dset.filter(_filter_rows_by_years, + batched=True, + batch_size=batch_size, + desc=f"Creating training dataset with years {train_years}", + num_proc=num_proc, + fn_kwargs={"years": train_years}) + test_dset = dset.filter(_filter_rows_by_years, + batched=True, + batch_size=batch_size, + desc=f"Creating test dataset with years {test_years}", + num_proc=num_proc, + fn_kwargs={"years": test_years}) dset = DatasetDict({'train': train_dset, 'test': test_dset.train_test_split(test_size)['test']}) else: From c08da49b38fdd6440a2685b68997fb70f5543d63 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 4 Aug 2023 15:56:39 +0100 Subject: [PATCH 122/300] Updates message --- grants_tagger_light/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 46adf258..b9979696 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -50,8 +50,8 @@ def train_bertmesh( logger.info(f"Preprocessing the dataset at {data_path}...") if os.path.isdir(data_path): logger.info( - "Folder found, which means you preprocessed and " - "save the data before. Loading from disk..." + "Train/test data found in a folder, which means you preprocessed and " + "save the data before. Loading that split from disk..." ) dset = load_from_disk(os.path.join(data_path, "dataset")) with open(os.path.join(data_path, "label2id"), "r") as f: From 3d736061f3035a4cc6a6d08bb2732bcd6e70943c Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 4 Aug 2023 15:58:06 +0100 Subject: [PATCH 123/300] Changes shards to cpu_count() by default --- grants_tagger_light/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index b9979696..aea6e843 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -38,7 +38,7 @@ def train_bertmesh( max_samples: int = -1, test_size: float = 0.05, num_proc: int = os.cpu_count(), - shards: int = -1, + shards: int = os.cpu_count(), from_checkpoint: str = None ): if not model_key: @@ -189,7 +189,7 @@ def train_bertmesh_cli( help="Maximum number of samples to use from the json", ), shards: int = typer.Option( - -1, + os.cpu_count(), help="Number os shards to divide training " "IterativeDataset to (improves performance)", ), From 07ea25db53bca17c51ff2c62bce0c8f3623ff8a9 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 4 Aug 2023 16:02:30 +0100 Subject: [PATCH 124/300] Adds `preprocess` example --- scripts/preprocess.sh | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 scripts/preprocess.sh diff --git a/scripts/preprocess.sh b/scripts/preprocess.sh new file mode 100644 index 00000000..0cda3229 --- /dev/null +++ b/scripts/preprocess.sh @@ -0,0 +1,4 @@ +grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl ./kk '' \ + --train-years 2016,2017,2018,2019 \ + --test-years 2020,2021 \ + --test-size 10000 \ No newline at end of file From d74c2f6b0b84cd1c98ce90e50590edaeb6dd11cc Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 4 Aug 2023 16:05:20 +0100 Subject: [PATCH 125/300] Adds evaluation by epoch --- scripts/train_by_epochs.sh | 26 +++++++++++++++++++++++++ scripts/{train.sh => train_by_steps.sh} | 0 2 files changed, 26 insertions(+) create mode 100644 scripts/train_by_epochs.sh rename scripts/{train.sh => train_by_steps.sh} (100%) diff --git a/scripts/train_by_epochs.sh b/scripts/train_by_epochs.sh new file mode 100644 index 00000000..bee24ae7 --- /dev/null +++ b/scripts/train_by_epochs.sh @@ -0,0 +1,26 @@ +# Run on g5.12xlargeinstance + +# Without preprocessing (on-the-fly) +SOURCE="data/raw/allMeSH_2021.jsonl" + +# After preprocessing first +# SOURCE="output_folder_from_preprocessing" + +grants-tagger train bertmesh \ + "" \ + $SOURCE \ + --output_dir bertmesh_outs/pipeline_test/ \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 1 \ + --num_train_epochs 5 \ + --learning_rate 5e-5 \ + - dropout 1.0 \ + --warmup_steps 1000 \ + --fp16 \ + --torch_compile \ + --evaluation_strategy epochs \ + --eval_accumulation_steps 20 \ + --save_strategy epochs \ + --wandb_project wellcome-mesh \ + --wandb_name test-train-all \ + --wandb_api_key ${WANDB_API_KEY} diff --git a/scripts/train.sh b/scripts/train_by_steps.sh similarity index 100% rename from scripts/train.sh rename to scripts/train_by_steps.sh From 4db6fb3d66e29fa8139ebfe0af435fa98b617a0e Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 4 Aug 2023 16:06:46 +0100 Subject: [PATCH 126/300] Adds evaluation by epoch --- scripts/train_by_epochs.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/train_by_epochs.sh b/scripts/train_by_epochs.sh index bee24ae7..43032c48 100644 --- a/scripts/train_by_epochs.sh +++ b/scripts/train_by_epochs.sh @@ -14,13 +14,13 @@ grants-tagger train bertmesh \ --per_device_eval_batch_size 1 \ --num_train_epochs 5 \ --learning_rate 5e-5 \ - - dropout 1.0 \ + --dropout 1.0 \ --warmup_steps 1000 \ --fp16 \ --torch_compile \ - --evaluation_strategy epochs \ + --evaluation_strategy epoch \ --eval_accumulation_steps 20 \ - --save_strategy epochs \ + --save_strategy epoch \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ --wandb_api_key ${WANDB_API_KEY} From b1a6d1961a94b4ce8fcde8db937406f002117c54 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 4 Aug 2023 16:47:17 +0100 Subject: [PATCH 127/300] Train/test split by years --- README.md | 28 ++++++++++----- .../preprocessing/preprocess_mesh.py | 26 +++++--------- grants_tagger_light/training/train.py | 36 +++++++++++++------ .../utils/years_tags_parser.py | 19 ++++++++++ ...sume_train.sh => resume_train_by_steps.sh} | 1 - scripts/train_by_epochs.sh | 3 ++ scripts/train_by_steps.sh | 6 ++-- 7 files changed, 80 insertions(+), 39 deletions(-) create mode 100644 grants_tagger_light/utils/years_tags_parser.py rename scripts/{resume_train.sh => resume_train_by_steps.sh} (97%) diff --git a/README.md b/README.md index 2e3ae992..0bfeb820 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ in square brackets the commands that are not implemented yet ## ⚙️Preprocess -This process is optional to run, since it will be managed by the `Train` process. +This process is optional to run, since it can be directly managed by the `Train` process. - If you run it manually, it will store the data in local first, which can help if you need finetune in the future, rerun, etc. - If not, the project will preprocess and then run, without any extra I/O operations on disk, @@ -109,12 +109,16 @@ your own data under development. │ [required] │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ --test-size FLOAT Fraction of data to use for testing [default: 0.05] │ +│ --test-size FLOAT Fraction of data to use for testing (if less than 1) or number of rows │ +│ [default: 0.05] │ │ --num-proc INTEGER Number of processes to use for preprocessing [default: 8] │ │ --max-samples INTEGER Maximum number of samples to use for preprocessing [default: -1] │ │ --batch-size INTEGER Size of the preprocessing batch [default: 256] │ -│ --years TEXT Comma-separated years you want to included (e.g: 2020,2021) [default: None] │ -│ --tags TEXT Comma-separated tags you want to included (e.g: Pandemics,COVID19) [default: None] │ +│ --train-years TEXT Comma-separated years you want to include in training (e.g: 2020,2021) │ +│ [default: None, meaning all years] │ +│ --test-years TEXT Comma-separated years you want to include in test (e.g: 2020,2021) │ +│ [default: None, meaning all years] │ +│ --tags TEXT Comma-separated tags you want to included (e.g: Pandemics,COVID19) │ │ --help Show this message and exit. │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯``` ``` @@ -134,14 +138,22 @@ the BertMesh model. The command will train a model and save it to the specified │ to disk │ │ [default: None] │ │ [required] │ +│ --shards INTEGER Number os shards to divide training IterativeDataset to (improves performance) │ +│ [default: -1, meaning no shards]. Recommended: os.cpu_count() │ +│ --num-proc INTEGER Number of processes to use for preprocessing [default: os.cpu_count()] │ +│ --help Show this message and exit. │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ + +If you are running directly training without calling preprocess, you can specify the same parameters as preprocess: + ╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────╮ │ --test-size FLOAT Fraction of data to use for testing [default: 0.05] │ -│ --num-proc INTEGER Number of processes to use for preprocessing [default: 8] │ │ --max-samples INTEGER Maximum number of samples to use from the json [default: -1] │ -│ --shards INTEGER Number os shards to divide training IterativeDataset to (improves performance) │ -│ [default: -1, meaning no shards]. Recommended: 100 │ -│ --help Show this message and exit. │ +│ --train-years TEXT Comma-separated years you want to include in training (e.g: 2020,2021) │ +│ [default: None, meaning all years] │ +│ --test-years TEXT Comma-separated years you want to include in test (e.g: 2020,2021) │ +│ [default: None, meaning all years] │ +│ --tags TEXT Comma-separated tags you want to included (e.g: Pandemics,COVID19) │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 7ce047d3..950f86d4 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -13,6 +13,8 @@ import numpy as np from datasets.dataset_dict import DatasetDict +from grants_tagger_light.utils.years_tags_parser import parse_tags, parse_years + preprocess_app = typer.Typer() @@ -57,7 +59,7 @@ def preprocess_mesh( num_proc: int = os.cpu_count(), max_samples: int = -1, batch_size: int = 256, - tags: str = None, + tags: list = None, train_years: list = None, test_years: list = None ): @@ -98,10 +100,9 @@ def preprocess_mesh( logger.info(f"Removing all years which are not in {years}") dset = dset.filter(lambda x: any(np.isin(years, [str(x["year"])]))) - if tags is not None: - logger.info(f"Filtering tags: {tags}") - filter_tags_list = list(filter(lambda x: x.strip() != "", tags.split(","))) - dset = dset.filter(lambda x: any(np.isin(filter_tags_list, x["meshMajor"]))) + if len(tags) > 0: + logger.info(f"Removing all tags which are not in {tags}") + dset = dset.filter(lambda x: any(np.isin(tags, x["meshMajor"]))) # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "pmid", "title"]) @@ -228,24 +229,15 @@ def preprocess_mesh_cli( ) exit(-1) - train_years_list = [] - test_years_list = [] - if train_years is not None: if test_years is None: logger.error("--train-years require --test-years") exit(-1) - filter_years_list = list(filter(lambda x: x.strip() != "", train_years.split(","))) - train_years_list = [str(y) for y in filter_years_list] - logger.info(f"Training years to be considered: {train_years_list}") if test_years is not None: if train_years is None: logger.error("--test-years require --train-years") exit(-1) - filter_years_list = list(filter(lambda x: x.strip() != "", test_years.split(","))) - test_years_list = [str(y) for y in filter_years_list] - logger.info(f"Test years to be considered: {test_years_list}") preprocess_mesh( data_path=data_path, @@ -255,7 +247,7 @@ def preprocess_mesh_cli( max_samples=max_samples, batch_size=batch_size, save_to_path=save_to_path, - tags=tags, - train_years=train_years_list, - test_years=test_years_list + tags=parse_tags(tags), + train_years=parse_years(train_years), + test_years=parse_years(test_years) ) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index aea6e843..0d4b145d 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -26,6 +26,7 @@ from datasets import load_from_disk from grants_tagger_light.utils.sharding import Sharding +from grants_tagger_light.utils.years_tags_parser import parse_years, parse_tags transformers.set_seed(42) @@ -39,7 +40,10 @@ def train_bertmesh( test_size: float = 0.05, num_proc: int = os.cpu_count(), shards: int = os.cpu_count(), - from_checkpoint: str = None + from_checkpoint: str = None, + tags: list = None, + train_years: list = None, + test_years: list = None ): if not model_key: assert isinstance(model_args, BertMeshModelArguments), ( @@ -65,6 +69,9 @@ def train_bertmesh( num_proc=num_proc, max_samples=max_samples, batch_size=training_args.per_device_train_batch_size, + tags=tags, + train_years=train_years, + test_years=test_years ) train_dset, val_dset = dset["train"], dset["test"] @@ -196,7 +203,11 @@ def train_bertmesh_cli( from_checkpoint: str = typer.Option( None, help="Name of the checkpoint to resume training" - ) + ), + tags: str = typer.Option(None, help="Comma-separated tags you want to include in the dataset " + "(the rest will be discarded)"), + train_years: str = typer.Option(None, help="Comma-separated years you want to include in the training dataset"), + test_years: str = typer.Option(None, help="Comma-separated years you want to include in the test dataset") ): parser = HfArgumentParser( ( @@ -215,13 +226,16 @@ def train_bertmesh_cli( logger.info("Wandb args: {}".format(pformat(wandb_args))) train_bertmesh( - model_key, - data_path, - training_args, - model_args, - max_samples, - test_size, - num_proc, - shards, - from_checkpoint + model_key=model_key, + data_path=data_path, + training_args=training_args, + model_args=model_args, + max_samples=max_samples, + test_size=test_size, + num_proc=num_proc, + shards=shards, + from_checkpoint=from_checkpoint, + tags=parse_tags(tags), + train_years=parse_years(train_years), + test_years=parse_years(test_years) ) diff --git a/grants_tagger_light/utils/years_tags_parser.py b/grants_tagger_light/utils/years_tags_parser.py new file mode 100644 index 00000000..12c65894 --- /dev/null +++ b/grants_tagger_light/utils/years_tags_parser.py @@ -0,0 +1,19 @@ +from loguru import logger + + +def parse_tags(tags): + tags_list = [] + if tags is not None: + filter_tags_list = list(filter(lambda x: x.strip() != "", tags.split(","))) + tags_list = [str(y) for y in filter_tags_list] + logger.info(f"Tags to be considered: {tags_list}") + return tags_list + + +def parse_years(years): + years_list = [] + if years is not None: + filter_years_list = list(filter(lambda x: x.strip() != "", years.split(","))) + years_list = [str(y) for y in filter_years_list] + logger.info(f"Years to be considered: {years_list}") + return years_list diff --git a/scripts/resume_train.sh b/scripts/resume_train_by_steps.sh similarity index 97% rename from scripts/resume_train.sh rename to scripts/resume_train_by_steps.sh index 4a39c264..da0c52c4 100644 --- a/scripts/resume_train.sh +++ b/scripts/resume_train_by_steps.sh @@ -14,7 +14,6 @@ grants-tagger train bertmesh \ $SOURCE \ --output_dir bertmesh_outs/pipeline_test_from_$CHECKPOINT/ \ --ignore_data_skip=True \ - --shards 48 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ --num_train_epochs 2 \ diff --git a/scripts/train_by_epochs.sh b/scripts/train_by_epochs.sh index 43032c48..7a1e8151 100644 --- a/scripts/train_by_epochs.sh +++ b/scripts/train_by_epochs.sh @@ -9,7 +9,10 @@ SOURCE="data/raw/allMeSH_2021.jsonl" grants-tagger train bertmesh \ "" \ $SOURCE \ + --test-size 10000 \ --output_dir bertmesh_outs/pipeline_test/ \ + --train-years 2021 \ + --test-years 2020 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ --num_train_epochs 5 \ diff --git a/scripts/train_by_steps.sh b/scripts/train_by_steps.sh index 88546eff..5b0b904d 100644 --- a/scripts/train_by_steps.sh +++ b/scripts/train_by_steps.sh @@ -9,12 +9,14 @@ SOURCE="data/raw/allMeSH_2021.jsonl" grants-tagger train bertmesh \ "" \ $SOURCE \ - --test-size 0.0025 \ - --shards 48 \ + --test-size 10000 \ --output_dir bertmesh_outs/pipeline_test/ \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ --num_train_epochs 1 \ + --learning_rate 5e-5 \ + --dropout 1.0 \ + --warmup_steps 1000 \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ From c2d3a901196d1d31e45de66a868e928875771156 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 5 Aug 2023 13:06:07 +0100 Subject: [PATCH 128/300] Fixes drop out and standardizes calls --- scripts/resume_train_by_epoch.sh | 30 ++++++++++++++++++++++++++++++ scripts/resume_train_by_steps.sh | 7 +++++-- scripts/train_by_epochs.sh | 5 +++-- scripts/train_by_steps.sh | 7 +++++-- 4 files changed, 43 insertions(+), 6 deletions(-) create mode 100644 scripts/resume_train_by_epoch.sh diff --git a/scripts/resume_train_by_epoch.sh b/scripts/resume_train_by_epoch.sh new file mode 100644 index 00000000..a198ff7b --- /dev/null +++ b/scripts/resume_train_by_epoch.sh @@ -0,0 +1,30 @@ +# Run on g5.12xlarge instance + +# Without preprocessing (on-the-fly) +SOURCE="data/raw/allMeSH_2021.jsonl" + +# After preprocessing first +# SOURCE="output_folder_from_preprocessing" + +# Checkpoint +CHECKPOINT="checkpoint-100000" + +grants-tagger train bertmesh \ + bertmesh_outs/pipeline_test/$CHECKPOINT \ + $SOURCE \ + --output_dir bertmesh_outs/pipeline_test_from_$CHECKPOINT/ \ + --ignore_data_skip=True \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 1 \ + --num_train_epochs 5 \ + --learning_rate 5e-5 \ + --dropout 0.1 \ + --warmup_steps 1000 \ + --fp16 \ + --torch_compile \ + --evaluation_strategy epoch \ + --eval_accumulation_steps 20 \ + --save_strategy epoch \ + --wandb_project wellcome-mesh \ + --wandb_name test-train-all \ + --wandb_api_key ${WANDB_API_KEY} \ No newline at end of file diff --git a/scripts/resume_train_by_steps.sh b/scripts/resume_train_by_steps.sh index da0c52c4..40770fd3 100644 --- a/scripts/resume_train_by_steps.sh +++ b/scripts/resume_train_by_steps.sh @@ -16,7 +16,10 @@ grants-tagger train bertmesh \ --ignore_data_skip=True \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ - --num_train_epochs 2 \ + --num_train_epochs 1 \ + --learning_rate 5e-5 \ + --dropout 0.1 \ + --warmup_steps 1000 \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ @@ -26,4 +29,4 @@ grants-tagger train bertmesh \ --save_steps 50000 \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ - --wandb_api_key ${WANDB_API_KEY} + --wandb_api_key ${WANDB_API_KEY} \ No newline at end of file diff --git a/scripts/train_by_epochs.sh b/scripts/train_by_epochs.sh index 7a1e8151..388e1b7e 100644 --- a/scripts/train_by_epochs.sh +++ b/scripts/train_by_epochs.sh @@ -3,8 +3,9 @@ # Without preprocessing (on-the-fly) SOURCE="data/raw/allMeSH_2021.jsonl" -# After preprocessing first +# If you have already preprocessed the data, you will have a folder. Use the folder instead. # SOURCE="output_folder_from_preprocessing" +# In that case, `test-size`, `train-years` and `test-years` will be ignored. grants-tagger train bertmesh \ "" \ @@ -17,7 +18,7 @@ grants-tagger train bertmesh \ --per_device_eval_batch_size 1 \ --num_train_epochs 5 \ --learning_rate 5e-5 \ - --dropout 1.0 \ + --dropout 0.1 \ --warmup_steps 1000 \ --fp16 \ --torch_compile \ diff --git a/scripts/train_by_steps.sh b/scripts/train_by_steps.sh index 5b0b904d..44a8772e 100644 --- a/scripts/train_by_steps.sh +++ b/scripts/train_by_steps.sh @@ -3,19 +3,22 @@ # Without preprocessing (on-the-fly) SOURCE="data/raw/allMeSH_2021.jsonl" -# After preprocessing first +# If you have already preprocessed the data, you will have a folder. Use the folder instead. # SOURCE="output_folder_from_preprocessing" +# In that case, `test-size`, `train-years` and `test-years` will be ignored. grants-tagger train bertmesh \ "" \ $SOURCE \ --test-size 10000 \ --output_dir bertmesh_outs/pipeline_test/ \ + --train-years 2021 \ + --test-years 2020 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ --num_train_epochs 1 \ --learning_rate 5e-5 \ - --dropout 1.0 \ + --dropout 0.1 \ --warmup_steps 1000 \ --fp16 \ --torch_compile \ From 5e04318eeb6822115a7b45b4773cf688401c6e86 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 5 Aug 2023 13:06:49 +0100 Subject: [PATCH 129/300] Fixes years --- scripts/train_by_epochs.sh | 4 ++-- scripts/train_by_steps.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/train_by_epochs.sh b/scripts/train_by_epochs.sh index 388e1b7e..1fedc7f3 100644 --- a/scripts/train_by_epochs.sh +++ b/scripts/train_by_epochs.sh @@ -12,8 +12,8 @@ grants-tagger train bertmesh \ $SOURCE \ --test-size 10000 \ --output_dir bertmesh_outs/pipeline_test/ \ - --train-years 2021 \ - --test-years 2020 \ + --train-years 2016,2017,2018,2019 \ + --test-years 2020,2021 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ --num_train_epochs 5 \ diff --git a/scripts/train_by_steps.sh b/scripts/train_by_steps.sh index 44a8772e..51ec22a6 100644 --- a/scripts/train_by_steps.sh +++ b/scripts/train_by_steps.sh @@ -12,8 +12,8 @@ grants-tagger train bertmesh \ $SOURCE \ --test-size 10000 \ --output_dir bertmesh_outs/pipeline_test/ \ - --train-years 2021 \ - --test-years 2020 \ + --train-years 2016,2017,2018,2019 \ + --test-years 2020,2021 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ --num_train_epochs 1 \ From b1a609859c0bfad3b2de30fb28fc04b9051a80e5 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sun, 6 Aug 2023 12:19:51 +0100 Subject: [PATCH 130/300] Test size frac vs row --- grants_tagger_light/preprocessing/preprocess_mesh.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 950f86d4..bdb74ec4 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -63,9 +63,11 @@ def preprocess_mesh( train_years: list = None, test_years: list = None ): - if test_size > 1: - logger.info(f"Test size found not as a fraction, but as a number of rows. Transforming {test_size} to integer") + if test_size > 1.0: test_size = int(test_size) + logger.info(f"Test size found as number of rows: {test_size}") + else: + logger.info(f"Test size found as fraction: {test_size}") if max_samples != -1: logger.info(f"Filtering examples to {max_samples}") From 384436c58094a55e52ca16995525d2ac8a9b6f90 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sun, 6 Aug 2023 12:24:03 +0100 Subject: [PATCH 131/300] Test size frac vs row --- scripts/{preprocess.sh => preprocess_splitting_by_years.sh} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename scripts/{preprocess.sh => preprocess_splitting_by_years.sh} (87%) diff --git a/scripts/preprocess.sh b/scripts/preprocess_splitting_by_years.sh similarity index 87% rename from scripts/preprocess.sh rename to scripts/preprocess_splitting_by_years.sh index 0cda3229..c470f9ac 100644 --- a/scripts/preprocess.sh +++ b/scripts/preprocess_splitting_by_years.sh @@ -1,4 +1,4 @@ grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl ./kk '' \ --train-years 2016,2017,2018,2019 \ --test-years 2020,2021 \ - --test-size 10000 \ No newline at end of file + --test-size 1 \ No newline at end of file From f19fe0f628c3d33d68677453e99d9c4d744b6297 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sun, 6 Aug 2023 12:39:10 +0100 Subject: [PATCH 132/300] Test size frac vs row --- .../preprocessing/preprocess_mesh.py | 31 +++++++++++++------ grants_tagger_light/training/train.py | 4 +-- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index bdb74ec4..b11a4f19 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -27,6 +27,18 @@ def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): ) +def _parse_test_size(test_dset, test_size): + if test_size is None or test_size == 1.0: + test_size = len(test_dset) + logger.info(f"Test size set to number of rows of the whole test dataset: {test_size}") + elif test_size > 1.0: + test_size = int(test_size) + logger.info(f"Test size found as number of rows: {test_size}") + else: + logger.info(f"Test size found as fraction: {test_size}") + return test_size + + def _map_label_to_ids(labels, label2id): return [label2id[label] for label in labels] @@ -55,7 +67,7 @@ def preprocess_mesh( data_path: str, model_key: str, save_to_path: str = None, - test_size: float = 0.05, + test_size: float = None, num_proc: int = os.cpu_count(), max_samples: int = -1, batch_size: int = 256, @@ -63,12 +75,6 @@ def preprocess_mesh( train_years: list = None, test_years: list = None ): - if test_size > 1.0: - test_size = int(test_size) - logger.info(f"Test size found as number of rows: {test_size}") - else: - logger.info(f"Test size found as fraction: {test_size}") - if max_samples != -1: logger.info(f"Filtering examples to {max_samples}") data_path = create_sample_file(data_path, max_samples) @@ -180,9 +186,16 @@ def preprocess_mesh( num_proc=num_proc, fn_kwargs={"years": test_years}) + test_size = _parse_test_size(test_dset, test_size) + logger.info(f"Splitting the dataset by years with a test_size of {test_size}") dset = DatasetDict({'train': train_dset, 'test': test_dset.train_test_split(test_size)['test']}) else: - logger.info(f"Splitting the dataset randomly by a fraction of {test_size}") + if test_size is None: + test_size = 0.05 + if test_size >= 1.0: + logger.info(f"Splitting the dataset randomly (not by years) by number of rows: {test_size}") + else: + logger.info(f"Splitting the dataset randomly (not by years) by frac: {test_size}") dset = dset.train_test_split(test_size=test_size) logger.info("Time taken to split into train and test: {}".format(time.time() - t1)) @@ -209,7 +222,7 @@ def preprocess_mesh_cli( help="Key to use when loading tokenizer and label2id. " "Leave blank if training from scratch", # noqa ), - test_size: float = typer.Option(0.05, help="Fraction of data to use for testing or number of rows"), + test_size: float = typer.Option(None, help="Fraction of data to use for testing in (0,1] or number of rows"), num_proc: int = typer.Option( os.cpu_count(), help="Number of processes to use for preprocessing" ), diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 0d4b145d..d0d9e215 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -37,7 +37,7 @@ def train_bertmesh( training_args: TrainingArguments, model_args: BertMeshModelArguments = None, max_samples: int = -1, - test_size: float = 0.05, + test_size: float = None, num_proc: int = os.cpu_count(), shards: int = os.cpu_count(), from_checkpoint: str = None, @@ -187,7 +187,7 @@ def train_bertmesh_cli( help="Path to allMeSH_2021.jsonl (or similar) " "or to a folder after preprocessing and saving to disk", ), - test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), + test_size: float = typer.Option(None, help="Fraction of data to use for testing (0,1] or number of rows"), num_proc: int = typer.Option( os.cpu_count(), help="Number of processes to use for preprocessing" ), From 4b5320896439f3d8c3cba3ecbafd18926c5fcf8f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sun, 6 Aug 2023 12:41:53 +0100 Subject: [PATCH 133/300] Test size frac vs row --- .../preprocessing/preprocess_mesh.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index b11a4f19..61d834d2 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -29,13 +29,18 @@ def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): def _parse_test_size(test_dset, test_size): if test_size is None or test_size == 1.0: - test_size = len(test_dset) - logger.info(f"Test size set to number of rows of the whole test dataset: {test_size}") + if test_dset is not None: + test_size = len(test_dset) + logger.info(f"Test size not gounf, set to number of rows of the whole test dataset: {test_size}") + else: + test_size = 0.05 + logger.info(f"Test size not found, setting to a fraction of all dataset: {test_size}") elif test_size > 1.0: test_size = int(test_size) logger.info(f"Test size found as number of rows: {test_size}") else: logger.info(f"Test size found as fraction: {test_size}") + return test_size @@ -190,12 +195,7 @@ def preprocess_mesh( logger.info(f"Splitting the dataset by years with a test_size of {test_size}") dset = DatasetDict({'train': train_dset, 'test': test_dset.train_test_split(test_size)['test']}) else: - if test_size is None: - test_size = 0.05 - if test_size >= 1.0: - logger.info(f"Splitting the dataset randomly (not by years) by number of rows: {test_size}") - else: - logger.info(f"Splitting the dataset randomly (not by years) by frac: {test_size}") + test_size = _parse_test_size(None, test_size) dset = dset.train_test_split(test_size=test_size) logger.info("Time taken to split into train and test: {}".format(time.time() - t1)) From fbf941ca593b73e50f2f53232b962234cac54214 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sun, 6 Aug 2023 12:42:27 +0100 Subject: [PATCH 134/300] Test size frac vs row --- scripts/preprocess_splitting_by_years.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/preprocess_splitting_by_years.sh b/scripts/preprocess_splitting_by_years.sh index c470f9ac..6e2b0949 100644 --- a/scripts/preprocess_splitting_by_years.sh +++ b/scripts/preprocess_splitting_by_years.sh @@ -1,4 +1,3 @@ grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl ./kk '' \ --train-years 2016,2017,2018,2019 \ - --test-years 2020,2021 \ - --test-size 1 \ No newline at end of file + --test-years 2020,2021 \ No newline at end of file From 4108f37d66670e5a1cece909d99603f862e6f603 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sun, 6 Aug 2023 12:44:38 +0100 Subject: [PATCH 135/300] Test size frac vs row --- scripts/preprocess_splitting_by_fract.sh | 2 ++ scripts/preprocess_splitting_by_rows.sh | 2 ++ 2 files changed, 4 insertions(+) create mode 100644 scripts/preprocess_splitting_by_fract.sh create mode 100644 scripts/preprocess_splitting_by_rows.sh diff --git a/scripts/preprocess_splitting_by_fract.sh b/scripts/preprocess_splitting_by_fract.sh new file mode 100644 index 00000000..9cf761d3 --- /dev/null +++ b/scripts/preprocess_splitting_by_fract.sh @@ -0,0 +1,2 @@ +grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl ./kk '' \ + --test-size 0.05 \ No newline at end of file diff --git a/scripts/preprocess_splitting_by_rows.sh b/scripts/preprocess_splitting_by_rows.sh new file mode 100644 index 00000000..eb1e5dae --- /dev/null +++ b/scripts/preprocess_splitting_by_rows.sh @@ -0,0 +1,2 @@ +grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl ./kk '' \ + --test-size 10000 \ No newline at end of file From 2a95c395a4d68744102a79436a66171bfeee52d6 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sun, 6 Aug 2023 12:51:50 +0100 Subject: [PATCH 136/300] Test size frac vs row --- .../preprocessing/preprocess_mesh.py | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 61d834d2..1aed5fa5 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -27,23 +27,6 @@ def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): ) -def _parse_test_size(test_dset, test_size): - if test_size is None or test_size == 1.0: - if test_dset is not None: - test_size = len(test_dset) - logger.info(f"Test size not gounf, set to number of rows of the whole test dataset: {test_size}") - else: - test_size = 0.05 - logger.info(f"Test size not found, setting to a fraction of all dataset: {test_size}") - elif test_size > 1.0: - test_size = int(test_size) - logger.info(f"Test size found as number of rows: {test_size}") - else: - logger.info(f"Test size found as fraction: {test_size}") - - return test_size - - def _map_label_to_ids(labels, label2id): return [label2id[label] for label in labels] @@ -191,11 +174,18 @@ def preprocess_mesh( num_proc=num_proc, fn_kwargs={"years": test_years}) - test_size = _parse_test_size(test_dset, test_size) - logger.info(f"Splitting the dataset by years with a test_size of {test_size}") - dset = DatasetDict({'train': train_dset, 'test': test_dset.train_test_split(test_size)['test']}) + if test_size is None or test_size == 1.0: + test_size = len(test_dset) + logger.info(f"Using the whole dataset of {test_size} rows") + dset = DatasetDict({'train': train_dset, 'test': test_dset}) + else: + logger.info(f"Using a test_size frac or number of rows of of {test_size}") + dset = DatasetDict({'train': train_dset, 'test': test_dset.train_test_split(test_size)['test']}) + else: - test_size = _parse_test_size(None, test_size) + if test_size is None: + test_size = 0.05 + logger.info(f"Test size not found. Setting it to a frac of the whole dataset equal to {test_size}") dset = dset.train_test_split(test_size=test_size) logger.info("Time taken to split into train and test: {}".format(time.time() - t1)) From 1097186a81a05826fd67e19bde03262763978d51 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 7 Aug 2023 10:43:52 +0100 Subject: [PATCH 137/300] Logging model params --- grants_tagger_light/training/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index d0d9e215..b2b8bb70 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -95,8 +95,7 @@ def train_bertmesh( logger.info(f"Loading `{model_args.pretrained_model_key}` tokenizer...") config = AutoConfig.from_pretrained(model_args.pretrained_model_key) - config.update( - { + overwritten_params = { "pretrained_model": model_args.pretrained_model_key, "num_labels": len(label2id), "hidden_size": model_args.hidden_size, @@ -106,7 +105,9 @@ def train_bertmesh( "id2label": {v: k for k, v in label2id.items()}, "freeze_backbone": model_args.freeze_backbone, } - ) + logger.info(f"Updating model params:\n{overwritten_params}") + + config.update(overwritten_params) model = BertMesh(config) else: @@ -222,6 +223,7 @@ def train_bertmesh_cli( model_args, ) = parser.parse_args_into_dataclasses(ctx.args) + logger.info("Model args: {}".format(pformat(model_args))) logger.info("Training args: {}".format(pformat(training_args))) logger.info("Wandb args: {}".format(pformat(wandb_args))) From d69553fde77689428441c632876e7817c4f582f1 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 7 Aug 2023 10:46:28 +0100 Subject: [PATCH 138/300] Adds hidden size to train params --- scripts/resume_train_by_epoch.sh | 1 + scripts/resume_train_by_steps.sh | 1 + scripts/train_by_epochs.sh | 3 ++- scripts/train_by_steps.sh | 3 ++- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/resume_train_by_epoch.sh b/scripts/resume_train_by_epoch.sh index a198ff7b..9f1a22f6 100644 --- a/scripts/resume_train_by_epoch.sh +++ b/scripts/resume_train_by_epoch.sh @@ -19,6 +19,7 @@ grants-tagger train bertmesh \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ + --hidden_size 1024 \ --warmup_steps 1000 \ --fp16 \ --torch_compile \ diff --git a/scripts/resume_train_by_steps.sh b/scripts/resume_train_by_steps.sh index 40770fd3..846b16f4 100644 --- a/scripts/resume_train_by_steps.sh +++ b/scripts/resume_train_by_steps.sh @@ -19,6 +19,7 @@ grants-tagger train bertmesh \ --num_train_epochs 1 \ --learning_rate 5e-5 \ --dropout 0.1 \ + --hidden_size 1024 \ --warmup_steps 1000 \ --fp16 \ --torch_compile \ diff --git a/scripts/train_by_epochs.sh b/scripts/train_by_epochs.sh index 1fedc7f3..0f8171be 100644 --- a/scripts/train_by_epochs.sh +++ b/scripts/train_by_epochs.sh @@ -5,7 +5,7 @@ SOURCE="data/raw/allMeSH_2021.jsonl" # If you have already preprocessed the data, you will have a folder. Use the folder instead. # SOURCE="output_folder_from_preprocessing" -# In that case, `test-size`, `train-years` and `test-years` will be ignored. +# In that case, `test-size`, `train-years` and `test-years` will be taken from the preprocessed folder grants-tagger train bertmesh \ "" \ @@ -19,6 +19,7 @@ grants-tagger train bertmesh \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ + --hidden_size 1024 \ --warmup_steps 1000 \ --fp16 \ --torch_compile \ diff --git a/scripts/train_by_steps.sh b/scripts/train_by_steps.sh index 51ec22a6..bdb3c208 100644 --- a/scripts/train_by_steps.sh +++ b/scripts/train_by_steps.sh @@ -5,7 +5,7 @@ SOURCE="data/raw/allMeSH_2021.jsonl" # If you have already preprocessed the data, you will have a folder. Use the folder instead. # SOURCE="output_folder_from_preprocessing" -# In that case, `test-size`, `train-years` and `test-years` will be ignored. +# In that case, `test-size`, `train-years` and `test-years` will be taken from the preprocessed folder grants-tagger train bertmesh \ "" \ @@ -19,6 +19,7 @@ grants-tagger train bertmesh \ --num_train_epochs 1 \ --learning_rate 5e-5 \ --dropout 0.1 \ + --hidden_size 1024 \ --warmup_steps 1000 \ --fp16 \ --torch_compile \ From d2a5c10de8c2f1c46191919bfdf09f3af4881241 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 7 Aug 2023 10:51:10 +0100 Subject: [PATCH 139/300] Adds hidden size to train params --- grants_tagger_light/training/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index b2b8bb70..c56cd198 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -95,7 +95,7 @@ def train_bertmesh( logger.info(f"Loading `{model_args.pretrained_model_key}` tokenizer...") config = AutoConfig.from_pretrained(model_args.pretrained_model_key) - overwritten_params = { + config.update({ "pretrained_model": model_args.pretrained_model_key, "num_labels": len(label2id), "hidden_size": model_args.hidden_size, @@ -104,10 +104,10 @@ def train_bertmesh( "label2id": label2id, "id2label": {v: k for k, v in label2id.items()}, "freeze_backbone": model_args.freeze_backbone, - } - logger.info(f"Updating model params:\n{overwritten_params}") + }) + logger.info(f"Hidden size: {config['hidden_size']}") + logger.info(f"Dropout: {config['dropout']}") - config.update(overwritten_params) model = BertMesh(config) else: From 38e9026907c8579fbeac87afde686a1adedd8a99 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 7 Aug 2023 10:52:42 +0100 Subject: [PATCH 140/300] Adds hidden size to train params --- grants_tagger_light/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index c56cd198..9cecef49 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -105,8 +105,8 @@ def train_bertmesh( "id2label": {v: k for k, v in label2id.items()}, "freeze_backbone": model_args.freeze_backbone, }) - logger.info(f"Hidden size: {config['hidden_size']}") - logger.info(f"Dropout: {config['dropout']}") + logger.info(f"Hidden size: {config.hidden_size}") + logger.info(f"Dropout: {config.dropout}") model = BertMesh(config) From f5e256831f22258e1159c5d307a82eeab5a19b61 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 7 Aug 2023 22:44:23 +0100 Subject: [PATCH 141/300] Fixes bug with number of rows as test size --- grants_tagger_light/preprocessing/preprocess_mesh.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 1aed5fa5..e99d00e3 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -179,6 +179,8 @@ def preprocess_mesh( logger.info(f"Using the whole dataset of {test_size} rows") dset = DatasetDict({'train': train_dset, 'test': test_dset}) else: + if test_size > 1.0: + test_size = int(test_size) logger.info(f"Using a test_size frac or number of rows of of {test_size}") dset = DatasetDict({'train': train_dset, 'test': test_dset.train_test_split(test_size)['test']}) From 7b1f213733c95afa3d80c985c6a6b2ae39c87ff1 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Tue, 8 Aug 2023 11:00:49 +0100 Subject: [PATCH 142/300] Adds multilabel_attention and freeze_backbone --- grants_tagger_light/training/train.py | 27 ++++++++++++++++++++------- scripts/train_by_steps.sh | 4 +++- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 9cecef49..c89ecce3 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -89,7 +89,7 @@ def train_bertmesh( train_dset = Sharding(num_shards=shards).shard(train_dset) if not model_key: - logger.info("No model key provided. Training model from scratch") + logger.info(f"Model key not found. Training from scratch {model_args.pretrained_model_key}") # Instantiate model from scratch logger.info(f"Loading `{model_args.pretrained_model_key}` tokenizer...") @@ -107,6 +107,8 @@ def train_bertmesh( }) logger.info(f"Hidden size: {config.hidden_size}") logger.info(f"Dropout: {config.dropout}") + logger.info(f"Multilabel Attention: {config.multilabel_attention}") + logger.info(f"Freeze Backbone: {config.freeze_backbone}") model = BertMesh(config) @@ -188,9 +190,12 @@ def train_bertmesh_cli( help="Path to allMeSH_2021.jsonl (or similar) " "or to a folder after preprocessing and saving to disk", ), - test_size: float = typer.Option(None, help="Fraction of data to use for testing (0,1] or number of rows"), + test_size: float = typer.Option( + None, + help="Fraction of data to use for testing (0,1] or number of rows"), num_proc: int = typer.Option( - os.cpu_count(), help="Number of processes to use for preprocessing" + os.cpu_count(), + help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( -1, @@ -205,10 +210,18 @@ def train_bertmesh_cli( None, help="Name of the checkpoint to resume training" ), - tags: str = typer.Option(None, help="Comma-separated tags you want to include in the dataset " - "(the rest will be discarded)"), - train_years: str = typer.Option(None, help="Comma-separated years you want to include in the training dataset"), - test_years: str = typer.Option(None, help="Comma-separated years you want to include in the test dataset") + tags: str = typer.Option( + None, + help="Comma-separated tags you want to include in the dataset " + "(the rest will be discarded)"), + train_years: str = typer.Option( + None, + help="Comma-separated years you want to include in the training dataset" + ), + test_years: str = typer.Option( + None, + help="Comma-separated years you want to include in the test dataset" + ) ): parser = HfArgumentParser( ( diff --git a/scripts/train_by_steps.sh b/scripts/train_by_steps.sh index bdb3c208..bdf62720 100644 --- a/scripts/train_by_steps.sh +++ b/scripts/train_by_steps.sh @@ -16,7 +16,9 @@ grants-tagger train bertmesh \ --test-years 2020,2021 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ - --num_train_epochs 1 \ + --multilabel_attention True \ + --freeze_backbone False \ + --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ From 1f14e33de4b13e8f4b19d1341c19a7b05590b199 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Tue, 8 Aug 2023 11:04:18 +0100 Subject: [PATCH 143/300] Adds multilabel_attention and freeze_backbone --- scripts/resume_train_by_epoch.sh | 2 ++ scripts/resume_train_by_steps.sh | 4 +++- scripts/train_by_epochs.sh | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/resume_train_by_epoch.sh b/scripts/resume_train_by_epoch.sh index 9f1a22f6..4556518f 100644 --- a/scripts/resume_train_by_epoch.sh +++ b/scripts/resume_train_by_epoch.sh @@ -16,6 +16,8 @@ grants-tagger train bertmesh \ --ignore_data_skip=True \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ + --multilabel_attention True \ + --freeze_backbone False \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ diff --git a/scripts/resume_train_by_steps.sh b/scripts/resume_train_by_steps.sh index 846b16f4..a27de4dc 100644 --- a/scripts/resume_train_by_steps.sh +++ b/scripts/resume_train_by_steps.sh @@ -16,7 +16,9 @@ grants-tagger train bertmesh \ --ignore_data_skip=True \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ - --num_train_epochs 1 \ + --multilabel_attention True \ + --freeze_backbone False \ + --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ diff --git a/scripts/train_by_epochs.sh b/scripts/train_by_epochs.sh index 0f8171be..ee64da00 100644 --- a/scripts/train_by_epochs.sh +++ b/scripts/train_by_epochs.sh @@ -16,6 +16,8 @@ grants-tagger train bertmesh \ --test-years 2020,2021 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ + --multilabel_attention True \ + --freeze_backbone False \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ From 51c43a82a157682872b92b803494308469850509 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 10 Aug 2023 12:44:52 +0100 Subject: [PATCH 144/300] Adds some debug info --- grants_tagger_light/models/bert_mesh/model.py | 25 ++++++++++++++++--- grants_tagger_light/training/train.py | 3 +++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index 4acc82d6..65451505 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -2,6 +2,7 @@ from transformers.modeling_outputs import SequenceClassifierOutput import torch import torch.nn.functional as F +from loguru import logger class MultiLabelAttention(torch.nn.Module): @@ -25,23 +26,37 @@ def __init__( config, ): super().__init__(config=config) - self.config.auto_map = {"AutoModel": "model.BertMesh"} self.pretrained_model = self.config.pretrained_model + logger.info(f"Pretrained model: {self.pretrained_model}") self.num_labels = self.config.num_labels + logger.info(f"Num labels: {self.num_labels}") + self.hidden_size = getattr(self.config, "hidden_size", 512) + logger.info(f"Hidden Size: {self.hidden_size}") + self.dropout = getattr(self.config, "dropout", 0.1) + logger.info(f"Dropout: {self.dropout}") + self.multilabel_attention = getattr(self.config, "multilabel_attention", False) + + logger.info(f"Multilabel attention: {self.multilabel_attention}") + self.id2label = self.config.id2label self.bert = AutoModel.from_pretrained(self.pretrained_model) # 768 self.multilabel_attention_layer = MultiLabelAttention( 768, self.num_labels ) # num_labels, 768 - self.linear_1 = torch.nn.Linear(768, self.hidden_size) # num_labels, 512 - self.linear_2 = torch.nn.Linear(self.hidden_size, 1) # num_labels, 1 + logger.info(f"multilabel_attention_layer: {self.multilabel_attention_layer}") + self.linear_1 = torch.nn.Linear(768, self.hidden_size) # 768, 1024 + logger.info(f"linear_1: {self.linear_1}") + self.linear_2 = torch.nn.Linear(self.hidden_size, 1) # 1024, 1 + logger.info(f"linear_2: {self.linear_2}") self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels) + logger.info(f"linear_out: {self.linear_out}") self.dropout_layer = torch.nn.Dropout(self.dropout) + logger.info(f"dropout_layer: {self.dropout_layer}") def freeze_backbone(self): for param in self.bert.parameters(): @@ -57,6 +72,7 @@ def forward(self, input_ids, labels=None, **kwargs): input_ids = torch.tensor(input_ids) if self.multilabel_attention: + logger.info(f"Forward: multilabel_attention!") hidden_states = self.bert(input_ids=input_ids)[0] attention_outs = self.multilabel_attention_layer(hidden_states) outs = torch.nn.functional.relu(self.linear_1(attention_outs)) @@ -64,6 +80,7 @@ def forward(self, input_ids, labels=None, **kwargs): outs = self.linear_2(outs) outs = torch.flatten(outs, start_dim=1) else: + logger.info(f"Forward: non-multilabel_attention!") cls = self.bert(input_ids=input_ids)[1] outs = torch.nn.functional.relu(self.linear_1(cls)) outs = self.dropout_layer(outs) @@ -73,7 +90,7 @@ def forward(self, input_ids, labels=None, **kwargs): loss = F.binary_cross_entropy_with_logits(outs, labels.float()) else: loss = -1 - + logger.info(f"loss: {loss}") return SequenceClassifierOutput( loss=loss, logits=outs, diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index c89ecce3..f41f2134 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -109,6 +109,7 @@ def train_bertmesh( logger.info(f"Dropout: {config.dropout}") logger.info(f"Multilabel Attention: {config.multilabel_attention}") logger.info(f"Freeze Backbone: {config.freeze_backbone}") + logger.info(f"Num labels: {config.num_labels}") model = BertMesh(config) @@ -119,6 +120,8 @@ def train_bertmesh( if model_args.freeze_backbone: logger.info("Freezing backbone") model.freeze_backbone() + else: + model.unfreeze_backbone() def sklearn_metrics(prediction: EvalPrediction): y_pred = prediction.predictions From 188ccc1cc84fc23f7e6f6b56a838f3ae1a4dcfe0 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 10 Aug 2023 12:53:46 +0100 Subject: [PATCH 145/300] Removes forward logs --- grants_tagger_light/models/bert_mesh/model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index 65451505..cdff897a 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -72,7 +72,6 @@ def forward(self, input_ids, labels=None, **kwargs): input_ids = torch.tensor(input_ids) if self.multilabel_attention: - logger.info(f"Forward: multilabel_attention!") hidden_states = self.bert(input_ids=input_ids)[0] attention_outs = self.multilabel_attention_layer(hidden_states) outs = torch.nn.functional.relu(self.linear_1(attention_outs)) @@ -80,7 +79,6 @@ def forward(self, input_ids, labels=None, **kwargs): outs = self.linear_2(outs) outs = torch.flatten(outs, start_dim=1) else: - logger.info(f"Forward: non-multilabel_attention!") cls = self.bert(input_ids=input_ids)[1] outs = torch.nn.functional.relu(self.linear_1(cls)) outs = self.dropout_layer(outs) @@ -90,7 +88,6 @@ def forward(self, input_ids, labels=None, **kwargs): loss = F.binary_cross_entropy_with_logits(outs, labels.float()) else: loss = -1 - logger.info(f"loss: {loss}") return SequenceClassifierOutput( loss=loss, logits=outs, From 08193dda25b94328b40e3462bb15c82e867ffc59 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 10 Aug 2023 14:50:48 +0100 Subject: [PATCH 146/300] Checking last implementation of BertMesh --- grants_tagger_light/models/bert_mesh/model.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index cdff897a..a3077f20 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -67,7 +67,7 @@ def unfreeze_backbone(self): param.requires_grad = True def forward(self, input_ids, labels=None, **kwargs): - if type(input_ids) is list: + """if type(input_ids) is list: # coming from tokenizer input_ids = torch.tensor(input_ids) @@ -82,12 +82,29 @@ def forward(self, input_ids, labels=None, **kwargs): cls = self.bert(input_ids=input_ids)[1] outs = torch.nn.functional.relu(self.linear_1(cls)) outs = self.dropout_layer(outs) - outs = self.linear_out(outs) + outs = self.linear_out(outs)""" + + if type(input_ids) is list: + # coming from tokenizer + input_ids = torch.tensor(input_ids) + if self.multilabel_attention: + hidden_states = self.bert(input_ids=input_ids)[0] + attention_outs = self.multilabel_attention_layer(hidden_states) + outs = torch.nn.functional.relu(self.linear_1(attention_outs)) + outs = self.dropout_layer(outs) + outs = torch.sigmoid(self.linear_2(outs)) + outs = torch.flatten(outs, start_dim=1) + else: + cls = self.bert(input_ids=input_ids)[1] + outs = torch.nn.functional.relu(self.linear_1(cls)) + outs = self.dropout_layer(outs) + outs = torch.sigmoid(self.linear_out(outs)) if labels is not None: - loss = F.binary_cross_entropy_with_logits(outs, labels.float()) + loss = F.binary_cross_entropy(outs, labels.float()) else: loss = -1 + return SequenceClassifierOutput( loss=loss, logits=outs, From b6268da3529bb121e394cbbbd714519b67812307 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 10 Aug 2023 15:20:56 +0100 Subject: [PATCH 147/300] Roll back early implementation of BertMesh --- grants_tagger_light/models/bert_mesh/model.py | 23 +++---------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index a3077f20..cdff897a 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -67,7 +67,7 @@ def unfreeze_backbone(self): param.requires_grad = True def forward(self, input_ids, labels=None, **kwargs): - """if type(input_ids) is list: + if type(input_ids) is list: # coming from tokenizer input_ids = torch.tensor(input_ids) @@ -82,29 +82,12 @@ def forward(self, input_ids, labels=None, **kwargs): cls = self.bert(input_ids=input_ids)[1] outs = torch.nn.functional.relu(self.linear_1(cls)) outs = self.dropout_layer(outs) - outs = self.linear_out(outs)""" - - if type(input_ids) is list: - # coming from tokenizer - input_ids = torch.tensor(input_ids) - if self.multilabel_attention: - hidden_states = self.bert(input_ids=input_ids)[0] - attention_outs = self.multilabel_attention_layer(hidden_states) - outs = torch.nn.functional.relu(self.linear_1(attention_outs)) - outs = self.dropout_layer(outs) - outs = torch.sigmoid(self.linear_2(outs)) - outs = torch.flatten(outs, start_dim=1) - else: - cls = self.bert(input_ids=input_ids)[1] - outs = torch.nn.functional.relu(self.linear_1(cls)) - outs = self.dropout_layer(outs) - outs = torch.sigmoid(self.linear_out(outs)) + outs = self.linear_out(outs) if labels is not None: - loss = F.binary_cross_entropy(outs, labels.float()) + loss = F.binary_cross_entropy_with_logits(outs, labels.float()) else: loss = -1 - return SequenceClassifierOutput( loss=loss, logits=outs, From 127ab20130c3859bfc689082c69f99d1fbb62e53 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 11 Aug 2023 11:00:58 +0100 Subject: [PATCH 148/300] Fixing bug with number of rows --- grants_tagger_light/preprocessing/preprocess_mesh.py | 2 ++ grants_tagger_light/training/train.py | 1 + 2 files changed, 3 insertions(+) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index e99d00e3..ca072df5 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -188,6 +188,8 @@ def preprocess_mesh( if test_size is None: test_size = 0.05 logger.info(f"Test size not found. Setting it to a frac of the whole dataset equal to {test_size}") + elif test_size > 1.0: + test_size = int(test_size) dset = dset.train_test_split(test_size=test_size) logger.info("Time taken to split into train and test: {}".format(time.time() - t1)) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index f41f2134..b1d44580 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -121,6 +121,7 @@ def train_bertmesh( logger.info("Freezing backbone") model.freeze_backbone() else: + logger.info("Unfreezing backbone") model.unfreeze_backbone() def sklearn_metrics(prediction: EvalPrediction): From aeaa75254ccd606d6be5d07c9472274071f7d159 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 14 Aug 2023 21:15:23 +0100 Subject: [PATCH 149/300] Adds OpenAI augmentation --- grants_tagger_light/augmentation/__init__.py | 8 + grants_tagger_light/augmentation/augment.py | 169 ++++++++++++++++++ .../augmentation/augment_openai.py | 52 ++++++ .../augmentation/prompt.template | 17 ++ grants_tagger_light/cli.py | 5 +- .../preprocessing/preprocess_mesh.py | 35 ++-- poetry.lock | 24 ++- pyproject.toml | 1 + 8 files changed, 298 insertions(+), 13 deletions(-) create mode 100644 grants_tagger_light/augmentation/__init__.py create mode 100644 grants_tagger_light/augmentation/augment.py create mode 100644 grants_tagger_light/augmentation/augment_openai.py create mode 100644 grants_tagger_light/augmentation/prompt.template diff --git a/grants_tagger_light/augmentation/__init__.py b/grants_tagger_light/augmentation/__init__.py new file mode 100644 index 00000000..81ed5b19 --- /dev/null +++ b/grants_tagger_light/augmentation/__init__.py @@ -0,0 +1,8 @@ +import typer +from .augment import augment_cli + +augment_app = typer.Typer() +augment_app.command( + "mesh", + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, +)(augment_cli) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py new file mode 100644 index 00000000..84720dae --- /dev/null +++ b/grants_tagger_light/augmentation/augment.py @@ -0,0 +1,169 @@ +import json +import multiprocessing +import os +import random + +import typer +from loguru import logger +from datasets import load_dataset +import numpy as np +import datetime + + +from grants_tagger_light.augmentation.augment_openai import AugmentOpenAI +from grants_tagger_light.utils.years_tags_parser import parse_years + +augment_app = typer.Typer() + + +def count_elements_in_sublist(sublist): + element_count = {} + for element in sublist: + if element in element_count: + element_count[element] += 1 + else: + element_count[element] = 1 + return element_count + + +def merge_dicts(dict_list): + merged_dict = {} + for d in dict_list: + for key, value in d.items(): + if key in merged_dict: + merged_dict[key] += value + else: + merged_dict[key] = value + return merged_dict + + +def augment( + data_path: str, + save_to_path: str, + model_key: str = 'gpt-3.5-turbo', + num_proc: int = os.cpu_count(), + train_years: list = None, + test_years: list = None, + min_examples: int = 25, + prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', + few_shot_examples: int = 5 +): + if model_key.strip().lower().startswith('gpt-3.5-turbo') or \ + model_key.strip().lower().startswith('text-davinci') or \ + model_key.strip().lower().startswith('gpt-4'): + augment_engine = AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key) + else: + raise NotImplementedError(f"{model_key} not implemented as an augmentation framework") + + # We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing + dset = load_dataset("json", data_files=data_path, num_proc=1) + # By default, any dataset loaded is set to 'train' using the previous command + if "train" in dset: + dset = dset["train"] + + if train_years is not None and len(train_years) > 0: + dset = dset.filter(lambda x: any(np.isin(train_years, [str(x["year"])])), num_proc=num_proc) + if test_years is not None and len(test_years) > 0: + dset = dset.filter(lambda x: not any(np.isin(test_years, [str(x["year"])])), num_proc=num_proc) + + logger.info("Obtaining count values from the labels...") + pool = multiprocessing.Pool(processes=num_proc) + element_counts_list = pool.map(count_elements_in_sublist, dset['meshMajor']) + pool.close() + pool.join() + + merged_element_counts = merge_dicts(element_counts_list) + sorted_merged_element_counts = sorted(merged_element_counts.items(), key=lambda x: x[1], reverse=True) + sorted_merged_element_counts_dict = dict(sorted_merged_element_counts) + + with open(f"{save_to_path}.count", 'w') as f: + f.write(json.dumps(sorted_merged_element_counts_dict, indent=2)) + + tags_to_augment_counts = {k: v for k, v in sorted_merged_element_counts_dict.items() if v < min_examples} + tags_to_augment = [k for k, v in sorted_merged_element_counts_dict.items() if v < min_examples] + + biggest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" for k in tags_to_augment[:5]] + smallest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" for k in tags_to_augment[-5:]] + logger.info(f"Augmenting a total of {len(tags_to_augment)} tags, from {biggest_tags_to_augment} to " + f"{smallest_tags_to_augment}") + + logger.info(f"Collecting existing examples of those tags to send in the prompt") + dset = dset.filter(lambda x: any(np.isin(tags_to_augment, x["meshMajor"])), num_proc=num_proc) + + counter = 0 + with open(save_to_path, 'w') as f: + for t in tags_to_augment: + if tags_to_augment_counts[t] < min_examples: + tmp_dset = dset.filter(lambda x: any(np.isin([t], x["meshMajor"])), num_proc=num_proc) + missing = min_examples - len(tmp_dset) + logger.info(f"Generating {missing} examples for class {t}") + for a in augment_engine.generate(t, tmp_dset, n=missing, few_shot_examples=few_shot_examples): + if a is None: + break + f.write(json.dumps({ + "journal": model_key, + "meshMajor": a['tags'], + "year": [random.choice(train_years) if len(train_years) > 0 else datetime.date.today().year], + "abstractText": a['abstract'], + "pmid": str(counter), + "title": a['title'] + })) + f.write('\n') + f.flush() + counter += 1 + + +@augment_app.command() +def augment_cli( + data_path: str = typer.Argument( + ..., + help="Path to mesh.jsonl"), + save_to_path: str = typer.Argument( + ..., help="Path to save the serialized PyArrow dataset after preprocessing" + ), + model_key: str = typer.Option( + "gpt-3.5-turbo", + help="LLM to use data augmentation. By now, only `openai` is supported" + ), + num_proc: int = typer.Option( + os.cpu_count(), + help="Number of processes to use for data augmentation" + ), + train_years: str = typer.Option( + None, + help="If set, Comma-separated years you want to include in the data augmentation process" + ), + test_years: str = typer.Option( + None, + help="If set, Comma-separated years you want to exclude in the data augmentation process" + ), + min_examples: int = typer.Option( + 25, + help="If set, Comma-separated years you want to exclude in the data augmentation process" + ), + prompt_template: str = typer.Option( + 'grants_tagger_light/augmentation/prompt.template', + help="File to use as a prompt. Make sure to ask the LLM to return a dict with two fields: `abstract` and `tags`" + ), + few_shot_examples: int = typer.Option( + 5, + help="If available, try to send this number of examples to the LLM so that it can generate better abstracts" + ), +): + if not data_path.endswith("jsonl"): + logger.error( + "It seems your input MeSH data is not in `jsonl` format. " + "Please, run first `scripts/mesh_json_to_jsonlpy.`" + ) + exit(-1) + + augment(data_path, + save_to_path, + model_key=model_key, + num_proc=num_proc, + train_years=parse_years(train_years), + test_years=parse_years(test_years), + min_examples=min_examples, + prompt_template=prompt_template, + few_shot_examples=few_shot_examples + ) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py new file mode 100644 index 00000000..9894fa76 --- /dev/null +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -0,0 +1,52 @@ +import json +import os +from loguru import logger +import openai + + +class AugmentOpenAI: + def __init__(self, prompt_template_path, model_key='gpt-3.5-turbo'): + if 'OPENAI_API_KEY' not in os.environ: + logger.error("OPENAI_API_KEY not found in env vars. Please define it before running this program.") + with open(prompt_template_path, 'r') as f: + self.prompt_template = f.read() + self.model_key = model_key + + def generate(self, featured_tag, dataset, n=1, few_shot_examples=10, temperature=1.5, top_p=1, frequence_penalty=0, + presence_penalty=0): + size = min(len(dataset), few_shot_examples) + dataset = dataset[:size] + abstracts = "\n".join(dataset['abstractText']) + tags = [] + for x in dataset['meshMajor']: + tags.extend(x) + mesh_tags = ",".join(list(set(tags))) + + prompt = self.prompt_template.replace('{FEATURED_TAG}', featured_tag) + prompt = prompt.replace('{ABSTRACTS}', abstracts) + prompt = prompt.replace('{MESH_TAGS}', mesh_tags) + + response = openai.ChatCompletion.create( + model=self.model_key, + messages=[ + {"role": "user", "content": prompt}], + n=n, + temperature=temperature, + top_p=top_p, + frequency_penalty=frequence_penalty, + presence_penalty=presence_penalty + ) + for r in response['choices']: + if 'message' in r: + if 'content' in r['message']: + print(r['message']['content']) + try: + r_json = json.loads(r['message']['content']) + a = r_json['abstract'] + # Make sure it does not hallucinate and adds anything new which may not be a MeSH tag + t = [x for x in r_json['tags'] if x in tags] + tl = r_json['title'] + yield {'abstract': a, 'tags': t, 'title': tl} + except Exception as e: + logger.info("OpenAI did not return a proper json format.") + yield None diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template new file mode 100644 index 00000000..ac4376be --- /dev/null +++ b/grants_tagger_light/augmentation/prompt.template @@ -0,0 +1,17 @@ +You are in charge of doing Data Augmentation. I will provide: +1) a REQUIRED MESH TAG; +2) a series of ABSTRACTS to use as a source for augmentation; +3) a series of MESH TAGS. + +You need to produce one new article in a json, with the following fields: +1) The field 'abstract', with a NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. +2) The field 'title', with a small title summarizing the NEW ABSTRACT; +3) The field 'tags', with list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. + +REQUIRED MESH TAG: {FEATURED_TAG} + +ABSTRACTS: +{ABSTRACTS} + +MESH TAGS: +{MESH_TAGS} \ No newline at end of file diff --git a/grants_tagger_light/cli.py b/grants_tagger_light/cli.py index 2a9585a9..62b9cf2a 100644 --- a/grants_tagger_light/cli.py +++ b/grants_tagger_light/cli.py @@ -2,6 +2,7 @@ import typer +from grants_tagger_light.augmentation import augment_app from grants_tagger_light.download_epmc import download_epmc_cli from grants_tagger_light.evaluation import evaluate_app from grants_tagger_light.predict import predict_cli @@ -12,11 +13,13 @@ logger = logging.getLogger(__name__) -app = typer.Typer() +app = typer.Typer(pretty_exceptions_enable=False) app.add_typer(preprocess_app, name="preprocess") +app.add_typer(augment_app, name="augment") app.add_typer(evaluate_app, name="evaluate") + app.command("predict")(predict_cli) tune_app = typer.Typer() diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index ca072df5..344faf63 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -1,5 +1,4 @@ import json -import math import tempfile import typer @@ -94,11 +93,11 @@ def preprocess_mesh( if len(years) > 0: logger.info(f"Removing all years which are not in {years}") - dset = dset.filter(lambda x: any(np.isin(years, [str(x["year"])]))) + dset = dset.filter(lambda x: any(np.isin(years, [str(x["year"])])), num_proc=num_proc) if len(tags) > 0: logger.info(f"Removing all tags which are not in {tags}") - dset = dset.filter(lambda x: any(np.isin(tags, x["meshMajor"]))) + dset = dset.filter(lambda x: any(np.isin(tags, x["meshMajor"])), num_proc=num_proc) # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "pmid", "title"]) @@ -207,28 +206,42 @@ def preprocess_mesh( @preprocess_app.command() def preprocess_mesh_cli( - data_path: str = typer.Argument(..., help="Path to mesh.jsonl"), + data_path: str = typer.Argument( + ..., + help="Path to mesh.jsonl"), save_to_path: str = typer.Argument( - ..., help="Path to save the serialized PyArrow dataset after preprocessing" + ..., + help="Path to save the serialized PyArrow dataset after preprocessing" ), model_key: str = typer.Argument( ..., help="Key to use when loading tokenizer and label2id. " "Leave blank if training from scratch", # noqa ), - test_size: float = typer.Option(None, help="Fraction of data to use for testing in (0,1] or number of rows"), + test_size: float = typer.Option( + None, + help="Fraction of data to use for testing in (0,1] or number of rows"), num_proc: int = typer.Option( - os.cpu_count(), help="Number of processes to use for preprocessing" + os.cpu_count(), + help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( -1, help="Maximum number of samples to use for preprocessing", ), - batch_size: int = typer.Option(256, help="Size of the preprocessing batch"), - tags: str = typer.Option(None, help="Comma-separated tags you want to include in the dataset " + batch_size: int = typer.Option( + 256, + help="Size of the preprocessing batch"), + tags: str = typer.Option( + None, + help="Comma-separated tags you want to include in the dataset " "(the rest will be discarded)"), - train_years: str = typer.Option(None, help="Comma-separated years you want to include in the training dataset"), - test_years: str = typer.Option(None, help="Comma-separated years you want to include in the test dataset"), + train_years: str = typer.Option( + None, + help="Comma-separated years you want to include in the training dataset"), + test_years: str = typer.Option( + None, + help="Comma-separated years you want to include in the test dataset"), ): if not data_path.endswith("jsonl"): diff --git a/poetry.lock b/poetry.lock index 72efc107..7cad6917 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2167,6 +2167,28 @@ files = [ antlr4-python3-runtime = "==4.9.*" PyYAML = ">=5.1.0" +[[package]] +name = "openai" +version = "0.27.8" +description = "Python client library for the OpenAI API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-0.27.8-py3-none-any.whl", hash = "sha256:e0a7c2f7da26bdbe5354b03c6d4b82a2f34bd4458c7a17ae1a7092c3e397e03c"}, + {file = "openai-0.27.8.tar.gz", hash = "sha256:2483095c7db1eee274cebac79e315a986c4e55207bb4fa7b82d185b3a2ed9536"}, +] + +[package.dependencies] +aiohttp = "*" +requests = ">=2.20" +tqdm = "*" + +[package.extras] +datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"] +embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] +wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] + [[package]] name = "orjson" version = "3.9.2" @@ -4123,4 +4145,4 @@ test = ["zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "1cd7926938b8bd8bcfec710786ecc7b248c23c72c6fe01d87b28dc861b12f445" +content-hash = "918a9fa5785312d012f5e08f71753e37f8ad175efb277ccb0de1fb4040b65a44" diff --git a/pyproject.toml b/pyproject.toml index 25ae415f..f8d942a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ transformers = "4.29.2" libpecos = "^1.0.0" loguru = "^0.7.0" wandb = "^0.15.4" +openai = "0.27.8" [tool.poetry.group.dev] From 9b71dd3795a0adfcda5f8560c1c78a601a9f0ddb Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 17:52:32 +0100 Subject: [PATCH 150/300] Parallel augmentation prototype --- README.md | 2 +- grants_tagger_light/augmentation/augment.py | 82 ++++++++++++------- .../augmentation/augment_openai.py | 81 +++++++++++------- grants_tagger_light/training/train.py | 2 +- poetry.lock | 46 ++++++++++- pyproject.toml | 1 + 6 files changed, 152 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index 0bfeb820..8289bf13 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ Lightweight repository for grant tagger model deployment and inference. Adapted from [the original repository](https://github.com/wellcometrust/grants_tagger) Grants tagger is a machine learning powered tool that -assigns biomedical related tags to grant proposals. +assigns biomedical-related tags to grant proposals. Those tags can be custom to the organisation or based upon a preexisting ontology like MeSH. diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 84720dae..ce9025e1 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -16,7 +16,7 @@ augment_app = typer.Typer() -def count_elements_in_sublist(sublist): +def _count_elements_in_sublist(sublist): element_count = {} for element in sublist: if element in element_count: @@ -26,7 +26,7 @@ def count_elements_in_sublist(sublist): return element_count -def merge_dicts(dict_list): +def _merge_dicts(dict_list): merged_dict = {} for d in dict_list: for key, value in d.items(): @@ -37,6 +37,26 @@ def merge_dicts(dict_list): return merged_dict +def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, + augmentation_engine): + logger.info(f"Generating {missing} examples for class {tag}") + counter = 0 + with open(save_to_path, 'a') as f: + for a in augmentation_engine.generate(collect_concurrent_calls, dset, few_shot_examples=few_shot_examples): + if a is None: + break + f.write(json.dumps({ + "journal": model_key, + "meshMajor": a['tags'], + "year": [random.choice(train_years) if len(train_years) > 0 else datetime.date.today().year], + "abstractText": a['abstract'], + "pmid": f"augmented_{tag}_{counter}", + "title": a['title'] + })) + f.write('\n') + f.flush() + counter += 1 + def augment( data_path: str, save_to_path: str, @@ -44,14 +64,15 @@ def augment( num_proc: int = os.cpu_count(), train_years: list = None, test_years: list = None, - min_examples: int = 25, + min_examples: int = 15, prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', - few_shot_examples: int = 5 + few_shot_examples: int = 5, + concurrent_calls: int = 5 ): if model_key.strip().lower().startswith('gpt-3.5-turbo') or \ model_key.strip().lower().startswith('text-davinci') or \ model_key.strip().lower().startswith('gpt-4'): - augment_engine = AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key) + augmentation_engine = AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key) else: raise NotImplementedError(f"{model_key} not implemented as an augmentation framework") @@ -68,11 +89,11 @@ def augment( logger.info("Obtaining count values from the labels...") pool = multiprocessing.Pool(processes=num_proc) - element_counts_list = pool.map(count_elements_in_sublist, dset['meshMajor']) + element_counts_list = pool.map(_count_elements_in_sublist, dset['meshMajor']) pool.close() pool.join() - merged_element_counts = merge_dicts(element_counts_list) + merged_element_counts = _merge_dicts(element_counts_list) sorted_merged_element_counts = sorted(merged_element_counts.items(), key=lambda x: x[1], reverse=True) sorted_merged_element_counts_dict = dict(sorted_merged_element_counts) @@ -89,28 +110,24 @@ def augment( logger.info(f"Collecting existing examples of those tags to send in the prompt") dset = dset.filter(lambda x: any(np.isin(tags_to_augment, x["meshMajor"])), num_proc=num_proc) - - counter = 0 - with open(save_to_path, 'w') as f: - for t in tags_to_augment: + dset = dset.map( + lambda _, y: {'idx': y}, + with_indices=True, + batched=True, + batch_size=batch_size, + desc="Encoding labels", + num_proc=num_proc, + ) + print(dset['idx']) + collect_concurrent_calls = [] + for t in tags_to_augment: + if len(collect_concurrent_calls) >= concurrent_calls: + _generate(collect_concurrent_calls, tag, dset, few_shot_examples, save_to_path, augmentation_engine) + else: if tags_to_augment_counts[t] < min_examples: - tmp_dset = dset.filter(lambda x: any(np.isin([t], x["meshMajor"])), num_proc=num_proc) - missing = min_examples - len(tmp_dset) - logger.info(f"Generating {missing} examples for class {t}") - for a in augment_engine.generate(t, tmp_dset, n=missing, few_shot_examples=few_shot_examples): - if a is None: - break - f.write(json.dumps({ - "journal": model_key, - "meshMajor": a['tags'], - "year": [random.choice(train_years) if len(train_years) > 0 else datetime.date.today().year], - "abstractText": a['abstract'], - "pmid": str(counter), - "title": a['title'] - })) - f.write('\n') - f.flush() - counter += 1 + missing = min_examples - tags_to_augment_counts[tag] + collect_concurrent_calls.append((tags_to_augment_counts[t], missing)) + @augment_app.command() @@ -138,7 +155,7 @@ def augment_cli( help="If set, Comma-separated years you want to exclude in the data augmentation process" ), min_examples: int = typer.Option( - 25, + 15, help="If set, Comma-separated years you want to exclude in the data augmentation process" ), prompt_template: str = typer.Option( @@ -149,6 +166,10 @@ def augment_cli( 5, help="If available, try to send this number of examples to the LLM so that it can generate better abstracts" ), + concurrent_calls: int = typer.Option( + 5, + help="Concurrent calls with 1 tag each to the different model" + ), ): if not data_path.endswith("jsonl"): logger.error( @@ -165,5 +186,6 @@ def augment_cli( test_years=parse_years(test_years), min_examples=min_examples, prompt_template=prompt_template, - few_shot_examples=few_shot_examples + few_shot_examples=few_shot_examples, + concurrent_calls=concurrent_calls ) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 9894fa76..5d89c8e6 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -2,6 +2,7 @@ import os from loguru import logger import openai +from openai_multi_client import OpenAIMultiClient class AugmentOpenAI: @@ -11,42 +12,64 @@ def __init__(self, prompt_template_path, model_key='gpt-3.5-turbo'): with open(prompt_template_path, 'r') as f: self.prompt_template = f.read() self.model_key = model_key + self.api = OpenAIMultiClient(endpoint="chats", data_template={"model": self.model_key}) - def generate(self, featured_tag, dataset, n=1, few_shot_examples=10, temperature=1.5, top_p=1, frequence_penalty=0, - presence_penalty=0): - size = min(len(dataset), few_shot_examples) - dataset = dataset[:size] - abstracts = "\n".join(dataset['abstractText']) + def _create_message(self, featured_tag, featured_tag_dset): + abstracts = "\n".join(featured_tag_dset['abstractText']) tags = [] for x in dataset['meshMajor']: tags.extend(x) + mesh_tags = ",".join(list(set(tags))) prompt = self.prompt_template.replace('{FEATURED_TAG}', featured_tag) prompt = prompt.replace('{ABSTRACTS}', abstracts) prompt = prompt.replace('{MESH_TAGS}', mesh_tags) - response = openai.ChatCompletion.create( - model=self.model_key, - messages=[ - {"role": "user", "content": prompt}], - n=n, - temperature=temperature, - top_p=top_p, - frequency_penalty=frequence_penalty, - presence_penalty=presence_penalty - ) - for r in response['choices']: - if 'message' in r: - if 'content' in r['message']: - print(r['message']['content']) - try: - r_json = json.loads(r['message']['content']) - a = r_json['abstract'] - # Make sure it does not hallucinate and adds anything new which may not be a MeSH tag - t = [x for x in r_json['tags'] if x in tags] - tl = r_json['title'] - yield {'abstract': a, 'tags': t, 'title': tl} - except Exception as e: - logger.info("OpenAI did not return a proper json format.") - yield None + return [{"role": "user", "content": prompt}] + + def _make_requests(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, + frequence_penalty=0, presence_penalty=0): + for num in range(len(collect_concurrent_calls)): + t = collect_concurrent_calls[num][0] + n = collect_concurrent_calls[num][1] + # RAG: I select similar articles to provide them to the LLM + tmp_dset = dset.filter(lambda x: any(np.isin([t], x["meshMajor"])), num_proc=num_proc) + # I remove them from the dataset to process to make it smaller and quicker over time + dset = dset.filter(lambda example, idx: idx not in tmp_dset['idx'], with_indices=True, + num_proc=num_proc) + size = min(len(tmp_dset), few_shot_examples) + tmp_dset = tmp_dset[:size] + + self.api.request(data={ + "model": self.model_key, + "n": n, + "temperature": temperature, + "top_p": top_p, + "frequence_penalty": frequence_penalty, + "presence_penalty": presence_penalty, + "messages": self._create_message(t, tmp_dset) + }, metadata={'num': num}) + + def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, + frequence_penalty=0, presence_penalty=0): + + self.api.run_request_function(self._make_requests, collect_concurrent_calls, dset, few_shot_examples, + temperature, top_p, frequence_penalty, presence_penalty) + + for response in self.api: + for r in response['choices']: + if 'message' in r: + if 'content' in r['message']: + print(r['message']['content']) + try: + r_json = json.loads(r['message']['content']) + a = r_json['abstract'] + # Make sure it does not hallucinate and adds anything new which may not be a MeSH tag + t = [x for x in r_json['tags'] if x in tags] + tl = r_json['title'] + yield {'abstract': a, 'tags': t, 'title': tl} + except Exception as e: + logger.info("OpenAI did not return a proper json format.") + yield None + diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index b1d44580..69b8c890 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -80,7 +80,7 @@ def train_bertmesh( if max_samples > 0: train_dset_size = min(max_samples, train_dset_size) logger.info(f"Training max samples: {train_dset_size}.") - train_dset.filter(lambda example, idx: idx < train_dset_size, with_indices=True) + train_dset.filter(lambda example, idx: idx < train_dset_size, with_indices=True, num_proc=num_proc) else: logger.info("Training with all data...") diff --git a/poetry.lock b/poetry.lock index 7cad6917..a98613d6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -183,6 +183,20 @@ files = [ {file = "aioitertools-0.11.0.tar.gz", hash = "sha256:42c68b8dd3a69c2bf7f2233bf7df4bb58b557bca5252ac02ed5187bbc67d6831"}, ] +[[package]] +name = "aioprocessing" +version = "2.0.1" +description = "A Python 3.5+ library that integrates the multiprocessing module with asyncio." +optional = false +python-versions = ">=3.5" +files = [ + {file = "aioprocessing-2.0.1-py3-none-any.whl", hash = "sha256:8fcac4b0108b72eb9df76e06a9d7e05720ee1e8330829d3fd53fa059879be586"}, + {file = "aioprocessing-2.0.1.tar.gz", hash = "sha256:fe01c7b1a38c78168611d3040e73d93036c3b7c8a649d636dc9ed7a3bc9b1ba2"}, +] + +[package.extras] +dill = ["multiprocess"] + [[package]] name = "aiosignal" version = "1.3.1" @@ -2189,6 +2203,22 @@ dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-moc embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] +[[package]] +name = "openai-multi-client" +version = "0.1.1" +description = "A parallel client for OpenAI API (and more)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "openai_multi_client-0.1.1-py3-none-any.whl", hash = "sha256:894bb336ea63cf555693a80988e830a2c4dfbe669ab2076100073a854c7d7835"}, + {file = "openai_multi_client-0.1.1.tar.gz", hash = "sha256:0433b7fa717f522ab403ba4b689ff09220e843aa2c2c1251a6f9f14545600e49"}, +] + +[package.dependencies] +aioprocessing = "*" +openai = "*" +tenacity = "*" + [[package]] name = "orjson" version = "3.9.2" @@ -3436,6 +3466,20 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tenacity" +version = "8.2.3" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"}, + {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"}, +] + +[package.extras] +doc = ["reno", "sphinx", "tornado (>=4.5)"] + [[package]] name = "threadpoolctl" version = "3.2.0" @@ -4145,4 +4189,4 @@ test = ["zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "918a9fa5785312d012f5e08f71753e37f8ad175efb277ccb0de1fb4040b65a44" +content-hash = "cb53df7c0ffa68f5c40fe32b94d716f6ee3d944bb795e67d653389a0c0070d93" diff --git a/pyproject.toml b/pyproject.toml index f8d942a9..574a313a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ libpecos = "^1.0.0" loguru = "^0.7.0" wandb = "^0.15.4" openai = "0.27.8" +openai-multi-client = "^0.1.1" [tool.poetry.group.dev] From a95e582d7134bf37802384c70bff5249b63e183e Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 18:30:32 +0100 Subject: [PATCH 151/300] Adds batch size to augment --- grants_tagger_light/augmentation/augment.py | 6 ++++++ grants_tagger_light/training/train.py | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index ce9025e1..dece8d83 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -62,6 +62,7 @@ def augment( save_to_path: str, model_key: str = 'gpt-3.5-turbo', num_proc: int = os.cpu_count(), + batch_size: int = 64, train_years: list = None, test_years: list = None, min_examples: int = 15, @@ -146,6 +147,10 @@ def augment_cli( os.cpu_count(), help="Number of processes to use for data augmentation" ), + batch_size: int = typer.Option( + 64, + help="Preprocessing batch size (for dataset, filter, map, ...)" + ), train_years: str = typer.Option( None, help="If set, Comma-separated years you want to include in the data augmentation process" @@ -182,6 +187,7 @@ def augment_cli( save_to_path, model_key=model_key, num_proc=num_proc, + batch_size=batch_size, train_years=parse_years(train_years), test_years=parse_years(test_years), min_examples=min_examples, diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 69b8c890..95c0089a 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -153,6 +153,12 @@ def sklearn_metrics(prediction: EvalPrediction): max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) training_args.max_steps = max_steps + # Instantiate the AdamW optimizer with a constant learning rate + optimizer = AdamW(model.parameters(), lr=training_args.learning_rate) # Set your desired learning rate + + # Create a learning rate scheduler + scheduler = get_constant_scheduler(optimizer, num_warmup_steps=0, num_training_steps=training_args.max_steps) + trainer = Trainer( model=model, args=training_args, @@ -160,6 +166,8 @@ def sklearn_metrics(prediction: EvalPrediction): eval_dataset=val_dset, data_collator=collator, compute_metrics=sklearn_metrics, + optimizer=optimizer, + scheduler=scheduler, ) logger.info(training_args) From 460b3da083a18917f254261962da868facb0b15b Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 18:35:25 +0100 Subject: [PATCH 152/300] Fixes bug --- grants_tagger_light/augmentation/augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index dece8d83..507d1a30 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -123,10 +123,10 @@ def augment( collect_concurrent_calls = [] for t in tags_to_augment: if len(collect_concurrent_calls) >= concurrent_calls: - _generate(collect_concurrent_calls, tag, dset, few_shot_examples, save_to_path, augmentation_engine) + _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine) else: if tags_to_augment_counts[t] < min_examples: - missing = min_examples - tags_to_augment_counts[tag] + missing = min_examples - tags_to_augment_counts[t] collect_concurrent_calls.append((tags_to_augment_counts[t], missing)) From c61d60db99429018350c99728313242ed5d908e6 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 18:47:17 +0100 Subject: [PATCH 153/300] Fixes bug --- grants_tagger_light/augmentation/augment.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 507d1a30..78291bb3 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -8,6 +8,7 @@ from datasets import load_dataset import numpy as np import datetime +import uuid from grants_tagger_light.augmentation.augment_openai import AugmentOpenAI @@ -38,8 +39,8 @@ def _merge_dicts(dict_list): def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, - augmentation_engine): - logger.info(f"Generating {missing} examples for class {tag}") + augmentation_engine, train_years): + logger.info(f"Generating missing examples for classes...") counter = 0 with open(save_to_path, 'a') as f: for a in augmentation_engine.generate(collect_concurrent_calls, dset, few_shot_examples=few_shot_examples): @@ -50,13 +51,14 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, "meshMajor": a['tags'], "year": [random.choice(train_years) if len(train_years) > 0 else datetime.date.today().year], "abstractText": a['abstract'], - "pmid": f"augmented_{tag}_{counter}", + "pmid": uuid.uuid4.hex(), "title": a['title'] })) f.write('\n') f.flush() counter += 1 + def augment( data_path: str, save_to_path: str, @@ -123,14 +125,13 @@ def augment( collect_concurrent_calls = [] for t in tags_to_augment: if len(collect_concurrent_calls) >= concurrent_calls: - _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine) + _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine, train_years) else: if tags_to_augment_counts[t] < min_examples: missing = min_examples - tags_to_augment_counts[t] collect_concurrent_calls.append((tags_to_augment_counts[t], missing)) - @augment_app.command() def augment_cli( data_path: str = typer.Argument( @@ -194,4 +195,4 @@ def augment_cli( prompt_template=prompt_template, few_shot_examples=few_shot_examples, concurrent_calls=concurrent_calls - ) + ) From f7c9fc016686e17fb445c8be037bd6b4b221745b Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 18:51:53 +0100 Subject: [PATCH 154/300] Fixes bug --- grants_tagger_light/augmentation/augment.py | 8 +++++--- grants_tagger_light/augmentation/augment_openai.py | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 78291bb3..67bef7cb 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -39,11 +39,12 @@ def _merge_dicts(dict_list): def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, - augmentation_engine, train_years): + augmentation_engine, train_years, num_proc): logger.info(f"Generating missing examples for classes...") counter = 0 with open(save_to_path, 'a') as f: - for a in augmentation_engine.generate(collect_concurrent_calls, dset, few_shot_examples=few_shot_examples): + for a in augmentation_engine.generate(collect_concurrent_calls, dset, few_shot_examples=few_shot_examples, + num_proc=num_proc): if a is None: break f.write(json.dumps({ @@ -125,7 +126,8 @@ def augment( collect_concurrent_calls = [] for t in tags_to_augment: if len(collect_concurrent_calls) >= concurrent_calls: - _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine, train_years) + _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine, + train_years, num_proc) else: if tags_to_augment_counts[t] < min_examples: missing = min_examples - tags_to_augment_counts[t] diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 5d89c8e6..b3d6d8cf 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -29,7 +29,7 @@ def _create_message(self, featured_tag, featured_tag_dset): return [{"role": "user", "content": prompt}] def _make_requests(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, - frequence_penalty=0, presence_penalty=0): + frequence_penalty=0, presence_penalty=0, num_proc=os.cpu.count()): for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] n = collect_concurrent_calls[num][1] @@ -52,10 +52,10 @@ def _make_requests(self, collect_concurrent_calls, dset, few_shot_examples=10, t }, metadata={'num': num}) def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, - frequence_penalty=0, presence_penalty=0): + frequence_penalty=0, presence_penalty=0, num_proc=os.cpu.count()): self.api.run_request_function(self._make_requests, collect_concurrent_calls, dset, few_shot_examples, - temperature, top_p, frequence_penalty, presence_penalty) + temperature, top_p, frequence_penalty, presence_penalty, num_proc) for response in self.api: for r in response['choices']: From 504075a58b4bd9fee53f8e56b01c22173a05a431 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 18:53:07 +0100 Subject: [PATCH 155/300] Fixes bug --- grants_tagger_light/augmentation/augment_openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index b3d6d8cf..01686940 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -29,7 +29,7 @@ def _create_message(self, featured_tag, featured_tag_dset): return [{"role": "user", "content": prompt}] def _make_requests(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, - frequence_penalty=0, presence_penalty=0, num_proc=os.cpu.count()): + frequence_penalty=0, presence_penalty=0, num_proc=os.cpu_count()): for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] n = collect_concurrent_calls[num][1] @@ -52,7 +52,7 @@ def _make_requests(self, collect_concurrent_calls, dset, few_shot_examples=10, t }, metadata={'num': num}) def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, - frequence_penalty=0, presence_penalty=0, num_proc=os.cpu.count()): + frequence_penalty=0, presence_penalty=0, num_proc=os.cpu_count()): self.api.run_request_function(self._make_requests, collect_concurrent_calls, dset, few_shot_examples, temperature, top_p, frequence_penalty, presence_penalty, num_proc) From 8249ab97a886e84122b8a60d489d3bc90ffcc6ec Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 18:55:28 +0100 Subject: [PATCH 156/300] Adds numpy --- grants_tagger_light/augmentation/augment_openai.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 01686940..f8389c18 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -3,6 +3,7 @@ from loguru import logger import openai from openai_multi_client import OpenAIMultiClient +import numpy as np class AugmentOpenAI: From 1a1309fda6d4ae63806f918efebe1984888ba708 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 18:56:58 +0100 Subject: [PATCH 157/300] Fixes bug --- grants_tagger_light/augmentation/augment_openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index f8389c18..c015bf56 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -18,7 +18,7 @@ def __init__(self, prompt_template_path, model_key='gpt-3.5-turbo'): def _create_message(self, featured_tag, featured_tag_dset): abstracts = "\n".join(featured_tag_dset['abstractText']) tags = [] - for x in dataset['meshMajor']: + for x in featured_tag_dset['meshMajor']: tags.extend(x) mesh_tags = ",".join(list(set(tags))) From 0aee6b56da8e4e36b6a2d34b6b493d7093cda2ff Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 18:58:11 +0100 Subject: [PATCH 158/300] Fixes bug --- grants_tagger_light/augmentation/augment.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 67bef7cb..2cb8fd68 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -44,7 +44,7 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, counter = 0 with open(save_to_path, 'a') as f: for a in augmentation_engine.generate(collect_concurrent_calls, dset, few_shot_examples=few_shot_examples, - num_proc=num_proc): + num_proc=num_proc): if a is None: break f.write(json.dumps({ @@ -122,7 +122,6 @@ def augment( desc="Encoding labels", num_proc=num_proc, ) - print(dset['idx']) collect_concurrent_calls = [] for t in tags_to_augment: if len(collect_concurrent_calls) >= concurrent_calls: From 2858c00a0a6f5b82952a4f8228db761ec46b28c1 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 18:58:30 +0100 Subject: [PATCH 159/300] Fixes bug --- grants_tagger_light/augmentation/augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 2cb8fd68..856569d2 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -119,7 +119,7 @@ def augment( with_indices=True, batched=True, batch_size=batch_size, - desc="Encoding labels", + desc="Creating idx", num_proc=num_proc, ) collect_concurrent_calls = [] From 612d970530b694fc820d8b6eb08bc288446cef42 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:00:44 +0100 Subject: [PATCH 160/300] Fixes bug --- grants_tagger_light/augmentation/augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 856569d2..f04e83d0 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -130,7 +130,7 @@ def augment( else: if tags_to_augment_counts[t] < min_examples: missing = min_examples - tags_to_augment_counts[t] - collect_concurrent_calls.append((tags_to_augment_counts[t], missing)) + collect_concurrent_calls.append((t, missing)) @augment_app.command() From e48cc067a0ae49388319ba9b05f9607c0360c85b Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:08:21 +0100 Subject: [PATCH 161/300] Args to kwargs --- grants_tagger_light/augmentation/augment_openai.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index c015bf56..47ba1793 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -55,8 +55,15 @@ def _make_requests(self, collect_concurrent_calls, dset, few_shot_examples=10, t def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, frequence_penalty=0, presence_penalty=0, num_proc=os.cpu_count()): - self.api.run_request_function(self._make_requests, collect_concurrent_calls, dset, few_shot_examples, - temperature, top_p, frequence_penalty, presence_penalty, num_proc) + self.api.run_request_function(self._make_requests, + collect_concurrent_calls=collect_concurrent_calls, + dset=dset, + few_shot_examples=few_shot_examples, + temperature=temperature, + top_p=top_p, + frequence_penalty=frequence_penalty, + presence_penalty=presence_penalty, + num_proc=num_proc) for response in self.api: for r in response['choices']: From 24790ae8b393241b0a5ee8ca80d8adfae5053a94 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:11:42 +0100 Subject: [PATCH 162/300] Args to kwargs --- .../augmentation/augment_openai.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 47ba1793..67deb19e 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -29,8 +29,22 @@ def _create_message(self, featured_tag, featured_tag_dset): return [{"role": "user", "content": prompt}] - def _make_requests(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, - frequence_penalty=0, presence_penalty=0, num_proc=os.cpu_count()): + def _make_requests(self, *kwargs): + temperature = kwargs['temperature'] + print(temperature) + + top_p = kwargs['top_p'] + print(top_p) + + frequence_penalty = kwargs['frequence_penalty'] + print(frequence_penalty) + + presence_penalty = kwargs['presence_penalty'] + print(presence_penalty) + + num_proc = kwargs['num_proc'] + print(num_proc) + for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] n = collect_concurrent_calls[num][1] From c098506e90f7959dbfd9a29419530971078d7a0e Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:13:15 +0100 Subject: [PATCH 163/300] Args to kwargs --- .../augmentation/augment_openai.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 67deb19e..506b4524 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -29,22 +29,20 @@ def _create_message(self, featured_tag, featured_tag_dset): return [{"role": "user", "content": prompt}] - def _make_requests(self, *kwargs): - temperature = kwargs['temperature'] + def _make_requests(self, + collect_concurrent_calls, + dset, + few_shot_examples, + temperature, + top_p, + frequence_penalty, + presence_penalty, + num_proc): print(temperature) - - top_p = kwargs['top_p'] print(top_p) - - frequence_penalty = kwargs['frequence_penalty'] print(frequence_penalty) - - presence_penalty = kwargs['presence_penalty'] print(presence_penalty) - - num_proc = kwargs['num_proc'] print(num_proc) - for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] n = collect_concurrent_calls[num][1] From 5ef18f4b907b01772bd4b260142dcd7d89c91efb Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:15:27 +0100 Subject: [PATCH 164/300] Removes frequence penalty --- grants_tagger_light/augmentation/augment_openai.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 506b4524..424d53d3 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -38,11 +38,6 @@ def _make_requests(self, frequence_penalty, presence_penalty, num_proc): - print(temperature) - print(top_p) - print(frequence_penalty) - print(presence_penalty) - print(num_proc) for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] n = collect_concurrent_calls[num][1] @@ -59,7 +54,7 @@ def _make_requests(self, "n": n, "temperature": temperature, "top_p": top_p, - "frequence_penalty": frequence_penalty, + # "frequence_penalty": frequence_penalty, "presence_penalty": presence_penalty, "messages": self._create_message(t, tmp_dset) }, metadata={'num': num}) From a4760466f87e824d5ed14f8fa0c30632ccb2e2de Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:17:06 +0100 Subject: [PATCH 165/300] Removes frequence penalty --- grants_tagger_light/augmentation/augment_openai.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 424d53d3..bf092f02 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -35,7 +35,6 @@ def _make_requests(self, few_shot_examples, temperature, top_p, - frequence_penalty, presence_penalty, num_proc): for num in range(len(collect_concurrent_calls)): @@ -54,13 +53,12 @@ def _make_requests(self, "n": n, "temperature": temperature, "top_p": top_p, - # "frequence_penalty": frequence_penalty, "presence_penalty": presence_penalty, "messages": self._create_message(t, tmp_dset) }, metadata={'num': num}) def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, - frequence_penalty=0, presence_penalty=0, num_proc=os.cpu_count()): + presence_penalty=0, num_proc=os.cpu_count()): self.api.run_request_function(self._make_requests, collect_concurrent_calls=collect_concurrent_calls, @@ -68,7 +66,6 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera few_shot_examples=few_shot_examples, temperature=temperature, top_p=top_p, - frequence_penalty=frequence_penalty, presence_penalty=presence_penalty, num_proc=num_proc) From e48bfbd32ecb895724371f2b95298dda01078a65 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:18:50 +0100 Subject: [PATCH 166/300] Payload response fix --- grants_tagger_light/augmentation/augment_openai.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index bf092f02..557441bc 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -69,8 +69,9 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera presence_penalty=presence_penalty, num_proc=num_proc) - for response in self.api: - for r in response['choices']: + for result in self.api: + num = result.metadata['num'] + for r in result.response['choices']: if 'message' in r: if 'content' in r['message']: print(r['message']['content']) From 33208556247a707d22e2efa35f0eb9bcd6022489 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:28:50 +0100 Subject: [PATCH 167/300] Payload response fix --- grants_tagger_light/augmentation/augment_openai.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 557441bc..35c33a5d 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -74,15 +74,17 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera for r in result.response['choices']: if 'message' in r: if 'content' in r['message']: - print(r['message']['content']) try: r_json = json.loads(r['message']['content']) a = r_json['abstract'] # Make sure it does not hallucinate and adds anything new which may not be a MeSH tag t = [x for x in r_json['tags'] if x in tags] tl = r_json['title'] + print("YIELD!!!!!!!!!!!!!!") yield {'abstract': a, 'tags': t, 'title': tl} except Exception as e: + + print("ERROR!!!!!!!!!!!!!!") logger.info("OpenAI did not return a proper json format.") yield None From 770717e5a1f43ded7c7170a366aec3626708492b Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:30:36 +0100 Subject: [PATCH 168/300] Payload response fix --- grants_tagger_light/augmentation/augment_openai.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 35c33a5d..158ee635 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -83,8 +83,10 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera print("YIELD!!!!!!!!!!!!!!") yield {'abstract': a, 'tags': t, 'title': tl} except Exception as e: - print("ERROR!!!!!!!!!!!!!!") + print(e) + print(r['message']['content']) + exit(-1) logger.info("OpenAI did not return a proper json format.") yield None From c35c1d37b79da82ea920060f4ab8a63b6335560e Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:34:38 +0100 Subject: [PATCH 169/300] Payload response fix --- grants_tagger_light/augmentation/augment_openai.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 158ee635..146502cd 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -77,8 +77,7 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera try: r_json = json.loads(r['message']['content']) a = r_json['abstract'] - # Make sure it does not hallucinate and adds anything new which may not be a MeSH tag - t = [x for x in r_json['tags'] if x in tags] + t = r_json['tags'] tl = r_json['title'] print("YIELD!!!!!!!!!!!!!!") yield {'abstract': a, 'tags': t, 'title': tl} From 7d92c1c0f8f916cb5dc4437f66804b02bcd1178b Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:36:06 +0100 Subject: [PATCH 170/300] Payload response fix --- grants_tagger_light/augmentation/augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index f04e83d0..64a5f09c 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -39,7 +39,7 @@ def _merge_dicts(dict_list): def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, - augmentation_engine, train_years, num_proc): + augmentation_engine, train_years, num_proc, model_key): logger.info(f"Generating missing examples for classes...") counter = 0 with open(save_to_path, 'a') as f: @@ -126,7 +126,7 @@ def augment( for t in tags_to_augment: if len(collect_concurrent_calls) >= concurrent_calls: _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine, - train_years, num_proc) + train_years, num_proc, model_key) else: if tags_to_augment_counts[t] < min_examples: missing = min_examples - tags_to_augment_counts[t] From 4285042aa92ffe953cecdd20b79195dfa6beb207 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:38:33 +0100 Subject: [PATCH 171/300] Fixing uuid --- grants_tagger_light/augmentation/augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 64a5f09c..018fa809 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -52,7 +52,7 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, "meshMajor": a['tags'], "year": [random.choice(train_years) if len(train_years) > 0 else datetime.date.today().year], "abstractText": a['abstract'], - "pmid": uuid.uuid4.hex(), + "pmid": uuid.uuid4().hex, "title": a['title'] })) f.write('\n') From 6afce7dc42e55f1b86d07f5a746c3b6625e91376 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:46:30 +0100 Subject: [PATCH 172/300] Adds inspiration --- grants_tagger_light/augmentation/augment.py | 3 ++- grants_tagger_light/augmentation/augment_openai.py | 3 ++- grants_tagger_light/augmentation/prompt.template | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 018fa809..9d8a85f1 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -53,7 +53,8 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, "year": [random.choice(train_years) if len(train_years) > 0 else datetime.date.today().year], "abstractText": a['abstract'], "pmid": uuid.uuid4().hex, - "title": a['title'] + "title": a['title'], + "inspiration": a['inspiration'] })) f.write('\n') f.flush() diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 146502cd..e4b9c18d 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -79,8 +79,9 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera a = r_json['abstract'] t = r_json['tags'] tl = r_json['title'] + i = r_json['inspiration'] print("YIELD!!!!!!!!!!!!!!") - yield {'abstract': a, 'tags': t, 'title': tl} + yield {'abstract': a, 'tags': t, 'title': tl, 'inspiration': i} except Exception as e: print("ERROR!!!!!!!!!!!!!!") print(e) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index ac4376be..498bf36f 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -7,7 +7,9 @@ You need to produce one new article in a json, with the following fields: 1) The field 'abstract', with a NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. 2) The field 'title', with a small title summarizing the NEW ABSTRACT; 3) The field 'tags', with list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. +4) The field 'inspiration', returning the ABSTRACTS I sent to you. +Only return a json, don't add any explanation or any other information besides the `abstract`, `title`, `tags` and `inspiration` fields. REQUIRED MESH TAG: {FEATURED_TAG} ABSTRACTS: From 94a8310cfdc52059a2eef5812f35adfff715f727 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:54:09 +0100 Subject: [PATCH 173/300] Debugs --- grants_tagger_light/augmentation/augment.py | 2 +- grants_tagger_light/augmentation/augment_openai.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 9d8a85f1..f3338920 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -44,7 +44,7 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, counter = 0 with open(save_to_path, 'a') as f: for a in augmentation_engine.generate(collect_concurrent_calls, dset, few_shot_examples=few_shot_examples, - num_proc=num_proc): + num_proc=num_proc, save_to_path=save_to_path) if a is None: break f.write(json.dumps({ diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index e4b9c18d..bf4900f8 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -58,7 +58,7 @@ def _make_requests(self, }, metadata={'num': num}) def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, - presence_penalty=0, num_proc=os.cpu_count()): + presence_penalty=0, num_proc=os.cpu_count(), save_to_path='err'): self.api.run_request_function(self._make_requests, collect_concurrent_calls=collect_concurrent_calls, @@ -83,9 +83,10 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera print("YIELD!!!!!!!!!!!!!!") yield {'abstract': a, 'tags': t, 'title': tl, 'inspiration': i} except Exception as e: + with open(f'{save_to_path}.err', 'w') as f: + f.write(r['message']['content']) + f.write("\n") print("ERROR!!!!!!!!!!!!!!") - print(e) - print(r['message']['content']) exit(-1) logger.info("OpenAI did not return a proper json format.") yield None From 6aadf6aa7e9075217ce2324cf785c749caf3bdf9 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 19:54:40 +0100 Subject: [PATCH 174/300] Debugs --- grants_tagger_light/augmentation/augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index f3338920..7cabe504 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -44,7 +44,7 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, counter = 0 with open(save_to_path, 'a') as f: for a in augmentation_engine.generate(collect_concurrent_calls, dset, few_shot_examples=few_shot_examples, - num_proc=num_proc, save_to_path=save_to_path) + num_proc=num_proc, save_to_path=save_to_path): if a is None: break f.write(json.dumps({ From 36e4d3e84c65ebc9cd68ab3f63b5c9b3aad4ce3f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:02:33 +0100 Subject: [PATCH 175/300] Asks to reformat the quotes --- grants_tagger_light/augmentation/prompt.template | 2 ++ 1 file changed, 2 insertions(+) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 498bf36f..b1a0f286 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -10,6 +10,8 @@ You need to produce one new article in a json, with the following fields: 4) The field 'inspiration', returning the ABSTRACTS I sent to you. Only return a json, don't add any explanation or any other information besides the `abstract`, `title`, `tags` and `inspiration` fields. +In the json, in the value (not in the field names), replace all the quotes (") with single quotes ('). + REQUIRED MESH TAG: {FEATURED_TAG} ABSTRACTS: From f351c6f65e8a207e67cc03f4c47bea23d7d4d7de Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:05:39 +0100 Subject: [PATCH 176/300] Asks to reformat the quotes --- grants_tagger_light/augmentation/prompt.template | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index b1a0f286..751e339b 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -4,12 +4,14 @@ You are in charge of doing Data Augmentation. I will provide: 3) a series of MESH TAGS. You need to produce one new article in a json, with the following fields: -1) The field 'abstract', with a NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. -2) The field 'title', with a small title summarizing the NEW ABSTRACT; -3) The field 'tags', with list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. -4) The field 'inspiration', returning the ABSTRACTS I sent to you. +1) The field "abstract", with a NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. +2) The field "title", with a small title summarizing the NEW ABSTRACT; +3) The field "tags", with list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. +4) The field "inspiration", returning the ABSTRACTS I sent to you. -Only return a json, don't add any explanation or any other information besides the `abstract`, `title`, `tags` and `inspiration` fields. +Only return a json, don't add any explanation or any other information besides the "abstract", "title", "tags" and "inspiration" fields. + +After producing the json, go over all the values of the fields "abstract", "title", "tags" and "inspiration", and escape all the quotes in them. In the json, in the value (not in the field names), replace all the quotes (") with single quotes ('). REQUIRED MESH TAG: {FEATURED_TAG} From f5679039f0f0dc1f75a00c980d3e6db9ec7fe47f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:10:01 +0100 Subject: [PATCH 177/300] Moves to a csv format --- .../augmentation/augment_openai.py | 10 +++++----- grants_tagger_light/augmentation/prompt.template | 15 ++++++--------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index bf4900f8..eea5f3d1 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -75,11 +75,11 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera if 'message' in r: if 'content' in r['message']: try: - r_json = json.loads(r['message']['content']) - a = r_json['abstract'] - t = r_json['tags'] - tl = r_json['title'] - i = r_json['inspiration'] + pieces = r['message']['content'].split('@@@@') + a = pieces[0] + t = pieces[1] + tl = pieces[2] + i = pieces[3] print("YIELD!!!!!!!!!!!!!!") yield {'abstract': a, 'tags': t, 'title': tl, 'inspiration': i} except Exception as e: diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 751e339b..8dd3d2cf 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -3,16 +3,13 @@ You are in charge of doing Data Augmentation. I will provide: 2) a series of ABSTRACTS to use as a source for augmentation; 3) a series of MESH TAGS. -You need to produce one new article in a json, with the following fields: -1) The field "abstract", with a NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. -2) The field "title", with a small title summarizing the NEW ABSTRACT; -3) The field "tags", with list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. -4) The field "inspiration", returning the ABSTRACTS I sent to you. +You need to return the following information, one after another, separated by '@@@@': +1) The NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. +2) A small title summarizing the NEW ABSTRACT; +3) The list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. +4) ABSTRACTS I sent to you. -Only return a json, don't add any explanation or any other information besides the "abstract", "title", "tags" and "inspiration" fields. - -After producing the json, go over all the values of the fields "abstract", "title", "tags" and "inspiration", and escape all the quotes in them. -In the json, in the value (not in the field names), replace all the quotes (") with single quotes ('). +Only return those four points, one after another, nothing more, nothing less. REQUIRED MESH TAG: {FEATURED_TAG} From c484ceec5dc2e5b1af6af91d3f799607cc6e0571 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:13:34 +0100 Subject: [PATCH 178/300] Moves to a csv format --- grants_tagger_light/augmentation/prompt.template | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 8dd3d2cf..e5e8c10c 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -3,12 +3,15 @@ You are in charge of doing Data Augmentation. I will provide: 2) a series of ABSTRACTS to use as a source for augmentation; 3) a series of MESH TAGS. -You need to return the following information, one after another, separated by '@@@@': +You need to return the following information: 1) The NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. -2) A small title summarizing the NEW ABSTRACT; -3) The list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. +2) A small TITLE summarizing the NEW ABSTRACT; +3) A comma-separated list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. 4) ABSTRACTS I sent to you. +The output format should be the information separated by '@@@@'. Example: NEW ABSTRACT@@@@TITLE@@@@TAGS@@@@ABSTRACTS. +Don't include the headers NEW ABSTRACT, TITLE, TAGS or ABSTRACTS. Just the information. + Only return those four points, one after another, nothing more, nothing less. REQUIRED MESH TAG: {FEATURED_TAG} From 7f040c8e1baa3e1f9ba3d600b0652906e9d76e79 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:16:29 +0100 Subject: [PATCH 179/300] Moves to a csv format --- grants_tagger_light/augmentation/prompt.template | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index e5e8c10c..c4f35680 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -3,14 +3,13 @@ You are in charge of doing Data Augmentation. I will provide: 2) a series of ABSTRACTS to use as a source for augmentation; 3) a series of MESH TAGS. -You need to return the following information: +You need to return the following information in a csv format: 1) The NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. 2) A small TITLE summarizing the NEW ABSTRACT; 3) A comma-separated list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. 4) ABSTRACTS I sent to you. -The output format should be the information separated by '@@@@'. Example: NEW ABSTRACT@@@@TITLE@@@@TAGS@@@@ABSTRACTS. -Don't include the headers NEW ABSTRACT, TITLE, TAGS or ABSTRACTS. Just the information. +The csv output format should contain 4 columns. Use the separator '@@@@' for the csv. Don't include a header in the csv. Only return those four points, one after another, nothing more, nothing less. From 6b1759ea8ea6693c2f269ca3edd6b4808dd1859a Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:19:01 +0100 Subject: [PATCH 180/300] Moves to a csv format --- grants_tagger_light/augmentation/prompt.template | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index c4f35680..dd9cf546 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -3,13 +3,17 @@ You are in charge of doing Data Augmentation. I will provide: 2) a series of ABSTRACTS to use as a source for augmentation; 3) a series of MESH TAGS. -You need to return the following information in a csv format: -1) The NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. -2) A small TITLE summarizing the NEW ABSTRACT; -3) A comma-separated list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. -4) ABSTRACTS I sent to you. +You need to return a .csv, with each line containing: +1) First column, the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. +2) Second column, a small TITLE summarizing the NEW ABSTRACT; +3) Third column, a comma-separated list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. +4) For column, the concatenated ABSTRACTS I sent to you. -The csv output format should contain 4 columns. Use the separator '@@@@' for the csv. Don't include a header in the csv. +The csv output format should contain only those 4 columns. + +Use the separator '@@@@' for the csv. + +Don't include a header in the csv. Only return those four points, one after another, nothing more, nothing less. From 3a45e0c299fd8af4d7cbd0a8491dba229ef25ebb Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:23:24 +0100 Subject: [PATCH 181/300] Moves to a csv format --- .../augmentation/augment_openai.py | 10 +++++----- .../augmentation/prompt.template | 18 +++++++----------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index eea5f3d1..c48d6c2f 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -75,11 +75,11 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera if 'message' in r: if 'content' in r['message']: try: - pieces = r['message']['content'].split('@@@@') - a = pieces[0] - t = pieces[1] - tl = pieces[2] - i = pieces[3] + pieces = json.loads(r['message']['content']) + a = pieces['abstract'] + t = pieces['tags'] + tl = pieces['title'] + i = pieces['inspiration'] print("YIELD!!!!!!!!!!!!!!") yield {'abstract': a, 'tags': t, 'title': tl, 'inspiration': i} except Exception as e: diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index dd9cf546..8b577737 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -3,19 +3,15 @@ You are in charge of doing Data Augmentation. I will provide: 2) a series of ABSTRACTS to use as a source for augmentation; 3) a series of MESH TAGS. -You need to return a .csv, with each line containing: -1) First column, the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. -2) Second column, a small TITLE summarizing the NEW ABSTRACT; -3) Third column, a comma-separated list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. -4) For column, the concatenated ABSTRACTS I sent to you. +You need to return a json file containing: +1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. +2) Field "title": a small TITLE summarizing the NEW ABSTRACT; +3) Field "tags": a comma-separated list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. +4) Field "inspiration", a list of the ABSTRACTS I sent to you, one after another. -The csv output format should contain only those 4 columns. +Don't include anything else besides those 4 fields in the json. +Make sure the in the values of the json, there are no quotes. If there are, escape them so there is no problem parsing the json. -Use the separator '@@@@' for the csv. - -Don't include a header in the csv. - -Only return those four points, one after another, nothing more, nothing less. REQUIRED MESH TAG: {FEATURED_TAG} From 2dbbdcda71279a40bb00f8e79358367fe12bc74c Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:31:35 +0100 Subject: [PATCH 182/300] Json asking to escape quotes --- grants_tagger_light/augmentation/prompt.template | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 8b577737..5baaed49 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -4,14 +4,13 @@ You are in charge of doing Data Augmentation. I will provide: 3) a series of MESH TAGS. You need to return a json file containing: -1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. -2) Field "title": a small TITLE summarizing the NEW ABSTRACT; -3) Field "tags": a comma-separated list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. -4) Field "inspiration", a list of the ABSTRACTS I sent to you, one after another. +1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Make sure the generated text has all the quotes escaped except the initial and final ones. +2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Make sure the generated title has all the quotes escaped except the initial and final ones. +3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Make sure each tag have all the quotes escaped except the initial and final ones. +4) Field "inspiration", a list of the ABSTRACTS I sent to you, one after another. Make sure each element of the list has all the quotes escaped except the initial and final ones. Don't include anything else besides those 4 fields in the json. -Make sure the in the values of the json, there are no quotes. If there are, escape them so there is no problem parsing the json. - +Make sure the in the values of the json, all the quotes are escaped excepted the initial and final ones. REQUIRED MESH TAG: {FEATURED_TAG} From 3170fc9de267b3ad0d1c05eff503050236cddbd1 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:34:07 +0100 Subject: [PATCH 183/300] Json asking to escape quotes --- grants_tagger_light/augmentation/prompt.template | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 5baaed49..a0189fa5 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -4,13 +4,13 @@ You are in charge of doing Data Augmentation. I will provide: 3) a series of MESH TAGS. You need to return a json file containing: -1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Make sure the generated text has all the quotes escaped except the initial and final ones. -2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Make sure the generated title has all the quotes escaped except the initial and final ones. -3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Make sure each tag have all the quotes escaped except the initial and final ones. -4) Field "inspiration", a list of the ABSTRACTS I sent to you, one after another. Make sure each element of the list has all the quotes escaped except the initial and final ones. +1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Make sure the generated text has all the quotes escaped (with one \) except the initial and final ones. +2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Make sure the generated title has all the quotes escaped (with one \) except the initial and final ones. +3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Make sure each tag have all the quotes escaped (with one \) except the initial and final ones. +4) Field "inspiration", a list of the ABSTRACTS I sent to you, one after another. Make sure each element of the list has all the quotes escaped (with one \) except the initial and final ones. Don't include anything else besides those 4 fields in the json. -Make sure the in the values of the json, all the quotes are escaped excepted the initial and final ones. +Make sure the in the values of the json, all the quotes are escaped with one \, excepted the initial and final ones, which should not be escaped. REQUIRED MESH TAG: {FEATURED_TAG} From 841477ef6c51737ecba4d6e6eeddc710bf8ccfa6 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:36:50 +0100 Subject: [PATCH 184/300] Json asking to escape quotes --- grants_tagger_light/augmentation/prompt.template | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index a0189fa5..ea95c7d7 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -4,10 +4,10 @@ You are in charge of doing Data Augmentation. I will provide: 3) a series of MESH TAGS. You need to return a json file containing: -1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Make sure the generated text has all the quotes escaped (with one \) except the initial and final ones. -2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Make sure the generated title has all the quotes escaped (with one \) except the initial and final ones. -3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Make sure each tag have all the quotes escaped (with one \) except the initial and final ones. -4) Field "inspiration", a list of the ABSTRACTS I sent to you, one after another. Make sure each element of the list has all the quotes escaped (with one \) except the initial and final ones. +1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Remove all quotes from the generated NEW ABSTRACT. +2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. +3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Remove all quotes from the generated TAGS. +4) Field "inspiration", a list of the ABSTRACTS I sent to you, one after another. Remove all quotes from the ABSTRACTS. Don't include anything else besides those 4 fields in the json. Make sure the in the values of the json, all the quotes are escaped with one \, excepted the initial and final ones, which should not be escaped. From e6abb7c8cf37933999b556260992f26fde904277 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:39:12 +0100 Subject: [PATCH 185/300] Json asking to escape quotes --- grants_tagger_light/augmentation/prompt.template | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index ea95c7d7..7adc5af1 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -7,10 +7,9 @@ You need to return a json file containing: 1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Remove all quotes from the generated NEW ABSTRACT. 2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. 3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Remove all quotes from the generated TAGS. -4) Field "inspiration", a list of the ABSTRACTS I sent to you, one after another. Remove all quotes from the ABSTRACTS. +4) Field "inspiration", a text concatenating all the ABSTRACTS I sent to you, one after another. Remove all quotes from the ABSTRACTS. Don't include anything else besides those 4 fields in the json. -Make sure the in the values of the json, all the quotes are escaped with one \, excepted the initial and final ones, which should not be escaped. REQUIRED MESH TAG: {FEATURED_TAG} From 6083b1bd78740862ebe6492c0d235e4f0d1307ed Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:42:40 +0100 Subject: [PATCH 186/300] Json asking to escape quotes --- grants_tagger_light/augmentation/prompt.template | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 7adc5af1..0d35c975 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -7,9 +7,9 @@ You need to return a json file containing: 1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Remove all quotes from the generated NEW ABSTRACT. 2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. 3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Remove all quotes from the generated TAGS. -4) Field "inspiration", a text concatenating all the ABSTRACTS I sent to you, one after another. Remove all quotes from the ABSTRACTS. +4) Field "inspiration": a text concatenating all the ABSTRACTS I sent to you, one after another. Remove all quotes from the ABSTRACTS. -Don't include anything else besides those 4 fields in the json. +Don't include anything else besides the four mentioned fields. "abstract", "title", "tags" and "inspiration". REQUIRED MESH TAG: {FEATURED_TAG} From 6bd9f5384a260e5286038a72eaeecf98395bea52 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:46:04 +0100 Subject: [PATCH 187/300] Json asking to escape quotes --- grants_tagger_light/augmentation/prompt.template | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 0d35c975..d2109529 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -11,6 +11,10 @@ You need to return a json file containing: Don't include anything else besides the four mentioned fields. "abstract", "title", "tags" and "inspiration". +Make sure each tag starts and finishes with a double quote. + +Make sure there is only starting and ending quotes in both keys and values. Remove any other quotes. + REQUIRED MESH TAG: {FEATURED_TAG} ABSTRACTS: From 3555cb576ea26368f8cc7ed319e49d19f89e976f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:49:27 +0100 Subject: [PATCH 188/300] Json asking to escape quotes --- grants_tagger_light/augmentation/prompt.template | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index d2109529..cf3f6a39 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -7,13 +7,13 @@ You need to return a json file containing: 1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Remove all quotes from the generated NEW ABSTRACT. 2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. 3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Remove all quotes from the generated TAGS. -4) Field "inspiration": a text concatenating all the ABSTRACTS I sent to you, one after another. Remove all quotes from the ABSTRACTS. +4) Field "inspiration": a text showing the first abstract in ABSTRACT I sent to you. Remove all quotes from the ABSTRACTS. Don't include anything else besides the four mentioned fields. "abstract", "title", "tags" and "inspiration". Make sure each tag starts and finishes with a double quote. - Make sure there is only starting and ending quotes in both keys and values. Remove any other quotes. +Make sure the json is well formed and can be parsed. REQUIRED MESH TAG: {FEATURED_TAG} From 027593879d805ce44c969e3b6f2ead7c61ef4ed6 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:54:07 +0100 Subject: [PATCH 189/300] Printing what is generating --- grants_tagger_light/augmentation/augment.py | 2 +- grants_tagger_light/augmentation/augment_openai.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 7cabe504..36b656fe 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -40,7 +40,6 @@ def _merge_dicts(dict_list): def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine, train_years, num_proc, model_key): - logger.info(f"Generating missing examples for classes...") counter = 0 with open(save_to_path, 'a') as f: for a in augmentation_engine.generate(collect_concurrent_calls, dset, few_shot_examples=few_shot_examples, @@ -128,6 +127,7 @@ def augment( if len(collect_concurrent_calls) >= concurrent_calls: _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine, train_years, num_proc, model_key) + collect_concurrent_calls = [] else: if tags_to_augment_counts[t] < min_examples: missing = min_examples - tags_to_augment_counts[t] diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index c48d6c2f..4f6c1bf6 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -40,6 +40,7 @@ def _make_requests(self, for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] n = collect_concurrent_calls[num][1] + logger.info(f"Augmenting {t} with {n} examples") # RAG: I select similar articles to provide them to the LLM tmp_dset = dset.filter(lambda x: any(np.isin([t], x["meshMajor"])), num_proc=num_proc) # I remove them from the dataset to process to make it smaller and quicker over time From d788e4ecaf2a2eca327a52d08d7b4cd5f38a50c8 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:56:09 +0100 Subject: [PATCH 190/300] Printing what is generating --- grants_tagger_light/augmentation/augment_openai.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 4f6c1bf6..39fd0bb2 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -81,14 +81,11 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera t = pieces['tags'] tl = pieces['title'] i = pieces['inspiration'] - print("YIELD!!!!!!!!!!!!!!") yield {'abstract': a, 'tags': t, 'title': tl, 'inspiration': i} except Exception as e: with open(f'{save_to_path}.err', 'w') as f: f.write(r['message']['content']) f.write("\n") - print("ERROR!!!!!!!!!!!!!!") - exit(-1) - logger.info("OpenAI did not return a proper json format.") + logger.info("OpenAI did not return a proper json format...") yield None From 345e3883d56c77aff9ac87831913a85b3d0b7129 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 20:59:37 +0100 Subject: [PATCH 191/300] Printing what is generating --- grants_tagger_light/augmentation/augment.py | 3 ++- grants_tagger_light/augmentation/augment_openai.py | 8 +++++++- grants_tagger_light/augmentation/prompt.template | 3 ++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 36b656fe..dd6bf67b 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -53,7 +53,8 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, "abstractText": a['abstract'], "pmid": uuid.uuid4().hex, "title": a['title'], - "inspiration": a['inspiration'] + "inspiration": a['inspiration'], + "all_inspiration_tags": a['all_inspiration_tags'] })) f.write('\n') f.flush() diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 39fd0bb2..63892a88 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -81,7 +81,13 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera t = pieces['tags'] tl = pieces['title'] i = pieces['inspiration'] - yield {'abstract': a, 'tags': t, 'title': tl, 'inspiration': i} + ait = pieces['all_inspiration_tags'] + yield {'abstract': a, + 'tags': t, + 'title': tl, + 'inspiration': i, + 'all_inspiration_tags': ait + } except Exception as e: with open(f'{save_to_path}.err', 'w') as f: f.write(r['message']['content']) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index cf3f6a39..cedf704a 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -8,8 +8,9 @@ You need to return a json file containing: 2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. 3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Remove all quotes from the generated TAGS. 4) Field "inspiration": a text showing the first abstract in ABSTRACT I sent to you. Remove all quotes from the ABSTRACTS. +5) Field "all_inspiration_tags": the unique MESH TAGS I sent you. Remove all quotes from the tags. -Don't include anything else besides the four mentioned fields. "abstract", "title", "tags" and "inspiration". +Don't include anything else besides the four mentioned fields. "abstract", "title", "tags", "inspiration" and "inspiration_tags". Make sure each tag starts and finishes with a double quote. Make sure there is only starting and ending quotes in both keys and values. Remove any other quotes. From 1669cceb5dae90fdce977b8f9e46b6d925170164 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 21:05:47 +0100 Subject: [PATCH 192/300] Printing what is generating --- grants_tagger_light/augmentation/augment.py | 2 +- grants_tagger_light/augmentation/augment_openai.py | 4 ++-- grants_tagger_light/augmentation/prompt.template | 9 +++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index dd6bf67b..8ea7b71e 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -53,7 +53,7 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, "abstractText": a['abstract'], "pmid": uuid.uuid4().hex, "title": a['title'], - "inspiration": a['inspiration'], + "inspiration_example": a['inspiration_example'], "all_inspiration_tags": a['all_inspiration_tags'] })) f.write('\n') diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 63892a88..c9cf804c 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -80,12 +80,12 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera a = pieces['abstract'] t = pieces['tags'] tl = pieces['title'] - i = pieces['inspiration'] + i = pieces['inspiration_example'] ait = pieces['all_inspiration_tags'] yield {'abstract': a, 'tags': t, 'title': tl, - 'inspiration': i, + 'inspiration_example': i, 'all_inspiration_tags': ait } except Exception as e: diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index cedf704a..065fb95a 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -6,15 +6,16 @@ You are in charge of doing Data Augmentation. I will provide: You need to return a json file containing: 1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Remove all quotes from the generated NEW ABSTRACT. 2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. -3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Remove all quotes from the generated TAGS. -4) Field "inspiration": a text showing the first abstract in ABSTRACT I sent to you. Remove all quotes from the ABSTRACTS. -5) Field "all_inspiration_tags": the unique MESH TAGS I sent you. Remove all quotes from the tags. +3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Don't create new tags not contained in MESH TAGS. Don't modify them, write exactly as they are in MESH TAGS. Remove all quotes from the generated TAGS. +4) Field "inspiration_example": a text showing the first abstract in ABSTRACT I sent to you. Remove all quotes from the ABSTRACTS. +5) Field "all_inspiration_tags": the MESH TAGS I sent you. Remove all quotes from the tags. -Don't include anything else besides the four mentioned fields. "abstract", "title", "tags", "inspiration" and "inspiration_tags". +Don't include anything else besides the four mentioned fields. "abstract", "title", "tags", "inspiration_example" and "inspiration_tags". Make sure each tag starts and finishes with a double quote. Make sure there is only starting and ending quotes in both keys and values. Remove any other quotes. Make sure the json is well formed and can be parsed. +Make sure each tag in field "tags" exists in MESH TAGS. Otherwise, remove it. REQUIRED MESH TAG: {FEATURED_TAG} From 38cfa580fefa47af89e1cc13f206abf5729a3bcd Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 21:07:53 +0100 Subject: [PATCH 193/300] Printing what is generating --- grants_tagger_light/augmentation/augment_openai.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index c9cf804c..2e19bbc6 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -90,7 +90,11 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera } except Exception as e: with open(f'{save_to_path}.err', 'w') as f: - f.write(r['message']['content']) + err = {'tags': [x[0] for x in collect_concurrent_calls], + 'tags_missing_examples': [x[1] for x in collect_concurrent_calls], + 'response_from_llm': r['message']['content'] + } + f.write(json.dumps(err)) f.write("\n") logger.info("OpenAI did not return a proper json format...") yield None From 50e1cfbed26451f609ad7e40a3b9d2a1f24d30c5 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 21:13:12 +0100 Subject: [PATCH 194/300] Printing what is generating --- grants_tagger_light/augmentation/augment_openai.py | 7 ------- grants_tagger_light/augmentation/prompt.template | 8 ++++---- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 2e19bbc6..3f13c1e6 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -89,13 +89,6 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera 'all_inspiration_tags': ait } except Exception as e: - with open(f'{save_to_path}.err', 'w') as f: - err = {'tags': [x[0] for x in collect_concurrent_calls], - 'tags_missing_examples': [x[1] for x in collect_concurrent_calls], - 'response_from_llm': r['message']['content'] - } - f.write(json.dumps(err)) - f.write("\n") logger.info("OpenAI did not return a proper json format...") yield None diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 065fb95a..4dc9812b 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -6,16 +6,16 @@ You are in charge of doing Data Augmentation. I will provide: You need to return a json file containing: 1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Remove all quotes from the generated NEW ABSTRACT. 2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. -3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS included before. Don't create new tags not contained in MESH TAGS. Don't modify them, write exactly as they are in MESH TAGS. Remove all quotes from the generated TAGS. +3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS I've sent you. Don't create new tags not contained in MESH TAGS. Don't modify them, write exactly as they are in MESH TAGS. Remove all quotes from the generated TAGS. 4) Field "inspiration_example": a text showing the first abstract in ABSTRACT I sent to you. Remove all quotes from the ABSTRACTS. -5) Field "all_inspiration_tags": the MESH TAGS I sent you. Remove all quotes from the tags. +5) Field "all_inspiration_tags": the exact list of MESH TAGS I sent you. Remove all quotes from the tags. -Don't include anything else besides the four mentioned fields. "abstract", "title", "tags", "inspiration_example" and "inspiration_tags". +Don't include anything else besides the four mentioned fields. "abstract", "title", "tags", "inspiration_example" and "all_inspiration_tags". Make sure each tag starts and finishes with a double quote. Make sure there is only starting and ending quotes in both keys and values. Remove any other quotes. Make sure the json is well formed and can be parsed. -Make sure each tag in field "tags" exists in MESH TAGS. Otherwise, remove it. +Make sure each tag in field "tags" exists in "all_inspiration_tags". Remove those which do not exist. REQUIRED MESH TAG: {FEATURED_TAG} From c607037150cd44c7306f83f2c4ce80fb93c323c6 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 21:19:33 +0100 Subject: [PATCH 195/300] Printing what is generating --- grants_tagger_light/augmentation/prompt.template | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 4dc9812b..261a8942 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -8,14 +8,14 @@ You need to return a json file containing: 2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. 3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS I've sent you. Don't create new tags not contained in MESH TAGS. Don't modify them, write exactly as they are in MESH TAGS. Remove all quotes from the generated TAGS. 4) Field "inspiration_example": a text showing the first abstract in ABSTRACT I sent to you. Remove all quotes from the ABSTRACTS. -5) Field "all_inspiration_tags": the exact list of MESH TAGS I sent you. Remove all quotes from the tags. +5) Field "all_inspiration_tags": the MESH TAGS Don't include anything else besides the four mentioned fields. "abstract", "title", "tags", "inspiration_example" and "all_inspiration_tags". Make sure each tag starts and finishes with a double quote. Make sure there is only starting and ending quotes in both keys and values. Remove any other quotes. Make sure the json is well formed and can be parsed. -Make sure each tag in field "tags" exists in "all_inspiration_tags". Remove those which do not exist. +Make sure each tag in field the list of "tags" is contained in the list of "all_inspiration_tags". REQUIRED MESH TAG: {FEATURED_TAG} From ac612b6642cf398687c2382e5af833606e1679f2 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 16 Aug 2023 21:27:09 +0100 Subject: [PATCH 196/300] Printing what is generating --- grants_tagger_light/augmentation/prompt.template | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 261a8942..f481065b 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -6,9 +6,9 @@ You are in charge of doing Data Augmentation. I will provide: You need to return a json file containing: 1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Remove all quotes from the generated NEW ABSTRACT. 2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. -3) Field "tags": a list of TAGS of your NEW ABSTRACT, where those TAGS should be part of any of the MESH TAGS I've sent you. Don't create new tags not contained in MESH TAGS. Don't modify them, write exactly as they are in MESH TAGS. Remove all quotes from the generated TAGS. -4) Field "inspiration_example": a text showing the first abstract in ABSTRACT I sent to you. Remove all quotes from the ABSTRACTS. -5) Field "all_inspiration_tags": the MESH TAGS +3) Field "tags": all the MESH TAGS which are relevant to the "abstract" you generated. Remove those not relevant. Don't create new tags not contained in MESH TAGS. Don't modify them, write exactly as they are in MESH TAGS. Remove all quotes from the generated tags. +4) Field "inspiration_example": the first abstract in ABSTRACT I sent to you. Remove all quotes from the ABSTRACTS. +5) Field "all_inspiration_tags": a list of first, the REQUIRED MESH TAG, and then, all the rest of MESH TAGS. Don't include anything else besides the four mentioned fields. "abstract", "title", "tags", "inspiration_example" and "all_inspiration_tags". From ca6643c434a8b12f6369e3ed6053ff87e0e97c0d Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 08:57:00 +0100 Subject: [PATCH 197/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/train.py | 5 ++++- scripts/resume_train_by_epoch.sh | 1 + scripts/resume_train_by_steps.sh | 1 + scripts/train_by_epochs.sh | 1 + scripts/train_by_steps.sh | 1 + 5 files changed, 8 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 95c0089a..702c7cad 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -27,6 +27,7 @@ from grants_tagger_light.utils.sharding import Sharding from grants_tagger_light.utils.years_tags_parser import parse_years, parse_tags +from transformers import get_cosine_schedule_with_warmup transformers.set_seed(42) @@ -157,7 +158,9 @@ def sklearn_metrics(prediction: EvalPrediction): optimizer = AdamW(model.parameters(), lr=training_args.learning_rate) # Set your desired learning rate # Create a learning rate scheduler - scheduler = get_constant_scheduler(optimizer, num_warmup_steps=0, num_training_steps=training_args.max_steps) + scheduler = get_cosine_schedule_with_warmup(optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=training_args.max_steps) trainer = Trainer( model=model, diff --git a/scripts/resume_train_by_epoch.sh b/scripts/resume_train_by_epoch.sh index 4556518f..a5eb9cc6 100644 --- a/scripts/resume_train_by_epoch.sh +++ b/scripts/resume_train_by_epoch.sh @@ -23,6 +23,7 @@ grants-tagger train bertmesh \ --dropout 0.1 \ --hidden_size 1024 \ --warmup_steps 1000 \ + --max_grad_norm 5.0 \ --fp16 \ --torch_compile \ --evaluation_strategy epoch \ diff --git a/scripts/resume_train_by_steps.sh b/scripts/resume_train_by_steps.sh index a27de4dc..456fa7d1 100644 --- a/scripts/resume_train_by_steps.sh +++ b/scripts/resume_train_by_steps.sh @@ -23,6 +23,7 @@ grants-tagger train bertmesh \ --dropout 0.1 \ --hidden_size 1024 \ --warmup_steps 1000 \ + --max_grad_norm 5.0 \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ diff --git a/scripts/train_by_epochs.sh b/scripts/train_by_epochs.sh index ee64da00..8d3102ad 100644 --- a/scripts/train_by_epochs.sh +++ b/scripts/train_by_epochs.sh @@ -23,6 +23,7 @@ grants-tagger train bertmesh \ --dropout 0.1 \ --hidden_size 1024 \ --warmup_steps 1000 \ + --max_grad_norm 5.0 \ --fp16 \ --torch_compile \ --evaluation_strategy epoch \ diff --git a/scripts/train_by_steps.sh b/scripts/train_by_steps.sh index bdf62720..3bdab802 100644 --- a/scripts/train_by_steps.sh +++ b/scripts/train_by_steps.sh @@ -23,6 +23,7 @@ grants-tagger train bertmesh \ --dropout 0.1 \ --hidden_size 1024 \ --warmup_steps 1000 \ + --max_grad_norm 5.0 \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ From 5636545fd146387c5d0e3d88e47adccf453be490 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 22:40:06 +0100 Subject: [PATCH 198/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 702c7cad..268d6fe4 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -4,6 +4,8 @@ EvalPrediction, HfArgumentParser, AutoConfig, + AdamW, + get_cosine_schedule_with_warmup ) from grants_tagger_light.models.bert_mesh import BertMesh from grants_tagger_light.preprocessing.preprocess_mesh import preprocess_mesh @@ -27,7 +29,6 @@ from grants_tagger_light.utils.sharding import Sharding from grants_tagger_light.utils.years_tags_parser import parse_years, parse_tags -from transformers import get_cosine_schedule_with_warmup transformers.set_seed(42) From 24075466839adbedbd2093dd087003b326b91514 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 22:52:02 +0100 Subject: [PATCH 199/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/custom_trainer.py | 18 ++++++++++++++++++ grants_tagger_light/training/train.py | 16 ++++------------ 2 files changed, 22 insertions(+), 12 deletions(-) create mode 100644 grants_tagger_light/training/custom_trainer.py diff --git a/grants_tagger_light/training/custom_trainer.py b/grants_tagger_light/training/custom_trainer.py new file mode 100644 index 00000000..4f85ebe7 --- /dev/null +++ b/grants_tagger_light/training/custom_trainer.py @@ -0,0 +1,18 @@ +from transformers import Trainer, AdamW, get_cosine_schedule_with_warmup + + +class CustomTrainer(Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def create_optimizer_and_scheduler(self, num_training_steps): + no_decay = ["bias", "LayerNorm.weight"] + + # Instantiate the AdamW optimizer with a constant learning rate + self.optimizer = AdamW(self.model.parameters(), + lr=self.args.learning_rate) # Set your desired learning rate + + # Create a learning rate scheduler + self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, + num_warmup_steps=self.args.warmup_steps, + num_training_steps=self.args.max_steps) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 268d6fe4..980c6531 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -3,9 +3,7 @@ TrainingArguments, EvalPrediction, HfArgumentParser, - AutoConfig, - AdamW, - get_cosine_schedule_with_warmup + AutoConfig ) from grants_tagger_light.models.bert_mesh import BertMesh from grants_tagger_light.preprocessing.preprocess_mesh import preprocess_mesh @@ -30,6 +28,8 @@ from grants_tagger_light.utils.sharding import Sharding from grants_tagger_light.utils.years_tags_parser import parse_years, parse_tags +from custom_trainer import CustomTrainer + transformers.set_seed(42) @@ -155,15 +155,7 @@ def sklearn_metrics(prediction: EvalPrediction): max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) training_args.max_steps = max_steps - # Instantiate the AdamW optimizer with a constant learning rate - optimizer = AdamW(model.parameters(), lr=training_args.learning_rate) # Set your desired learning rate - - # Create a learning rate scheduler - scheduler = get_cosine_schedule_with_warmup(optimizer, - num_warmup_steps=training_args.warmup_steps, - num_training_steps=training_args.max_steps) - - trainer = Trainer( + trainer = CustomTrainer( model=model, args=training_args, train_dataset=train_dset, From b89e19356cba3a77935897d10e7a250885675e65 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 22:52:42 +0100 Subject: [PATCH 200/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 980c6531..f2127ef7 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -28,7 +28,7 @@ from grants_tagger_light.utils.sharding import Sharding from grants_tagger_light.utils.years_tags_parser import parse_years, parse_tags -from custom_trainer import CustomTrainer +from grants_tagger_light.custom_trainer import CustomTrainer transformers.set_seed(42) From f2e147d31940f8b2620c07fe059362f52bc12f62 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 22:53:29 +0100 Subject: [PATCH 201/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index f2127ef7..a54efaa3 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -28,7 +28,7 @@ from grants_tagger_light.utils.sharding import Sharding from grants_tagger_light.utils.years_tags_parser import parse_years, parse_tags -from grants_tagger_light.custom_trainer import CustomTrainer +from grants_tagger_light.training.custom_trainer import CustomTrainer transformers.set_seed(42) From c759962605634f9bdadcbb5889797faf2519081b Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 22:54:33 +0100 Subject: [PATCH 202/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index a54efaa3..fbc6e1f4 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -161,9 +161,7 @@ def sklearn_metrics(prediction: EvalPrediction): train_dataset=train_dset, eval_dataset=val_dset, data_collator=collator, - compute_metrics=sklearn_metrics, - optimizer=optimizer, - scheduler=scheduler, + compute_metrics=sklearn_metrics ) logger.info(training_args) From 549b8178bdb1370d3f65b0e1bef562e7c3acefb6 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 22:55:41 +0100 Subject: [PATCH 203/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/custom_trainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/training/custom_trainer.py b/grants_tagger_light/training/custom_trainer.py index 4f85ebe7..e831472a 100644 --- a/grants_tagger_light/training/custom_trainer.py +++ b/grants_tagger_light/training/custom_trainer.py @@ -1,18 +1,17 @@ from transformers import Trainer, AdamW, get_cosine_schedule_with_warmup - +from loguru import logger class CustomTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def create_optimizer_and_scheduler(self, num_training_steps): - no_decay = ["bias", "LayerNorm.weight"] - # Instantiate the AdamW optimizer with a constant learning rate self.optimizer = AdamW(self.model.parameters(), lr=self.args.learning_rate) # Set your desired learning rate - + logger.info(f"Optimizer: {self.optimizer}") # Create a learning rate scheduler self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.args.max_steps) + logger.info(f"Scheduler: {self.lr_scheduler}") From 996d1677852eb0d6e84161c70414652a4ada7d2d Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 23:01:14 +0100 Subject: [PATCH 204/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/train.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index fbc6e1f4..1809461d 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -155,13 +155,22 @@ def sklearn_metrics(prediction: EvalPrediction): max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) training_args.max_steps = max_steps - trainer = CustomTrainer( + optimizer = AdamW(model.parameters(), + lr=training_args.learning_rate) # Set your desired learning rate + # Create a learning rate scheduler + scheduler = get_cosine_schedule_with_warmup(optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=training_args.max_steps) + logger.info(f"Scheduler: {self.lr_scheduler}") + + trainer = Trainer( model=model, args=training_args, train_dataset=train_dset, eval_dataset=val_dset, data_collator=collator, - compute_metrics=sklearn_metrics + compute_metrics=sklearn_metrics, + optimizers=(optimizer, scheduler) ) logger.info(training_args) From 810910575f137b587a6138158735473f7d98fdb5 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 23:01:56 +0100 Subject: [PATCH 205/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/custom_trainer.py | 1 + grants_tagger_light/training/train.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/custom_trainer.py b/grants_tagger_light/training/custom_trainer.py index e831472a..cefa456c 100644 --- a/grants_tagger_light/training/custom_trainer.py +++ b/grants_tagger_light/training/custom_trainer.py @@ -1,6 +1,7 @@ from transformers import Trainer, AdamW, get_cosine_schedule_with_warmup from loguru import logger + class CustomTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 1809461d..4d15a6ca 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -3,7 +3,9 @@ TrainingArguments, EvalPrediction, HfArgumentParser, - AutoConfig + AutoConfig, + AdamW, + et_cosine_schedule_with_warmup ) from grants_tagger_light.models.bert_mesh import BertMesh from grants_tagger_light.preprocessing.preprocess_mesh import preprocess_mesh From f1a5370d0b14ad89732c04f02bacfd3dff44482e Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 23:02:22 +0100 Subject: [PATCH 206/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 4d15a6ca..1451cf0d 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -5,7 +5,7 @@ HfArgumentParser, AutoConfig, AdamW, - et_cosine_schedule_with_warmup + get_cosine_schedule_with_warmup ) from grants_tagger_light.models.bert_mesh import BertMesh from grants_tagger_light.preprocessing.preprocess_mesh import preprocess_mesh From c2bc2abb6c6a37d54471c47f9dcde9c7662314f8 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 23:03:10 +0100 Subject: [PATCH 207/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 1451cf0d..71d6c109 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -163,7 +163,6 @@ def sklearn_metrics(prediction: EvalPrediction): scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps) - logger.info(f"Scheduler: {self.lr_scheduler}") trainer = Trainer( model=model, From c7b631aac156fe00f429794d18239916fa3169bb Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 17 Aug 2023 23:06:57 +0100 Subject: [PATCH 208/300] Adds gradient clipping and cosine optimizer --- grants_tagger_light/training/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 71d6c109..69a006a6 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -164,6 +164,9 @@ def sklearn_metrics(prediction: EvalPrediction): num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps) + training_args.optim = optimizer + training_args.lr_scheduler_type = scheduler + trainer = Trainer( model=model, args=training_args, @@ -171,7 +174,8 @@ def sklearn_metrics(prediction: EvalPrediction): eval_dataset=val_dset, data_collator=collator, compute_metrics=sklearn_metrics, - optimizers=(optimizer, scheduler) + optimizers=(optimizer, scheduler), + ) logger.info(training_args) From e8e85916074a42308550b19ce6a6c76a32df7bb9 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 18 Aug 2023 08:42:36 +0100 Subject: [PATCH 209/300] Freezing bias --- grants_tagger_light/models/bert_mesh/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index cdff897a..3df0b405 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -63,8 +63,10 @@ def freeze_backbone(self): param.requires_grad = False def unfreeze_backbone(self): - for param in self.bert.parameters(): - param.requires_grad = True + for name, param in self.bert.named_parameters(): + if 'bias' in name.lower(): + logger.info(f"Freezing {name}") + param.requires_grad = True def forward(self, input_ids, labels=None, **kwargs): if type(input_ids) is list: From 00f71d6cdb13cfd11d72ccb295cd80feeae34e44 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 18 Aug 2023 08:46:08 +0100 Subject: [PATCH 210/300] Freezing bias --- grants_tagger_light/models/bert_mesh/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index 3df0b405..a64bf904 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -65,8 +65,10 @@ def freeze_backbone(self): def unfreeze_backbone(self): for name, param in self.bert.named_parameters(): if 'bias' in name.lower(): - logger.info(f"Freezing {name}") param.requires_grad = True + else: + param.requires_grad = False + logger.info(f"Unfreezing {name}") def forward(self, input_ids, labels=None, **kwargs): if type(input_ids) is list: From 21c9a662ff6e8c3c52915ba6dd30792c23cfcd41 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 18 Aug 2023 08:46:46 +0100 Subject: [PATCH 211/300] Freezing bias --- grants_tagger_light/models/bert_mesh/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index a64bf904..b05cb458 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -66,9 +66,9 @@ def unfreeze_backbone(self): for name, param in self.bert.named_parameters(): if 'bias' in name.lower(): param.requires_grad = True + logger.info(f"Unfreezing {name}") else: param.requires_grad = False - logger.info(f"Unfreezing {name}") def forward(self, input_ids, labels=None, **kwargs): if type(input_ids) is list: From e3f8d9edea8b697696e64da1af280602ec2b162d Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 11:22:47 +0100 Subject: [PATCH 212/300] Removes cosine scheduler --- grants_tagger_light/training/custom_trainer.py | 18 ------------------ grants_tagger_light/training/train.py | 13 ++++++------- 2 files changed, 6 insertions(+), 25 deletions(-) delete mode 100644 grants_tagger_light/training/custom_trainer.py diff --git a/grants_tagger_light/training/custom_trainer.py b/grants_tagger_light/training/custom_trainer.py deleted file mode 100644 index cefa456c..00000000 --- a/grants_tagger_light/training/custom_trainer.py +++ /dev/null @@ -1,18 +0,0 @@ -from transformers import Trainer, AdamW, get_cosine_schedule_with_warmup -from loguru import logger - - -class CustomTrainer(Trainer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def create_optimizer_and_scheduler(self, num_training_steps): - # Instantiate the AdamW optimizer with a constant learning rate - self.optimizer = AdamW(self.model.parameters(), - lr=self.args.learning_rate) # Set your desired learning rate - logger.info(f"Optimizer: {self.optimizer}") - # Create a learning rate scheduler - self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, - num_warmup_steps=self.args.warmup_steps, - num_training_steps=self.args.max_steps) - logger.info(f"Scheduler: {self.lr_scheduler}") diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 69a006a6..5b120c3c 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -30,8 +30,6 @@ from grants_tagger_light.utils.sharding import Sharding from grants_tagger_light.utils.years_tags_parser import parse_years, parse_tags -from grants_tagger_light.training.custom_trainer import CustomTrainer - transformers.set_seed(42) @@ -159,13 +157,14 @@ def sklearn_metrics(prediction: EvalPrediction): optimizer = AdamW(model.parameters(), lr=training_args.learning_rate) # Set your desired learning rate + # Create a learning rate scheduler - scheduler = get_cosine_schedule_with_warmup(optimizer, - num_warmup_steps=training_args.warmup_steps, - num_training_steps=training_args.max_steps) + #scheduler = get_cosine_schedule_with_warmup(optimizer, + # num_warmup_steps=training_args.warmup_steps, + # num_training_steps=training_args.max_steps) training_args.optim = optimizer - training_args.lr_scheduler_type = scheduler + # training_args.lr_scheduler_type = scheduler trainer = Trainer( model=model, @@ -174,7 +173,7 @@ def sklearn_metrics(prediction: EvalPrediction): eval_dataset=val_dset, data_collator=collator, compute_metrics=sklearn_metrics, - optimizers=(optimizer, scheduler), + optimizers=(optimizer, None), ) logger.info(training_args) From a04193528d9ab4e6e714435c40e783cd240ef950 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 11:25:11 +0100 Subject: [PATCH 213/300] Removes cosine scheduler --- grants_tagger_light/preprocessing/preprocess_mesh.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 344faf63..81492f9b 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -68,7 +68,6 @@ def preprocess_mesh( if not model_key: label2id = None - # Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh tokenizer = AutoTokenizer.from_pretrained( "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" ) From 1ed906068431f5790cb94dcb219739fbcfe608e0 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 11:57:10 +0100 Subject: [PATCH 214/300] Removes cosine scheduler --- grants_tagger_light/augmentation/augment.py | 6 ++-- .../augmentation/augment_openai.py | 18 ++++++++---- .../augmentation/prompt.template | 29 ++++++++++--------- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 8ea7b71e..c0b48249 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -46,6 +46,7 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, num_proc=num_proc, save_to_path=save_to_path): if a is None: break + f.write(json.dumps({ "journal": model_key, "meshMajor": a['tags'], @@ -53,8 +54,9 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, "abstractText": a['abstract'], "pmid": uuid.uuid4().hex, "title": a['title'], - "inspiration_example": a['inspiration_example'], - "all_inspiration_tags": a['all_inspiration_tags'] + "inspiration_examples": a['inspiration_examples'], + "all_inspiration_tags": a['all_inspiration_tags'], + "required_examples": a['required_examples'] })) f.write('\n') f.flush() diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 3f13c1e6..7f472233 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -56,7 +56,11 @@ def _make_requests(self, "top_p": top_p, "presence_penalty": presence_penalty, "messages": self._create_message(t, tmp_dset) - }, metadata={'num': num}) + }, metadata={ + 'all_inspiration_tags': t, + 'required_examples': n, + 'existing_examples_used': size + }) def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count(), save_to_path='err'): @@ -71,7 +75,10 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera num_proc=num_proc) for result in self.api: - num = result.metadata['num'] + print(result) + ait = result.metadata['all_inspiration_tags'] + ex = result.metadata['existing_examples_used'] + req = result.metadata['required_examples'] for r in result.response['choices']: if 'message' in r: if 'content' in r['message']: @@ -80,13 +87,12 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera a = pieces['abstract'] t = pieces['tags'] tl = pieces['title'] - i = pieces['inspiration_example'] - ait = pieces['all_inspiration_tags'] yield {'abstract': a, 'tags': t, 'title': tl, - 'inspiration_example': i, - 'all_inspiration_tags': ait + 'inspiration_examples': ex, + 'all_inspiration_tags': ait, + 'required_examples': req } except Exception as e: logger.info("OpenAI did not return a proper json format...") diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index f481065b..77a66e5e 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -1,23 +1,24 @@ You are in charge of doing Data Augmentation. I will provide: -1) a REQUIRED MESH TAG; -2) a series of ABSTRACTS to use as a source for augmentation; -3) a series of MESH TAGS. +- ABSTRACTS to use as a source for augmentation; +- a REQUIRED MESH TAG, the main label to use as a topic for augmentation; +- a series of MESH TAGS, secondary labels of the ABSTRACTS. You need to return a json file containing: -1) Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Remove all quotes from the generated NEW ABSTRACT. -2) Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. -3) Field "tags": all the MESH TAGS which are relevant to the "abstract" you generated. Remove those not relevant. Don't create new tags not contained in MESH TAGS. Don't modify them, write exactly as they are in MESH TAGS. Remove all quotes from the generated tags. -4) Field "inspiration_example": the first abstract in ABSTRACT I sent to you. Remove all quotes from the ABSTRACTS. -5) Field "all_inspiration_tags": a list of first, the REQUIRED MESH TAG, and then, all the rest of MESH TAGS. +- Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Remove all quotes from the generated NEW ABSTRACT. +- Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. +- Field "tags": all the MESH TAGS which are relevant to the "abstract" you generated, taken from MESH TAGS. Remove those not relevant. Don't create new tags not contained in MESH TAGS. Don't modify them, write exactly as they are in MESH TAGS. Remove all quotes from the generated tags. -Don't include anything else besides the four mentioned fields. "abstract", "title", "tags", "inspiration_example" and "all_inspiration_tags". +There are several conditions for the output you will produce: +- Return a well-formed json only including the fields "abstract", "title", "tags" and "tags". +- Make sure each tag starts and finishes with a double quote. +- Make sure there is only starting and ending quotes in both keys and values. Remove any other quotes. +- Make sure the json is well formed and can be parsed. +- Make sure each tag in "tags" is contained in MESH TAGS. -Make sure each tag starts and finishes with a double quote. -Make sure there is only starting and ending quotes in both keys and values. Remove any other quotes. -Make sure the json is well formed and can be parsed. -Make sure each tag in field the list of "tags" is contained in the list of "all_inspiration_tags". +======================= -REQUIRED MESH TAG: {FEATURED_TAG} +REQUIRED MESH TAG: +{FEATURED_TAG} ABSTRACTS: {ABSTRACTS} From ae4386b229ebd0940697dde152b4d063e665bf51 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 12:07:41 +0100 Subject: [PATCH 215/300] Sending metadata --- grants_tagger_light/augmentation/augment_openai.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 7f472233..f6a2d2a4 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -59,7 +59,8 @@ def _make_requests(self, }, metadata={ 'all_inspiration_tags': t, 'required_examples': n, - 'existing_examples_used': size + 'existing_examples_used': size, + 'first_existing_example': tmp_dset['abstractText'][0] }) def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, @@ -75,7 +76,8 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera num_proc=num_proc) for result in self.api: - print(result) + with open('kk.kk', 'w') as f: + f.write(str(result)) ait = result.metadata['all_inspiration_tags'] ex = result.metadata['existing_examples_used'] req = result.metadata['required_examples'] From a54bd52b4c800f61dc1f015c35464b2bad1df43a Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 12:33:42 +0100 Subject: [PATCH 216/300] Refactors --- grants_tagger_light/augmentation/augment.py | 8 +- .../augmentation/augment_openai.py | 106 ++++++++++++------ 2 files changed, 76 insertions(+), 38 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index c0b48249..65fcbbbe 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -40,7 +40,11 @@ def _merge_dicts(dict_list): def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine, train_years, num_proc, model_key): - counter = 0 + year = [random.choice(train_years) if train_years is not None and isinstance(train_years, list) + else datetime.date.year] + augmentation_engine.generate(collect_concurrent_calls, dset, save_to_path, year, model_key, + few_shot_examples=few_shot_examples, num_proc=num_proc) + """counter = 0 with open(save_to_path, 'a') as f: for a in augmentation_engine.generate(collect_concurrent_calls, dset, few_shot_examples=few_shot_examples, num_proc=num_proc, save_to_path=save_to_path): @@ -60,7 +64,7 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, })) f.write('\n') f.flush() - counter += 1 + counter += 1""" def augment( diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index f6a2d2a4..780e6b97 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -1,5 +1,8 @@ import json import os +import random +import uuid + from loguru import logger import openai from openai_multi_client import OpenAIMultiClient @@ -15,13 +18,18 @@ def __init__(self, prompt_template_path, model_key='gpt-3.5-turbo'): self.model_key = model_key self.api = OpenAIMultiClient(endpoint="chats", data_template={"model": self.model_key}) - def _create_message(self, featured_tag, featured_tag_dset): - abstracts = "\n".join(featured_tag_dset['abstractText']) + + @staticmethod + def _get_secondary_tags(featured_tag_dset): tags = [] for x in featured_tag_dset['meshMajor']: tags.extend(x) + return ",".join(list(set(tags))) - mesh_tags = ",".join(list(set(tags))) + + def _create_message(self, featured_tag, featured_tag_dset): + abstracts = "\n".join(featured_tag_dset['abstractText']) + mesh_tags = self._get_secondary_tags(featured_tag_dset) prompt = self.prompt_template.replace('{FEATURED_TAG}', featured_tag) prompt = prompt.replace('{ABSTRACTS}', abstracts) @@ -29,6 +37,46 @@ def _create_message(self, featured_tag, featured_tag_dset): return [{"role": "user", "content": prompt}] + @staticmethod + def _process_response(result): + with open('kk.kk', 'w') as f: + f.write(str(result)) + + if result.failed: + logger.warning(f'Failed to get augmentation from {}') + return + + with open(result.metadata['save_to_path'], 'w') as f: + + for r in result.response['choices']: + if 'message' in r: + if 'content' in r['message']: + try: + pieces = json.loads(r['message']['content']) + a = pieces['abstract'] + t = pieces['tags'] + tl = pieces['title'] + + f.write(json.dumps({ + "journal": result.metadata['model_key'], + "meshMajor": t, + "year": [ + result.metadata['year'] + ], + "abstractText": a['abstract'], + "pmid": uuid.uuid4().hex, + "title": tl, + "first_existing_example": result.metadata['first_existing_example'], + "all_inspiration_tags": result.metadata['all_inspiration_tags'], + "required_examples": result.metadata['required_examples'], + "featured_tag": result.metadata['featured_tag'] + })) + f.write('\n') + f.flush() + except Exception as e: + logger.info("OpenAI did not return a proper json format...") + + def _make_requests(self, collect_concurrent_calls, dset, @@ -36,7 +84,10 @@ def _make_requests(self, temperature, top_p, presence_penalty, - num_proc): + num_proc, + year, + model_key, + save_to_path): for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] n = collect_concurrent_calls[num][1] @@ -48,7 +99,7 @@ def _make_requests(self, num_proc=num_proc) size = min(len(tmp_dset), few_shot_examples) tmp_dset = tmp_dset[:size] - + ait = self._get_secondary_tags(dset) self.api.request(data={ "model": self.model_key, "n": n, @@ -57,14 +108,19 @@ def _make_requests(self, "presence_penalty": presence_penalty, "messages": self._create_message(t, tmp_dset) }, metadata={ - 'all_inspiration_tags': t, + 'featured_tag': t, + 'all_inspiration_tags': ait, 'required_examples': n, 'existing_examples_used': size, - 'first_existing_example': tmp_dset['abstractText'][0] - }) + 'first_existing_example': tmp_dset['abstractText'][0], + 'year': year, + 'model_key': model_key, + 'save_to_path': save_to_path + }, callback=self._process_response + ) - def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, temperature=1.5, top_p=1, - presence_penalty=0, num_proc=os.cpu_count(), save_to_path='err'): + def generate(self, collect_concurrent_calls, dset, save_to_path, year, model_key, few_shot_examples=10, + temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count() ): self.api.run_request_function(self._make_requests, collect_concurrent_calls=collect_concurrent_calls, @@ -73,30 +129,8 @@ def generate(self, collect_concurrent_calls, dset, few_shot_examples=10, tempera temperature=temperature, top_p=top_p, presence_penalty=presence_penalty, - num_proc=num_proc) - - for result in self.api: - with open('kk.kk', 'w') as f: - f.write(str(result)) - ait = result.metadata['all_inspiration_tags'] - ex = result.metadata['existing_examples_used'] - req = result.metadata['required_examples'] - for r in result.response['choices']: - if 'message' in r: - if 'content' in r['message']: - try: - pieces = json.loads(r['message']['content']) - a = pieces['abstract'] - t = pieces['tags'] - tl = pieces['title'] - yield {'abstract': a, - 'tags': t, - 'title': tl, - 'inspiration_examples': ex, - 'all_inspiration_tags': ait, - 'required_examples': req - } - except Exception as e: - logger.info("OpenAI did not return a proper json format...") - yield None + num_proc=num_proc, + year=year, + model_key=model_key, + save_to_path=save_to_path) From 7135c1e4016936c5d8c020f868fedca4f17c2bc6 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 12:35:38 +0100 Subject: [PATCH 217/300] Refactors --- grants_tagger_light/augmentation/augment_openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 780e6b97..3300d886 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -43,7 +43,7 @@ def _process_response(result): f.write(str(result)) if result.failed: - logger.warning(f'Failed to get augmentation from {}') + logger.warning(f"Failed to get augmentation for {result.metadata['featured_tag']}") return with open(result.metadata['save_to_path'], 'w') as f: From f68785b05c77f93b773b3b7da7edfa3e5aef70cf Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 12:41:01 +0100 Subject: [PATCH 218/300] Adds sleep --- grants_tagger_light/augmentation/augment.py | 33 ++++++------------- .../augmentation/augment_openai.py | 1 - 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 65fcbbbe..9c73c6e8 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -2,6 +2,7 @@ import multiprocessing import os import random +import time import typer from loguru import logger @@ -44,27 +45,6 @@ def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, else datetime.date.year] augmentation_engine.generate(collect_concurrent_calls, dset, save_to_path, year, model_key, few_shot_examples=few_shot_examples, num_proc=num_proc) - """counter = 0 - with open(save_to_path, 'a') as f: - for a in augmentation_engine.generate(collect_concurrent_calls, dset, few_shot_examples=few_shot_examples, - num_proc=num_proc, save_to_path=save_to_path): - if a is None: - break - - f.write(json.dumps({ - "journal": model_key, - "meshMajor": a['tags'], - "year": [random.choice(train_years) if len(train_years) > 0 else datetime.date.today().year], - "abstractText": a['abstract'], - "pmid": uuid.uuid4().hex, - "title": a['title'], - "inspiration_examples": a['inspiration_examples'], - "all_inspiration_tags": a['all_inspiration_tags'], - "required_examples": a['required_examples'] - })) - f.write('\n') - f.flush() - counter += 1""" def augment( @@ -78,7 +58,8 @@ def augment( min_examples: int = 15, prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', few_shot_examples: int = 5, - concurrent_calls: int = 5 + concurrent_calls: int = 5, + sleep: int = 5 ): if model_key.strip().lower().startswith('gpt-3.5-turbo') or \ model_key.strip().lower().startswith('text-davinci') or \ @@ -135,6 +116,7 @@ def augment( _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine, train_years, num_proc, model_key) collect_concurrent_calls = [] + time.sleep(sleep) else: if tags_to_augment_counts[t] < min_examples: missing = min_examples - tags_to_augment_counts[t] @@ -185,6 +167,10 @@ def augment_cli( 5, help="Concurrent calls with 1 tag each to the different model" ), + sleep: int = typer.Option( + 10, + help="Time to wait before each concurrent call" + ), ): if not data_path.endswith("jsonl"): logger.error( @@ -203,5 +189,6 @@ def augment_cli( min_examples=min_examples, prompt_template=prompt_template, few_shot_examples=few_shot_examples, - concurrent_calls=concurrent_calls + concurrent_calls=concurrent_calls, + sleep=sleep ) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 3300d886..a1705b6c 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -76,7 +76,6 @@ def _process_response(result): except Exception as e: logger.info("OpenAI did not return a proper json format...") - def _make_requests(self, collect_concurrent_calls, dset, From 4c06616ef0e7b3a682853fd79836b210b1f0a6ce Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 12:43:19 +0100 Subject: [PATCH 219/300] Adds sleep --- grants_tagger_light/augmentation/augment_openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index a1705b6c..c1673036 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -41,7 +41,7 @@ def _create_message(self, featured_tag, featured_tag_dset): def _process_response(result): with open('kk.kk', 'w') as f: f.write(str(result)) - + print(result) if result.failed: logger.warning(f"Failed to get augmentation for {result.metadata['featured_tag']}") return From 72d080b063e39c974a8449bda2367b677ec800d8 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 12:49:53 +0100 Subject: [PATCH 220/300] Adds sleep --- .../augmentation/augment_openai.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index c1673036..6539f557 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -99,14 +99,17 @@ def _make_requests(self, size = min(len(tmp_dset), few_shot_examples) tmp_dset = tmp_dset[:size] ait = self._get_secondary_tags(dset) - self.api.request(data={ + + data = { "model": self.model_key, "n": n, "temperature": temperature, "top_p": top_p, "presence_penalty": presence_penalty, "messages": self._create_message(t, tmp_dset) - }, metadata={ + } + + metadata = { 'featured_tag': t, 'all_inspiration_tags': ait, 'required_examples': n, @@ -115,11 +118,15 @@ def _make_requests(self, 'year': year, 'model_key': model_key, 'save_to_path': save_to_path - }, callback=self._process_response + } + + print(data) + print(metadata) + self.api.request(data=data, metadata=metadata, callback=self._process_response ) def generate(self, collect_concurrent_calls, dset, save_to_path, year, model_key, few_shot_examples=10, - temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count() ): + temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count()): self.api.run_request_function(self._make_requests, collect_concurrent_calls=collect_concurrent_calls, From 60fa413d0e1727d1c34b91af3404872ceca70c00 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:15:10 +0100 Subject: [PATCH 221/300] Adds sleep --- .../augmentation/augment_openai.py | 73 +++++++++---------- .../augmentation/prompt.template | 31 ++------ 2 files changed, 42 insertions(+), 62 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 6539f557..1014005f 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -26,14 +26,9 @@ def _get_secondary_tags(featured_tag_dset): tags.extend(x) return ",".join(list(set(tags))) - - def _create_message(self, featured_tag, featured_tag_dset): - abstracts = "\n".join(featured_tag_dset['abstractText']) - mesh_tags = self._get_secondary_tags(featured_tag_dset) - - prompt = self.prompt_template.replace('{FEATURED_TAG}', featured_tag) - prompt = prompt.replace('{ABSTRACTS}', abstracts) - prompt = prompt.replace('{MESH_TAGS}', mesh_tags) + def _create_message(self, abstract, tag): + prompt = self.prompt_template.replace('{TOPIC}', tag) + prompt = prompt.replace('{ABSTRACTS}', abstract) return [{"role": "user", "content": prompt}] @@ -54,20 +49,18 @@ def _process_response(result): try: pieces = json.loads(r['message']['content']) a = pieces['abstract'] - t = pieces['tags'] tl = pieces['title'] f.write(json.dumps({ "journal": result.metadata['model_key'], - "meshMajor": t, + "meshMajor": result.metadata['tags'], "year": [ result.metadata['year'] ], - "abstractText": a['abstract'], + "abstractText": a, "pmid": uuid.uuid4().hex, "title": tl, - "first_existing_example": result.metadata['first_existing_example'], - "all_inspiration_tags": result.metadata['all_inspiration_tags'], + "existing_example": result.metadata['example'], "required_examples": result.metadata['required_examples'], "featured_tag": result.metadata['featured_tag'] })) @@ -98,32 +91,34 @@ def _make_requests(self, num_proc=num_proc) size = min(len(tmp_dset), few_shot_examples) tmp_dset = tmp_dset[:size] - ait = self._get_secondary_tags(dset) - - data = { - "model": self.model_key, - "n": n, - "temperature": temperature, - "top_p": top_p, - "presence_penalty": presence_penalty, - "messages": self._create_message(t, tmp_dset) - } - - metadata = { - 'featured_tag': t, - 'all_inspiration_tags': ait, - 'required_examples': n, - 'existing_examples_used': size, - 'first_existing_example': tmp_dset['abstractText'][0], - 'year': year, - 'model_key': model_key, - 'save_to_path': save_to_path - } - - print(data) - print(metadata) - self.api.request(data=data, metadata=metadata, callback=self._process_response - ) + + for i in range(n): + selected_row = tmp_dset[random.randint(0, len(tmp_dset)-1)] + abstract = selected_row['abstractText'] + tags = selected_row['meshMajor'] + data = { + "model": self.model_key, + "n": n, + "temperature": temperature, + "top_p": top_p, + "presence_penalty": presence_penalty, + "messages": self._create_message(abstract, t) + } + + metadata = { + 'featured_tag': t, + 'tags': tags, + 'required_examples': n, + 'existing_examples_used': size, + 'existing_example': abstract, + 'year': year, + 'model_key': model_key, + 'save_to_path': save_to_path + } + + print(data) + print(metadata) + self.api.request(data=data, metadata=metadata, callback=self._process_response) def generate(self, collect_concurrent_calls, dset, save_to_path, year, model_key, few_shot_examples=10, temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count()): diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 77a66e5e..60ecbd15 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -1,27 +1,12 @@ -You are in charge of doing Data Augmentation. I will provide: -- ABSTRACTS to use as a source for augmentation; -- a REQUIRED MESH TAG, the main label to use as a topic for augmentation; -- a series of MESH TAGS, secondary labels of the ABSTRACTS. - -You need to return a json file containing: -- Field "abstract": the NEW ABSTRACT featuring the REQUIRED MESH TAG, which must use the information from the ABSTRACTS but be completely new. It should contain minimum 200 words. Remove all quotes from the generated NEW ABSTRACT. -- Field "title": a small TITLE summarizing the NEW ABSTRACT. Remove all quotes from the generated TITLE. -- Field "tags": all the MESH TAGS which are relevant to the "abstract" you generated, taken from MESH TAGS. Remove those not relevant. Don't create new tags not contained in MESH TAGS. Don't modify them, write exactly as they are in MESH TAGS. Remove all quotes from the generated tags. - -There are several conditions for the output you will produce: -- Return a well-formed json only including the fields "abstract", "title", "tags" and "tags". -- Make sure each tag starts and finishes with a double quote. -- Make sure there is only starting and ending quotes in both keys and values. Remove any other quotes. -- Make sure the json is well formed and can be parsed. -- Make sure each tag in "tags" is contained in MESH TAGS. +You are in charge of doing Data Augmentation. I will provide an ABSTRACT and a TOPIC and you will create a json with two fields: +1. 'abstract': a variation of that ABSTRACT talking about TOPIC. Some creativity is allowed. Remove all the quotes except the starting and ending of the field. +2. 'title': a sentence summarizing the abstract. Remove all the quotes except the starting and ending of the field. +Make sure the json is well formed. ======================= -REQUIRED MESH TAG: -{FEATURED_TAG} - -ABSTRACTS: -{ABSTRACTS} +TOPIC +{TOPIC} -MESH TAGS: -{MESH_TAGS} \ No newline at end of file +ABSTRACT: +{ABSTRACT} \ No newline at end of file From 53452f81ff5e521d2cebf11270c66bf1736cade4 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:17:14 +0100 Subject: [PATCH 222/300] Adds sleep --- grants_tagger_light/augmentation/augment_openai.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 1014005f..e1591a29 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -93,9 +93,9 @@ def _make_requests(self, tmp_dset = tmp_dset[:size] for i in range(n): - selected_row = tmp_dset[random.randint(0, len(tmp_dset)-1)] - abstract = selected_row['abstractText'] - tags = selected_row['meshMajor'] + selected_row = random.randint(0, len(tmp_dset)-1) + abstract = tmp_dset['abstractText'][selected_row] + tags = tmp_dset['meshMajor'][selected_row] data = { "model": self.model_key, "n": n, From 808e2515b01384c5f69fef42dcac392d29271cb0 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:21:23 +0100 Subject: [PATCH 223/300] Sends one by one --- grants_tagger_light/augmentation/augment_openai.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index e1591a29..554df9e5 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -28,12 +28,13 @@ def _get_secondary_tags(featured_tag_dset): def _create_message(self, abstract, tag): prompt = self.prompt_template.replace('{TOPIC}', tag) - prompt = prompt.replace('{ABSTRACTS}', abstract) + prompt = prompt.replace('{ABSTRACT}', abstract) return [{"role": "user", "content": prompt}] @staticmethod def _process_response(result): + print("Response!!!") with open('kk.kk', 'w') as f: f.write(str(result)) print(result) @@ -94,7 +95,11 @@ def _make_requests(self, for i in range(n): selected_row = random.randint(0, len(tmp_dset)-1) + print(f"Selected row: {selected_row}") + print(tmp_dset['abstractText']) + print(len(tmp_dset['abstractText'])) abstract = tmp_dset['abstractText'][selected_row] + print(tmp_dset['abstractText']) tags = tmp_dset['meshMajor'][selected_row] data = { "model": self.model_key, From dd3d76369204e92eeb7bf3dbc0122e8815f4e913 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:24:36 +0100 Subject: [PATCH 224/300] Sends one by one --- grants_tagger_light/augmentation/augment_openai.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 554df9e5..1bbce7a3 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -96,10 +96,8 @@ def _make_requests(self, for i in range(n): selected_row = random.randint(0, len(tmp_dset)-1) print(f"Selected row: {selected_row}") - print(tmp_dset['abstractText']) - print(len(tmp_dset['abstractText'])) + print(f"Size of dataset: {len(tmp_dset)}") abstract = tmp_dset['abstractText'][selected_row] - print(tmp_dset['abstractText']) tags = tmp_dset['meshMajor'][selected_row] data = { "model": self.model_key, From 9a1d245e60d29672d6b9e15614250120a63be7be Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:27:02 +0100 Subject: [PATCH 225/300] Sends one by one --- grants_tagger_light/augmentation/augment_openai.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 1bbce7a3..38620aa0 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -90,13 +90,11 @@ def _make_requests(self, # I remove them from the dataset to process to make it smaller and quicker over time dset = dset.filter(lambda example, idx: idx not in tmp_dset['idx'], with_indices=True, num_proc=num_proc) - size = min(len(tmp_dset), few_shot_examples) - tmp_dset = tmp_dset[:size] for i in range(n): selected_row = random.randint(0, len(tmp_dset)-1) print(f"Selected row: {selected_row}") - print(f"Size of dataset: {len(tmp_dset)}") + print(f"Size of dataset: {len(tmp_dset['abstractText'])}") abstract = tmp_dset['abstractText'][selected_row] tags = tmp_dset['meshMajor'][selected_row] data = { From d980d8cbbff8fbea1283cc8fd630ea769743999d Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:28:31 +0100 Subject: [PATCH 226/300] Sends one by one --- grants_tagger_light/augmentation/augment_openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 38620aa0..cd773735 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -110,7 +110,6 @@ def _make_requests(self, 'featured_tag': t, 'tags': tags, 'required_examples': n, - 'existing_examples_used': size, 'existing_example': abstract, 'year': year, 'model_key': model_key, From 5830f70eef897d8c3e76d20d98016784b277f113 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:32:38 +0100 Subject: [PATCH 227/300] Changes from static to global --- .../augmentation/augment_openai.py | 87 +++++++++---------- 1 file changed, 39 insertions(+), 48 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index cd773735..916e8c0d 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -9,6 +9,43 @@ import numpy as np +def process_response(result): + print("Response!!!") + with open('kk.kk', 'w') as f: + f.write(str(result)) + print(result) + if result.failed: + logger.warning(f"Failed to get augmentation for {result.metadata['featured_tag']}") + return + + with open(result.metadata['save_to_path'], 'w') as f: + + for r in result.response['choices']: + if 'message' in r: + if 'content' in r['message']: + try: + pieces = json.loads(r['message']['content']) + a = pieces['abstract'] + tl = pieces['title'] + + f.write(json.dumps({ + "journal": result.metadata['model_key'], + "meshMajor": result.metadata['tags'], + "year": [ + result.metadata['year'] + ], + "abstractText": a, + "pmid": uuid.uuid4().hex, + "title": tl, + "existing_example": result.metadata['example'], + "required_examples": result.metadata['required_examples'], + "featured_tag": result.metadata['featured_tag'] + })) + f.write('\n') + f.flush() + except Exception as e: + logger.info("OpenAI did not return a proper json format...") + class AugmentOpenAI: def __init__(self, prompt_template_path, model_key='gpt-3.5-turbo'): if 'OPENAI_API_KEY' not in os.environ: @@ -18,57 +55,13 @@ def __init__(self, prompt_template_path, model_key='gpt-3.5-turbo'): self.model_key = model_key self.api = OpenAIMultiClient(endpoint="chats", data_template={"model": self.model_key}) - - @staticmethod - def _get_secondary_tags(featured_tag_dset): - tags = [] - for x in featured_tag_dset['meshMajor']: - tags.extend(x) - return ",".join(list(set(tags))) - def _create_message(self, abstract, tag): prompt = self.prompt_template.replace('{TOPIC}', tag) prompt = prompt.replace('{ABSTRACT}', abstract) return [{"role": "user", "content": prompt}] - @staticmethod - def _process_response(result): - print("Response!!!") - with open('kk.kk', 'w') as f: - f.write(str(result)) - print(result) - if result.failed: - logger.warning(f"Failed to get augmentation for {result.metadata['featured_tag']}") - return - - with open(result.metadata['save_to_path'], 'w') as f: - - for r in result.response['choices']: - if 'message' in r: - if 'content' in r['message']: - try: - pieces = json.loads(r['message']['content']) - a = pieces['abstract'] - tl = pieces['title'] - - f.write(json.dumps({ - "journal": result.metadata['model_key'], - "meshMajor": result.metadata['tags'], - "year": [ - result.metadata['year'] - ], - "abstractText": a, - "pmid": uuid.uuid4().hex, - "title": tl, - "existing_example": result.metadata['example'], - "required_examples": result.metadata['required_examples'], - "featured_tag": result.metadata['featured_tag'] - })) - f.write('\n') - f.flush() - except Exception as e: - logger.info("OpenAI did not return a proper json format...") + def _make_requests(self, collect_concurrent_calls, @@ -116,9 +109,7 @@ def _make_requests(self, 'save_to_path': save_to_path } - print(data) - print(metadata) - self.api.request(data=data, metadata=metadata, callback=self._process_response) + self.api.request(data=data, metadata=metadata, callback=process_response) def generate(self, collect_concurrent_calls, dset, save_to_path, year, model_key, few_shot_examples=10, temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count()): From 29f32f95dfb390b1a337f5519714826bbee065f2 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:40:36 +0100 Subject: [PATCH 228/300] Removes param --- grants_tagger_light/augmentation/augment.py | 15 +-- .../augmentation/augment_openai.py | 100 +++++++++--------- 2 files changed, 56 insertions(+), 59 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 9c73c6e8..23e61b73 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -39,12 +39,10 @@ def _merge_dicts(dict_list): return merged_dict -def _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, +def _generate(collect_concurrent_calls, dset, save_to_path, augmentation_engine, train_years, num_proc, model_key): - year = [random.choice(train_years) if train_years is not None and isinstance(train_years, list) - else datetime.date.year] - augmentation_engine.generate(collect_concurrent_calls, dset, save_to_path, year, model_key, - few_shot_examples=few_shot_examples, num_proc=num_proc) + augmentation_engine.generate(collect_concurrent_calls, dset, save_to_path, train_years, model_key, + num_proc=num_proc) def augment( @@ -57,7 +55,6 @@ def augment( test_years: list = None, min_examples: int = 15, prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', - few_shot_examples: int = 5, concurrent_calls: int = 5, sleep: int = 5 ): @@ -113,7 +110,7 @@ def augment( collect_concurrent_calls = [] for t in tags_to_augment: if len(collect_concurrent_calls) >= concurrent_calls: - _generate(collect_concurrent_calls, dset, few_shot_examples, save_to_path, augmentation_engine, + _generate(collect_concurrent_calls, dset, save_to_path, augmentation_engine, train_years, num_proc, model_key) collect_concurrent_calls = [] time.sleep(sleep) @@ -159,10 +156,6 @@ def augment_cli( 'grants_tagger_light/augmentation/prompt.template', help="File to use as a prompt. Make sure to ask the LLM to return a dict with two fields: `abstract` and `tags`" ), - few_shot_examples: int = typer.Option( - 5, - help="If available, try to send this number of examples to the LLM so that it can generate better abstracts" - ), concurrent_calls: int = typer.Option( 5, help="Concurrent calls with 1 tag each to the different model" diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 916e8c0d..b77c5eb0 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -1,3 +1,4 @@ +import datetime import json import os import random @@ -9,43 +10,6 @@ import numpy as np -def process_response(result): - print("Response!!!") - with open('kk.kk', 'w') as f: - f.write(str(result)) - print(result) - if result.failed: - logger.warning(f"Failed to get augmentation for {result.metadata['featured_tag']}") - return - - with open(result.metadata['save_to_path'], 'w') as f: - - for r in result.response['choices']: - if 'message' in r: - if 'content' in r['message']: - try: - pieces = json.loads(r['message']['content']) - a = pieces['abstract'] - tl = pieces['title'] - - f.write(json.dumps({ - "journal": result.metadata['model_key'], - "meshMajor": result.metadata['tags'], - "year": [ - result.metadata['year'] - ], - "abstractText": a, - "pmid": uuid.uuid4().hex, - "title": tl, - "existing_example": result.metadata['example'], - "required_examples": result.metadata['required_examples'], - "featured_tag": result.metadata['featured_tag'] - })) - f.write('\n') - f.flush() - except Exception as e: - logger.info("OpenAI did not return a proper json format...") - class AugmentOpenAI: def __init__(self, prompt_template_path, model_key='gpt-3.5-turbo'): if 'OPENAI_API_KEY' not in os.environ: @@ -61,19 +25,58 @@ def _create_message(self, abstract, tag): return [{"role": "user", "content": prompt}] - + @staticmethod + def _process_response(result): + print("Response!!!") + with open('kk.kk', 'w') as f: + f.write(str(result)) + print(result) + if result.failed: + logger.warning(f"Failed to get augmentation for {result.metadata['featured_tag']}") + return + + with open(result.metadata['save_to_path'], 'w') as f: + + for r in result.response['choices']: + if 'message' in r: + if 'content' in r['message']: + try: + pieces = json.loads(r['message']['content']) + a = pieces['abstract'] + tl = pieces['title'] + + f.write(json.dumps({ + "journal": result.metadata['model_key'], + "meshMajor": result.metadata['tags'], + "year": [ + result.metadata['year'] + ], + "abstractText": a, + "pmid": uuid.uuid4().hex, + "title": tl, + "existing_example": result.metadata['example'], + "required_examples": result.metadata['required_examples'], + "featured_tag": result.metadata['featured_tag'] + })) + f.write('\n') + f.flush() + except Exception as e: + logger.info("OpenAI did not return a proper json format...") def _make_requests(self, collect_concurrent_calls, dset, - few_shot_examples, temperature, top_p, presence_penalty, num_proc, - year, + train_years, model_key, save_to_path): + + year = [random.choice(train_years) if train_years is not None and isinstance(train_years, list) + else datetime.date.year] + for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] n = collect_concurrent_calls[num][1] @@ -84,10 +87,11 @@ def _make_requests(self, dset = dset.filter(lambda example, idx: idx not in tmp_dset['idx'], with_indices=True, num_proc=num_proc) + abstracts_num = [i for i in range(len(tmp_dset))] + random.shuffle(abstracts_num) + for i in range(n): - selected_row = random.randint(0, len(tmp_dset)-1) - print(f"Selected row: {selected_row}") - print(f"Size of dataset: {len(tmp_dset['abstractText'])}") + selected_row = abstracts_num[i % len(tmp_dset)] abstract = tmp_dset['abstractText'][selected_row] tags = tmp_dset['meshMajor'][selected_row] data = { @@ -109,20 +113,20 @@ def _make_requests(self, 'save_to_path': save_to_path } - self.api.request(data=data, metadata=metadata, callback=process_response) + self.api.request(data=data, metadata=metadata, callback=self._process_response) - def generate(self, collect_concurrent_calls, dset, save_to_path, year, model_key, few_shot_examples=10, + def generate(self, collect_concurrent_calls, dset, save_to_path, train_years, model_key, temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count()): - self.api.run_request_function(self._make_requests, collect_concurrent_calls=collect_concurrent_calls, dset=dset, - few_shot_examples=few_shot_examples, temperature=temperature, top_p=top_p, presence_penalty=presence_penalty, num_proc=num_proc, - year=year, + year=train_years, model_key=model_key, save_to_path=save_to_path) + self.api.pull_all() + From 82ef9e9fe61ef3bc9bc865be630690903b240383 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:41:04 +0100 Subject: [PATCH 229/300] Removes param --- grants_tagger_light/augmentation/augment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 23e61b73..6b0d1a9f 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -181,7 +181,6 @@ def augment_cli( test_years=parse_years(test_years), min_examples=min_examples, prompt_template=prompt_template, - few_shot_examples=few_shot_examples, concurrent_calls=concurrent_calls, sleep=sleep ) From efb2ef545bc7f30b8ac233046e11842a5f1bab7f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:42:24 +0100 Subject: [PATCH 230/300] Removes param --- grants_tagger_light/augmentation/augment_openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index b77c5eb0..c79b6635 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -124,7 +124,7 @@ def generate(self, collect_concurrent_calls, dset, save_to_path, train_years, mo top_p=top_p, presence_penalty=presence_penalty, num_proc=num_proc, - year=train_years, + train_years=train_years, model_key=model_key, save_to_path=save_to_path) From 9319a9dcfb30c62d40675f91c7aa56fbca3dff34 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:45:26 +0100 Subject: [PATCH 231/300] Checks error --- grants_tagger_light/augmentation/augment.py | 3 --- grants_tagger_light/augmentation/augment_openai.py | 4 +--- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 6b0d1a9f..4f0d66d4 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -1,15 +1,12 @@ import json import multiprocessing import os -import random import time import typer from loguru import logger from datasets import load_dataset import numpy as np -import datetime -import uuid from grants_tagger_light.augmentation.augment_openai import AugmentOpenAI diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index c79b6635..8be9ad09 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -27,10 +27,8 @@ def _create_message(self, abstract, tag): @staticmethod def _process_response(result): - print("Response!!!") with open('kk.kk', 'w') as f: f.write(str(result)) - print(result) if result.failed: logger.warning(f"Failed to get augmentation for {result.metadata['featured_tag']}") return @@ -61,7 +59,7 @@ def _process_response(result): f.write('\n') f.flush() except Exception as e: - logger.info("OpenAI did not return a proper json format...") + logger.info(f"Error processing output: {e}") def _make_requests(self, collect_concurrent_calls, From 22993e11ed8bf5554f584aa1eef388693e6f5b29 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 13:47:26 +0100 Subject: [PATCH 232/300] Fixes bug with metadata field name --- grants_tagger_light/augmentation/augment_openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 8be9ad09..e59dfeef 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -52,7 +52,7 @@ def _process_response(result): "abstractText": a, "pmid": uuid.uuid4().hex, "title": tl, - "existing_example": result.metadata['example'], + "existing_example": result.metadata['existing_example'], "required_examples": result.metadata['required_examples'], "featured_tag": result.metadata['featured_tag'] })) From be78b3c77201e42ee9b6322e12eb036b00d21171 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 14:08:57 +0100 Subject: [PATCH 233/300] Adds JsonParser --- .../augmentation/JsonParser.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 grants_tagger_light/augmentation/JsonParser.py diff --git a/grants_tagger_light/augmentation/JsonParser.py b/grants_tagger_light/augmentation/JsonParser.py new file mode 100644 index 00000000..71c2a1e1 --- /dev/null +++ b/grants_tagger_light/augmentation/JsonParser.py @@ -0,0 +1,63 @@ +""" +From langchain: https://raw.githubusercontent.com/langchain-ai/langchain/master/libs/langchain/langchain/output_parsers/json.py +""" + +import json +import re + + +class JsonParser: + + @staticmethod + def _replace_new_line(match: re.Match[str]) -> str: + value = match.group(2) + value = re.sub(r"\n", r"\\n", value) + value = re.sub(r"\r", r"\\r", value) + value = re.sub(r"\t", r"\\t", value) + value = re.sub('"', r"\"", value) + + return match.group(1) + value + match.group(3) + + @staticmethod + def _custom_parser(multiline_string: str) -> str: + """ + The LLM response for `action_input` may be a multiline + string containing unescaped newlines, tabs or quotes. This function + replaces those characters with their escaped counterparts. + (newlines in JSON must be double-escaped: `\\n`) + """ + if isinstance(multiline_string, (bytes, bytearray)): + multiline_string = multiline_string.decode() + + multiline_string = re.sub( + r'("action_input"\:\s*")(.*)(")', + JsonParser._replace_new_line, + multiline_string, + flags=re.DOTALL, + ) + + return multiline_string + + @staticmethod + def parse_json(json_string: str) -> dict: + """ + Parse a JSON string from LLM response + + Args: + json_string: The Markdown string. + + Returns: + The parsed JSON object as a Python dictionary. + """ + json_str = json_string + + # Strip whitespace and newlines from the start and end + json_str = json_str.strip() + + # handle newlines and other special characters inside the returned value + json_str = JsonParser._custom_parser(json_str) + + # Parse the JSON string into a Python dictionary + parsed = json.loads(json_str) + + return parsed From 425c12b3cebe1546a2d96d42a8d7edfe95d872a2 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 14:10:13 +0100 Subject: [PATCH 234/300] Adds JsonParser --- grants_tagger_light/augmentation/augment_openai.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index e59dfeef..cd5ec694 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -9,6 +9,8 @@ from openai_multi_client import OpenAIMultiClient import numpy as np +from grants_tagger_light.augmentation.JsonParser import JsonParser + class AugmentOpenAI: def __init__(self, prompt_template_path, model_key='gpt-3.5-turbo'): @@ -39,9 +41,9 @@ def _process_response(result): if 'message' in r: if 'content' in r['message']: try: - pieces = json.loads(r['message']['content']) - a = pieces['abstract'] - tl = pieces['title'] + json_response = JsonParser.parse_json(r['message']['content']) + a = json_response['abstract'] + tl = json_response['title'] f.write(json.dumps({ "journal": result.metadata['model_key'], From a45a0f1f14f269ecdcbd55881fcdd230b3819a10 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 14:16:04 +0100 Subject: [PATCH 235/300] Removes sleep --- grants_tagger_light/augmentation/augment.py | 12 +----- .../augmentation/augment_openai.py | 42 ++++++++++--------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 4f0d66d4..e08e07be 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -1,7 +1,6 @@ import json import multiprocessing import os -import time import typer from loguru import logger @@ -52,8 +51,7 @@ def augment( test_years: list = None, min_examples: int = 15, prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', - concurrent_calls: int = 5, - sleep: int = 5 + concurrent_calls: int = 5 ): if model_key.strip().lower().startswith('gpt-3.5-turbo') or \ model_key.strip().lower().startswith('text-davinci') or \ @@ -110,7 +108,6 @@ def augment( _generate(collect_concurrent_calls, dset, save_to_path, augmentation_engine, train_years, num_proc, model_key) collect_concurrent_calls = [] - time.sleep(sleep) else: if tags_to_augment_counts[t] < min_examples: missing = min_examples - tags_to_augment_counts[t] @@ -157,10 +154,6 @@ def augment_cli( 5, help="Concurrent calls with 1 tag each to the different model" ), - sleep: int = typer.Option( - 10, - help="Time to wait before each concurrent call" - ), ): if not data_path.endswith("jsonl"): logger.error( @@ -178,6 +171,5 @@ def augment_cli( test_years=parse_years(test_years), min_examples=min_examples, prompt_template=prompt_template, - concurrent_calls=concurrent_calls, - sleep=sleep + concurrent_calls=concurrent_calls ) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index cd5ec694..6df60646 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -42,26 +42,30 @@ def _process_response(result): if 'content' in r['message']: try: json_response = JsonParser.parse_json(r['message']['content']) - a = json_response['abstract'] - tl = json_response['title'] - - f.write(json.dumps({ - "journal": result.metadata['model_key'], - "meshMajor": result.metadata['tags'], - "year": [ - result.metadata['year'] - ], - "abstractText": a, - "pmid": uuid.uuid4().hex, - "title": tl, - "existing_example": result.metadata['existing_example'], - "required_examples": result.metadata['required_examples'], - "featured_tag": result.metadata['featured_tag'] - })) - f.write('\n') - f.flush() except Exception as e: - logger.info(f"Error processing output: {e}") + logger.info(f"Error processing output: {e}. Skipping...") + continue + + a = json_response['abstract'] + tl = json_response['title'] + + f.write(json.dumps({ + "journal": result.metadata['model_key'], + "meshMajor": result.metadata['tags'], + "year": [ + result.metadata['year'] + ], + "abstractText": a, + "pmid": uuid.uuid4().hex, + "title": tl, + "existing_example": result.metadata['existing_example'], + "required_examples": result.metadata['required_examples'], + "featured_tag": result.metadata['featured_tag'] + })) + f.write('\n') + f.flush() + + logger.info(f"Data received successfully for {result.metadata['featured_tag']}") def _make_requests(self, collect_concurrent_calls, From 62f76e27e4e7cd9d9f95baaf88c6eb834abecca3 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 14:55:01 +0100 Subject: [PATCH 236/300] Prevents locks --- grants_tagger_light/augmentation/augment.py | 23 +++++++++------------ 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index e08e07be..743044ee 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -35,12 +35,6 @@ def _merge_dicts(dict_list): return merged_dict -def _generate(collect_concurrent_calls, dset, save_to_path, - augmentation_engine, train_years, num_proc, model_key): - augmentation_engine.generate(collect_concurrent_calls, dset, save_to_path, train_years, model_key, - num_proc=num_proc) - - def augment( data_path: str, save_to_path: str, @@ -53,11 +47,7 @@ def augment( prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', concurrent_calls: int = 5 ): - if model_key.strip().lower().startswith('gpt-3.5-turbo') or \ - model_key.strip().lower().startswith('text-davinci') or \ - model_key.strip().lower().startswith('gpt-4'): - augmentation_engine = AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key) - else: + if model_key.strip().lower() not in ['gpt-3.5-turbo', 'text-davinci', 'gpt-4']: raise NotImplementedError(f"{model_key} not implemented as an augmentation framework") # We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing @@ -105,8 +95,15 @@ def augment( collect_concurrent_calls = [] for t in tags_to_augment: if len(collect_concurrent_calls) >= concurrent_calls: - _generate(collect_concurrent_calls, dset, save_to_path, augmentation_engine, - train_years, num_proc, model_key) + AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key).generate( + collect_concurrent_calls, + dset, + save_to_path, + train_years, + model_key, + temperature=1.5, + num_proc=num_proc, + ) collect_concurrent_calls = [] else: if tags_to_augment_counts[t] < min_examples: From e674dd845a7533bcc348d5f1e32678eedfc2d97d Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 14:56:49 +0100 Subject: [PATCH 237/300] Write > Append --- grants_tagger_light/augmentation/augment_openai.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 6df60646..1b32249a 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -29,14 +29,11 @@ def _create_message(self, abstract, tag): @staticmethod def _process_response(result): - with open('kk.kk', 'w') as f: - f.write(str(result)) if result.failed: logger.warning(f"Failed to get augmentation for {result.metadata['featured_tag']}") return - with open(result.metadata['save_to_path'], 'w') as f: - + with open(result.metadata['save_to_path'], 'a') as f: for r in result.response['choices']: if 'message' in r: if 'content' in r['message']: From a81760b26b6ac8653941f84ebbf78fd0421ce5b5 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 15:04:56 +0100 Subject: [PATCH 238/300] Write > Append --- grants_tagger_light/augmentation/augment_openai.py | 4 +--- grants_tagger_light/augmentation/prompt.template | 14 +++++++------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 1b32249a..ab1bfe46 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -49,9 +49,7 @@ def _process_response(result): f.write(json.dumps({ "journal": result.metadata['model_key'], "meshMajor": result.metadata['tags'], - "year": [ - result.metadata['year'] - ], + "year": result.metadata['year'], "abstractText": a, "pmid": uuid.uuid4().hex, "title": tl, diff --git a/grants_tagger_light/augmentation/prompt.template b/grants_tagger_light/augmentation/prompt.template index 60ecbd15..45eb75d3 100644 --- a/grants_tagger_light/augmentation/prompt.template +++ b/grants_tagger_light/augmentation/prompt.template @@ -1,12 +1,12 @@ -You are in charge of doing Data Augmentation. I will provide an ABSTRACT and a TOPIC and you will create a json with two fields: -1. 'abstract': a variation of that ABSTRACT talking about TOPIC. Some creativity is allowed. Remove all the quotes except the starting and ending of the field. -2. 'title': a sentence summarizing the abstract. Remove all the quotes except the starting and ending of the field. +You will act as a Data Augmentation engine. I will provide an ABSTRACT and a TOPIC and you will create a json with two fields: +1. 'abstract': a new abstract, created using Data Augmentation, using the ABSTRACT I've sent to you as an inspiration. Make sure the abstract you generate is quite different to the ABSTRACT I've sent to you but keeping the TOPIC. +2. 'title': a sentence summarizing the abstract you have created. Make sure the json is well formed. ======================= -TOPIC -{TOPIC} - ABSTRACT: -{ABSTRACT} \ No newline at end of file +{ABSTRACT} + +TOPIC: +{TOPIC} \ No newline at end of file From 780475ad294cb53f82e3ee093bfde8cc92e273b8 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 21:05:16 +0100 Subject: [PATCH 239/300] Adds different schedulers --- .../augmentation/JsonParser.py | 4 ++ grants_tagger_light/training/train.py | 56 +++++++++++++++---- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/grants_tagger_light/augmentation/JsonParser.py b/grants_tagger_light/augmentation/JsonParser.py index 71c2a1e1..4e41c0e8 100644 --- a/grants_tagger_light/augmentation/JsonParser.py +++ b/grants_tagger_light/augmentation/JsonParser.py @@ -7,6 +7,10 @@ class JsonParser: + def __init(self): + """Class to parse json produced by LLMs. Inspiration taken from langchain. It fixes quotes, + it escapes separators, etc.""" + pass @staticmethod def _replace_new_line(match: re.Match[str]) -> str: diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 5b120c3c..710c64ec 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -5,7 +5,10 @@ HfArgumentParser, AutoConfig, AdamW, - get_cosine_schedule_with_warmup + get_cosine_schedule_with_warmup, + get_constant_scheduler_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup ) from grants_tagger_light.models.bert_mesh import BertMesh from grants_tagger_light.preprocessing.preprocess_mesh import preprocess_mesh @@ -45,7 +48,8 @@ def train_bertmesh( from_checkpoint: str = None, tags: list = None, train_years: list = None, - test_years: list = None + test_years: list = None, + scheduler_type: str = 'cosine_hard_restart' ): if not model_key: assert isinstance(model_args, BertMeshModelArguments), ( @@ -155,16 +159,41 @@ def sklearn_metrics(prediction: EvalPrediction): max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) training_args.max_steps = max_steps - optimizer = AdamW(model.parameters(), - lr=training_args.learning_rate) # Set your desired learning rate + optimizer = AdamW(model.parameters(), lr=training_args.learning_rate) + + if training_args.warmup_steps is None: + training_args.warmup_steps = 0 + + if scheduler_type is None or scheduler_type.lower().strip() == '': + scheduler_type = 'linear' + + if scheduler_type.lower().strip() == 'cosine': + scheduler = get_cosine_schedule_with_warmup(optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=training_args.max_steps) + elif scheduler_type.lower().strip() == 'constant': + scheduler = get_constant_scheduler_with_warmup(optimizer, + num_warmup_steps=training_args.warmup_steps) + elif scheduler_type.lower().strip() == 'cosine_hard_restart': + scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=training_args.max_steps, + num_cycles=training_args.epochs) + elif scheduler_type.lower().strip() == 'linear': + scheduler = get_linear_schedule_with_warmup(optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=training_args.max_steps) + else: + logger.warning(f"{scheduler_type} not recognized. Falling back to `linear`") + scheduler = get_linear_schedule_with_warmup(optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=training_args.max_steps) - # Create a learning rate scheduler - #scheduler = get_cosine_schedule_with_warmup(optimizer, - # num_warmup_steps=training_args.warmup_steps, - # num_training_steps=training_args.max_steps) + logger.info(f"Optimizer: {optimizer}") + logger.info(f"Scheduler: {scheduler}") training_args.optim = optimizer - # training_args.lr_scheduler_type = scheduler + training_args.lr_scheduler_type = scheduler trainer = Trainer( model=model, @@ -173,7 +202,7 @@ def sklearn_metrics(prediction: EvalPrediction): eval_dataset=val_dset, data_collator=collator, compute_metrics=sklearn_metrics, - optimizers=(optimizer, None), + optimizers=(optimizer, scheduler), ) logger.info(training_args) @@ -240,6 +269,10 @@ def train_bertmesh_cli( test_years: str = typer.Option( None, help="Comma-separated years you want to include in the test dataset" + ), + scheduler_type: str = typer.Option( + 'cosine_hard_restart', + help="One of the following lr schedulers: `cosine`, `linear`, `constant`, `cosine_hard_restart`" ) ): parser = HfArgumentParser( @@ -271,5 +304,6 @@ def train_bertmesh_cli( from_checkpoint=from_checkpoint, tags=parse_tags(tags), train_years=parse_years(train_years), - test_years=parse_years(test_years) + test_years=parse_years(test_years), + scheduler_type=scheduler_type ) From 7a2ebe1e16344f428042655d77bfb8830d871f47 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 21:08:03 +0100 Subject: [PATCH 240/300] Parametrizes temperature --- grants_tagger_light/augmentation/augment.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 743044ee..3efd764f 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -45,7 +45,8 @@ def augment( test_years: list = None, min_examples: int = 15, prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', - concurrent_calls: int = 5 + concurrent_calls: int = 5, + temperature: float = 1.5 ): if model_key.strip().lower() not in ['gpt-3.5-turbo', 'text-davinci', 'gpt-4']: raise NotImplementedError(f"{model_key} not implemented as an augmentation framework") @@ -101,7 +102,7 @@ def augment( save_to_path, train_years, model_key, - temperature=1.5, + temperature=temperature, num_proc=num_proc, ) collect_concurrent_calls = [] @@ -148,9 +149,13 @@ def augment_cli( help="File to use as a prompt. Make sure to ask the LLM to return a dict with two fields: `abstract` and `tags`" ), concurrent_calls: int = typer.Option( - 5, + 25, help="Concurrent calls with 1 tag each to the different model" ), + temperature: float = typer.Option( + 1.5, + help="A value between -2 and 2. The bigger - the more creative." + ), ): if not data_path.endswith("jsonl"): logger.error( @@ -159,6 +164,12 @@ def augment_cli( ) exit(-1) + if float(temperature) > 2.0 or float(temperature) < -2.0: + logger.error( + "Temperature should be in the range [-2, 2]" + ) + exit(-1) + augment(data_path, save_to_path, model_key=model_key, @@ -168,5 +179,6 @@ def augment_cli( test_years=parse_years(test_years), min_examples=min_examples, prompt_template=prompt_template, - concurrent_calls=concurrent_calls + concurrent_calls=concurrent_calls, + temperature=temperature ) From 94809f333faa6cb62c2e8577fd09571e7a640749 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 21:09:58 +0100 Subject: [PATCH 241/300] Fixes schedule name bug --- grants_tagger_light/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 710c64ec..edfc1486 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -6,7 +6,7 @@ AutoConfig, AdamW, get_cosine_schedule_with_warmup, - get_constant_scheduler_with_warmup, + get_constant_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup ) @@ -172,7 +172,7 @@ def sklearn_metrics(prediction: EvalPrediction): num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps) elif scheduler_type.lower().strip() == 'constant': - scheduler = get_constant_scheduler_with_warmup(optimizer, + scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps) elif scheduler_type.lower().strip() == 'cosine_hard_restart': scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, From c664a0af78981e821fd82311f9c328ae692397b0 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 21:12:42 +0100 Subject: [PATCH 242/300] Fixes schedule name bug --- grants_tagger_light/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index edfc1486..73895b87 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -173,12 +173,12 @@ def sklearn_metrics(prediction: EvalPrediction): num_training_steps=training_args.max_steps) elif scheduler_type.lower().strip() == 'constant': scheduler = get_constant_schedule_with_warmup(optimizer, - num_warmup_steps=training_args.warmup_steps) + num_warmup_steps=training_args.warmup_steps) elif scheduler_type.lower().strip() == 'cosine_hard_restart': scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps, - num_cycles=training_args.epochs) + num_cycles=training_args.num_train_epochs) elif scheduler_type.lower().strip() == 'linear': scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps, From 0cd4d1e33f7896a4070ce48ac6947762b88af664 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 21:15:14 +0100 Subject: [PATCH 243/300] Fixes schedule name bug --- grants_tagger_light/training/train.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 73895b87..f091feb6 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -164,9 +164,6 @@ def sklearn_metrics(prediction: EvalPrediction): if training_args.warmup_steps is None: training_args.warmup_steps = 0 - if scheduler_type is None or scheduler_type.lower().strip() == '': - scheduler_type = 'linear' - if scheduler_type.lower().strip() == 'cosine': scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps, @@ -184,13 +181,14 @@ def sklearn_metrics(prediction: EvalPrediction): num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps) else: - logger.warning(f"{scheduler_type} not recognized. Falling back to `linear`") + logger.warning(f"{scheduler_type}: not found or not valid. Falling back to `linear`") + scheduler_type = 'linear' scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps) logger.info(f"Optimizer: {optimizer}") - logger.info(f"Scheduler: {scheduler}") + logger.info(f"Scheduler: {scheduler_type}") training_args.optim = optimizer training_args.lr_scheduler_type = scheduler From 4ac7e098cf2ed07bdf7188ff0643f1818af09c55 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 21:17:58 +0100 Subject: [PATCH 244/300] Refactors and adds augment script --- examples/augment.sh | 1 + {scripts => examples}/preprocess_splitting_by_fract.sh | 0 {scripts => examples}/preprocess_splitting_by_rows.sh | 0 {scripts => examples}/preprocess_splitting_by_years.sh | 0 {scripts => examples}/resume_train_by_epoch.sh | 0 {scripts => examples}/resume_train_by_steps.sh | 0 {scripts => examples}/train_by_epochs.sh | 0 {scripts => examples}/train_by_steps.sh | 0 8 files changed, 1 insertion(+) create mode 100644 examples/augment.sh rename {scripts => examples}/preprocess_splitting_by_fract.sh (100%) rename {scripts => examples}/preprocess_splitting_by_rows.sh (100%) rename {scripts => examples}/preprocess_splitting_by_years.sh (100%) rename {scripts => examples}/resume_train_by_epoch.sh (100%) rename {scripts => examples}/resume_train_by_steps.sh (100%) rename {scripts => examples}/train_by_epochs.sh (100%) rename {scripts => examples}/train_by_steps.sh (100%) diff --git a/examples/augment.sh b/examples/augment.sh new file mode 100644 index 00000000..c8bc4482 --- /dev/null +++ b/examples/augment.sh @@ -0,0 +1 @@ +grants-tagger augment mesh data/raw/allMeSH_2021.jsonl more_data.jsonl --train-years 2017 \ No newline at end of file diff --git a/scripts/preprocess_splitting_by_fract.sh b/examples/preprocess_splitting_by_fract.sh similarity index 100% rename from scripts/preprocess_splitting_by_fract.sh rename to examples/preprocess_splitting_by_fract.sh diff --git a/scripts/preprocess_splitting_by_rows.sh b/examples/preprocess_splitting_by_rows.sh similarity index 100% rename from scripts/preprocess_splitting_by_rows.sh rename to examples/preprocess_splitting_by_rows.sh diff --git a/scripts/preprocess_splitting_by_years.sh b/examples/preprocess_splitting_by_years.sh similarity index 100% rename from scripts/preprocess_splitting_by_years.sh rename to examples/preprocess_splitting_by_years.sh diff --git a/scripts/resume_train_by_epoch.sh b/examples/resume_train_by_epoch.sh similarity index 100% rename from scripts/resume_train_by_epoch.sh rename to examples/resume_train_by_epoch.sh diff --git a/scripts/resume_train_by_steps.sh b/examples/resume_train_by_steps.sh similarity index 100% rename from scripts/resume_train_by_steps.sh rename to examples/resume_train_by_steps.sh diff --git a/scripts/train_by_epochs.sh b/examples/train_by_epochs.sh similarity index 100% rename from scripts/train_by_epochs.sh rename to examples/train_by_epochs.sh diff --git a/scripts/train_by_steps.sh b/examples/train_by_steps.sh similarity index 100% rename from scripts/train_by_steps.sh rename to examples/train_by_steps.sh From 7661575ae322184c29797cefa45cd3d6ce0a0636 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 21:22:09 +0100 Subject: [PATCH 245/300] Adds 25 concurrent calls by default --- examples/augment.sh | 4 +++- grants_tagger_light/augmentation/augment.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/augment.sh b/examples/augment.sh index c8bc4482..e1b16a70 100644 --- a/examples/augment.sh +++ b/examples/augment.sh @@ -1 +1,3 @@ -grants-tagger augment mesh data/raw/allMeSH_2021.jsonl more_data.jsonl --train-years 2017 \ No newline at end of file +grants-tagger augment mesh data/raw/allMeSH_2021.jsonl more_data.jsonl \ + --train-years 2017 \ + --concurrent-calls 25 \ No newline at end of file diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 3efd764f..dca91f80 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -45,7 +45,7 @@ def augment( test_years: list = None, min_examples: int = 15, prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', - concurrent_calls: int = 5, + concurrent_calls: int = 25, temperature: float = 1.5 ): if model_key.strip().lower() not in ['gpt-3.5-turbo', 'text-davinci', 'gpt-4']: From 4a1571d65f52d8ddb45b257ec0809d1a27fd79d9 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 21:25:46 +0100 Subject: [PATCH 246/300] Adds 25 concurrent calls by default --- grants_tagger_light/augmentation/augment.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index dca91f80..47def3a1 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -96,6 +96,7 @@ def augment( collect_concurrent_calls = [] for t in tags_to_augment: if len(collect_concurrent_calls) >= concurrent_calls: + logger.info(f"Sending {len(collect_concurrent_calls)} to LLM") AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key).generate( collect_concurrent_calls, dset, @@ -107,6 +108,8 @@ def augment( ) collect_concurrent_calls = [] else: + logger.info(f"Accumulating {t}") + if tags_to_augment_counts[t] < min_examples: missing = min_examples - tags_to_augment_counts[t] collect_concurrent_calls.append((t, missing)) From 6405184b7be60b6134248475584f073bb5bd1916 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 21:30:16 +0100 Subject: [PATCH 247/300] Adds 25 concurrent calls by default --- grants_tagger_light/augmentation/augment.py | 3 --- grants_tagger_light/augmentation/augment_openai.py | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 47def3a1..dca91f80 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -96,7 +96,6 @@ def augment( collect_concurrent_calls = [] for t in tags_to_augment: if len(collect_concurrent_calls) >= concurrent_calls: - logger.info(f"Sending {len(collect_concurrent_calls)} to LLM") AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key).generate( collect_concurrent_calls, dset, @@ -108,8 +107,6 @@ def augment( ) collect_concurrent_calls = [] else: - logger.info(f"Accumulating {t}") - if tags_to_augment_counts[t] < min_examples: missing = min_examples - tags_to_augment_counts[t] collect_concurrent_calls.append((t, missing)) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index ab1bfe46..13e2a8d0 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -78,6 +78,7 @@ def _make_requests(self, for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] + logger.info(f"Sending request for {t}") n = collect_concurrent_calls[num][1] logger.info(f"Augmenting {t} with {n} examples") # RAG: I select similar articles to provide them to the LLM @@ -113,6 +114,7 @@ def _make_requests(self, } self.api.request(data=data, metadata=metadata, callback=self._process_response) + logger.info(f"Waiting response for {t}") def generate(self, collect_concurrent_calls, dset, save_to_path, train_years, model_key, temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count()): From be3b7a145f9e1f7ec332714879eb0945567f3c32 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 19 Aug 2023 21:36:22 +0100 Subject: [PATCH 248/300] Adds 25 concurrent calls by default --- grants_tagger_light/augmentation/augment_openai.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 13e2a8d0..f3e9cc70 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -78,7 +78,6 @@ def _make_requests(self, for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] - logger.info(f"Sending request for {t}") n = collect_concurrent_calls[num][1] logger.info(f"Augmenting {t} with {n} examples") # RAG: I select similar articles to provide them to the LLM @@ -96,7 +95,7 @@ def _make_requests(self, tags = tmp_dset['meshMajor'][selected_row] data = { "model": self.model_key, - "n": n, + "n": 1, "temperature": temperature, "top_p": top_p, "presence_penalty": presence_penalty, @@ -114,7 +113,6 @@ def _make_requests(self, } self.api.request(data=data, metadata=metadata, callback=self._process_response) - logger.info(f"Waiting response for {t}") def generate(self, collect_concurrent_calls, dset, save_to_path, train_years, model_key, temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count()): From 006c864ef616a15c8c73910e37d140bd37d71ba2 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 21 Aug 2023 09:28:08 +0100 Subject: [PATCH 249/300] Adds more examples --- examples/augment.sh | 1 - examples/augment_specific_years.sh | 3 +++ examples/preprocess_and_train_by_epochs.sh | 31 ++++++++++++++++++++++ examples/preprocess_and_train_by_steps.sh | 30 +++++++++++++++++++++ examples/train_by_epochs.sh | 11 +++----- examples/train_by_steps.sh | 1 + 6 files changed, 68 insertions(+), 9 deletions(-) create mode 100644 examples/augment_specific_years.sh create mode 100644 examples/preprocess_and_train_by_epochs.sh create mode 100644 examples/preprocess_and_train_by_steps.sh diff --git a/examples/augment.sh b/examples/augment.sh index e1b16a70..caf3ecb3 100644 --- a/examples/augment.sh +++ b/examples/augment.sh @@ -1,3 +1,2 @@ grants-tagger augment mesh data/raw/allMeSH_2021.jsonl more_data.jsonl \ - --train-years 2017 \ --concurrent-calls 25 \ No newline at end of file diff --git a/examples/augment_specific_years.sh b/examples/augment_specific_years.sh new file mode 100644 index 00000000..e1b16a70 --- /dev/null +++ b/examples/augment_specific_years.sh @@ -0,0 +1,3 @@ +grants-tagger augment mesh data/raw/allMeSH_2021.jsonl more_data.jsonl \ + --train-years 2017 \ + --concurrent-calls 25 \ No newline at end of file diff --git a/examples/preprocess_and_train_by_epochs.sh b/examples/preprocess_and_train_by_epochs.sh new file mode 100644 index 00000000..e706d05f --- /dev/null +++ b/examples/preprocess_and_train_by_epochs.sh @@ -0,0 +1,31 @@ +# Run on g5.12xlargeinstance + +# Without preprocessing (on-the-fly) +SOURCE="data/raw/allMeSH_2021.jsonl" + +grants-tagger train bertmesh \ + "" \ + $SOURCE \ + --test-size 10000 \ + --output_dir bertmesh_outs/pipeline_test/ \ + --train-years 2016,2017,2018,2019 \ + --test-years 2020,2021 \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 1 \ + --multilabel_attention True \ + --freeze_backbone False \ + --num_train_epochs 5 \ + --learning_rate 5e-5 \ + --dropout 0.1 \ + --hidden_size 1024 \ + --warmup_steps 1000 \ + --max_grad_norm 5.0 \ + --scheduler-type cosine \ + --fp16 \ + --torch_compile \ + --evaluation_strategy epoch \ + --eval_accumulation_steps 20 \ + --save_strategy epoch \ + --wandb_project wellcome-mesh \ + --wandb_name test-train-all \ + --wandb_api_key ${WANDB_API_KEY} diff --git a/examples/preprocess_and_train_by_steps.sh b/examples/preprocess_and_train_by_steps.sh new file mode 100644 index 00000000..dd65a405 --- /dev/null +++ b/examples/preprocess_and_train_by_steps.sh @@ -0,0 +1,30 @@ +# Run on g5.12xlargeinstance + +# In that case, `test-size`, `train-years` and `test-years` will be taken from the preprocessed folder +SOURCE="output_folder_from_preprocessing" + +grants-tagger train bertmesh \ + "" \ + $SOURCE \ + --output_dir bertmesh_outs/pipeline_test/ \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 1 \ + --multilabel_attention True \ + --freeze_backbone False \ + --num_train_epochs 5 \ + --learning_rate 5e-5 \ + --dropout 0.1 \ + --hidden_size 1024 \ + --warmup_steps 1000 \ + --max_grad_norm 5.0 \ + --scheduler-type cosine \ + --fp16 \ + --torch_compile \ + --evaluation_strategy steps \ + --eval_steps 50000 \ + --eval_accumulation_steps 20 \ + --save_strategy steps \ + --save_steps 50000 \ + --wandb_project wellcome-mesh \ + --wandb_name test-train-all \ + --wandb_api_key ${WANDB_API_KEY} diff --git a/examples/train_by_epochs.sh b/examples/train_by_epochs.sh index 8d3102ad..c4131ae8 100644 --- a/examples/train_by_epochs.sh +++ b/examples/train_by_epochs.sh @@ -1,19 +1,13 @@ # Run on g5.12xlargeinstance -# Without preprocessing (on-the-fly) -SOURCE="data/raw/allMeSH_2021.jsonl" +# In this case, `test-size`, `train-years` and `test-years` will be taken from the preprocessed folder +SOURCE="output_folder_from_preprocessing" -# If you have already preprocessed the data, you will have a folder. Use the folder instead. -# SOURCE="output_folder_from_preprocessing" -# In that case, `test-size`, `train-years` and `test-years` will be taken from the preprocessed folder grants-tagger train bertmesh \ "" \ $SOURCE \ - --test-size 10000 \ --output_dir bertmesh_outs/pipeline_test/ \ - --train-years 2016,2017,2018,2019 \ - --test-years 2020,2021 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ @@ -24,6 +18,7 @@ grants-tagger train bertmesh \ --hidden_size 1024 \ --warmup_steps 1000 \ --max_grad_norm 5.0 \ + --scheduler-type cosine \ --fp16 \ --torch_compile \ --evaluation_strategy epoch \ diff --git a/examples/train_by_steps.sh b/examples/train_by_steps.sh index 3bdab802..5a7dd409 100644 --- a/examples/train_by_steps.sh +++ b/examples/train_by_steps.sh @@ -24,6 +24,7 @@ grants-tagger train bertmesh \ --hidden_size 1024 \ --warmup_steps 1000 \ --max_grad_norm 5.0 \ + --scheduler-type cosine \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ From c0fdc5a581ef248e0bc0706878152a64e1c1ea6e Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Tue, 22 Aug 2023 10:27:20 +0100 Subject: [PATCH 250/300] Adds scheduler type --- examples/resume_train_by_epoch.sh | 3 ++- examples/resume_train_by_steps.sh | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/resume_train_by_epoch.sh b/examples/resume_train_by_epoch.sh index a5eb9cc6..de00f044 100644 --- a/examples/resume_train_by_epoch.sh +++ b/examples/resume_train_by_epoch.sh @@ -13,7 +13,7 @@ grants-tagger train bertmesh \ bertmesh_outs/pipeline_test/$CHECKPOINT \ $SOURCE \ --output_dir bertmesh_outs/pipeline_test_from_$CHECKPOINT/ \ - --ignore_data_skip=True \ + --ignore_data_skip True \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ @@ -24,6 +24,7 @@ grants-tagger train bertmesh \ --hidden_size 1024 \ --warmup_steps 1000 \ --max_grad_norm 5.0 \ + --scheduler-type cosine \ --fp16 \ --torch_compile \ --evaluation_strategy epoch \ diff --git a/examples/resume_train_by_steps.sh b/examples/resume_train_by_steps.sh index 456fa7d1..bf26f17a 100644 --- a/examples/resume_train_by_steps.sh +++ b/examples/resume_train_by_steps.sh @@ -13,7 +13,7 @@ grants-tagger train bertmesh \ bertmesh_outs/pipeline_test/$CHECKPOINT \ $SOURCE \ --output_dir bertmesh_outs/pipeline_test_from_$CHECKPOINT/ \ - --ignore_data_skip=True \ + --ignore_data_skip True \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ @@ -24,6 +24,7 @@ grants-tagger train bertmesh \ --hidden_size 1024 \ --warmup_steps 1000 \ --max_grad_norm 5.0 \ + --scheduler-type cosine \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ From 21b0d6c98121b40be398c6f1f35bc997686c5be7 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Wed, 23 Aug 2023 19:20:28 +0100 Subject: [PATCH 251/300] Freezes everything except weights --- grants_tagger_light/models/bert_mesh/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index b05cb458..32aaec69 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -64,11 +64,11 @@ def freeze_backbone(self): def unfreeze_backbone(self): for name, param in self.bert.named_parameters(): - if 'bias' in name.lower(): + if 'weight' in name.lower(): + param.requires_grad = False + else: param.requires_grad = True logger.info(f"Unfreezing {name}") - else: - param.requires_grad = False def forward(self, input_ids, labels=None, **kwargs): if type(input_ids) is list: From 8b92dae8bcb2c33d4a791c004b4585361d66f2cf Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 10:37:28 +0100 Subject: [PATCH 252/300] Changes threshold, evaluation on tags, freezing backbone --- grants_tagger_light/models/bert_mesh/model.py | 10 +++-- .../training/cli_args/bertmesh_args.py | 2 +- grants_tagger_light/training/train.py | 39 +++++++++++++------ 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index 32aaec69..71f8b5b4 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -62,10 +62,14 @@ def freeze_backbone(self): for param in self.bert.parameters(): param.requires_grad = False - def unfreeze_backbone(self): + def unfreeze_backbone(self, only_bias=False): for name, param in self.bert.named_parameters(): - if 'weight' in name.lower(): - param.requires_grad = False + if only_bias: + if 'bias' in name.lower(): + logger.info(f"Unfreezing {name}") + param.requires_grad = True + else: + param.requires_grad = False else: param.requires_grad = True logger.info(f"Unfreezing {name}") diff --git a/grants_tagger_light/training/cli_args/bertmesh_args.py b/grants_tagger_light/training/cli_args/bertmesh_args.py index d1d50331..d94ecdef 100644 --- a/grants_tagger_light/training/cli_args/bertmesh_args.py +++ b/grants_tagger_light/training/cli_args/bertmesh_args.py @@ -9,4 +9,4 @@ class BertMeshModelArguments: hidden_size: int = field(default=512) dropout: float = field(default=0) multilabel_attention: bool = field(default=False) - freeze_backbone: bool = field(default=False) + freeze_backbone: str = field(default=None) # unfreeze, unfreeze_bias, freeze diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index f091feb6..4d4f57e8 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -49,7 +49,8 @@ def train_bertmesh( tags: list = None, train_years: list = None, test_years: list = None, - scheduler_type: str = 'cosine_hard_restart' + scheduler_type: str = 'cosine_hard_restart', + threshold: int = 0.25 ): if not model_key: assert isinstance(model_args, BertMeshModelArguments), ( @@ -81,6 +82,13 @@ def train_bertmesh( ) train_dset, val_dset = dset["train"], dset["test"] + + metric_labels = [] + for x in train_dset['meshMajor']: + metric_labels.extend(x) + + logger.info(f"For metric purposes, only considering labels present in `training`: {metric_labels[:15]}") + train_dset_size = len(train_dset) logger.info(f"Training dataset size: {train_dset_size}") if max_samples > 0: @@ -123,22 +131,26 @@ def train_bertmesh( logger.info(f"Training from pretrained key {model_key}") model = BertMesh.from_pretrained(model_key, trust_remote_code=True) - if model_args.freeze_backbone: - logger.info("Freezing backbone") - model.freeze_backbone() - else: - logger.info("Unfreezing backbone") - model.unfreeze_backbone() + if model_args.freeze_backbone is not None: + if model_args.freeze_backbone.lower().strip() == 'unfreeze': + logger.info("Unfreezing weights&biases in the backbone") + model.unfreeze_backbone() + elif model_args.freeze_backbone.lower().strip() == 'unfreeze_bias': + logger.info("Unfreezing only biases in the backbone") + model.unfreeze_backbone(only_bias=True) + elif model_args.freeze_backbone.lower().strip() == 'freeze': + logger.info("Freezing backbone") + model.freeze_backbone() def sklearn_metrics(prediction: EvalPrediction): y_pred = prediction.predictions - y_true = prediction.label_ids + y_true = [x for x in prediction.label_ids if x in metric_labels] # TODO make thresh configurable or return metrics # for multiple thresholds # e.g. 0.5:0.95:0.05 - y_pred = np.int64(y_pred > 0.5) + y_pred = np.int64(y_pred > threshold) report = classification_report(y_true, y_pred, output_dict=True) @@ -271,7 +283,11 @@ def train_bertmesh_cli( scheduler_type: str = typer.Option( 'cosine_hard_restart', help="One of the following lr schedulers: `cosine`, `linear`, `constant`, `cosine_hard_restart`" - ) + ), + threshold: int = typer.Option( + 0.25, + help="Threshold to considered a class as a positive" + ), ): parser = HfArgumentParser( ( @@ -303,5 +319,6 @@ def train_bertmesh_cli( tags=parse_tags(tags), train_years=parse_years(train_years), test_years=parse_years(test_years), - scheduler_type=scheduler_type + scheduler_type=scheduler_type, + threshold=threshold ) From 52f1303338c70ed4fa020329930470096a645389 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 10:39:54 +0100 Subject: [PATCH 253/300] Changes threshold, evaluation on tags, freezing backbone --- examples/resume_train_by_epoch.sh | 5 +++-- examples/resume_train_by_steps.sh | 5 +++-- examples/train_by_epochs.sh | 5 +++-- examples/train_by_steps.sh | 3 ++- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/resume_train_by_epoch.sh b/examples/resume_train_by_epoch.sh index de00f044..4c35d68c 100644 --- a/examples/resume_train_by_epoch.sh +++ b/examples/resume_train_by_epoch.sh @@ -14,10 +14,10 @@ grants-tagger train bertmesh \ $SOURCE \ --output_dir bertmesh_outs/pipeline_test_from_$CHECKPOINT/ \ --ignore_data_skip True \ - --per_device_train_batch_size 32 \ + --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ - --freeze_backbone False \ + --freeze_backbone unfreeze_bias \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ @@ -25,6 +25,7 @@ grants-tagger train bertmesh \ --warmup_steps 1000 \ --max_grad_norm 5.0 \ --scheduler-type cosine \ + --threshold 0.25 \ --fp16 \ --torch_compile \ --evaluation_strategy epoch \ diff --git a/examples/resume_train_by_steps.sh b/examples/resume_train_by_steps.sh index bf26f17a..6f43ecb5 100644 --- a/examples/resume_train_by_steps.sh +++ b/examples/resume_train_by_steps.sh @@ -14,10 +14,10 @@ grants-tagger train bertmesh \ $SOURCE \ --output_dir bertmesh_outs/pipeline_test_from_$CHECKPOINT/ \ --ignore_data_skip True \ - --per_device_train_batch_size 32 \ + --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ - --freeze_backbone False \ + --freeze_backbone unfreeze_bias \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ @@ -25,6 +25,7 @@ grants-tagger train bertmesh \ --warmup_steps 1000 \ --max_grad_norm 5.0 \ --scheduler-type cosine \ + --threshold 0.25 \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ diff --git a/examples/train_by_epochs.sh b/examples/train_by_epochs.sh index c4131ae8..9980de0f 100644 --- a/examples/train_by_epochs.sh +++ b/examples/train_by_epochs.sh @@ -8,10 +8,10 @@ grants-tagger train bertmesh \ "" \ $SOURCE \ --output_dir bertmesh_outs/pipeline_test/ \ - --per_device_train_batch_size 32 \ + --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ - --freeze_backbone False \ + --freeze_backbone unfreeze_bias \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ @@ -19,6 +19,7 @@ grants-tagger train bertmesh \ --warmup_steps 1000 \ --max_grad_norm 5.0 \ --scheduler-type cosine \ + --threshold 0.25 \ --fp16 \ --torch_compile \ --evaluation_strategy epoch \ diff --git a/examples/train_by_steps.sh b/examples/train_by_steps.sh index 5a7dd409..be382476 100644 --- a/examples/train_by_steps.sh +++ b/examples/train_by_steps.sh @@ -14,7 +14,7 @@ grants-tagger train bertmesh \ --output_dir bertmesh_outs/pipeline_test/ \ --train-years 2016,2017,2018,2019 \ --test-years 2020,2021 \ - --per_device_train_batch_size 32 \ + --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ --freeze_backbone False \ @@ -25,6 +25,7 @@ grants-tagger train bertmesh \ --warmup_steps 1000 \ --max_grad_norm 5.0 \ --scheduler-type cosine \ + --threshold 0.25 \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ From 6ab48765c3eff609d59475ab762fe605ab775259 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 10:41:53 +0100 Subject: [PATCH 254/300] Changes threshold, evaluation on tags, freezing backbone --- grants_tagger_light/training/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 4d4f57e8..dddbebce 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -50,7 +50,7 @@ def train_bertmesh( train_years: list = None, test_years: list = None, scheduler_type: str = 'cosine_hard_restart', - threshold: int = 0.25 + threshold: float = 0.25 ): if not model_key: assert isinstance(model_args, BertMeshModelArguments), ( @@ -284,9 +284,9 @@ def train_bertmesh_cli( 'cosine_hard_restart', help="One of the following lr schedulers: `cosine`, `linear`, `constant`, `cosine_hard_restart`" ), - threshold: int = typer.Option( + threshold: float = typer.Option( 0.25, - help="Threshold to considered a class as a positive" + help="Threshold (0, 1) to considered a class as a positive" ), ): parser = HfArgumentParser( From 54a754f9a7c1eab45e0751947272ca813fc497e7 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 10:42:56 +0100 Subject: [PATCH 255/300] Changes threshold, evaluation on tags, freezing backbone --- grants_tagger_light/training/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index dddbebce..f1a6a83f 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -84,7 +84,7 @@ def train_bertmesh( train_dset, val_dset = dset["train"], dset["test"] metric_labels = [] - for x in train_dset['meshMajor']: + for x in train_dset['label_ids']: metric_labels.extend(x) logger.info(f"For metric purposes, only considering labels present in `training`: {metric_labels[:15]}") From 1fdeada0c74d24148468f8ce48bbc168050a43fe Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 11:32:20 +0100 Subject: [PATCH 256/300] Changes threshold, evaluation on tags, freezing backbone --- grants_tagger_light/training/train.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index f1a6a83f..86aef6ee 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -143,14 +143,20 @@ def train_bertmesh( model.freeze_backbone() def sklearn_metrics(prediction: EvalPrediction): + # This is a batch, so it's an array (rows) of array (labels) + # Array of arrays with probas [[5.4e-5 1.3e-3...] [5.4e-5 1.3e-3...] ... ] y_pred = prediction.predictions - y_true = [x for x in prediction.label_ids if x in metric_labels] + # Transformed to 0-1 if bigger than threshold [[0 1 0...] [0 0 1...] ... ] + y_pred = np.int64(y_pred > threshold) - # TODO make thresh configurable or return metrics - # for multiple thresholds - # e.g. 0.5:0.95:0.05 + # Array of arrays with 0/1 [[0 0 1 ...] [0 1 0 ...] ... ] + y_true = prediction.label_ids - y_pred = np.int64(y_pred > threshold) + # I will remove those tags which where not in training dataset, and only in test + for label_id in metric_labels: + for i in range(y_pred): + y_pred[i].pop(label_id) # removing prediction for row `i` for label `label_id` + y_true[i].pop(label_id) # removing expected for row `i` for label `label_id` report = classification_report(y_true, y_pred, output_dict=True) From 608a9a7b9b220d21833e799c0f32033e207bd2b1 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 11:43:23 +0100 Subject: [PATCH 257/300] Changes threshold, evaluation on tags, freezing backbone --- grants_tagger_light/training/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 86aef6ee..666a3dd6 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -154,7 +154,7 @@ def sklearn_metrics(prediction: EvalPrediction): # I will remove those tags which where not in training dataset, and only in test for label_id in metric_labels: - for i in range(y_pred): + for i in range(len(y_pred)): y_pred[i].pop(label_id) # removing prediction for row `i` for label `label_id` y_true[i].pop(label_id) # removing expected for row `i` for label `label_id` From 063caa4f7625ff24efff22afa9f0f390c2cc6e33 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 11:56:07 +0100 Subject: [PATCH 258/300] Changes threshold, evaluation on tags, freezing backbone --- grants_tagger_light/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 666a3dd6..047fd88d 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -155,8 +155,8 @@ def sklearn_metrics(prediction: EvalPrediction): # I will remove those tags which where not in training dataset, and only in test for label_id in metric_labels: for i in range(len(y_pred)): - y_pred[i].pop(label_id) # removing prediction for row `i` for label `label_id` - y_true[i].pop(label_id) # removing expected for row `i` for label `label_id` + y_pred[i] = np.delete(y_pred[i], label_id, 0) # removing prediction for row `i` for label `label_id` + y_true[i] = np.delete(y_true[i], label_id, 0) # removing expected for row `i` for label `label_id` report = classification_report(y_true, y_pred, output_dict=True) From 9c518d738a09053a0264d3130c0e05dbbd666605 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 12:08:28 +0100 Subject: [PATCH 259/300] Changes threshold, evaluation on tags, freezing backbone --- grants_tagger_light/training/train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 047fd88d..f74b41dc 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -152,13 +152,15 @@ def sklearn_metrics(prediction: EvalPrediction): # Array of arrays with 0/1 [[0 0 1 ...] [0 1 0 ...] ... ] y_true = prediction.label_ids + filtered_y_pred = [] + filtered_y_true = [] # I will remove those tags which where not in training dataset, and only in test for label_id in metric_labels: for i in range(len(y_pred)): - y_pred[i] = np.delete(y_pred[i], label_id, 0) # removing prediction for row `i` for label `label_id` - y_true[i] = np.delete(y_true[i], label_id, 0) # removing expected for row `i` for label `label_id` + filtered_y_pred.append(np.delete(y_pred[i], label_id, 0)) # removing prediction for row `i` for label `label_id` + filtered_y_true.append(np.delete(y_true[i], label_id, 0)) # removing expected for row `i` for label `label_id` - report = classification_report(y_true, y_pred, output_dict=True) + report = classification_report(filtered_y_pred, filtered_y_true, output_dict=True) metric_dict = { "micro_avg": report["micro avg"], From 25c1a9b094431e2f52ccc4a1355dc810a27e9dd4 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 12:41:33 +0100 Subject: [PATCH 260/300] Changes threshold, evaluation on tags, freezing backbone --- grants_tagger_light/training/train.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index f74b41dc..20c3853f 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -155,10 +155,26 @@ def sklearn_metrics(prediction: EvalPrediction): filtered_y_pred = [] filtered_y_true = [] # I will remove those tags which where not in training dataset, and only in test - for label_id in metric_labels: - for i in range(len(y_pred)): - filtered_y_pred.append(np.delete(y_pred[i], label_id, 0)) # removing prediction for row `i` for label `label_id` - filtered_y_true.append(np.delete(y_true[i], label_id, 0)) # removing expected for row `i` for label `label_id` + + predictions_batch_size = len(y_pred) + true_batch_size = len(y_true) + assert predictions_batch_size == true_batch_size + batch_size = predictions_batch_size + + num_predictions_per_row = len(y_pred[0]) + num_true_per_row = len(y_true[0]) + assert num_predictions_per_row == num_true_per_row + total_labels = num_predictions_per_row + + for row_num in range(batch_size): + yp = [] + yt = [] + for label_num in range(total_labels): + if label_num in metric_labels: + yp.append(y_pred[row_num][label_num]) + yt.append(y_true[row_num][label_num]) + filtered_y_pred.append(yp) + filtered_y_true.append(yt) report = classification_report(filtered_y_pred, filtered_y_true, output_dict=True) From a0bc3edc4d288dc45d1bba288c1fe8551de4aa7c Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 15:52:27 +0100 Subject: [PATCH 261/300] Adds filtering by tags --- grants_tagger_light/augmentation/augment.py | 22 ++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index dca91f80..61e8ba9b 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -45,8 +45,9 @@ def augment( test_years: list = None, min_examples: int = 15, prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', - concurrent_calls: int = 25, - temperature: float = 1.5 + concurrent_calls: int = os.cpu_count()*2, + temperature: float = 1.5, + tags_file_path: str = None, ): if model_key.strip().lower() not in ['gpt-3.5-turbo', 'text-davinci', 'gpt-4']: raise NotImplementedError(f"{model_key} not implemented as an augmentation framework") @@ -71,6 +72,12 @@ def augment( merged_element_counts = _merge_dicts(element_counts_list) sorted_merged_element_counts = sorted(merged_element_counts.items(), key=lambda x: x[1], reverse=True) sorted_merged_element_counts_dict = dict(sorted_merged_element_counts) + if tags_file_path is not None: + with open(tags_file_path, 'r') as f: + tags = f.read().split('\n') + logger.info(f"Tags file path found. Filtering tags (examples found: {tags[:15]}...)") + sorted_merged_element_counts_dict = {k: v for k, v in sorted_merged_element_counts_dict.items() + if k in tags} with open(f"{save_to_path}.count", 'w') as f: f.write(json.dumps(sorted_merged_element_counts_dict, indent=2)) @@ -142,20 +149,24 @@ def augment_cli( ), min_examples: int = typer.Option( 15, - help="If set, Comma-separated years you want to exclude in the data augmentation process" + help="Minimum number of examples to require. Less than that will trigger data augmentation." ), prompt_template: str = typer.Option( 'grants_tagger_light/augmentation/prompt.template', help="File to use as a prompt. Make sure to ask the LLM to return a dict with two fields: `abstract` and `tags`" ), concurrent_calls: int = typer.Option( - 25, + os.cpu_count()*2, help="Concurrent calls with 1 tag each to the different model" ), temperature: float = typer.Option( 1.5, help="A value between -2 and 2. The bigger - the more creative." ), + tags_file_path: str = typer.Option( + None, + help="Text file containing one line per tag to be considered. The rest will be discarded." + ) ): if not data_path.endswith("jsonl"): logger.error( @@ -180,5 +191,6 @@ def augment_cli( min_examples=min_examples, prompt_template=prompt_template, concurrent_calls=concurrent_calls, - temperature=temperature + temperature=temperature, + tags_file_path=tags_file_path ) From e28b683473956ba5d3f8273272b5bfc77c0985df Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 16:40:18 +0100 Subject: [PATCH 262/300] Adds edge case for the remaining X end texts to augment --- grants_tagger_light/augmentation/augment.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 61e8ba9b..fc59b474 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -118,6 +118,17 @@ def augment( missing = min_examples - tags_to_augment_counts[t] collect_concurrent_calls.append((t, missing)) + if len(collect_concurrent_calls) > 0: + AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key).generate( + collect_concurrent_calls, + dset, + save_to_path, + train_years, + model_key, + temperature=temperature, + num_proc=num_proc, + ) + @augment_app.command() def augment_cli( From 93d5ad5194d51cb7a5c93c6e632070922097437f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 18:37:40 +0100 Subject: [PATCH 263/300] Removes metrics from tags absent in training --- grants_tagger_light/training/train.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 20c3853f..60b54b7f 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -152,29 +152,10 @@ def sklearn_metrics(prediction: EvalPrediction): # Array of arrays with 0/1 [[0 0 1 ...] [0 1 0 ...] ... ] y_true = prediction.label_ids - filtered_y_pred = [] - filtered_y_true = [] - # I will remove those tags which where not in training dataset, and only in test - - predictions_batch_size = len(y_pred) - true_batch_size = len(y_true) - assert predictions_batch_size == true_batch_size - batch_size = predictions_batch_size - - num_predictions_per_row = len(y_pred[0]) - num_true_per_row = len(y_true[0]) - assert num_predictions_per_row == num_true_per_row - total_labels = num_predictions_per_row - - for row_num in range(batch_size): - yp = [] - yt = [] - for label_num in range(total_labels): - if label_num in metric_labels: - yp.append(y_pred[row_num][label_num]) - yt.append(y_true[row_num][label_num]) - filtered_y_pred.append(yp) - filtered_y_true.append(yt) + mask = np.ones(y_pred.shape, dtype=bool) + mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = False + filtered_y_pred = y_pred[mask].reshape(y_pred.shape[0], -1) + filtered_y_true = y_true[mask].reshape(y_true.shape[0], -1) report = classification_report(filtered_y_pred, filtered_y_true, output_dict=True) From 1d4cb0fc64067b1b33a06dd0a3845cf536dfa0eb Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 24 Aug 2023 23:34:40 +0100 Subject: [PATCH 264/300] Removes metrics from tags absent in training --- grants_tagger_light/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 60b54b7f..3bcaf521 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -152,8 +152,8 @@ def sklearn_metrics(prediction: EvalPrediction): # Array of arrays with 0/1 [[0 0 1 ...] [0 1 0 ...] ... ] y_true = prediction.label_ids - mask = np.ones(y_pred.shape, dtype=bool) - mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = False + mask = np.zeros(y_pred.shape, dtype=bool) + mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True filtered_y_pred = y_pred[mask].reshape(y_pred.shape[0], -1) filtered_y_true = y_true[mask].reshape(y_true.shape[0], -1) From 644d651bd0186f23d4c50ba7964b804066267509 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 25 Aug 2023 10:17:36 +0100 Subject: [PATCH 265/300] Rolls back filtering tags in metrics --- grants_tagger_light/training/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 3bcaf521..6275ee70 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -152,12 +152,13 @@ def sklearn_metrics(prediction: EvalPrediction): # Array of arrays with 0/1 [[0 0 1 ...] [0 1 0 ...] ... ] y_true = prediction.label_ids - mask = np.zeros(y_pred.shape, dtype=bool) - mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True - filtered_y_pred = y_pred[mask].reshape(y_pred.shape[0], -1) - filtered_y_true = y_true[mask].reshape(y_true.shape[0], -1) + # mask = np.zeros(y_pred.shape, dtype=bool) + # mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True + # filtered_y_pred = y_pred[mask].reshape(y_pred.shape[0], -1) + # filtered_y_true = y_true[mask].reshape(y_true.shape[0], -1) + # report = classification_report(filtered_y_pred, filtered_y_true, output_dict=True) - report = classification_report(filtered_y_pred, filtered_y_true, output_dict=True) + report = classification_report(y_pred, y_true, output_dict=True) metric_dict = { "micro_avg": report["micro avg"], From ad40350191562e79a6218eed207014f037c0a948 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 25 Aug 2023 13:05:46 +0100 Subject: [PATCH 266/300] Adds back tag filtering --- grants_tagger_light/training/train.py | 38 ++++++++++++++------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 6275ee70..e1381b5e 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -87,8 +87,6 @@ def train_bertmesh( for x in train_dset['label_ids']: metric_labels.extend(x) - logger.info(f"For metric purposes, only considering labels present in `training`: {metric_labels[:15]}") - train_dset_size = len(train_dset) logger.info(f"Training dataset size: {train_dset_size}") if max_samples > 0: @@ -131,18 +129,22 @@ def train_bertmesh( logger.info(f"Training from pretrained key {model_key}") model = BertMesh.from_pretrained(model_key, trust_remote_code=True) - if model_args.freeze_backbone is not None: - if model_args.freeze_backbone.lower().strip() == 'unfreeze': - logger.info("Unfreezing weights&biases in the backbone") - model.unfreeze_backbone() - elif model_args.freeze_backbone.lower().strip() == 'unfreeze_bias': - logger.info("Unfreezing only biases in the backbone") - model.unfreeze_backbone(only_bias=True) - elif model_args.freeze_backbone.lower().strip() == 'freeze': - logger.info("Freezing backbone") - model.freeze_backbone() + if model_args.freeze_backbone is None: + model_args.freeze_backbone = 'freeze' + + if model_args.freeze_backbone.lower().strip() == 'unfreeze': + logger.info("Unfreezing weights&biases in the backbone") + model.unfreeze_backbone() + elif model_args.freeze_backbone.lower().strip() == 'unfreeze_bias': + logger.info("Unfreezing only biases in the backbone") + model.unfreeze_backbone(only_bias=True) + elif model_args.freeze_backbone.lower().strip() == 'freeze': + logger.info("Freezing backbone") + model.freeze_backbone() def sklearn_metrics(prediction: EvalPrediction): + logger.info(f"Threshold: {threshold}") + logger.info(f"For metric purposes, only considering labels present in `training`: {metric_labels[:15]}") # This is a batch, so it's an array (rows) of array (labels) # Array of arrays with probas [[5.4e-5 1.3e-3...] [5.4e-5 1.3e-3...] ... ] y_pred = prediction.predictions @@ -152,13 +154,13 @@ def sklearn_metrics(prediction: EvalPrediction): # Array of arrays with 0/1 [[0 0 1 ...] [0 1 0 ...] ... ] y_true = prediction.label_ids - # mask = np.zeros(y_pred.shape, dtype=bool) - # mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True - # filtered_y_pred = y_pred[mask].reshape(y_pred.shape[0], -1) - # filtered_y_true = y_true[mask].reshape(y_true.shape[0], -1) - # report = classification_report(filtered_y_pred, filtered_y_true, output_dict=True) + # report = classification_report(y_pred, y_true, output_dict=True) - report = classification_report(y_pred, y_true, output_dict=True) + mask = np.zeros(y_pred.shape, dtype=bool) + mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True + filtered_y_pred = y_pred[mask].reshape(y_pred.shape[0], -1) + filtered_y_true = y_true[mask].reshape(y_true.shape[0], -1) + report = classification_report(filtered_y_pred, filtered_y_true, output_dict=True) metric_dict = { "micro_avg": report["micro avg"], From 6cf244b5303777cd5ed43b30cd9b5bad5391db34 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 25 Aug 2023 17:03:50 +0100 Subject: [PATCH 267/300] Adds weight_decay, correct_bias, dropout probs, attention dropout... --- .../training/cli_args/bertmesh_args.py | 2 + grants_tagger_light/training/train.py | 48 +++++++++++++++---- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/grants_tagger_light/training/cli_args/bertmesh_args.py b/grants_tagger_light/training/cli_args/bertmesh_args.py index d94ecdef..f96426ac 100644 --- a/grants_tagger_light/training/cli_args/bertmesh_args.py +++ b/grants_tagger_light/training/cli_args/bertmesh_args.py @@ -10,3 +10,5 @@ class BertMeshModelArguments: dropout: float = field(default=0) multilabel_attention: bool = field(default=False) freeze_backbone: str = field(default=None) # unfreeze, unfreeze_bias, freeze + hidden_dropout_prob: float = field(default=0.1) + attention_probs_dropout_prob: float = field(default=0.1) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index e1381b5e..3e1f240e 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -50,7 +50,10 @@ def train_bertmesh( train_years: list = None, test_years: list = None, scheduler_type: str = 'cosine_hard_restart', - threshold: float = 0.25 + threshold: float = 0.25, + weight_decay: float =0.1, + correct_bias: bool= True, + prune_labels_in_evaluation: bool =True ): if not model_key: assert isinstance(model_args, BertMeshModelArguments), ( @@ -116,12 +119,16 @@ def train_bertmesh( "label2id": label2id, "id2label": {v: k for k, v in label2id.items()}, "freeze_backbone": model_args.freeze_backbone, + "hidden_dropout_prob": model_args.hidden_dropout_prob, + "attention_probs_dropout_prob": model_args.attention_probs_dropout_prob, }) logger.info(f"Hidden size: {config.hidden_size}") logger.info(f"Dropout: {config.dropout}") logger.info(f"Multilabel Attention: {config.multilabel_attention}") logger.info(f"Freeze Backbone: {config.freeze_backbone}") logger.info(f"Num labels: {config.num_labels}") + logger.info(f"hidden_dropout_prob: {config.hidden_dropout_prob}") + logger.info(f"attention_probs_dropout_prob: {config.attention_probs_dropout_prob}") model = BertMesh(config) @@ -144,7 +151,6 @@ def train_bertmesh( def sklearn_metrics(prediction: EvalPrediction): logger.info(f"Threshold: {threshold}") - logger.info(f"For metric purposes, only considering labels present in `training`: {metric_labels[:15]}") # This is a batch, so it's an array (rows) of array (labels) # Array of arrays with probas [[5.4e-5 1.3e-3...] [5.4e-5 1.3e-3...] ... ] y_pred = prediction.predictions @@ -156,10 +162,17 @@ def sklearn_metrics(prediction: EvalPrediction): # report = classification_report(y_pred, y_true, output_dict=True) - mask = np.zeros(y_pred.shape, dtype=bool) - mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True - filtered_y_pred = y_pred[mask].reshape(y_pred.shape[0], -1) - filtered_y_true = y_true[mask].reshape(y_true.shape[0], -1) + if prune_labels_in_evaluation: + logger.info(f"For metric purposes, only considering labels present in `training`: {metric_labels[:15]}") + mask = np.zeros(y_pred.shape, dtype=bool) + mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True + + filtered_y_pred = y_pred[mask].reshape(y_pred.shape[0], -1) + filtered_y_true = y_true[mask].reshape(y_true.shape[0], -1) + else: + filtered_y_pred = y_pred + filtered_y_true = y_true + report = classification_report(filtered_y_pred, filtered_y_true, output_dict=True) metric_dict = { @@ -179,7 +192,11 @@ def sklearn_metrics(prediction: EvalPrediction): max_steps = Sharding.calculate_max_steps(training_args, train_dset_size) training_args.max_steps = max_steps - optimizer = AdamW(model.parameters(), lr=training_args.learning_rate) + optimizer = AdamW( + model.parameters(), + lr=training_args.learning_rate, + weight_decay=weight_decay, + correct_bias=correct_bias) if training_args.warmup_steps is None: training_args.warmup_steps = 0 @@ -296,6 +313,18 @@ def train_bertmesh_cli( 0.25, help="Threshold (0, 1) to considered a class as a positive" ), + weight_decay: float = typer.Option( + 0.1, + help="Optimizer weight decay. Default: 0.1" + ), + correct_bias: bool = typer.Option( + True, + help="Optimizer bias correction. Default: True" + ), + prune_labels_in_evaluation: bool = typer.Option( + True, + help="Remove before evaluation all the labels not present in training data. Default: True" + ), ): parser = HfArgumentParser( ( @@ -328,5 +357,8 @@ def train_bertmesh_cli( train_years=parse_years(train_years), test_years=parse_years(test_years), scheduler_type=scheduler_type, - threshold=threshold + threshold=threshold, + weight_decay=weight_decay, + correct_bias=correct_bias, + prune_labels_in_evaluation=prune_labels_in_evaluation ) From f7845666d267304f48c6ceb261e7512165f70693 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 25 Aug 2023 17:20:38 +0100 Subject: [PATCH 268/300] Adds weight_decay, correct_bias, dropout probs, attention dropout... --- .../training/cli_args/train_args.py | 6 ++ grants_tagger_light/training/train.py | 59 +++++-------------- 2 files changed, 20 insertions(+), 45 deletions(-) diff --git a/grants_tagger_light/training/cli_args/train_args.py b/grants_tagger_light/training/cli_args/train_args.py index 61820322..2a8467f1 100644 --- a/grants_tagger_light/training/cli_args/train_args.py +++ b/grants_tagger_light/training/cli_args/train_args.py @@ -62,6 +62,12 @@ class BertMeshTrainingArguments(TrainingArguments): # default="default" # ) # default | reduce-overhead | max-autotune + correct_bias: bool = field(default=True) + weight_decay: float = field(default=0.1) + prune_labels_in_evaluation: bool = field(default=False) + threshold: float = field(default=0.5) + scheduler_type: str = field(default="cosine") + def __post_init__(self): super().__post_init__() if "fused" in self.optim and not torch.cuda.is_available(): diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 3e1f240e..a6437415 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -48,12 +48,7 @@ def train_bertmesh( from_checkpoint: str = None, tags: list = None, train_years: list = None, - test_years: list = None, - scheduler_type: str = 'cosine_hard_restart', - threshold: float = 0.25, - weight_decay: float =0.1, - correct_bias: bool= True, - prune_labels_in_evaluation: bool =True + test_years: list = None ): if not model_key: assert isinstance(model_args, BertMeshModelArguments), ( @@ -150,19 +145,19 @@ def train_bertmesh( model.freeze_backbone() def sklearn_metrics(prediction: EvalPrediction): - logger.info(f"Threshold: {threshold}") + logger.info(f"Threshold: {training_args.threshold}") # This is a batch, so it's an array (rows) of array (labels) # Array of arrays with probas [[5.4e-5 1.3e-3...] [5.4e-5 1.3e-3...] ... ] y_pred = prediction.predictions # Transformed to 0-1 if bigger than threshold [[0 1 0...] [0 0 1...] ... ] - y_pred = np.int64(y_pred > threshold) + y_pred = np.int64(y_pred > training_args.threshold) # Array of arrays with 0/1 [[0 0 1 ...] [0 1 0 ...] ... ] y_true = prediction.label_ids # report = classification_report(y_pred, y_true, output_dict=True) - if prune_labels_in_evaluation: + if training_args.prune_labels_in_evaluation: logger.info(f"For metric purposes, only considering labels present in `training`: {metric_labels[:15]}") mask = np.zeros(y_pred.shape, dtype=bool) mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True @@ -195,37 +190,36 @@ def sklearn_metrics(prediction: EvalPrediction): optimizer = AdamW( model.parameters(), lr=training_args.learning_rate, - weight_decay=weight_decay, - correct_bias=correct_bias) + weight_decay=training_args.weight_decay, + correct_bias=training_args.correct_bias) if training_args.warmup_steps is None: training_args.warmup_steps = 0 - if scheduler_type.lower().strip() == 'cosine': + if training_args.scheduler_type.lower().strip() == 'cosine': scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps) - elif scheduler_type.lower().strip() == 'constant': + elif training_args.scheduler_type.lower().strip() == 'constant': scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps) - elif scheduler_type.lower().strip() == 'cosine_hard_restart': + elif training_args.scheduler_type.lower().strip() == 'cosine_hard_restart': scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps, num_cycles=training_args.num_train_epochs) - elif scheduler_type.lower().strip() == 'linear': + elif training_args.scheduler_type.lower().strip() == 'linear': scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps) else: - logger.warning(f"{scheduler_type}: not found or not valid. Falling back to `linear`") - scheduler_type = 'linear' + logger.warning(f"{training_args.scheduler_type}: not found or not valid. Falling back to `linear`") scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps) logger.info(f"Optimizer: {optimizer}") - logger.info(f"Scheduler: {scheduler_type}") + logger.info(f"Scheduler: {training_args.scheduler_type}") training_args.optim = optimizer training_args.lr_scheduler_type = scheduler @@ -304,27 +298,7 @@ def train_bertmesh_cli( test_years: str = typer.Option( None, help="Comma-separated years you want to include in the test dataset" - ), - scheduler_type: str = typer.Option( - 'cosine_hard_restart', - help="One of the following lr schedulers: `cosine`, `linear`, `constant`, `cosine_hard_restart`" - ), - threshold: float = typer.Option( - 0.25, - help="Threshold (0, 1) to considered a class as a positive" - ), - weight_decay: float = typer.Option( - 0.1, - help="Optimizer weight decay. Default: 0.1" - ), - correct_bias: bool = typer.Option( - True, - help="Optimizer bias correction. Default: True" - ), - prune_labels_in_evaluation: bool = typer.Option( - True, - help="Remove before evaluation all the labels not present in training data. Default: True" - ), + ) ): parser = HfArgumentParser( ( @@ -355,10 +329,5 @@ def train_bertmesh_cli( from_checkpoint=from_checkpoint, tags=parse_tags(tags), train_years=parse_years(train_years), - test_years=parse_years(test_years), - scheduler_type=scheduler_type, - threshold=threshold, - weight_decay=weight_decay, - correct_bias=correct_bias, - prune_labels_in_evaluation=prune_labels_in_evaluation + test_years=parse_years(test_years) ) From 7da5ac43fc5a03d591c36dbba438871b7e1d57d4 Mon Sep 17 00:00:00 2001 From: Juan Martinez Date: Fri, 25 Aug 2023 19:30:14 +0000 Subject: [PATCH 269/300] Adds best params --- examples/train_by_steps.sh | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/train_by_steps.sh b/examples/train_by_steps.sh index be382476..34e1c7a2 100644 --- a/examples/train_by_steps.sh +++ b/examples/train_by_steps.sh @@ -1,38 +1,37 @@ # Run on g5.12xlargeinstance -# Without preprocessing (on-the-fly) -SOURCE="data/raw/allMeSH_2021.jsonl" - # If you have already preprocessed the data, you will have a folder. Use the folder instead. -# SOURCE="output_folder_from_preprocessing" +SOURCE="output_folder_from_preprocessing" # In that case, `test-size`, `train-years` and `test-years` will be taken from the preprocessed folder grants-tagger train bertmesh \ "" \ $SOURCE \ - --test-size 10000 \ --output_dir bertmesh_outs/pipeline_test/ \ - --train-years 2016,2017,2018,2019 \ - --test-years 2020,2021 \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ - --freeze_backbone False \ + --freeze_backbone unfreeze_bias \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ - --warmup_steps 1000 \ - --max_grad_norm 5.0 \ - --scheduler-type cosine \ + --warmup_steps 5000 \ + --max_grad_norm 1.0 \ + --scheduler_type cosine \ + --weight_decay 0.2 \ + --correct_bias False \ --threshold 0.25 \ + --prune_labels_in_evaluation True \ + --hidden_dropout_prob 0.2 \ + --attention_probs_dropout_prob 0.2 \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ - --eval_steps 50000 \ + --eval_steps 10000 \ --eval_accumulation_steps 20 \ --save_strategy steps \ - --save_steps 50000 \ + --save_steps 10000 \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ --wandb_api_key ${WANDB_API_KEY} From fbdecd61506b75fa071f554f40d379df0e315e6d Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 26 Aug 2023 19:40:23 +0100 Subject: [PATCH 270/300] Updates resume_train_by_steps.sh --- examples/resume_train_by_steps.sh | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/resume_train_by_steps.sh b/examples/resume_train_by_steps.sh index 6f43ecb5..fc21daa8 100644 --- a/examples/resume_train_by_steps.sh +++ b/examples/resume_train_by_steps.sh @@ -12,8 +12,7 @@ CHECKPOINT="checkpoint-100000" grants-tagger train bertmesh \ bertmesh_outs/pipeline_test/$CHECKPOINT \ $SOURCE \ - --output_dir bertmesh_outs/pipeline_test_from_$CHECKPOINT/ \ - --ignore_data_skip True \ + --output_dir bertmesh_outs/pipeline_test/ \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ @@ -22,17 +21,22 @@ grants-tagger train bertmesh \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ - --warmup_steps 1000 \ - --max_grad_norm 5.0 \ - --scheduler-type cosine \ + --warmup_steps 5000 \ + --max_grad_norm 1.0 \ + --scheduler_type cosine \ + --weight_decay 0.2 \ + --correct_bias False \ --threshold 0.25 \ + --prune_labels_in_evaluation True \ + --hidden_dropout_prob 0.2 \ + --attention_probs_dropout_prob 0.2 \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ - --eval_steps 50000 \ + --eval_steps 10000 \ --eval_accumulation_steps 20 \ --save_strategy steps \ - --save_steps 50000 \ + --save_steps 10000 \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ --wandb_api_key ${WANDB_API_KEY} \ No newline at end of file From 5758b423f76c61d529881f700a3917f633389bc8 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Sat, 26 Aug 2023 19:40:55 +0100 Subject: [PATCH 271/300] Updates resume_train_by_steps.sh --- examples/resume_train_by_steps.sh | 2 +- examples/train_by_steps.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/resume_train_by_steps.sh b/examples/resume_train_by_steps.sh index fc21daa8..39c6baa5 100644 --- a/examples/resume_train_by_steps.sh +++ b/examples/resume_train_by_steps.sh @@ -23,7 +23,7 @@ grants-tagger train bertmesh \ --hidden_size 1024 \ --warmup_steps 5000 \ --max_grad_norm 1.0 \ - --scheduler_type cosine \ + --scheduler_type cosine_hard_restart \ --weight_decay 0.2 \ --correct_bias False \ --threshold 0.25 \ diff --git a/examples/train_by_steps.sh b/examples/train_by_steps.sh index 34e1c7a2..f4c664b9 100644 --- a/examples/train_by_steps.sh +++ b/examples/train_by_steps.sh @@ -18,7 +18,7 @@ grants-tagger train bertmesh \ --hidden_size 1024 \ --warmup_steps 5000 \ --max_grad_norm 1.0 \ - --scheduler_type cosine \ + --scheduler_type cosine_hard_restart \ --weight_decay 0.2 \ --correct_bias False \ --threshold 0.25 \ From 9582a6a70e09ebd9a6fa2896e06f5dc7d1f84e5c Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 11:17:52 +0100 Subject: [PATCH 272/300] Adds tags-based augmentation --- examples/augment.sh | 2 +- examples/augment_specific_tags.sh | 5 +++ examples/augment_specific_years.sh | 3 -- grants_tagger_light/augmentation/augment.py | 35 +++++-------------- .../augmentation/augment_openai.py | 8 ++--- 5 files changed, 16 insertions(+), 37 deletions(-) create mode 100644 examples/augment_specific_tags.sh delete mode 100644 examples/augment_specific_years.sh diff --git a/examples/augment.sh b/examples/augment.sh index caf3ecb3..cb85617a 100644 --- a/examples/augment.sh +++ b/examples/augment.sh @@ -1,2 +1,2 @@ -grants-tagger augment mesh data/raw/allMeSH_2021.jsonl more_data.jsonl \ +grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [OUTPUT_FOLDER] \ --concurrent-calls 25 \ No newline at end of file diff --git a/examples/augment_specific_tags.sh b/examples/augment_specific_tags.sh new file mode 100644 index 00000000..94300bb3 --- /dev/null +++ b/examples/augment_specific_tags.sh @@ -0,0 +1,5 @@ +# Augments data using a file with 1 label per line and years +grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [OUTPUT_FOLDER] \ + --tags-file-path [YOUR_TAGS_FILE] \ + --min-examples 25 \ + --concurrent-calls 25 \ No newline at end of file diff --git a/examples/augment_specific_years.sh b/examples/augment_specific_years.sh deleted file mode 100644 index e1b16a70..00000000 --- a/examples/augment_specific_years.sh +++ /dev/null @@ -1,3 +0,0 @@ -grants-tagger augment mesh data/raw/allMeSH_2021.jsonl more_data.jsonl \ - --train-years 2017 \ - --concurrent-calls 25 \ No newline at end of file diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index fc59b474..711622c3 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -11,6 +11,8 @@ from grants_tagger_light.augmentation.augment_openai import AugmentOpenAI from grants_tagger_light.utils.years_tags_parser import parse_years +from datasets import load_from_disk + augment_app = typer.Typer() @@ -41,8 +43,6 @@ def augment( model_key: str = 'gpt-3.5-turbo', num_proc: int = os.cpu_count(), batch_size: int = 64, - train_years: list = None, - test_years: list = None, min_examples: int = 15, prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', concurrent_calls: int = os.cpu_count()*2, @@ -52,16 +52,10 @@ def augment( if model_key.strip().lower() not in ['gpt-3.5-turbo', 'text-davinci', 'gpt-4']: raise NotImplementedError(f"{model_key} not implemented as an augmentation framework") - # We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing - dset = load_dataset("json", data_files=data_path, num_proc=1) - # By default, any dataset loaded is set to 'train' using the previous command - if "train" in dset: - dset = dset["train"] - - if train_years is not None and len(train_years) > 0: - dset = dset.filter(lambda x: any(np.isin(train_years, [str(x["year"])])), num_proc=num_proc) - if test_years is not None and len(test_years) > 0: - dset = dset.filter(lambda x: not any(np.isin(test_years, [str(x["year"])])), num_proc=num_proc) + try: + dset = load_from_disk(data_path) + except Exception: + dset = load_from_disk(os.path.join(data_path, "dataset")) logger.info("Obtaining count values from the labels...") pool = multiprocessing.Pool(processes=num_proc) @@ -107,7 +101,6 @@ def augment( collect_concurrent_calls, dset, save_to_path, - train_years, model_key, temperature=temperature, num_proc=num_proc, @@ -123,7 +116,6 @@ def augment( collect_concurrent_calls, dset, save_to_path, - train_years, model_key, temperature=temperature, num_proc=num_proc, @@ -150,14 +142,6 @@ def augment_cli( 64, help="Preprocessing batch size (for dataset, filter, map, ...)" ), - train_years: str = typer.Option( - None, - help="If set, Comma-separated years you want to include in the data augmentation process" - ), - test_years: str = typer.Option( - None, - help="If set, Comma-separated years you want to exclude in the data augmentation process" - ), min_examples: int = typer.Option( 15, help="Minimum number of examples to require. Less than that will trigger data augmentation." @@ -179,10 +163,9 @@ def augment_cli( help="Text file containing one line per tag to be considered. The rest will be discarded." ) ): - if not data_path.endswith("jsonl"): + if not os.path.isdir(data_path): logger.error( - "It seems your input MeSH data is not in `jsonl` format. " - "Please, run first `scripts/mesh_json_to_jsonlpy.`" + "The data path should be a folder with saved data from `preprocessing` step." ) exit(-1) @@ -197,8 +180,6 @@ def augment_cli( model_key=model_key, num_proc=num_proc, batch_size=batch_size, - train_years=parse_years(train_years), - test_years=parse_years(test_years), min_examples=min_examples, prompt_template=prompt_template, concurrent_calls=concurrent_calls, diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index f3e9cc70..ce7f0490 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -69,12 +69,10 @@ def _make_requests(self, top_p, presence_penalty, num_proc, - train_years, model_key, save_to_path): - year = [random.choice(train_years) if train_years is not None and isinstance(train_years, list) - else datetime.date.year] + year = datetime.date.year for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] @@ -114,7 +112,7 @@ def _make_requests(self, self.api.request(data=data, metadata=metadata, callback=self._process_response) - def generate(self, collect_concurrent_calls, dset, save_to_path, train_years, model_key, + def generate(self, collect_concurrent_calls, dset, save_to_path, model_key, temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count()): self.api.run_request_function(self._make_requests, collect_concurrent_calls=collect_concurrent_calls, @@ -123,9 +121,7 @@ def generate(self, collect_concurrent_calls, dset, save_to_path, train_years, mo top_p=top_p, presence_penalty=presence_penalty, num_proc=num_proc, - train_years=train_years, model_key=model_key, save_to_path=save_to_path) self.api.pull_all() - From 60c9fcad4a02948e97ad2ca0674aca97ee17bfda Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 12:47:47 +0100 Subject: [PATCH 273/300] Adds id2label for augmentation --- grants_tagger_light/augmentation/augment.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 711622c3..f360d782 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -4,7 +4,6 @@ import typer from loguru import logger -from datasets import load_dataset import numpy as np @@ -52,10 +51,20 @@ def augment( if model_key.strip().lower() not in ['gpt-3.5-turbo', 'text-davinci', 'gpt-4']: raise NotImplementedError(f"{model_key} not implemented as an augmentation framework") - try: - dset = load_from_disk(data_path) - except Exception: - dset = load_from_disk(os.path.join(data_path, "dataset")) + dset = load_from_disk(data_path) + + with open(os.path.join(data_path, "label2id"), "r") as f: + label2id = json.load(f) + id2label = {v: k for k, v in label2id.items()} + + dset = dset.map( + lambda x: {'meshMajor': [id2label[y] for y in x['label_ids']]}, + with_indices=False, + batched=True, + batch_size=batch_size, + desc="Adding label names", + num_proc=num_proc, + ) logger.info("Obtaining count values from the labels...") pool = multiprocessing.Pool(processes=num_proc) From eb2d0f3770479200507b0dd224e7f3579b542c7e Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 12:54:05 +0100 Subject: [PATCH 274/300] Adds `dataset` folder --- grants_tagger_light/augmentation/augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index f360d782..c50ccca6 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -51,7 +51,7 @@ def augment( if model_key.strip().lower() not in ['gpt-3.5-turbo', 'text-davinci', 'gpt-4']: raise NotImplementedError(f"{model_key} not implemented as an augmentation framework") - dset = load_from_disk(data_path) + dset = load_from_disk(os.path.join(data_path, "dataset")) with open(os.path.join(data_path, "label2id"), "r") as f: label2id = json.load(f) From 98150278294e770a9b169b4b75d9f66612e86d97 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 13:01:05 +0100 Subject: [PATCH 275/300] Decodes id back into labels --- grants_tagger_light/augmentation/augment.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index c50ccca6..146dee80 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -15,6 +15,14 @@ augment_app = typer.Typer() +def _map_id_to_labels(ids, id2label): + return [id2label[i] for i in ids] + + +def _restore_meshmajor(sample, id2label): + return {"meshMajor": [_map_id_to_labels(x, id2label) for x in sample["label_ids"]]} + + def _count_elements_in_sublist(sublist): element_count = {} for element in sublist: @@ -58,12 +66,12 @@ def augment( id2label = {v: k for k, v in label2id.items()} dset = dset.map( - lambda x: {'meshMajor': [id2label[y] for y in x['label_ids']]}, - with_indices=False, + _restore_meshmajor, batched=True, batch_size=batch_size, - desc="Adding label names", + desc="Decoding labels", num_proc=num_proc, + fn_kwargs={"id2label": id2label} ) logger.info("Obtaining count values from the labels...") From 56ef40cadd768a29d7a5553d3ff73c015c40ce34 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 13:10:12 +0100 Subject: [PATCH 276/300] Decodes id back into labels --- grants_tagger_light/augmentation/augment.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 146dee80..dac58c2f 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -60,6 +60,8 @@ def augment( raise NotImplementedError(f"{model_key} not implemented as an augmentation framework") dset = load_from_disk(os.path.join(data_path, "dataset")) + if "train" in dset: + dset = dset["train"] with open(os.path.join(data_path, "label2id"), "r") as f: label2id = json.load(f) From d77191840b7929679eb11f16fd08509386633c0b Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 14:07:30 +0100 Subject: [PATCH 277/300] Generates examples also for not-underrepresented --- examples/augment_specific_tags.sh | 2 +- grants_tagger_light/augmentation/augment.py | 38 +++++++++++++------ .../augmentation/augment_openai.py | 14 +++++-- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/examples/augment_specific_tags.sh b/examples/augment_specific_tags.sh index 94300bb3..1c504d61 100644 --- a/examples/augment_specific_tags.sh +++ b/examples/augment_specific_tags.sh @@ -1,5 +1,5 @@ # Augments data using a file with 1 label per line and years grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [OUTPUT_FOLDER] \ --tags-file-path [YOUR_TAGS_FILE] \ - --min-examples 25 \ + --examples 25 \ --concurrent-calls 25 \ No newline at end of file diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index dac58c2f..93afd4ad 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -50,7 +50,8 @@ def augment( model_key: str = 'gpt-3.5-turbo', num_proc: int = os.cpu_count(), batch_size: int = 64, - min_examples: int = 15, + min_examples: int = None, + examples: int = 25, prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', concurrent_calls: int = os.cpu_count()*2, temperature: float = 1.5, @@ -92,16 +93,22 @@ def augment( sorted_merged_element_counts_dict = {k: v for k, v in sorted_merged_element_counts_dict.items() if k in tags} + if min_examples is not None: + sorted_merged_element_counts_dict = {k: v for k, v in sorted_merged_element_counts_dict.items() + if v < min_examples} + with open(f"{save_to_path}.count", 'w') as f: f.write(json.dumps(sorted_merged_element_counts_dict, indent=2)) - tags_to_augment_counts = {k: v for k, v in sorted_merged_element_counts_dict.items() if v < min_examples} - tags_to_augment = [k for k, v in sorted_merged_element_counts_dict.items() if v < min_examples] + tags_to_augment = sorted_merged_element_counts_dict.keys() + + biggest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" + for k in tags_to_augment[:5]] + smallest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" + for k in tags_to_augment[-5:]] - biggest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" for k in tags_to_augment[:5]] - smallest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" for k in tags_to_augment[-5:]] - logger.info(f"Augmenting a total of {len(tags_to_augment)} tags, from {biggest_tags_to_augment} to " - f"{smallest_tags_to_augment}") + logger.info(f"Augmenting a total of {len(tags_to_augment)} tags, " + f"from {biggest_tags_to_augment} to {smallest_tags_to_augment}") logger.info(f"Collecting existing examples of those tags to send in the prompt") dset = dset.filter(lambda x: any(np.isin(tags_to_augment, x["meshMajor"])), num_proc=num_proc) @@ -126,9 +133,7 @@ def augment( ) collect_concurrent_calls = [] else: - if tags_to_augment_counts[t] < min_examples: - missing = min_examples - tags_to_augment_counts[t] - collect_concurrent_calls.append((t, missing)) + collect_concurrent_calls.append((t, examples)) if len(collect_concurrent_calls) > 0: AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key).generate( @@ -162,9 +167,13 @@ def augment_cli( help="Preprocessing batch size (for dataset, filter, map, ...)" ), min_examples: int = typer.Option( - 15, + None, help="Minimum number of examples to require. Less than that will trigger data augmentation." ), + examples: int = typer.Option( + 25, + help="Examples to generate per each tag." + ), prompt_template: str = typer.Option( 'grants_tagger_light/augmentation/prompt.template', help="File to use as a prompt. Make sure to ask the LLM to return a dict with two fields: `abstract` and `tags`" @@ -188,6 +197,12 @@ def augment_cli( ) exit(-1) + if tags_file_path is None and min_examples is None: + logger.error( + "To understand which tags need to be augmented, set either --min-examples or --tags-file-path" + ) + exit(-1) + if float(temperature) > 2.0 or float(temperature) < -2.0: logger.error( "Temperature should be in the range [-2, 2]" @@ -200,6 +215,7 @@ def augment_cli( num_proc=num_proc, batch_size=batch_size, min_examples=min_examples, + examples=examples, prompt_template=prompt_template, concurrent_calls=concurrent_calls, temperature=temperature, diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index ce7f0490..13260245 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -1,5 +1,6 @@ import datetime import json +import math import os import random import uuid @@ -77,7 +78,6 @@ def _make_requests(self, for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] n = collect_concurrent_calls[num][1] - logger.info(f"Augmenting {t} with {n} examples") # RAG: I select similar articles to provide them to the LLM tmp_dset = dset.filter(lambda x: any(np.isin([t], x["meshMajor"])), num_proc=num_proc) # I remove them from the dataset to process to make it smaller and quicker over time @@ -87,13 +87,19 @@ def _make_requests(self, abstracts_num = [i for i in range(len(tmp_dset))] random.shuffle(abstracts_num) - for i in range(n): - selected_row = abstracts_num[i % len(tmp_dset)] + required_examples = n + existing_examples = min(required_examples, len(tmp_dset)) + + n_per_example = math.ceil(required_examples / existing_examples) + + for i in range(existing_examples): + logger.info(f"Augmenting {t} with {n_per_example} examples using 1 already existing example") + selected_row = abstracts_num[i] abstract = tmp_dset['abstractText'][selected_row] tags = tmp_dset['meshMajor'][selected_row] data = { "model": self.model_key, - "n": 1, + "n": n_per_example, "temperature": temperature, "top_p": top_p, "presence_penalty": presence_penalty, From 99e1067864cc80fc6214d4f7fb2f52cf4669c160 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 14:12:01 +0100 Subject: [PATCH 278/300] Generates examples also for not-underrepresented --- grants_tagger_light/augmentation/augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 93afd4ad..15e2e777 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -100,7 +100,7 @@ def augment( with open(f"{save_to_path}.count", 'w') as f: f.write(json.dumps(sorted_merged_element_counts_dict, indent=2)) - tags_to_augment = sorted_merged_element_counts_dict.keys() + tags_to_augment = list(sorted_merged_element_counts_dict.keys()) biggest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" for k in tags_to_augment[:5]] From fa1723327422eb9bf7193132331ae623e300ca75 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 18:11:10 +0100 Subject: [PATCH 279/300] Generates examples also for not-underrepresented --- grants_tagger_light/augmentation/augment_openai.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 13260245..1df8f552 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -78,11 +78,11 @@ def _make_requests(self, for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] n = collect_concurrent_calls[num][1] - # RAG: I select similar articles to provide them to the LLM + # RAG: I select similar articles to provide them to the LLM, maximum ´n´ (I don't need more) tmp_dset = dset.filter(lambda x: any(np.isin([t], x["meshMajor"])), num_proc=num_proc) # I remove them from the dataset to process to make it smaller and quicker over time - dset = dset.filter(lambda example, idx: idx not in tmp_dset['idx'], with_indices=True, - num_proc=num_proc) + # dset = dset.filter(lambda example, idx: idx not in tmp_dset['idx'], with_indices=True, + # num_proc=num_proc) abstracts_num = [i for i in range(len(tmp_dset))] random.shuffle(abstracts_num) From c811f8a1e5b5b6aa0107d60c1461bb2543cf12c5 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 18:24:34 +0100 Subject: [PATCH 280/300] Adds more columns to preprocessing --- grants_tagger_light/augmentation/augment.py | 22 ------------------- .../augmentation/augment_openai.py | 8 ++----- .../preprocessing/preprocess_mesh.py | 14 ++++++------ grants_tagger_light/training/train.py | 6 +++-- 4 files changed, 13 insertions(+), 37 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 15e2e777..e8a35075 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -8,21 +8,12 @@ from grants_tagger_light.augmentation.augment_openai import AugmentOpenAI -from grants_tagger_light.utils.years_tags_parser import parse_years from datasets import load_from_disk augment_app = typer.Typer() -def _map_id_to_labels(ids, id2label): - return [id2label[i] for i in ids] - - -def _restore_meshmajor(sample, id2label): - return {"meshMajor": [_map_id_to_labels(x, id2label) for x in sample["label_ids"]]} - - def _count_elements_in_sublist(sublist): element_count = {} for element in sublist: @@ -64,19 +55,6 @@ def augment( if "train" in dset: dset = dset["train"] - with open(os.path.join(data_path, "label2id"), "r") as f: - label2id = json.load(f) - id2label = {v: k for k, v in label2id.items()} - - dset = dset.map( - _restore_meshmajor, - batched=True, - batch_size=batch_size, - desc="Decoding labels", - num_proc=num_proc, - fn_kwargs={"id2label": id2label} - ) - logger.info("Obtaining count values from the labels...") pool = multiprocessing.Pool(processes=num_proc) element_counts_list = pool.map(_count_elements_in_sublist, dset['meshMajor']) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 1df8f552..3d248c94 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -84,9 +84,6 @@ def _make_requests(self, # dset = dset.filter(lambda example, idx: idx not in tmp_dset['idx'], with_indices=True, # num_proc=num_proc) - abstracts_num = [i for i in range(len(tmp_dset))] - random.shuffle(abstracts_num) - required_examples = n existing_examples = min(required_examples, len(tmp_dset)) @@ -94,9 +91,8 @@ def _make_requests(self, for i in range(existing_examples): logger.info(f"Augmenting {t} with {n_per_example} examples using 1 already existing example") - selected_row = abstracts_num[i] - abstract = tmp_dset['abstractText'][selected_row] - tags = tmp_dset['meshMajor'][selected_row] + abstract = tmp_dset['abstractText'][i] + tags = tmp_dset['meshMajor'][i] data = { "model": self.model_key, "n": n_per_example, diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 81492f9b..7de2acb3 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -68,6 +68,7 @@ def preprocess_mesh( if not model_key: label2id = None + id2label = None tokenizer = AutoTokenizer.from_pretrained( "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" ) @@ -76,7 +77,8 @@ def preprocess_mesh( tokenizer = AutoTokenizer.from_pretrained(model_key) model = BertMesh.from_pretrained(model_key, trust_remote_code=True) - label2id = {v: k for k, v in model.id2label.items()} + id2label = model.id2label + label2id = {v: k for k, v in id2label.items()} # We only have 1 file, so no sharding is available https://huggingface.co/docs/datasets/loading#multiprocessing dset = load_dataset("json", data_files=data_path, num_proc=1) @@ -109,12 +111,10 @@ def preprocess_mesh( num_proc=num_proc, desc="Tokenizing", fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, - remove_columns=["abstractText"], load_from_cache_file=False, ) logger.info("Time taken to tokenize: {}".format(time.time() - t1)) - columns_to_remove = ["meshMajor"] # Generate label2id if None if label2id is None: logger.info("Getting the labels...") @@ -139,8 +139,7 @@ def preprocess_mesh( label2id = dict() for idx, label in enumerate(tqdm(unique_labels_set)): label2id.update({label: idx}) - - columns_to_remove.append("labels") + id2label.update({idx: label}) t1 = time.time() dset = dset.map( @@ -150,7 +149,6 @@ def preprocess_mesh( desc="Encoding labels", num_proc=num_proc, fn_kwargs={"label2id": label2id}, - remove_columns=columns_to_remove, ) logger.info("Time taken to encode labels: {}".format(time.time() - t1)) @@ -199,8 +197,10 @@ def preprocess_mesh( dset.save_to_disk(os.path.join(save_to_path, "dataset"), num_proc=num_proc) with open(os.path.join(save_to_path, "label2id"), "w") as f: json.dump(label2id, f) + with open(os.path.join(save_to_path, "id2label"), "w") as f: + json.dump(id2label, f) - return dset, label2id + return dset, label2id, id2label @preprocess_app.command() diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index a6437415..70032b37 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -65,9 +65,11 @@ def train_bertmesh( dset = load_from_disk(os.path.join(data_path, "dataset")) with open(os.path.join(data_path, "label2id"), "r") as f: label2id = json.load(f) + with open(os.path.join(data_path, "id2label"), "r") as f: + id2label = json.load(f) else: logger.info("Preprocessing the data on the fly...") - dset, label2id = preprocess_mesh( + dset, label2id, id2label = preprocess_mesh( data_path=data_path, model_key=model_key, test_size=test_size, @@ -112,7 +114,7 @@ def train_bertmesh( "dropout": model_args.dropout, "multilabel_attention": model_args.multilabel_attention, "label2id": label2id, - "id2label": {v: k for k, v in label2id.items()}, + "id2label": id2label, "freeze_backbone": model_args.freeze_backbone, "hidden_dropout_prob": model_args.hidden_dropout_prob, "attention_probs_dropout_prob": model_args.attention_probs_dropout_prob, From 6d58ea9a43bf5d6ef17a9d3447a3380853d1259c Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 18:27:01 +0100 Subject: [PATCH 281/300] Adds more columns to preprocessing --- grants_tagger_light/preprocessing/preprocess_mesh.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 7de2acb3..9382ea1d 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -137,6 +137,7 @@ def preprocess_mesh( # Most efficient way to do dictionary creation logger.info("Creating label2id dictionary...") label2id = dict() + id2label = dict() for idx, label in enumerate(tqdm(unique_labels_set)): label2id.update({label: idx}) id2label.update({idx: label}) From c75c36a68983ef77d6558f6c3396bfd715369c1f Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 18:31:23 +0100 Subject: [PATCH 282/300] Adds more columns to preprocessing --- grants_tagger_light/augmentation/augment_openai.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 3d248c94..382c5ab5 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -47,7 +47,7 @@ def _process_response(result): a = json_response['abstract'] tl = json_response['title'] - f.write(json.dumps({ + res = { "journal": result.metadata['model_key'], "meshMajor": result.metadata['tags'], "year": result.metadata['year'], @@ -57,7 +57,11 @@ def _process_response(result): "existing_example": result.metadata['existing_example'], "required_examples": result.metadata['required_examples'], "featured_tag": result.metadata['featured_tag'] - })) + } + + print(res) + + f.write(json.dumps(res)) f.write('\n') f.flush() From 2e0474eda8d6ab59eb10285dbd3345254584ba15 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 18:36:56 +0100 Subject: [PATCH 283/300] Adds more columns to preprocessing --- grants_tagger_light/augmentation/augment_openai.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 382c5ab5..38315cac 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -44,23 +44,18 @@ def _process_response(result): logger.info(f"Error processing output: {e}. Skipping...") continue - a = json_response['abstract'] - tl = json_response['title'] - res = { "journal": result.metadata['model_key'], "meshMajor": result.metadata['tags'], "year": result.metadata['year'], - "abstractText": a, + "abstractText": json_response['abstract'].replace("'", "").replace('"', ''), "pmid": uuid.uuid4().hex, - "title": tl, - "existing_example": result.metadata['existing_example'], + "title": json_response['title'].replace("'", "").replace('"', ''), + "existing_example": result.metadata['existing_example'].replace("'", "").replace('"', ''), "required_examples": result.metadata['required_examples'], "featured_tag": result.metadata['featured_tag'] } - print(res) - f.write(json.dumps(res)) f.write('\n') f.flush() @@ -77,7 +72,7 @@ def _make_requests(self, model_key, save_to_path): - year = datetime.date.year + year = datetime.date.today().year for num in range(len(collect_concurrent_calls)): t = collect_concurrent_calls[num][0] From 1c173e1bf88d316b8a0b83e0c73a836715192503 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 19:15:07 +0100 Subject: [PATCH 284/300] Adds more columns to preprocessing --- grants_tagger_light/augmentation/augment.py | 4 ++-- grants_tagger_light/augmentation/augment_openai.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index e8a35075..b0950146 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -67,7 +67,7 @@ def augment( if tags_file_path is not None: with open(tags_file_path, 'r') as f: tags = f.read().split('\n') - logger.info(f"Tags file path found. Filtering tags (examples found: {tags[:15]}...)") + logger.info(f"Tags file path found. Filtering {len(tags)} tags (examples found: {tags[:15]}...)") sorted_merged_element_counts_dict = {k: v for k, v in sorted_merged_element_counts_dict.items() if k in tags} @@ -88,7 +88,7 @@ def augment( logger.info(f"Augmenting a total of {len(tags_to_augment)} tags, " f"from {biggest_tags_to_augment} to {smallest_tags_to_augment}") - logger.info(f"Collecting existing examples of those tags to send in the prompt") + logger.info(f"RAG: Collecting existing examples of those tags to send in the prompt") dset = dset.filter(lambda x: any(np.isin(tags_to_augment, x["meshMajor"])), num_proc=num_proc) dset = dset.map( lambda _, y: {'idx': y}, diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 38315cac..aa6d1bc2 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -38,6 +38,7 @@ def _process_response(result): for r in result.response['choices']: if 'message' in r: if 'content' in r['message']: + print(r['message']['content']) try: json_response = JsonParser.parse_json(r['message']['content']) except Exception as e: From 726837fa55dc009bbd37f2fb93f0190f9b8bea62 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 19:21:59 +0100 Subject: [PATCH 285/300] Prevents crashes --- .../augmentation/augment_openai.py | 53 ++++++++++--------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index aa6d1bc2..8a36e7ed 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -38,30 +38,33 @@ def _process_response(result): for r in result.response['choices']: if 'message' in r: if 'content' in r['message']: - print(r['message']['content']) try: - json_response = JsonParser.parse_json(r['message']['content']) - except Exception as e: - logger.info(f"Error processing output: {e}. Skipping...") - continue - - res = { - "journal": result.metadata['model_key'], - "meshMajor": result.metadata['tags'], - "year": result.metadata['year'], - "abstractText": json_response['abstract'].replace("'", "").replace('"', ''), - "pmid": uuid.uuid4().hex, - "title": json_response['title'].replace("'", "").replace('"', ''), - "existing_example": result.metadata['existing_example'].replace("'", "").replace('"', ''), - "required_examples": result.metadata['required_examples'], - "featured_tag": result.metadata['featured_tag'] - } - - f.write(json.dumps(res)) - f.write('\n') - f.flush() - - logger.info(f"Data received successfully for {result.metadata['featured_tag']}") + logger.info(r['message']['content']) + try: + json_response = JsonParser.parse_json(r['message']['content']) + except Exception as e: + logger.info(f"Error processing output: {e}. Skipping...") + continue + + res = { + "journal": result.metadata['model_key'], + "meshMajor": result.metadata['tags'], + "year": result.metadata['year'], + "abstractText": json_response['abstract'].replace("'", "").replace('"', ''), + "pmid": uuid.uuid4().hex, + "title": json_response['title'].replace("'", "").replace('"', ''), + "existing_example": result.metadata['existing_example'].replace("'", "").replace('"', ''), + "required_examples": result.metadata['required_examples'], + "featured_tag": result.metadata['featured_tag'] + } + + f.write(json.dumps(res)) + f.write('\n') + f.flush() + + logger.info(f"Data received successfully for {result.metadata['featured_tag']}") + except: + logger.info(f"Missing or malformed data. Skipping") def _make_requests(self, collect_concurrent_calls, @@ -88,9 +91,9 @@ def _make_requests(self, existing_examples = min(required_examples, len(tmp_dset)) n_per_example = math.ceil(required_examples / existing_examples) - + logger.info(f"Augmenting {t} with {required_examples} examples, using {existing_examples} in RAG mode") for i in range(existing_examples): - logger.info(f"Augmenting {t} with {n_per_example} examples using 1 already existing example") + abstract = tmp_dset['abstractText'][i] tags = tmp_dset['meshMajor'][i] data = { From 60a3ec98ce0f2a74aa9cf60ff80e594a01d24422 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Mon, 28 Aug 2023 20:07:36 +0100 Subject: [PATCH 286/300] Better hyperparams --- examples/resume_train_by_epoch.sh | 16 ++++++++++------ examples/resume_train_by_steps.sh | 6 +++--- examples/train_by_epochs.sh | 13 +++++++++---- examples/train_by_steps.sh | 6 +++--- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/examples/resume_train_by_epoch.sh b/examples/resume_train_by_epoch.sh index 4c35d68c..261f9d57 100644 --- a/examples/resume_train_by_epoch.sh +++ b/examples/resume_train_by_epoch.sh @@ -12,20 +12,24 @@ CHECKPOINT="checkpoint-100000" grants-tagger train bertmesh \ bertmesh_outs/pipeline_test/$CHECKPOINT \ $SOURCE \ - --output_dir bertmesh_outs/pipeline_test_from_$CHECKPOINT/ \ - --ignore_data_skip True \ + --output_dir bertmesh_outs/pipeline_test/ \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ - --freeze_backbone unfreeze_bias \ + --freeze_backbone unfreeze \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ - --warmup_steps 1000 \ - --max_grad_norm 5.0 \ - --scheduler-type cosine \ + --warmup_steps 5000 \ + --max_grad_norm 2.0 \ + --scheduler_type cosine_hard_restart \ + --weight_decay 0.2 \ + --correct_bias True \ --threshold 0.25 \ + --prune_labels_in_evaluation True \ + --hidden_dropout_prob 0.2 \ + --attention_probs_dropout_prob 0.2 \ --fp16 \ --torch_compile \ --evaluation_strategy epoch \ diff --git a/examples/resume_train_by_steps.sh b/examples/resume_train_by_steps.sh index 39c6baa5..831fbd22 100644 --- a/examples/resume_train_by_steps.sh +++ b/examples/resume_train_by_steps.sh @@ -16,16 +16,16 @@ grants-tagger train bertmesh \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ - --freeze_backbone unfreeze_bias \ + --freeze_backbone unfreeze \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ --warmup_steps 5000 \ - --max_grad_norm 1.0 \ + --max_grad_norm 2.0 \ --scheduler_type cosine_hard_restart \ --weight_decay 0.2 \ - --correct_bias False \ + --correct_bias True \ --threshold 0.25 \ --prune_labels_in_evaluation True \ --hidden_dropout_prob 0.2 \ diff --git a/examples/train_by_epochs.sh b/examples/train_by_epochs.sh index 9980de0f..51686cdf 100644 --- a/examples/train_by_epochs.sh +++ b/examples/train_by_epochs.sh @@ -11,15 +11,20 @@ grants-tagger train bertmesh \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ - --freeze_backbone unfreeze_bias \ + --freeze_backbone unfreeze \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ - --warmup_steps 1000 \ - --max_grad_norm 5.0 \ - --scheduler-type cosine \ + --warmup_steps 5000 \ + --max_grad_norm 2.0 \ + --scheduler_type cosine_hard_restart \ + --weight_decay 0.2 \ + --correct_bias True \ --threshold 0.25 \ + --prune_labels_in_evaluation True \ + --hidden_dropout_prob 0.2 \ + --attention_probs_dropout_prob 0.2 \ --fp16 \ --torch_compile \ --evaluation_strategy epoch \ diff --git a/examples/train_by_steps.sh b/examples/train_by_steps.sh index f4c664b9..a114970f 100644 --- a/examples/train_by_steps.sh +++ b/examples/train_by_steps.sh @@ -11,16 +11,16 @@ grants-tagger train bertmesh \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ - --freeze_backbone unfreeze_bias \ + --freeze_backbone unfreeze \ --num_train_epochs 5 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ --warmup_steps 5000 \ - --max_grad_norm 1.0 \ + --max_grad_norm 2.0 \ --scheduler_type cosine_hard_restart \ --weight_decay 0.2 \ - --correct_bias False \ + --correct_bias True \ --threshold 0.25 \ --prune_labels_in_evaluation True \ --hidden_dropout_prob 0.2 \ From 4f70bbcb6369847e8e125cc0c118539d4e133868 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 31 Aug 2023 17:36:01 +0100 Subject: [PATCH 287/300] Check if fixes tests --- README.md | 130 +++++++++--------- examples/preprocess_and_train_by_steps.sh | 4 +- examples/resume_train_by_epoch.sh | 9 +- examples/resume_train_by_steps.sh | 9 +- examples/train_by_epochs.sh | 4 +- examples/train_by_steps.sh | 6 +- .../preprocessing/preprocess_mesh.py | 3 + .../training/cli_args/bertmesh_args.py | 12 +- .../training/cli_args/train_args.py | 7 +- grants_tagger_light/training/train.py | 4 +- tests/test_preprocess_mesh.py | 2 +- tests/test_split_data.py | 4 +- tests/test_train.py | 17 +-- 13 files changed, 105 insertions(+), 106 deletions(-) diff --git a/README.md b/README.md index 8289bf13..7535dde7 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ in square brackets the commands that are not implemented yet This process is optional to run, since it can be directly managed by the `Train` process. - If you run it manually, it will store the data in local first, which can help if you need finetune in the future, rerun, etc. -- If not, the project will preprocess and then run, without any extra I/O operations on disk, +- If not run it, the `train` step will preprocess and then run, without any extra I/O operations on disk, which may add latency depending on the infrastructure. It requires data in `jsonl` format for parallelization purposes. In `data/raw` you can find `allMesH_2021.jsonl` @@ -96,65 +96,49 @@ your own data under development. ### Preprocessing bertmesh ``` - Usage: grants-tagger preprocess mesh [OPTIONS] DATA_PATH SAVE_TO_PATH - MODEL_KEY - -╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ * data_path TEXT Path to mesh.jsonl [default: None] [required] │ -│ * save_to_path TEXT Path to save the serialized PyArrow dataset after preprocessing [default: None] │ -│ [required] │ -│ * model_key TEXT Key to use when loading tokenizer and label2id. Leave blank if training from │ -│ scratch │ -│ [default: None] │ -│ [required] │ -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ --test-size FLOAT Fraction of data to use for testing (if less than 1) or number of rows │ -│ [default: 0.05] │ -│ --num-proc INTEGER Number of processes to use for preprocessing [default: 8] │ -│ --max-samples INTEGER Maximum number of samples to use for preprocessing [default: -1] │ -│ --batch-size INTEGER Size of the preprocessing batch [default: 256] │ -│ --train-years TEXT Comma-separated years you want to include in training (e.g: 2020,2021) │ -│ [default: None, meaning all years] │ -│ --test-years TEXT Comma-separated years you want to include in test (e.g: 2020,2021) │ -│ [default: None, meaning all years] │ -│ --tags TEXT Comma-separated tags you want to included (e.g: Pandemics,COVID19) │ -│ --help Show this message and exit. │ -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯``` + Usage: grants-tagger preprocess mesh [OPTIONS] DATA_PATH SAVE_TO_PATH + MODEL_KEY + +╭─ Arguments ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ * data_path TEXT Path to mesh.jsonl [default: None] [required] │ +│ * save_to_path TEXT Path to save the serialized PyArrow dataset after preprocessing [default: None] [required] │ +│ * model_key TEXT Key to use when loading tokenizer and label2id. Leave blank if training from scratch [default: None] [required] │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ --test-size FLOAT Fraction of data to use for testing in (0,1] or number of rows [default: None] │ +│ --num-proc INTEGER Number of processes to use for preprocessing [default: 8] │ +│ --max-samples INTEGER Maximum number of samples to use for preprocessing [default: -1] │ +│ --batch-size INTEGER Size of the preprocessing batch [default: 256] │ +│ --tags TEXT Comma-separated tags you want to include in the dataset (the rest will be discarded) [default: None] │ +│ --train-years TEXT Comma-separated years you want to include in the training dataset [default: None] │ +│ --test-years TEXT Comma-separated years you want to include in the test dataset [default: None] │ +│ --help Show this message and exit. │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` ## 🔥 Train -Train acts as the entry point command for training all models. Currently, we only support -the BertMesh model. The command will train a model and save it to the specified path. +The command will train a model and save it to the specified path. Currently we support on BertMesh. ### bertmesh ``` - Usage: grants-tagger train bertmesh [OPTIONS] MODEL_KEY DATA_PATH - -╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ * model_key TEXT Pretrained model key. Local path or HF location [default: None] [required] │ -│ * data_path TEXT Path to allMeSH_2021.jsonl (or similar) or to a folder after preprocessing and saving │ -│ to disk │ -│ [default: None] │ -│ [required] │ -│ --shards INTEGER Number os shards to divide training IterativeDataset to (improves performance) │ -│ [default: -1, meaning no shards]. Recommended: os.cpu_count() │ -│ --num-proc INTEGER Number of processes to use for preprocessing [default: os.cpu_count()] │ -│ --help Show this message and exit. │ -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ - -If you are running directly training without calling preprocess, you can specify the same parameters as preprocess: - -╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ --test-size FLOAT Fraction of data to use for testing [default: 0.05] │ -│ --max-samples INTEGER Maximum number of samples to use from the json [default: -1] │ -│ --train-years TEXT Comma-separated years you want to include in training (e.g: 2020,2021) │ -│ [default: None, meaning all years] │ -│ --test-years TEXT Comma-separated years you want to include in test (e.g: 2020,2021) │ -│ [default: None, meaning all years] │ -│ --tags TEXT Comma-separated tags you want to included (e.g: Pandemics,COVID19) │ -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ + Usage: grants-tagger train bertmesh [OPTIONS] MODEL_KEY DATA_PATH + +╭─ Arguments ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ * model_key TEXT Pretrained model key. Local path or HF location [default: None] [required] │ +│ * data_path TEXT Path to allMeSH_2021.jsonl (or similar) or to a folder after preprocessing and saving to disk [default: None] [required] │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ --test-size FLOAT Fraction of data to use for testing (0,1] or number of rows [default: None] │ +│ --num-proc INTEGER Number of processes to use for preprocessing [default: 8] │ +│ --max-samples INTEGER Maximum number of samples to use from the json [default: -1] │ +│ --shards INTEGER Number os shards to divide training IterativeDataset to (improves performance) [default: 8] │ +│ --from-checkpoint TEXT Name of the checkpoint to resume training [default: None] │ +│ --tags TEXT Comma-separated tags you want to include in the dataset (the rest will be discarded) [default: None] │ +│ --train-years TEXT Comma-separated years you want to include in the training dataset [default: None] │ +│ --test-years TEXT Comma-separated years you want to include in the test dataset [default: None] │ +│ --help Show this message and exit. │ +╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` #### About `model_key` @@ -168,30 +152,40 @@ to improve performance on big datasets. To enable it: - set shards to something bigger than 1 (Recommended: same number as cpu cores) #### Other arguments -Besides those arguments, feel free to add any other TrainingArgument from Hugging Face or Wand DB. Examples: +Besides those arguments, feel free to add any other TrainingArgument from Hugging Face or Wand DB. +This is the example used to train reaching a ~0.6 F1, also available at `examples/train_by_epochs.sh` ```commandline grants-tagger train bertmesh \ "" \ - data/raw/allMeSH_2021.jsonl \ - --test-size 0.005 \ - --shards 250 \ - --output_dir bertmesh_outs/pipeline_test/ \ - --per_device_train_batch_size 32 \ - --num_train_epochs 1 \ - --save_strategy steps \ - --save_steps 50000 \ + [YOUR_PREPROCESSED_FOLDER] \ + --output_dir [YOUR_OUTPUT_FOLDER] \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --multilabel_attention True \ + --freeze_backbone unfreeze \ + --num_train_epochs 7 \ + --learning_rate 5e-5 \ + --dropout 0.1 \ + --hidden_size 1024 \ + --warmup_steps 5000 \ + --max_grad_norm 2.0 \ + --scheduler_type cosine_hard_restart \ + --weight_decay 0.2 \ + --correct_bias True \ + --threshold 0.25 \ + --prune_labels_in_evaluation True \ + --hidden_dropout_prob 0.2 \ + --attention_probs_dropout_prob 0.2 \ --fp16 \ --torch_compile \ + --evaluation_strategy epoch \ + --eval_accumulation_steps 20 \ + --save_strategy epoch \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ - --wandb_api_key ${WANDB_API_KEY} \ - --per_device_eval_batch_size 8 \ - --eval_steps 50000 \ - --evaluation_strategy steps + --wandb_api_key ${WANDB_API_KEY} ``` - - ## 📈 Evaluate Evaluate enables evaluation of the performance of various approaches including diff --git a/examples/preprocess_and_train_by_steps.sh b/examples/preprocess_and_train_by_steps.sh index dd65a405..e6ad8cd7 100644 --- a/examples/preprocess_and_train_by_steps.sh +++ b/examples/preprocess_and_train_by_steps.sh @@ -1,7 +1,7 @@ # Run on g5.12xlargeinstance -# In that case, `test-size`, `train-years` and `test-years` will be taken from the preprocessed folder -SOURCE="output_folder_from_preprocessing" +# Without preprocessing (on-the-fly) +SOURCE="data/raw/allMeSH_2021.jsonl" grants-tagger train bertmesh \ "" \ diff --git a/examples/resume_train_by_epoch.sh b/examples/resume_train_by_epoch.sh index 261f9d57..41f8901b 100644 --- a/examples/resume_train_by_epoch.sh +++ b/examples/resume_train_by_epoch.sh @@ -1,10 +1,7 @@ # Run on g5.12xlarge instance -# Without preprocessing (on-the-fly) -SOURCE="data/raw/allMeSH_2021.jsonl" - -# After preprocessing first -# SOURCE="output_folder_from_preprocessing" +# After preprocessing +SOURCE="output_folder_from_preprocessing" # Checkpoint CHECKPOINT="checkpoint-100000" @@ -17,7 +14,7 @@ grants-tagger train bertmesh \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ --freeze_backbone unfreeze \ - --num_train_epochs 5 \ + --num_train_epochs 3 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ diff --git a/examples/resume_train_by_steps.sh b/examples/resume_train_by_steps.sh index 831fbd22..43a31d95 100644 --- a/examples/resume_train_by_steps.sh +++ b/examples/resume_train_by_steps.sh @@ -1,10 +1,7 @@ # Run on g5.12xlarge instance -# Without preprocessing (on-the-fly) -SOURCE="data/raw/allMeSH_2021.jsonl" - -# After preprocessing first -# SOURCE="output_folder_from_preprocessing" +# After preprocessing +SOURCE="output_folder_from_preprocessing" # Checkpoint CHECKPOINT="checkpoint-100000" @@ -17,7 +14,7 @@ grants-tagger train bertmesh \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ --freeze_backbone unfreeze \ - --num_train_epochs 5 \ + --num_train_epochs 3 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ diff --git a/examples/train_by_epochs.sh b/examples/train_by_epochs.sh index 51686cdf..4cf4f8f6 100644 --- a/examples/train_by_epochs.sh +++ b/examples/train_by_epochs.sh @@ -1,6 +1,6 @@ # Run on g5.12xlargeinstance -# In this case, `test-size`, `train-years` and `test-years` will be taken from the preprocessed folder +# After preprocessing SOURCE="output_folder_from_preprocessing" @@ -12,7 +12,7 @@ grants-tagger train bertmesh \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ --freeze_backbone unfreeze \ - --num_train_epochs 5 \ + --num_train_epochs 7 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ diff --git a/examples/train_by_steps.sh b/examples/train_by_steps.sh index a114970f..eb23036e 100644 --- a/examples/train_by_steps.sh +++ b/examples/train_by_steps.sh @@ -1,8 +1,8 @@ # Run on g5.12xlargeinstance -# If you have already preprocessed the data, you will have a folder. Use the folder instead. +# After preprocessing SOURCE="output_folder_from_preprocessing" -# In that case, `test-size`, `train-years` and `test-years` will be taken from the preprocessed folder + grants-tagger train bertmesh \ "" \ @@ -12,7 +12,7 @@ grants-tagger train bertmesh \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ --freeze_backbone unfreeze \ - --num_train_epochs 5 \ + --num_train_epochs 7 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 9382ea1d..0330f7e8 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -96,6 +96,9 @@ def preprocess_mesh( logger.info(f"Removing all years which are not in {years}") dset = dset.filter(lambda x: any(np.isin(years, [str(x["year"])])), num_proc=num_proc) + if tags is None: + tags = [] + if len(tags) > 0: logger.info(f"Removing all tags which are not in {tags}") dset = dset.filter(lambda x: any(np.isin(tags, x["meshMajor"])), num_proc=num_proc) diff --git a/grants_tagger_light/training/cli_args/bertmesh_args.py b/grants_tagger_light/training/cli_args/bertmesh_args.py index f96426ac..262ffe5c 100644 --- a/grants_tagger_light/training/cli_args/bertmesh_args.py +++ b/grants_tagger_light/training/cli_args/bertmesh_args.py @@ -6,9 +6,9 @@ class BertMeshModelArguments: pretrained_model_key: str = field( default="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" ) - hidden_size: int = field(default=512) - dropout: float = field(default=0) - multilabel_attention: bool = field(default=False) - freeze_backbone: str = field(default=None) # unfreeze, unfreeze_bias, freeze - hidden_dropout_prob: float = field(default=0.1) - attention_probs_dropout_prob: float = field(default=0.1) + hidden_size: int = field(default=1024) + dropout: float = field(default=0.1) + multilabel_attention: bool = field(default=True) + freeze_backbone: str = field(default="unfreeze") # unfreeze, unfreeze_bias, freeze + hidden_dropout_prob: float = field(default=0.2) + attention_probs_dropout_prob: float = field(default=0.2) diff --git a/grants_tagger_light/training/cli_args/train_args.py b/grants_tagger_light/training/cli_args/train_args.py index 2a8467f1..e8bc1b4d 100644 --- a/grants_tagger_light/training/cli_args/train_args.py +++ b/grants_tagger_light/training/cli_args/train_args.py @@ -66,7 +66,12 @@ class BertMeshTrainingArguments(TrainingArguments): weight_decay: float = field(default=0.1) prune_labels_in_evaluation: bool = field(default=False) threshold: float = field(default=0.5) - scheduler_type: str = field(default="cosine") + scheduler_type: str = field(default="cosine_hard_restart") + save_steps: int = field(default=500) + eval_steps: int = field(default=None) + max_steps: int = field(default=-1) + no_cuda: bool = field(default=False) + warmup_steps: int = field(default=0) def __post_init__(self): super().__post_init__() diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 70032b37..1c5035f5 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -193,7 +193,7 @@ def sklearn_metrics(prediction: EvalPrediction): model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay, - correct_bias=training_args.correct_bias) + correct_bias=training_args.correct_bias if hasattr(training_args, 'correct_bias') else True) if training_args.warmup_steps is None: training_args.warmup_steps = 0 @@ -226,6 +226,8 @@ def sklearn_metrics(prediction: EvalPrediction): training_args.optim = optimizer training_args.lr_scheduler_type = scheduler + logger.info(f"Test dataset size: {len(val_dset)}") + trainer = Trainer( model=model, args=training_args, diff --git a/tests/test_preprocess_mesh.py b/tests/test_preprocess_mesh.py index a984b9a6..d0e11a2d 100644 --- a/tests/test_preprocess_mesh.py +++ b/tests/test_preprocess_mesh.py @@ -91,7 +91,7 @@ def test_json_to_jsonl(json_data_path): def test_preprocess_mesh(jsonl_data_path): - dset, label2id = preprocess_mesh( + dset, label2id, id2label = preprocess_mesh( data_path=jsonl_data_path, model_key="", num_proc=2, batch_size=1, test_size=0.5 ) assert "train" in dset diff --git a/tests/test_split_data.py b/tests/test_split_data.py index c4e9ba6d..a225b63b 100644 --- a/tests/test_split_data.py +++ b/tests/test_split_data.py @@ -27,12 +27,12 @@ def test_split_data(): examples = 0 with open(train_output_path) as f: - for line in f: + for _ in f: examples += 1 assert examples == 9 examples = 0 with open(test_output_path) as f: - for line in f: + for _ in f: examples += 1 assert examples == 1 diff --git a/tests/test_train.py b/tests/test_train.py index 42e49159..e4a1afaa 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,11 +1,12 @@ from grants_tagger_light.training.train import train_bertmesh -from grants_tagger_light.training.cli_args import BertMeshModelArguments -from transformers import TrainingArguments +from grants_tagger_light.training.cli_args import BertMeshModelArguments, BertMeshTrainingArguments import tempfile import pytest # Note dummy data is not necessarily annotated correctly dummy_data = """{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"} +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"} +{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"} {"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" # noqa @@ -26,17 +27,16 @@ def save_path(): def _train_bertmesh_from_model_key(data_path, save_path, model_key): # 1 train step, 1 eval step, save after training - training_args = TrainingArguments( + training_args = BertMeshTrainingArguments( output_dir=save_path, max_steps=1, per_device_train_batch_size=2, per_device_eval_batch_size=2, - evaluation_strategy="steps", - eval_steps=1, - save_strategy="steps", - save_steps=1, + evaluation_strategy="no", + save_strategy="no", report_to="none", no_cuda=True, + num_train_epochs=1, ) model_args = BertMeshModelArguments() @@ -47,8 +47,9 @@ def _train_bertmesh_from_model_key(data_path, save_path, model_key): max_samples=-1, training_args=training_args, model_args=model_args, - num_proc=2, + num_proc=1, test_size=0.5, + shards=1 ) From 8b080ebbe665937b6d9f269134aeaf7202a10c53 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Thu, 31 Aug 2023 17:42:06 +0100 Subject: [PATCH 288/300] Tries to fix torch-cpu recent issue --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 574a313a..edcc4bf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,9 @@ typer = "^0.9.0" datasets = "2.13.1" accelerate = "^0.19.0" dvc = {extras = ["s3"], version = "^2.58.2"} -torch = {version = "2.0.1", source = "torch-cpu"} +# torch-cpu is not available anymore! +# torch = {version = "2.0.1", source = "torch-cpu"} +torch = "2.0.1" transformers = "4.29.2" libpecos = "^1.0.0" loguru = "^0.7.0" From 52231e042c8dbedf1e3be520b792cbd369667727 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 1 Sep 2023 10:45:37 +0100 Subject: [PATCH 289/300] Refactors augmentation --- .../augmentation/augment_openai.py | 248 ++++++++++-------- .../augmentation/parallel_augment_openai.py | 79 ++++++ tests/test_train.py | 9 +- 3 files changed, 226 insertions(+), 110 deletions(-) create mode 100644 grants_tagger_light/augmentation/parallel_augment_openai.py diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 8a36e7ed..b3e6481f 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -1,26 +1,26 @@ import datetime import json -import math import os import random import uuid -from loguru import logger import openai -from openai_multi_client import OpenAIMultiClient + +from loguru import logger import numpy as np from grants_tagger_light.augmentation.JsonParser import JsonParser class AugmentOpenAI: - def __init__(self, prompt_template_path, model_key='gpt-3.5-turbo'): - if 'OPENAI_API_KEY' not in os.environ: - logger.error("OPENAI_API_KEY not found in env vars. Please define it before running this program.") - with open(prompt_template_path, 'r') as f: + def __init__(self, prompt_template_path, model_key="gpt-3.5-turbo"): + if "OPENAI_API_KEY" not in os.environ: + logger.error( + "OPENAI_API_KEY not found in env vars. Please define it before running this program." + ) + with open(prompt_template_path, "r") as f: self.prompt_template = f.read() self.model_key = model_key - self.api = OpenAIMultiClient(endpoint="chats", data_template={"model": self.model_key}) def _create_message(self, abstract, tag): prompt = self.prompt_template.replace('{TOPIC}', tag) @@ -29,104 +29,142 @@ def _create_message(self, abstract, tag): return [{"role": "user", "content": prompt}] @staticmethod - def _process_response(result): - if result.failed: - logger.warning(f"Failed to get augmentation for {result.metadata['featured_tag']}") - return - - with open(result.metadata['save_to_path'], 'a') as f: - for r in result.response['choices']: - if 'message' in r: - if 'content' in r['message']: - try: - logger.info(r['message']['content']) - try: - json_response = JsonParser.parse_json(r['message']['content']) - except Exception as e: - logger.info(f"Error processing output: {e}. Skipping...") - continue - - res = { - "journal": result.metadata['model_key'], - "meshMajor": result.metadata['tags'], - "year": result.metadata['year'], - "abstractText": json_response['abstract'].replace("'", "").replace('"', ''), - "pmid": uuid.uuid4().hex, - "title": json_response['title'].replace("'", "").replace('"', ''), - "existing_example": result.metadata['existing_example'].replace("'", "").replace('"', ''), - "required_examples": result.metadata['required_examples'], - "featured_tag": result.metadata['featured_tag'] - } - - f.write(json.dumps(res)) - f.write('\n') - f.flush() - - logger.info(f"Data received successfully for {result.metadata['featured_tag']}") - except: - logger.info(f"Missing or malformed data. Skipping") - - def _make_requests(self, - collect_concurrent_calls, - dset, - temperature, - top_p, - presence_penalty, - num_proc, - model_key, - save_to_path): + def _parse_response(answer, metadata): + print(json.dumps(answer, indent=2)) + with open(metadata["save_to_path"], "a") as f: + try: + json_response = JsonParser.parse_json(answer) + except Exception as e: + logger.info(f"Error processing output: {e}. Skipping...") + return + + res = { + "journal": metadata['model_key'], + "meshMajor": metadata['tags'], + "year": metadata['year'], + "abstractText": json_response['abstract'].replace("'", "").replace('"', ''), + "pmid": uuid.uuid4().hex, + "title": json_response['title'].replace("'", "").replace('"', ''), + "existing_example": metadata['existing_example'].replace("'", "").replace('"', ''), + "required_examples": metadata['required_examples'], + "featured_tag": metadata['featured_tag'] + } + + f.write(json.dumps(res)) + f.write('\n') + f.flush() + + logger.info(f"Data received successfully for {metadata['featured_tag']}") + + @staticmethod + def process_choices(choices, metadata): + for c in choices: + if "message" in c: + if "content" in c["message"]: + AugmentOpenAI._parse_response(c["message"]["content"], metadata) + @staticmethod + def _process_response(result): + AugmentOpenAI.process_choices(result.choices, result.metadata) + + def _prepare_request( + self, + tag, + missing_num, + dset, + temperature, + top_p, + presence_penalty, + num_proc, + model_key, + save_to_path, + ): year = datetime.date.today().year + logger.info(f"Augmenting {tag} with {missing_num} examples") + # RAG: I select similar articles to provide them to the LLM + + tmp_dset = dset.filter(lambda x: any(np.isin([tag], x["meshMajor"])), num_proc=num_proc) + + abstracts_num = [i for i in range(len(tmp_dset))] + random.shuffle(abstracts_num) + + for i in range(missing_num): + selected_row = abstracts_num[i % len(tmp_dset)] + abstract = tmp_dset['abstractText'][selected_row] + tags = tmp_dset['meshMajor'][selected_row] + data = { + "model": self.model_key, + "n": 1, + "temperature": temperature, + "top_p": top_p, + "presence_penalty": presence_penalty, + "messages": self._create_message(abstract, tag) + } + + metadata = { + 'featured_tag': tag, + 'tags': tags, + 'required_examples': missing_num, + 'existing_example': abstract, + 'year': year, + 'model_key': model_key, + 'save_to_path': save_to_path + } + + yield data, metadata + + def _make_requests( + self, + collect_concurrent_calls, + dset, + temperature, + top_p, + presence_penalty, + num_proc, + model_key, + save_to_path, + ): + for num in range(len(collect_concurrent_calls)): - t = collect_concurrent_calls[num][0] - n = collect_concurrent_calls[num][1] - # RAG: I select similar articles to provide them to the LLM, maximum ´n´ (I don't need more) - tmp_dset = dset.filter(lambda x: any(np.isin([t], x["meshMajor"])), num_proc=num_proc) - # I remove them from the dataset to process to make it smaller and quicker over time - # dset = dset.filter(lambda example, idx: idx not in tmp_dset['idx'], with_indices=True, - # num_proc=num_proc) - - required_examples = n - existing_examples = min(required_examples, len(tmp_dset)) - - n_per_example = math.ceil(required_examples / existing_examples) - logger.info(f"Augmenting {t} with {required_examples} examples, using {existing_examples} in RAG mode") - for i in range(existing_examples): - - abstract = tmp_dset['abstractText'][i] - tags = tmp_dset['meshMajor'][i] - data = { - "model": self.model_key, - "n": n_per_example, - "temperature": temperature, - "top_p": top_p, - "presence_penalty": presence_penalty, - "messages": self._create_message(abstract, t) - } - - metadata = { - 'featured_tag': t, - 'tags': tags, - 'required_examples': n, - 'existing_example': abstract, - 'year': year, - 'model_key': model_key, - 'save_to_path': save_to_path - } - - self.api.request(data=data, metadata=metadata, callback=self._process_response) - - def generate(self, collect_concurrent_calls, dset, save_to_path, model_key, - temperature=1.5, top_p=1, presence_penalty=0, num_proc=os.cpu_count()): - self.api.run_request_function(self._make_requests, - collect_concurrent_calls=collect_concurrent_calls, - dset=dset, - temperature=temperature, - top_p=top_p, - presence_penalty=presence_penalty, - num_proc=num_proc, - model_key=model_key, - save_to_path=save_to_path) - - self.api.pull_all() + tag = collect_concurrent_calls[num][0] + missing_num = collect_concurrent_calls[num][1] + + for data, metadata in self._prepare_request( + tag, + missing_num, + dset, + temperature, + top_p, + presence_penalty, + num_proc, + model_key, + save_to_path, + ): + chat_completion = openai.ChatCompletion.create(**data) + chat_completion.metadata = metadata + + self._process_response(chat_completion) + + def generate( + self, + collect_concurrent_calls, + dset, + save_to_path, + model_key, + temperature=1.5, + top_p=1, + presence_penalty=0, + num_proc=os.cpu_count(), + ): + self._make_requests( + collect_concurrent_calls=collect_concurrent_calls, + dset=dset, + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty, + num_proc=num_proc, + model_key=model_key, + save_to_path=save_to_path, + ) + diff --git a/grants_tagger_light/augmentation/parallel_augment_openai.py b/grants_tagger_light/augmentation/parallel_augment_openai.py new file mode 100644 index 00000000..4d9088d3 --- /dev/null +++ b/grants_tagger_light/augmentation/parallel_augment_openai.py @@ -0,0 +1,79 @@ +import os + +from loguru import logger +from openai_multi_client import OpenAIMultiClient + +from grants_tagger_light.augmentation.augment_openai import AugmentOpenAI + + +class ParallelAugmentOpenAI(AugmentOpenAI): + def __init__(self, prompt_template_path, model_key="gpt-3.5-turbo"): + super().__init__(prompt_template_path, model_key) + self.api = OpenAIMultiClient( + endpoint="chats", data_template={"model": self.model_key} + ) + + @staticmethod + def _process_response(result): + if result.failed: + logger.warning( + f"Failed to get augmentation for {result.metadata['featured_tag']}" + ) + return + choices = result.response["choices"] + AugmentOpenAI.process_choices(choices, result.metadata) + + def _make_requests( + self, + collect_concurrent_calls, + dset, + temperature, + top_p, + presence_penalty, + num_proc, + model_key, + save_to_path, + ): + for num in range(len(collect_concurrent_calls)): + tag = collect_concurrent_calls[num][0] + missing_num = collect_concurrent_calls[num][1] + + for data, metadata in self._prepare_request( + tag, + missing_num, + dset, + temperature, + top_p, + presence_penalty, + num_proc, + model_key, + save_to_path, + ): + self.api.request( + data=data, metadata=metadata, callback=self._process_response + ) + + def generate( + self, + collect_concurrent_calls, + dset, + save_to_path, + model_key, + temperature=1.5, + top_p=1, + presence_penalty=0, + num_proc=os.cpu_count(), + ): + self.api.run_request_function( + self._make_requests, + collect_concurrent_calls=collect_concurrent_calls, + dset=dset, + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty, + num_proc=num_proc, + model_key=model_key, + save_to_path=save_to_path, + ) + + self.api.pull_all() diff --git a/tests/test_train.py b/tests/test_train.py index e4a1afaa..aa7374a2 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -5,15 +5,13 @@ # Note dummy data is not necessarily annotated correctly dummy_data = """{"journal":"dummyJournal","meshMajor":["COVID-19","SARS-CoV-2"],"year":"2023","abstractText":"This is an article about coronavirus.","title":"article1","pmid":"pmid1"} -{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"} -{"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"} {"journal":"dummyJournal","meshMajor":["Malaria"],"year":"2023","abstractText":"This is an article about malaria", "title": "article3", "pmid": "pmid3"}""" # noqa @pytest.fixture def data_path(): with tempfile.TemporaryDirectory() as tmpdirname: - data_path = tmpdirname + "/data.json" + data_path = tmpdirname + "/data.jsonl" with open(data_path, "w") as f: f.write(dummy_data) yield data_path @@ -37,6 +35,7 @@ def _train_bertmesh_from_model_key(data_path, save_path, model_key): report_to="none", no_cuda=True, num_train_epochs=1, + dataloader_num_workers=1 ) model_args = BertMeshModelArguments() @@ -57,5 +56,5 @@ def test_train_bertmesh_from_model_key(data_path, save_path): _train_bertmesh_from_model_key(data_path, save_path, "Wellcome/WellcomeBertMesh") -def test_train_bertmesh_from_scratch(data_path, save_path): - _train_bertmesh_from_model_key(data_path, save_path, "") +"""def test_train_bertmesh_from_scratch(data_path, save_path): + _train_bertmesh_from_model_key(data_path, save_path, "")""" From 69b56ea8484948a1e36ccc8e000a5c840f5268d3 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 1 Sep 2023 12:10:13 +0100 Subject: [PATCH 290/300] Refactors augmentation --- grants_tagger_light/augmentation/augment.py | 24 +++++++++++++++---- .../augmentation/augment_openai.py | 20 ++++++++++------ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index b0950146..9082435c 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -11,6 +11,8 @@ from datasets import load_from_disk +from grants_tagger_light.augmentation.parallel_augment_openai import ParallelAugmentOpenAI + augment_app = typer.Typer() @@ -78,7 +80,10 @@ def augment( with open(f"{save_to_path}.count", 'w') as f: f.write(json.dumps(sorted_merged_element_counts_dict, indent=2)) - tags_to_augment = list(sorted_merged_element_counts_dict.keys()) + tags_to_augment_counts = { + k: v for k, v in sorted_merged_element_counts_dict.items() if v < min_examples + } + tags_to_augment = list(tags_to_augment_counts.keys()) biggest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" for k in tags_to_augment[:5]] @@ -98,10 +103,17 @@ def augment( desc="Creating idx", num_proc=num_proc, ) + + if concurrent_calls == 1: + openai = AugmentOpenAI + else: + openai = ParallelAugmentOpenAI + collect_concurrent_calls = [] + for t in tags_to_augment: if len(collect_concurrent_calls) >= concurrent_calls: - AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key).generate( + openai(prompt_template_path=prompt_template, model_key=model_key).generate( collect_concurrent_calls, dset, save_to_path, @@ -113,8 +125,9 @@ def augment( else: collect_concurrent_calls.append((t, examples)) + # Remaining rows of the last batch if len(collect_concurrent_calls) > 0: - AugmentOpenAI(prompt_template_path=prompt_template, model_key=model_key).generate( + openai(prompt_template_path=prompt_template, model_key=model_key).generate( collect_concurrent_calls, dset, save_to_path, @@ -158,11 +171,14 @@ def augment_cli( ), concurrent_calls: int = typer.Option( os.cpu_count()*2, + min=1, help="Concurrent calls with 1 tag each to the different model" ), temperature: float = typer.Option( 1.5, - help="A value between -2 and 2. The bigger - the more creative." + min=0, + max=2, + help="A value between 0 and 2. The bigger - the more creative." ), tags_file_path: str = typer.Option( None, diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index b3e6481f..b02db140 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -1,5 +1,6 @@ import datetime import json +import math import os import random import uuid @@ -86,16 +87,21 @@ def _prepare_request( tmp_dset = dset.filter(lambda x: any(np.isin([tag], x["meshMajor"])), num_proc=num_proc) - abstracts_num = [i for i in range(len(tmp_dset))] - random.shuffle(abstracts_num) + # abstracts_num = [i for i in range(len(tmp_dset))] + # random.shuffle(abstracts_num) - for i in range(missing_num): - selected_row = abstracts_num[i % len(tmp_dset)] - abstract = tmp_dset['abstractText'][selected_row] - tags = tmp_dset['meshMajor'][selected_row] + required_examples = missing_num + existing_examples = min(required_examples, len(tmp_dset)) + + n_per_example = math.ceil(required_examples / existing_examples) + logger.info(f"Augmenting {tag} with {required_examples} examples, using {existing_examples} in RAG mode") + + for i in range(existing_examples): + abstract = tmp_dset['abstractText'][i] + tags = tmp_dset['meshMajor'][i] data = { "model": self.model_key, - "n": 1, + "n": n_per_example, "temperature": temperature, "top_p": top_p, "presence_penalty": presence_penalty, From 77761be7a3204b01856eac75759f8bc470016af4 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 1 Sep 2023 12:18:43 +0100 Subject: [PATCH 291/300] Refactors augmentation --- grants_tagger_light/augmentation/augment.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 9082435c..c89289cf 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -80,10 +80,7 @@ def augment( with open(f"{save_to_path}.count", 'w') as f: f.write(json.dumps(sorted_merged_element_counts_dict, indent=2)) - tags_to_augment_counts = { - k: v for k, v in sorted_merged_element_counts_dict.items() if v < min_examples - } - tags_to_augment = list(tags_to_augment_counts.keys()) + tags_to_augment = list(sorted_merged_element_counts_dict.keys()) biggest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" for k in tags_to_augment[:5]] From d6247dbd0e565d36fc3b4296658d198529c28a1c Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 1 Sep 2023 14:23:30 +0100 Subject: [PATCH 292/300] Fixes tests --- examples/augment.sh | 3 +- examples/augment_specific_tags.sh | 4 +-- examples/preprocess_and_train_by_epochs.sh | 30 +++++++++++-------- examples/preprocess_and_train_by_steps.sh | 25 +++++++++++----- examples/preprocess_splitting_by_fract.sh | 2 +- examples/preprocess_splitting_by_rows.sh | 4 +-- examples/preprocess_splitting_by_years.sh | 3 +- examples/resume_train_by_epoch.sh | 4 +-- examples/resume_train_by_steps.sh | 4 +-- examples/train_by_epochs.sh | 5 ++-- examples/train_by_steps.sh | 5 ++-- grants_tagger_light/augmentation/augment.py | 1 - .../preprocessing/preprocess_mesh.py | 4 +-- tests/test_train.py | 4 +-- 14 files changed, 56 insertions(+), 42 deletions(-) diff --git a/examples/augment.sh b/examples/augment.sh index cb85617a..9ad482b1 100644 --- a/examples/augment.sh +++ b/examples/augment.sh @@ -1,2 +1,3 @@ -grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [OUTPUT_FOLDER] \ +grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [SET_YOUR_OUTPUT_FOLDER_HERE] \ + --min-examples 25 \ --concurrent-calls 25 \ No newline at end of file diff --git a/examples/augment_specific_tags.sh b/examples/augment_specific_tags.sh index 1c504d61..3ce920c8 100644 --- a/examples/augment_specific_tags.sh +++ b/examples/augment_specific_tags.sh @@ -1,5 +1,5 @@ # Augments data using a file with 1 label per line and years -grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [OUTPUT_FOLDER] \ - --tags-file-path [YOUR_TAGS_FILE] \ +grants-tagger augment mesh [FOLDER_AFTER_PREPROCESSING] [SET_YOUR_OUTPUT_FOLDER_HERE] \ + --tags-file-path tags_to_augment.txt \ --examples 25 \ --concurrent-calls 25 \ No newline at end of file diff --git a/examples/preprocess_and_train_by_epochs.sh b/examples/preprocess_and_train_by_epochs.sh index e706d05f..0e2f2f29 100644 --- a/examples/preprocess_and_train_by_epochs.sh +++ b/examples/preprocess_and_train_by_epochs.sh @@ -1,31 +1,37 @@ -# Run on g5.12xlargeinstance +# Run on g5.12xlarge instance -# Without preprocessing (on-the-fly) +# Without saving (on-the-fly) SOURCE="data/raw/allMeSH_2021.jsonl" grants-tagger train bertmesh \ "" \ $SOURCE \ - --test-size 10000 \ - --output_dir bertmesh_outs/pipeline_test/ \ + --test-size 25000 \ --train-years 2016,2017,2018,2019 \ --test-years 2020,2021 \ - --per_device_train_batch_size 32 \ + --output_dir bertmesh_outs/pipeline_test/ \ + --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ - --freeze_backbone False \ - --num_train_epochs 5 \ + --freeze_backbone unfreeze \ + --num_train_epochs 7 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ - --warmup_steps 1000 \ - --max_grad_norm 5.0 \ - --scheduler-type cosine \ + --warmup_steps 5000 \ + --max_grad_norm 2.0 \ + --scheduler_type cosine_hard_restart \ + --weight_decay 0.2 \ + --correct_bias True \ + --threshold 0.25 \ + --prune_labels_in_evaluation True \ + --hidden_dropout_prob 0.2 \ + --attention_probs_dropout_prob 0.2 \ --fp16 \ --torch_compile \ - --evaluation_strategy epoch \ + --evaluation_strategy epochs \ --eval_accumulation_steps 20 \ - --save_strategy epoch \ + --save_strategy epochs \ --wandb_project wellcome-mesh \ --wandb_name test-train-all \ --wandb_api_key ${WANDB_API_KEY} diff --git a/examples/preprocess_and_train_by_steps.sh b/examples/preprocess_and_train_by_steps.sh index e6ad8cd7..83ec7478 100644 --- a/examples/preprocess_and_train_by_steps.sh +++ b/examples/preprocess_and_train_by_steps.sh @@ -1,23 +1,32 @@ -# Run on g5.12xlargeinstance +# Run on g5.12xlarge instance -# Without preprocessing (on-the-fly) +# Without saving (on-the-fly) SOURCE="data/raw/allMeSH_2021.jsonl" grants-tagger train bertmesh \ "" \ $SOURCE \ + --test-size 25000 \ + --train-years 2016,2017,2018,2019 \ + --test-years 2020,2021 \ --output_dir bertmesh_outs/pipeline_test/ \ - --per_device_train_batch_size 32 \ + --per_device_train_batch_size 16 \ --per_device_eval_batch_size 1 \ --multilabel_attention True \ - --freeze_backbone False \ - --num_train_epochs 5 \ + --freeze_backbone unfreeze \ + --num_train_epochs 7 \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ - --warmup_steps 1000 \ - --max_grad_norm 5.0 \ - --scheduler-type cosine \ + --warmup_steps 5000 \ + --max_grad_norm 2.0 \ + --scheduler_type cosine_hard_restart \ + --weight_decay 0.2 \ + --correct_bias True \ + --threshold 0.25 \ + --prune_labels_in_evaluation True \ + --hidden_dropout_prob 0.2 \ + --attention_probs_dropout_prob 0.2 \ --fp16 \ --torch_compile \ --evaluation_strategy steps \ diff --git a/examples/preprocess_splitting_by_fract.sh b/examples/preprocess_splitting_by_fract.sh index 9cf761d3..526133f2 100644 --- a/examples/preprocess_splitting_by_fract.sh +++ b/examples/preprocess_splitting_by_fract.sh @@ -1,2 +1,2 @@ -grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl ./kk '' \ +grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \ --test-size 0.05 \ No newline at end of file diff --git a/examples/preprocess_splitting_by_rows.sh b/examples/preprocess_splitting_by_rows.sh index eb1e5dae..42ba15a4 100644 --- a/examples/preprocess_splitting_by_rows.sh +++ b/examples/preprocess_splitting_by_rows.sh @@ -1,2 +1,2 @@ -grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl ./kk '' \ - --test-size 10000 \ No newline at end of file +grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \ + --test-size 25000 \ No newline at end of file diff --git a/examples/preprocess_splitting_by_years.sh b/examples/preprocess_splitting_by_years.sh index 6e2b0949..629e74fa 100644 --- a/examples/preprocess_splitting_by_years.sh +++ b/examples/preprocess_splitting_by_years.sh @@ -1,3 +1,4 @@ -grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl ./kk '' \ +grants-tagger preprocess mesh data/raw/allMeSH_2021.jsonl [SET_YOUR_OUTPUT_FOLDER_HERE] '' \ + --test-size 25000 \ --train-years 2016,2017,2018,2019 \ --test-years 2020,2021 \ No newline at end of file diff --git a/examples/resume_train_by_epoch.sh b/examples/resume_train_by_epoch.sh index 41f8901b..38b520b4 100644 --- a/examples/resume_train_by_epoch.sh +++ b/examples/resume_train_by_epoch.sh @@ -1,7 +1,7 @@ # Run on g5.12xlarge instance # After preprocessing -SOURCE="output_folder_from_preprocessing" +SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]" # Checkpoint CHECKPOINT="checkpoint-100000" @@ -18,7 +18,7 @@ grants-tagger train bertmesh \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ - --warmup_steps 5000 \ + --warmup_steps 0 \ --max_grad_norm 2.0 \ --scheduler_type cosine_hard_restart \ --weight_decay 0.2 \ diff --git a/examples/resume_train_by_steps.sh b/examples/resume_train_by_steps.sh index 43a31d95..3c251cf3 100644 --- a/examples/resume_train_by_steps.sh +++ b/examples/resume_train_by_steps.sh @@ -1,7 +1,7 @@ # Run on g5.12xlarge instance # After preprocessing -SOURCE="output_folder_from_preprocessing" +SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]" # Checkpoint CHECKPOINT="checkpoint-100000" @@ -18,7 +18,7 @@ grants-tagger train bertmesh \ --learning_rate 5e-5 \ --dropout 0.1 \ --hidden_size 1024 \ - --warmup_steps 5000 \ + --warmup_steps 0 \ --max_grad_norm 2.0 \ --scheduler_type cosine_hard_restart \ --weight_decay 0.2 \ diff --git a/examples/train_by_epochs.sh b/examples/train_by_epochs.sh index 4cf4f8f6..bf06afac 100644 --- a/examples/train_by_epochs.sh +++ b/examples/train_by_epochs.sh @@ -1,8 +1,7 @@ -# Run on g5.12xlargeinstance +# Run on g5.12xlarge instance # After preprocessing -SOURCE="output_folder_from_preprocessing" - +SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]" grants-tagger train bertmesh \ "" \ diff --git a/examples/train_by_steps.sh b/examples/train_by_steps.sh index eb23036e..67cabcd6 100644 --- a/examples/train_by_steps.sh +++ b/examples/train_by_steps.sh @@ -1,8 +1,7 @@ -# Run on g5.12xlargeinstance +# Run on g5.12xlarge instance # After preprocessing -SOURCE="output_folder_from_preprocessing" - +SOURCE="[SET_YOUR_PREPROCESSING_FOLDER_HERE]" grants-tagger train bertmesh \ "" \ diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index c89289cf..bdafccef 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -56,7 +56,6 @@ def augment( dset = load_from_disk(os.path.join(data_path, "dataset")) if "train" in dset: dset = dset["train"] - logger.info("Obtaining count values from the labels...") pool = multiprocessing.Pool(processes=num_proc) element_counts_list = pool.map(_count_elements_in_sublist, dset['meshMajor']) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 0330f7e8..0e55b23e 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -122,7 +122,7 @@ def preprocess_mesh( if label2id is None: logger.info("Getting the labels...") dset = dset.map( - lambda x: {"labels": x["meshMajor"]}, + lambda x: {"meshMajor": x["meshMajor"]}, batched=True, batch_size=batch_size, num_proc=num_proc, @@ -134,7 +134,7 @@ def preprocess_mesh( logger.info("Obtaining unique values from the labels...") # Iterate through the lists and add elements to the set - for arr in tqdm(dset["labels"]): + for arr in tqdm(dset["meshMajor"]): unique_labels_set.update(arr) # Most efficient way to do dictionary creation diff --git a/tests/test_train.py b/tests/test_train.py index aa7374a2..512c2ea9 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -56,5 +56,5 @@ def test_train_bertmesh_from_model_key(data_path, save_path): _train_bertmesh_from_model_key(data_path, save_path, "Wellcome/WellcomeBertMesh") -"""def test_train_bertmesh_from_scratch(data_path, save_path): - _train_bertmesh_from_model_key(data_path, save_path, "")""" +def test_train_bertmesh_from_scratch(data_path, save_path): + _train_bertmesh_from_model_key(data_path, save_path, "") From 8298b5a9354a5f0adae26bfdfa097fad6485da91 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 1 Sep 2023 14:26:00 +0100 Subject: [PATCH 293/300] Fixes black --- grants_tagger_light/augmentation/augment.py | 132 ++++++++++-------- .../augmentation/augment_openai.py | 110 ++++++++------- .../preprocessing/preprocess_mesh.py | 83 ++++++----- .../training/cli_args/bertmesh_args.py | 2 +- grants_tagger_light/training/train.py | 124 +++++++++------- tests/test_train.py | 9 +- 6 files changed, 257 insertions(+), 203 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index bdafccef..377c822d 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -11,7 +11,9 @@ from datasets import load_from_disk -from grants_tagger_light.augmentation.parallel_augment_openai import ParallelAugmentOpenAI +from grants_tagger_light.augmentation.parallel_augment_openai import ( + ParallelAugmentOpenAI, +) augment_app = typer.Typer() @@ -40,59 +42,77 @@ def _merge_dicts(dict_list): def augment( data_path: str, save_to_path: str, - model_key: str = 'gpt-3.5-turbo', + model_key: str = "gpt-3.5-turbo", num_proc: int = os.cpu_count(), batch_size: int = 64, min_examples: int = None, examples: int = 25, - prompt_template: str = 'grants_tagger_light/augmentation/prompt.template', - concurrent_calls: int = os.cpu_count()*2, + prompt_template: str = "grants_tagger_light/augmentation/prompt.template", + concurrent_calls: int = os.cpu_count() * 2, temperature: float = 1.5, tags_file_path: str = None, ): - if model_key.strip().lower() not in ['gpt-3.5-turbo', 'text-davinci', 'gpt-4']: - raise NotImplementedError(f"{model_key} not implemented as an augmentation framework") + if model_key.strip().lower() not in ["gpt-3.5-turbo", "text-davinci", "gpt-4"]: + raise NotImplementedError( + f"{model_key} not implemented as an augmentation framework" + ) dset = load_from_disk(os.path.join(data_path, "dataset")) if "train" in dset: dset = dset["train"] logger.info("Obtaining count values from the labels...") pool = multiprocessing.Pool(processes=num_proc) - element_counts_list = pool.map(_count_elements_in_sublist, dset['meshMajor']) + element_counts_list = pool.map(_count_elements_in_sublist, dset["meshMajor"]) pool.close() pool.join() merged_element_counts = _merge_dicts(element_counts_list) - sorted_merged_element_counts = sorted(merged_element_counts.items(), key=lambda x: x[1], reverse=True) + sorted_merged_element_counts = sorted( + merged_element_counts.items(), key=lambda x: x[1], reverse=True + ) sorted_merged_element_counts_dict = dict(sorted_merged_element_counts) if tags_file_path is not None: - with open(tags_file_path, 'r') as f: - tags = f.read().split('\n') - logger.info(f"Tags file path found. Filtering {len(tags)} tags (examples found: {tags[:15]}...)") - sorted_merged_element_counts_dict = {k: v for k, v in sorted_merged_element_counts_dict.items() - if k in tags} + with open(tags_file_path, "r") as f: + tags = f.read().split("\n") + logger.info( + f"Tags file path found. Filtering {len(tags)} tags (examples found: {tags[:15]}...)" + ) + sorted_merged_element_counts_dict = { + k: v for k, v in sorted_merged_element_counts_dict.items() if k in tags + } if min_examples is not None: - sorted_merged_element_counts_dict = {k: v for k, v in sorted_merged_element_counts_dict.items() - if v < min_examples} + sorted_merged_element_counts_dict = { + k: v + for k, v in sorted_merged_element_counts_dict.items() + if v < min_examples + } - with open(f"{save_to_path}.count", 'w') as f: + with open(f"{save_to_path}.count", "w") as f: f.write(json.dumps(sorted_merged_element_counts_dict, indent=2)) tags_to_augment = list(sorted_merged_element_counts_dict.keys()) - biggest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" - for k in tags_to_augment[:5]] - smallest_tags_to_augment = [f"{k}({sorted_merged_element_counts_dict[k]})" - for k in tags_to_augment[-5:]] + biggest_tags_to_augment = [ + f"{k}({sorted_merged_element_counts_dict[k]})" for k in tags_to_augment[:5] + ] + smallest_tags_to_augment = [ + f"{k}({sorted_merged_element_counts_dict[k]})" for k in tags_to_augment[-5:] + ] - logger.info(f"Augmenting a total of {len(tags_to_augment)} tags, " - f"from {biggest_tags_to_augment} to {smallest_tags_to_augment}") + logger.info( + f"Augmenting a total of {len(tags_to_augment)} tags, " + f"from {biggest_tags_to_augment} to {smallest_tags_to_augment}" + ) - logger.info(f"RAG: Collecting existing examples of those tags to send in the prompt") - dset = dset.filter(lambda x: any(np.isin(tags_to_augment, x["meshMajor"])), num_proc=num_proc) + logger.info( + f"RAG: Collecting existing examples of those tags to send in the prompt" + ) + dset = dset.filter( + lambda x: any(np.isin(tags_to_augment, x["meshMajor"])), num_proc=num_proc + ) dset = dset.map( - lambda _, y: {'idx': y}, + lambda _, y: {"idx": y}, with_indices=True, batched=True, batch_size=batch_size, @@ -135,51 +155,44 @@ def augment( @augment_app.command() def augment_cli( - data_path: str = typer.Argument( - ..., - help="Path to mesh.jsonl"), + data_path: str = typer.Argument(..., help="Path to mesh.jsonl"), save_to_path: str = typer.Argument( ..., help="Path to save the serialized PyArrow dataset after preprocessing" ), model_key: str = typer.Option( "gpt-3.5-turbo", - help="LLM to use data augmentation. By now, only `openai` is supported" + help="LLM to use data augmentation. By now, only `openai` is supported", ), num_proc: int = typer.Option( - os.cpu_count(), - help="Number of processes to use for data augmentation" + os.cpu_count(), help="Number of processes to use for data augmentation" ), batch_size: int = typer.Option( - 64, - help="Preprocessing batch size (for dataset, filter, map, ...)" + 64, help="Preprocessing batch size (for dataset, filter, map, ...)" ), min_examples: int = typer.Option( None, - help="Minimum number of examples to require. Less than that will trigger data augmentation." - ), - examples: int = typer.Option( - 25, - help="Examples to generate per each tag." + help="Minimum number of examples to require. Less than that will trigger data augmentation.", ), + examples: int = typer.Option(25, help="Examples to generate per each tag."), prompt_template: str = typer.Option( - 'grants_tagger_light/augmentation/prompt.template', - help="File to use as a prompt. Make sure to ask the LLM to return a dict with two fields: `abstract` and `tags`" + "grants_tagger_light/augmentation/prompt.template", + help="File to use as a prompt. Make sure to ask the LLM to return a dict with two fields: `abstract` and `tags`", ), concurrent_calls: int = typer.Option( - os.cpu_count()*2, + os.cpu_count() * 2, min=1, - help="Concurrent calls with 1 tag each to the different model" + help="Concurrent calls with 1 tag each to the different model", ), temperature: float = typer.Option( 1.5, min=0, max=2, - help="A value between 0 and 2. The bigger - the more creative." + help="A value between 0 and 2. The bigger - the more creative.", ), tags_file_path: str = typer.Option( None, - help="Text file containing one line per tag to be considered. The rest will be discarded." - ) + help="Text file containing one line per tag to be considered. The rest will be discarded.", + ), ): if not os.path.isdir(data_path): logger.error( @@ -194,20 +207,19 @@ def augment_cli( exit(-1) if float(temperature) > 2.0 or float(temperature) < -2.0: - logger.error( - "Temperature should be in the range [-2, 2]" - ) + logger.error("Temperature should be in the range [-2, 2]") exit(-1) - augment(data_path, - save_to_path, - model_key=model_key, - num_proc=num_proc, - batch_size=batch_size, - min_examples=min_examples, - examples=examples, - prompt_template=prompt_template, - concurrent_calls=concurrent_calls, - temperature=temperature, - tags_file_path=tags_file_path - ) + augment( + data_path, + save_to_path, + model_key=model_key, + num_proc=num_proc, + batch_size=batch_size, + min_examples=min_examples, + examples=examples, + prompt_template=prompt_template, + concurrent_calls=concurrent_calls, + temperature=temperature, + tags_file_path=tags_file_path, + ) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index b02db140..0fa15d21 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -24,8 +24,8 @@ def __init__(self, prompt_template_path, model_key="gpt-3.5-turbo"): self.model_key = model_key def _create_message(self, abstract, tag): - prompt = self.prompt_template.replace('{TOPIC}', tag) - prompt = prompt.replace('{ABSTRACT}', abstract) + prompt = self.prompt_template.replace("{TOPIC}", tag) + prompt = prompt.replace("{ABSTRACT}", abstract) return [{"role": "user", "content": prompt}] @@ -40,19 +40,23 @@ def _parse_response(answer, metadata): return res = { - "journal": metadata['model_key'], - "meshMajor": metadata['tags'], - "year": metadata['year'], - "abstractText": json_response['abstract'].replace("'", "").replace('"', ''), + "journal": metadata["model_key"], + "meshMajor": metadata["tags"], + "year": metadata["year"], + "abstractText": json_response["abstract"] + .replace("'", "") + .replace('"', ""), "pmid": uuid.uuid4().hex, - "title": json_response['title'].replace("'", "").replace('"', ''), - "existing_example": metadata['existing_example'].replace("'", "").replace('"', ''), - "required_examples": metadata['required_examples'], - "featured_tag": metadata['featured_tag'] + "title": json_response["title"].replace("'", "").replace('"', ""), + "existing_example": metadata["existing_example"] + .replace("'", "") + .replace('"', ""), + "required_examples": metadata["required_examples"], + "featured_tag": metadata["featured_tag"], } f.write(json.dumps(res)) - f.write('\n') + f.write("\n") f.flush() logger.info(f"Data received successfully for {metadata['featured_tag']}") @@ -85,7 +89,9 @@ def _prepare_request( logger.info(f"Augmenting {tag} with {missing_num} examples") # RAG: I select similar articles to provide them to the LLM - tmp_dset = dset.filter(lambda x: any(np.isin([tag], x["meshMajor"])), num_proc=num_proc) + tmp_dset = dset.filter( + lambda x: any(np.isin([tag], x["meshMajor"])), num_proc=num_proc + ) # abstracts_num = [i for i in range(len(tmp_dset))] # random.shuffle(abstracts_num) @@ -94,58 +100,59 @@ def _prepare_request( existing_examples = min(required_examples, len(tmp_dset)) n_per_example = math.ceil(required_examples / existing_examples) - logger.info(f"Augmenting {tag} with {required_examples} examples, using {existing_examples} in RAG mode") + logger.info( + f"Augmenting {tag} with {required_examples} examples, using {existing_examples} in RAG mode" + ) for i in range(existing_examples): - abstract = tmp_dset['abstractText'][i] - tags = tmp_dset['meshMajor'][i] + abstract = tmp_dset["abstractText"][i] + tags = tmp_dset["meshMajor"][i] data = { "model": self.model_key, "n": n_per_example, "temperature": temperature, "top_p": top_p, "presence_penalty": presence_penalty, - "messages": self._create_message(abstract, tag) + "messages": self._create_message(abstract, tag), } metadata = { - 'featured_tag': tag, - 'tags': tags, - 'required_examples': missing_num, - 'existing_example': abstract, - 'year': year, - 'model_key': model_key, - 'save_to_path': save_to_path + "featured_tag": tag, + "tags": tags, + "required_examples": missing_num, + "existing_example": abstract, + "year": year, + "model_key": model_key, + "save_to_path": save_to_path, } yield data, metadata def _make_requests( - self, - collect_concurrent_calls, - dset, - temperature, - top_p, - presence_penalty, - num_proc, - model_key, - save_to_path, + self, + collect_concurrent_calls, + dset, + temperature, + top_p, + presence_penalty, + num_proc, + model_key, + save_to_path, ): - for num in range(len(collect_concurrent_calls)): tag = collect_concurrent_calls[num][0] missing_num = collect_concurrent_calls[num][1] for data, metadata in self._prepare_request( - tag, - missing_num, - dset, - temperature, - top_p, - presence_penalty, - num_proc, - model_key, - save_to_path, + tag, + missing_num, + dset, + temperature, + top_p, + presence_penalty, + num_proc, + model_key, + save_to_path, ): chat_completion = openai.ChatCompletion.create(**data) chat_completion.metadata = metadata @@ -153,15 +160,15 @@ def _make_requests( self._process_response(chat_completion) def generate( - self, - collect_concurrent_calls, - dset, - save_to_path, - model_key, - temperature=1.5, - top_p=1, - presence_penalty=0, - num_proc=os.cpu_count(), + self, + collect_concurrent_calls, + dset, + save_to_path, + model_key, + temperature=1.5, + top_p=1, + presence_penalty=0, + num_proc=os.cpu_count(), ): self._make_requests( collect_concurrent_calls=collect_concurrent_calls, @@ -173,4 +180,3 @@ def generate( model_key=model_key, save_to_path=save_to_path, ) - diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index 0e55b23e..b9e82e40 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -60,7 +60,7 @@ def preprocess_mesh( batch_size: int = 256, tags: list = None, train_years: list = None, - test_years: list = None + test_years: list = None, ): if max_samples != -1: logger.info(f"Filtering examples to {max_samples}") @@ -94,14 +94,18 @@ def preprocess_mesh( if len(years) > 0: logger.info(f"Removing all years which are not in {years}") - dset = dset.filter(lambda x: any(np.isin(years, [str(x["year"])])), num_proc=num_proc) + dset = dset.filter( + lambda x: any(np.isin(years, [str(x["year"])])), num_proc=num_proc + ) if tags is None: tags = [] if len(tags) > 0: logger.info(f"Removing all tags which are not in {tags}") - dset = dset.filter(lambda x: any(np.isin(tags, x["meshMajor"])), num_proc=num_proc) + dset = dset.filter( + lambda x: any(np.isin(tags, x["meshMajor"])), num_proc=num_proc + ) # Remove unused columns to save space & time dset = dset.remove_columns(["journal", "pmid", "title"]) @@ -161,33 +165,44 @@ def preprocess_mesh( t1 = time.time() if len(years) > 0: logger.info("Splitting the dataset by training and test years") - train_dset = dset.filter(_filter_rows_by_years, - batched=True, - batch_size=batch_size, - desc=f"Creating training dataset with years {train_years}", - num_proc=num_proc, - fn_kwargs={"years": train_years}) - test_dset = dset.filter(_filter_rows_by_years, - batched=True, - batch_size=batch_size, - desc=f"Creating test dataset with years {test_years}", - num_proc=num_proc, - fn_kwargs={"years": test_years}) + train_dset = dset.filter( + _filter_rows_by_years, + batched=True, + batch_size=batch_size, + desc=f"Creating training dataset with years {train_years}", + num_proc=num_proc, + fn_kwargs={"years": train_years}, + ) + test_dset = dset.filter( + _filter_rows_by_years, + batched=True, + batch_size=batch_size, + desc=f"Creating test dataset with years {test_years}", + num_proc=num_proc, + fn_kwargs={"years": test_years}, + ) if test_size is None or test_size == 1.0: test_size = len(test_dset) logger.info(f"Using the whole dataset of {test_size} rows") - dset = DatasetDict({'train': train_dset, 'test': test_dset}) + dset = DatasetDict({"train": train_dset, "test": test_dset}) else: if test_size > 1.0: test_size = int(test_size) logger.info(f"Using a test_size frac or number of rows of of {test_size}") - dset = DatasetDict({'train': train_dset, 'test': test_dset.train_test_split(test_size)['test']}) + dset = DatasetDict( + { + "train": train_dset, + "test": test_dset.train_test_split(test_size)["test"], + } + ) else: if test_size is None: test_size = 0.05 - logger.info(f"Test size not found. Setting it to a frac of the whole dataset equal to {test_size}") + logger.info( + f"Test size not found. Setting it to a frac of the whole dataset equal to {test_size}" + ) elif test_size > 1.0: test_size = int(test_size) dset = dset.train_test_split(test_size=test_size) @@ -209,12 +224,9 @@ def preprocess_mesh( @preprocess_app.command() def preprocess_mesh_cli( - data_path: str = typer.Argument( - ..., - help="Path to mesh.jsonl"), + data_path: str = typer.Argument(..., help="Path to mesh.jsonl"), save_to_path: str = typer.Argument( - ..., - help="Path to save the serialized PyArrow dataset after preprocessing" + ..., help="Path to save the serialized PyArrow dataset after preprocessing" ), model_key: str = typer.Argument( ..., @@ -222,31 +234,28 @@ def preprocess_mesh_cli( "Leave blank if training from scratch", # noqa ), test_size: float = typer.Option( - None, - help="Fraction of data to use for testing in (0,1] or number of rows"), + None, help="Fraction of data to use for testing in (0,1] or number of rows" + ), num_proc: int = typer.Option( - os.cpu_count(), - help="Number of processes to use for preprocessing" + os.cpu_count(), help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( -1, help="Maximum number of samples to use for preprocessing", ), - batch_size: int = typer.Option( - 256, - help="Size of the preprocessing batch"), + batch_size: int = typer.Option(256, help="Size of the preprocessing batch"), tags: str = typer.Option( None, help="Comma-separated tags you want to include in the dataset " - "(the rest will be discarded)"), + "(the rest will be discarded)", + ), train_years: str = typer.Option( - None, - help="Comma-separated years you want to include in the training dataset"), + None, help="Comma-separated years you want to include in the training dataset" + ), test_years: str = typer.Option( - None, - help="Comma-separated years you want to include in the test dataset"), + None, help="Comma-separated years you want to include in the test dataset" + ), ): - if not data_path.endswith("jsonl"): logger.error( "It seems your input MeSH data is not in `jsonl` format. " @@ -274,5 +283,5 @@ def preprocess_mesh_cli( save_to_path=save_to_path, tags=parse_tags(tags), train_years=parse_years(train_years), - test_years=parse_years(test_years) + test_years=parse_years(test_years), ) diff --git a/grants_tagger_light/training/cli_args/bertmesh_args.py b/grants_tagger_light/training/cli_args/bertmesh_args.py index 262ffe5c..c59a3364 100644 --- a/grants_tagger_light/training/cli_args/bertmesh_args.py +++ b/grants_tagger_light/training/cli_args/bertmesh_args.py @@ -9,6 +9,6 @@ class BertMeshModelArguments: hidden_size: int = field(default=1024) dropout: float = field(default=0.1) multilabel_attention: bool = field(default=True) - freeze_backbone: str = field(default="unfreeze") # unfreeze, unfreeze_bias, freeze + freeze_backbone: str = field(default="unfreeze") # unfreeze, unfreeze_bias, freeze hidden_dropout_prob: float = field(default=0.2) attention_probs_dropout_prob: float = field(default=0.2) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 1c5035f5..5ffc530b 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -8,7 +8,7 @@ get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, - get_linear_schedule_with_warmup + get_linear_schedule_with_warmup, ) from grants_tagger_light.models.bert_mesh import BertMesh from grants_tagger_light.preprocessing.preprocess_mesh import preprocess_mesh @@ -48,7 +48,7 @@ def train_bertmesh( from_checkpoint: str = None, tags: list = None, train_years: list = None, - test_years: list = None + test_years: list = None, ): if not model_key: assert isinstance(model_args, BertMeshModelArguments), ( @@ -78,13 +78,13 @@ def train_bertmesh( batch_size=training_args.per_device_train_batch_size, tags=tags, train_years=train_years, - test_years=test_years + test_years=test_years, ) train_dset, val_dset = dset["train"], dset["test"] metric_labels = [] - for x in train_dset['label_ids']: + for x in train_dset["label_ids"]: metric_labels.extend(x) train_dset_size = len(train_dset) @@ -92,7 +92,11 @@ def train_bertmesh( if max_samples > 0: train_dset_size = min(max_samples, train_dset_size) logger.info(f"Training max samples: {train_dset_size}.") - train_dset.filter(lambda example, idx: idx < train_dset_size, with_indices=True, num_proc=num_proc) + train_dset.filter( + lambda example, idx: idx < train_dset_size, + with_indices=True, + num_proc=num_proc, + ) else: logger.info("Training with all data...") @@ -101,13 +105,16 @@ def train_bertmesh( train_dset = Sharding(num_shards=shards).shard(train_dset) if not model_key: - logger.info(f"Model key not found. Training from scratch {model_args.pretrained_model_key}") + logger.info( + f"Model key not found. Training from scratch {model_args.pretrained_model_key}" + ) # Instantiate model from scratch logger.info(f"Loading `{model_args.pretrained_model_key}` tokenizer...") config = AutoConfig.from_pretrained(model_args.pretrained_model_key) - config.update({ + config.update( + { "pretrained_model": model_args.pretrained_model_key, "num_labels": len(label2id), "hidden_size": model_args.hidden_size, @@ -118,14 +125,17 @@ def train_bertmesh( "freeze_backbone": model_args.freeze_backbone, "hidden_dropout_prob": model_args.hidden_dropout_prob, "attention_probs_dropout_prob": model_args.attention_probs_dropout_prob, - }) + } + ) logger.info(f"Hidden size: {config.hidden_size}") logger.info(f"Dropout: {config.dropout}") logger.info(f"Multilabel Attention: {config.multilabel_attention}") logger.info(f"Freeze Backbone: {config.freeze_backbone}") logger.info(f"Num labels: {config.num_labels}") logger.info(f"hidden_dropout_prob: {config.hidden_dropout_prob}") - logger.info(f"attention_probs_dropout_prob: {config.attention_probs_dropout_prob}") + logger.info( + f"attention_probs_dropout_prob: {config.attention_probs_dropout_prob}" + ) model = BertMesh(config) @@ -134,15 +144,15 @@ def train_bertmesh( model = BertMesh.from_pretrained(model_key, trust_remote_code=True) if model_args.freeze_backbone is None: - model_args.freeze_backbone = 'freeze' + model_args.freeze_backbone = "freeze" - if model_args.freeze_backbone.lower().strip() == 'unfreeze': + if model_args.freeze_backbone.lower().strip() == "unfreeze": logger.info("Unfreezing weights&biases in the backbone") model.unfreeze_backbone() - elif model_args.freeze_backbone.lower().strip() == 'unfreeze_bias': + elif model_args.freeze_backbone.lower().strip() == "unfreeze_bias": logger.info("Unfreezing only biases in the backbone") model.unfreeze_backbone(only_bias=True) - elif model_args.freeze_backbone.lower().strip() == 'freeze': + elif model_args.freeze_backbone.lower().strip() == "freeze": logger.info("Freezing backbone") model.freeze_backbone() @@ -160,7 +170,9 @@ def sklearn_metrics(prediction: EvalPrediction): # report = classification_report(y_pred, y_true, output_dict=True) if training_args.prune_labels_in_evaluation: - logger.info(f"For metric purposes, only considering labels present in `training`: {metric_labels[:15]}") + logger.info( + f"For metric purposes, only considering labels present in `training`: {metric_labels[:15]}" + ) mask = np.zeros(y_pred.shape, dtype=bool) mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True @@ -170,7 +182,9 @@ def sklearn_metrics(prediction: EvalPrediction): filtered_y_pred = y_pred filtered_y_true = y_true - report = classification_report(filtered_y_pred, filtered_y_true, output_dict=True) + report = classification_report( + filtered_y_pred, filtered_y_true, output_dict=True + ) metric_dict = { "micro_avg": report["micro avg"], @@ -193,32 +207,46 @@ def sklearn_metrics(prediction: EvalPrediction): model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay, - correct_bias=training_args.correct_bias if hasattr(training_args, 'correct_bias') else True) + correct_bias=training_args.correct_bias + if hasattr(training_args, "correct_bias") + else True, + ) if training_args.warmup_steps is None: training_args.warmup_steps = 0 - if training_args.scheduler_type.lower().strip() == 'cosine': - scheduler = get_cosine_schedule_with_warmup(optimizer, - num_warmup_steps=training_args.warmup_steps, - num_training_steps=training_args.max_steps) - elif training_args.scheduler_type.lower().strip() == 'constant': - scheduler = get_constant_schedule_with_warmup(optimizer, - num_warmup_steps=training_args.warmup_steps) - elif training_args.scheduler_type.lower().strip() == 'cosine_hard_restart': - scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, - num_warmup_steps=training_args.warmup_steps, - num_training_steps=training_args.max_steps, - num_cycles=training_args.num_train_epochs) - elif training_args.scheduler_type.lower().strip() == 'linear': - scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps=training_args.warmup_steps, - num_training_steps=training_args.max_steps) + if training_args.scheduler_type.lower().strip() == "cosine": + scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=training_args.max_steps, + ) + elif training_args.scheduler_type.lower().strip() == "constant": + scheduler = get_constant_schedule_with_warmup( + optimizer, num_warmup_steps=training_args.warmup_steps + ) + elif training_args.scheduler_type.lower().strip() == "cosine_hard_restart": + scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=training_args.max_steps, + num_cycles=training_args.num_train_epochs, + ) + elif training_args.scheduler_type.lower().strip() == "linear": + scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=training_args.max_steps, + ) else: - logger.warning(f"{training_args.scheduler_type}: not found or not valid. Falling back to `linear`") - scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps=training_args.warmup_steps, - num_training_steps=training_args.max_steps) + logger.warning( + f"{training_args.scheduler_type}: not found or not valid. Falling back to `linear`" + ) + scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=training_args.max_steps, + ) logger.info(f"Optimizer: {optimizer}") logger.info(f"Scheduler: {training_args.scheduler_type}") @@ -236,7 +264,6 @@ def sklearn_metrics(prediction: EvalPrediction): data_collator=collator, compute_metrics=sklearn_metrics, optimizers=(optimizer, scheduler), - ) logger.info(training_args) @@ -272,11 +299,10 @@ def train_bertmesh_cli( "or to a folder after preprocessing and saving to disk", ), test_size: float = typer.Option( - None, - help="Fraction of data to use for testing (0,1] or number of rows"), + None, help="Fraction of data to use for testing (0,1] or number of rows" + ), num_proc: int = typer.Option( - os.cpu_count(), - help="Number of processes to use for preprocessing" + os.cpu_count(), help="Number of processes to use for preprocessing" ), max_samples: int = typer.Option( -1, @@ -288,21 +314,19 @@ def train_bertmesh_cli( "IterativeDataset to (improves performance)", ), from_checkpoint: str = typer.Option( - None, - help="Name of the checkpoint to resume training" + None, help="Name of the checkpoint to resume training" ), tags: str = typer.Option( None, help="Comma-separated tags you want to include in the dataset " - "(the rest will be discarded)"), + "(the rest will be discarded)", + ), train_years: str = typer.Option( - None, - help="Comma-separated years you want to include in the training dataset" + None, help="Comma-separated years you want to include in the training dataset" ), test_years: str = typer.Option( - None, - help="Comma-separated years you want to include in the test dataset" - ) + None, help="Comma-separated years you want to include in the test dataset" + ), ): parser = HfArgumentParser( ( @@ -333,5 +357,5 @@ def train_bertmesh_cli( from_checkpoint=from_checkpoint, tags=parse_tags(tags), train_years=parse_years(train_years), - test_years=parse_years(test_years) + test_years=parse_years(test_years), ) diff --git a/tests/test_train.py b/tests/test_train.py index 512c2ea9..c12a4b2b 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,5 +1,8 @@ from grants_tagger_light.training.train import train_bertmesh -from grants_tagger_light.training.cli_args import BertMeshModelArguments, BertMeshTrainingArguments +from grants_tagger_light.training.cli_args import ( + BertMeshModelArguments, + BertMeshTrainingArguments, +) import tempfile import pytest @@ -35,7 +38,7 @@ def _train_bertmesh_from_model_key(data_path, save_path, model_key): report_to="none", no_cuda=True, num_train_epochs=1, - dataloader_num_workers=1 + dataloader_num_workers=1, ) model_args = BertMeshModelArguments() @@ -48,7 +51,7 @@ def _train_bertmesh_from_model_key(data_path, save_path, model_key): model_args=model_args, num_proc=1, test_size=0.5, - shards=1 + shards=1, ) From 6a440fa2030018ea935cabea6985577357db8da8 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 1 Sep 2023 14:26:50 +0100 Subject: [PATCH 294/300] Fixes black --- grants_tagger_light/models/bert_mesh/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index 71f8b5b4..e4367344 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -65,7 +65,7 @@ def freeze_backbone(self): def unfreeze_backbone(self, only_bias=False): for name, param in self.bert.named_parameters(): if only_bias: - if 'bias' in name.lower(): + if "bias" in name.lower(): logger.info(f"Unfreezing {name}") param.requires_grad = True else: From 37c9a2b0d1119e52f90239c7efa6df6a543229f6 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 1 Sep 2023 14:37:05 +0100 Subject: [PATCH 295/300] Fixes ruff --- .../augmentation/JsonParser.py | 4 ++-- grants_tagger_light/augmentation/augment.py | 21 ++++++++++++------- .../augmentation/augment_openai.py | 7 ++++--- .../evaluation/evaluate_model.py | 2 +- grants_tagger_light/models/bert_mesh/model.py | 17 +++------------ .../preprocessing/preprocess_mesh.py | 4 +++- grants_tagger_light/training/train.py | 14 +++++-------- grants_tagger_light/utils/utils.py | 2 +- 8 files changed, 33 insertions(+), 38 deletions(-) diff --git a/grants_tagger_light/augmentation/JsonParser.py b/grants_tagger_light/augmentation/JsonParser.py index 4e41c0e8..fe9a2e97 100644 --- a/grants_tagger_light/augmentation/JsonParser.py +++ b/grants_tagger_light/augmentation/JsonParser.py @@ -8,8 +8,8 @@ class JsonParser: def __init(self): - """Class to parse json produced by LLMs. Inspiration taken from langchain. It fixes quotes, - it escapes separators, etc.""" + """Class to parse json produced by LLMs. Inspiration taken from langchain. + It fixes quotes, it escapes separators, etc.""" pass @staticmethod diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index 377c822d..de4d05cd 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -75,7 +75,8 @@ def augment( with open(tags_file_path, "r") as f: tags = f.read().split("\n") logger.info( - f"Tags file path found. Filtering {len(tags)} tags (examples found: {tags[:15]}...)" + f"Tags file path found. Filtering {len(tags)} tags " + f"(examples found: {tags[:15]}...)" ) sorted_merged_element_counts_dict = { k: v for k, v in sorted_merged_element_counts_dict.items() if k in tags @@ -106,7 +107,7 @@ def augment( ) logger.info( - f"RAG: Collecting existing examples of those tags to send in the prompt" + "RAG: Collecting existing examples of those tags to send in the prompt" ) dset = dset.filter( lambda x: any(np.isin(tags_to_augment, x["meshMajor"])), num_proc=num_proc @@ -171,12 +172,15 @@ def augment_cli( ), min_examples: int = typer.Option( None, - help="Minimum number of examples to require. Less than that will trigger data augmentation.", + help="Minimum number of examples to require. " + "Less than that will trigger data augmentation.", ), examples: int = typer.Option(25, help="Examples to generate per each tag."), prompt_template: str = typer.Option( "grants_tagger_light/augmentation/prompt.template", - help="File to use as a prompt. Make sure to ask the LLM to return a dict with two fields: `abstract` and `tags`", + help="File to use as a prompt. " + "Make sure to ask the LLM to return a dict with two fields: " + "`abstract` and `tags`", ), concurrent_calls: int = typer.Option( os.cpu_count() * 2, @@ -191,18 +195,21 @@ def augment_cli( ), tags_file_path: str = typer.Option( None, - help="Text file containing one line per tag to be considered. The rest will be discarded.", + help="Text file containing one line per tag to be considered. " + "The rest will be discarded.", ), ): if not os.path.isdir(data_path): logger.error( - "The data path should be a folder with saved data from `preprocessing` step." + "The data path should be a folder with saved data from " + "`preprocessing` step." ) exit(-1) if tags_file_path is None and min_examples is None: logger.error( - "To understand which tags need to be augmented, set either --min-examples or --tags-file-path" + "To understand which tags need to be augmented, " + "set either --min-examples or --tags-file-path" ) exit(-1) diff --git a/grants_tagger_light/augmentation/augment_openai.py b/grants_tagger_light/augmentation/augment_openai.py index 0fa15d21..22c6dcd4 100644 --- a/grants_tagger_light/augmentation/augment_openai.py +++ b/grants_tagger_light/augmentation/augment_openai.py @@ -2,7 +2,6 @@ import json import math import os -import random import uuid import openai @@ -17,7 +16,8 @@ class AugmentOpenAI: def __init__(self, prompt_template_path, model_key="gpt-3.5-turbo"): if "OPENAI_API_KEY" not in os.environ: logger.error( - "OPENAI_API_KEY not found in env vars. Please define it before running this program." + "OPENAI_API_KEY not found in env vars. " + "Please define it before running this program." ) with open(prompt_template_path, "r") as f: self.prompt_template = f.read() @@ -101,7 +101,8 @@ def _prepare_request( n_per_example = math.ceil(required_examples / existing_examples) logger.info( - f"Augmenting {tag} with {required_examples} examples, using {existing_examples} in RAG mode" + f"Augmenting {tag} with {required_examples} examples, " + f"using {existing_examples} in RAG mode" ) for i in range(existing_examples): diff --git a/grants_tagger_light/evaluation/evaluate_model.py b/grants_tagger_light/evaluation/evaluate_model.py index 4f7c8352..eb43b720 100644 --- a/grants_tagger_light/evaluation/evaluate_model.py +++ b/grants_tagger_light/evaluation/evaluate_model.py @@ -56,7 +56,7 @@ def evaluate_model( Y_pred_proba = sp.csr_matrix(Y_pred_proba) - if type(threshold) != list: + if not isinstance(threshold, list): threshold = [threshold] widths = (12, 5, 5, 5) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index e4367344..2f8245fb 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -28,35 +28,24 @@ def __init__( super().__init__(config=config) self.config.auto_map = {"AutoModel": "model.BertMesh"} self.pretrained_model = self.config.pretrained_model - logger.info(f"Pretrained model: {self.pretrained_model}") self.num_labels = self.config.num_labels - logger.info(f"Num labels: {self.num_labels}") self.hidden_size = getattr(self.config, "hidden_size", 512) - logger.info(f"Hidden Size: {self.hidden_size}") self.dropout = getattr(self.config, "dropout", 0.1) - logger.info(f"Dropout: {self.dropout}") self.multilabel_attention = getattr(self.config, "multilabel_attention", False) - logger.info(f"Multilabel attention: {self.multilabel_attention}") - self.id2label = self.config.id2label self.bert = AutoModel.from_pretrained(self.pretrained_model) # 768 self.multilabel_attention_layer = MultiLabelAttention( 768, self.num_labels ) # num_labels, 768 - logger.info(f"multilabel_attention_layer: {self.multilabel_attention_layer}") self.linear_1 = torch.nn.Linear(768, self.hidden_size) # 768, 1024 - logger.info(f"linear_1: {self.linear_1}") self.linear_2 = torch.nn.Linear(self.hidden_size, 1) # 1024, 1 - logger.info(f"linear_2: {self.linear_2}") self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels) - logger.info(f"linear_out: {self.linear_out}") self.dropout_layer = torch.nn.Dropout(self.dropout) - logger.info(f"dropout_layer: {self.dropout_layer}") def freeze_backbone(self): for param in self.bert.parameters(): @@ -66,16 +55,16 @@ def unfreeze_backbone(self, only_bias=False): for name, param in self.bert.named_parameters(): if only_bias: if "bias" in name.lower(): - logger.info(f"Unfreezing {name}") + # logger.info(f"Unfreezing {name}") param.requires_grad = True else: param.requires_grad = False else: + # logger.info(f"Unfreezing {name}") param.requires_grad = True - logger.info(f"Unfreezing {name}") def forward(self, input_ids, labels=None, **kwargs): - if type(input_ids) is list: + if isinstance(input_ids, list): # coming from tokenizer input_ids = torch.tensor(input_ids) diff --git a/grants_tagger_light/preprocessing/preprocess_mesh.py b/grants_tagger_light/preprocessing/preprocess_mesh.py index b9e82e40..ef1ca963 100644 --- a/grants_tagger_light/preprocessing/preprocess_mesh.py +++ b/grants_tagger_light/preprocessing/preprocess_mesh.py @@ -201,7 +201,9 @@ def preprocess_mesh( if test_size is None: test_size = 0.05 logger.info( - f"Test size not found. Setting it to a frac of the whole dataset equal to {test_size}" + f"Test size not found. " + f"Setting it to a frac of the whole dataset equal to " + f"{test_size}" ) elif test_size > 1.0: test_size = int(test_size) diff --git a/grants_tagger_light/training/train.py b/grants_tagger_light/training/train.py index 5ffc530b..680e8fd6 100644 --- a/grants_tagger_light/training/train.py +++ b/grants_tagger_light/training/train.py @@ -106,7 +106,8 @@ def train_bertmesh( if not model_key: logger.info( - f"Model key not found. Training from scratch {model_args.pretrained_model_key}" + f"Model key not found. " + f"Training from scratch {model_args.pretrained_model_key}" ) # Instantiate model from scratch @@ -157,7 +158,6 @@ def train_bertmesh( model.freeze_backbone() def sklearn_metrics(prediction: EvalPrediction): - logger.info(f"Threshold: {training_args.threshold}") # This is a batch, so it's an array (rows) of array (labels) # Array of arrays with probas [[5.4e-5 1.3e-3...] [5.4e-5 1.3e-3...] ... ] y_pred = prediction.predictions @@ -170,9 +170,6 @@ def sklearn_metrics(prediction: EvalPrediction): # report = classification_report(y_pred, y_true, output_dict=True) if training_args.prune_labels_in_evaluation: - logger.info( - f"For metric purposes, only considering labels present in `training`: {metric_labels[:15]}" - ) mask = np.zeros(y_pred.shape, dtype=bool) mask[np.arange(y_pred.shape[0])[:, np.newaxis], metric_labels] = True @@ -240,7 +237,8 @@ def sklearn_metrics(prediction: EvalPrediction): ) else: logger.warning( - f"{training_args.scheduler_type}: not found or not valid. Falling back to `linear`" + f"{training_args.scheduler_type}: not found or not valid. " + f"Falling back to `linear`" ) scheduler = get_linear_schedule_with_warmup( optimizer, @@ -254,8 +252,6 @@ def sklearn_metrics(prediction: EvalPrediction): training_args.optim = optimizer training_args.lr_scheduler_type = scheduler - logger.info(f"Test dataset size: {len(val_dset)}") - trainer = Trainer( model=model, args=training_args, @@ -265,7 +261,7 @@ def sklearn_metrics(prediction: EvalPrediction): compute_metrics=sklearn_metrics, optimizers=(optimizer, scheduler), ) - logger.info(training_args) + # logger.info(training_args) if from_checkpoint is None: logger.info("Training...") diff --git a/grants_tagger_light/utils/utils.py b/grants_tagger_light/utils/utils.py index 05aa36ae..8264ffe2 100644 --- a/grants_tagger_light/utils/utils.py +++ b/grants_tagger_light/utils/utils.py @@ -155,7 +155,7 @@ def convert_dvc_to_sklearn_params(parameters): return {} # indication of sklearn pipeline - has_nested_params = any([v for v in parameters.values() if type(v) is dict]) + has_nested_params = any([v for v in parameters.values() if isinstance(v, dict)]) if has_nested_params: return { f"{pipeline_name}__{param_name}": param_value From 3a8fd489e392a83cfcc86c7a9a972fee6464c487 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 1 Sep 2023 14:37:47 +0100 Subject: [PATCH 296/300] Fixes ruff --- grants_tagger_light/models/bert_mesh/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index 2f8245fb..d1a9c5df 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -2,7 +2,7 @@ from transformers.modeling_outputs import SequenceClassifierOutput import torch import torch.nn.functional as F -from loguru import logger +# from loguru import logger class MultiLabelAttention(torch.nn.Module): From 06c5479cd8635fe2adae357f19ee825d1deb0852 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 1 Sep 2023 14:39:07 +0100 Subject: [PATCH 297/300] Fixes black --- grants_tagger_light/models/bert_mesh/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grants_tagger_light/models/bert_mesh/model.py b/grants_tagger_light/models/bert_mesh/model.py index d1a9c5df..bc14d850 100644 --- a/grants_tagger_light/models/bert_mesh/model.py +++ b/grants_tagger_light/models/bert_mesh/model.py @@ -2,6 +2,7 @@ from transformers.modeling_outputs import SequenceClassifierOutput import torch import torch.nn.functional as F + # from loguru import logger From e1fbe2bfa0495ccf5d05ca252aff292056f46608 Mon Sep 17 00:00:00 2001 From: "Jose J. Martinez" Date: Fri, 1 Sep 2023 14:41:35 +0100 Subject: [PATCH 298/300] Fixes black --- grants_tagger_light/augmentation/augment.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/grants_tagger_light/augmentation/augment.py b/grants_tagger_light/augmentation/augment.py index de4d05cd..05a38cda 100644 --- a/grants_tagger_light/augmentation/augment.py +++ b/grants_tagger_light/augmentation/augment.py @@ -106,12 +106,12 @@ def augment( f"from {biggest_tags_to_augment} to {smallest_tags_to_augment}" ) - logger.info( - "RAG: Collecting existing examples of those tags to send in the prompt" - ) + logger.info("Collecting existing examples of those tags to send in the prompt") + dset = dset.filter( lambda x: any(np.isin(tags_to_augment, x["meshMajor"])), num_proc=num_proc ) + dset = dset.map( lambda _, y: {"idx": y}, with_indices=True, @@ -173,14 +173,14 @@ def augment_cli( min_examples: int = typer.Option( None, help="Minimum number of examples to require. " - "Less than that will trigger data augmentation.", + "Less than that will trigger data augmentation.", ), examples: int = typer.Option(25, help="Examples to generate per each tag."), prompt_template: str = typer.Option( "grants_tagger_light/augmentation/prompt.template", help="File to use as a prompt. " - "Make sure to ask the LLM to return a dict with two fields: " - "`abstract` and `tags`", + "Make sure to ask the LLM to return a dict with two fields: " + "`abstract` and `tags`", ), concurrent_calls: int = typer.Option( os.cpu_count() * 2, @@ -196,7 +196,7 @@ def augment_cli( tags_file_path: str = typer.Option( None, help="Text file containing one line per tag to be considered. " - "The rest will be discarded.", + "The rest will be discarded.", ), ): if not os.path.isdir(data_path): From 35c54f86d917744d36a18ea168e79a434b644ad2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 6 Sep 2023 07:20:26 +0000 Subject: [PATCH 299/300] Modify script to include active portfolio sample --- data/grants_comparison/.gitignore | 2 +- data/grants_comparison/comparison.csv.dvc | 4 ---- data/raw/.gitignore | 1 + data/raw/active_grants_last_5_years.csv.dvc | 4 ++++ pipelines/generate_grants/dvc.lock | 19 +++++++++++---- pipelines/generate_grants/dvc.yaml | 24 +++++++++++++++---- .../create_xlinear_bertmesh_comparison_csv.py | 17 +++++++++++++ 7 files changed, 57 insertions(+), 14 deletions(-) delete mode 100644 data/grants_comparison/comparison.csv.dvc create mode 100644 data/raw/active_grants_last_5_years.csv.dvc diff --git a/data/grants_comparison/.gitignore b/data/grants_comparison/.gitignore index 1357bfdf..1fd228cc 100644 --- a/data/grants_comparison/.gitignore +++ b/data/grants_comparison/.gitignore @@ -1,2 +1,2 @@ -/comparison.csv /meshterms_list.txt +/comparison.csv diff --git a/data/grants_comparison/comparison.csv.dvc b/data/grants_comparison/comparison.csv.dvc deleted file mode 100644 index 96a38344..00000000 --- a/data/grants_comparison/comparison.csv.dvc +++ /dev/null @@ -1,4 +0,0 @@ -outs: -- md5: a445d8b6396364ef601638e57d6c9a46 - size: 845271 - path: comparison.csv diff --git a/data/raw/.gitignore b/data/raw/.gitignore index 69e7cea9..b0254cca 100644 --- a/data/raw/.gitignore +++ b/data/raw/.gitignore @@ -2,3 +2,4 @@ /allMeSH_2021.jsonl /desc2021.xml /disease_tags_validation_grants.xlsx +/active_grants_last_5_years.csv diff --git a/data/raw/active_grants_last_5_years.csv.dvc b/data/raw/active_grants_last_5_years.csv.dvc new file mode 100644 index 00000000..19717091 --- /dev/null +++ b/data/raw/active_grants_last_5_years.csv.dvc @@ -0,0 +1,4 @@ +outs: +- md5: d664be2a9000d44bb0325f364ec20e27 + size: 4953477 + path: active_grants_last_5_years.csv diff --git a/pipelines/generate_grants/dvc.lock b/pipelines/generate_grants/dvc.lock index db87cd92..1c9bff45 100644 --- a/pipelines/generate_grants/dvc.lock +++ b/pipelines/generate_grants/dvc.lock @@ -1,9 +1,18 @@ schema: '2.0' stages: generate: - cmd: python ../../scripts/create_grants_sample.py --s3-url s3://datalabs-data/dimensions/grants/grants - --num-parquet-files-to-consider 10 --num-samples-per-cat 10 --pre-annotate True + cmd: python scripts/create_xlinear_bertmesh_comparison_csv.py --s3-url s3://datalabs-data/dimensions/grants/grants + --num-parquet-files-to-consider 10 --num-samples-per-cat 10 --mesh-metadata-path + data/raw/desc2021.xml --mesh-terms-list-path data/grants_comparison/meshterms_list.txt + --active-portfolio-path data/raw/active_grants_last_5_years.csv --bertmesh-path + Wellcome/WellcomeBertMesh --bertmesh-thresh 0.5 --pre-annotate-bertmesh --xlinear-path + models/xlinear-0.2.5/model --xlinear-label-binarizer-path models/xlinear-0.2.5/label_binarizer.pkl + --xlinear-thresh 0.2 --pre-annotate-xlinear --output-path data/grants_comparison/comparison.csv + deps: + - path: scripts/create_xlinear_bertmesh_comparison_csv.py + md5: 3c138310387fb1d0f9fa650b3bc55842 + size: 8123 outs: - - path: grants_sample.jsonl - md5: 76bbfd9043e20866382ff9713cba7483 - size: 387951 + - path: data/grants_comparison/comparison.csv + md5: 72506f7ea22006c68c5707588ca7ecf2 + size: 600428 diff --git a/pipelines/generate_grants/dvc.yaml b/pipelines/generate_grants/dvc.yaml index 6b2a09dd..0a982435 100644 --- a/pipelines/generate_grants/dvc.yaml +++ b/pipelines/generate_grants/dvc.yaml @@ -1,9 +1,25 @@ vars: - s3-url: "s3://datalabs-data/dimensions/grants/grants" - - scripts_location: "../../scripts" - - argilla_project_name: "grants" stages: generate: - cmd: python ${scripts_location}/create_grants_sample.py --s3-url ${s3-url} --num-parquet-files-to-consider 10 --num-samples-per-cat 10 --pre-annotate True + cmd: >- + python scripts/create_xlinear_bertmesh_comparison_csv.py + --s3-url ${s3-url} + --num-parquet-files-to-consider 10 + --num-samples-per-cat 10 + --mesh-metadata-path data/raw/desc2021.xml + --mesh-terms-list-path data/grants_comparison/meshterms_list.txt + --active-portfolio-path data/raw/active_grants_last_5_years.csv + --bertmesh-path Wellcome/WellcomeBertMesh + --bertmesh-thresh 0.5 + --pre-annotate-bertmesh + --xlinear-path models/xlinear-0.2.5/model + --xlinear-label-binarizer-path models/xlinear-0.2.5/label_binarizer.pkl + --xlinear-thresh 0.2 + --pre-annotate-xlinear + --output-path data/grants_comparison/comparison.csv + deps: + - scripts/create_xlinear_bertmesh_comparison_csv.py + wdir: "../.." outs: - - grants_sample.jsonl + - data/grants_comparison/comparison.csv diff --git a/scripts/create_xlinear_bertmesh_comparison_csv.py b/scripts/create_xlinear_bertmesh_comparison_csv.py index 48e00ff0..01e66787 100644 --- a/scripts/create_xlinear_bertmesh_comparison_csv.py +++ b/scripts/create_xlinear_bertmesh_comparison_csv.py @@ -89,6 +89,8 @@ def create_comparison_csv( num_samples_per_cat: int, mesh_metadata_path: str, mesh_terms_list_path: str, + active_portfolio_path: str, + active_portfolio_sample: int, pre_annotate_bertmesh: bool, bertmesh_path: str, bertmesh_thresh: float, @@ -122,10 +124,20 @@ def create_comparison_csv( lambda x: x.sample(min(len(x), num_samples_per_cat)) ) + # Add active portfolio + active_grants = pd.read_csv(active_portfolio_path) + active_grants = active_grants[~active_grants["Synopsis"].isna()] + active_grants.sample(frac=1) + active_grants_sample = active_grants.iloc[:active_portfolio_sample] + active_grants_sample = pd.DataFrame({"abstract": active_grants_sample["Synopsis"]}) + grants_sample = pd.concat([grants_sample, active_grants_sample]) + abstracts = grants_sample["abstract"].tolist() + print(f"{len(abstracts)} abstracts to tag") # Annotate with bertmesh if pre_annotate_bertmesh: + print("Tagging with bertmesh") tags = predict_tags_bertmesh( abstracts, bertmesh_path, @@ -141,6 +153,7 @@ def create_comparison_csv( # Annotate with xlinear if pre_annotate_xlinear: + print("Tagging with xlinear") model = MeshXLinear( model_path=xlinear_path, label_binarizer_path=xlinear_label_binarizer_path, @@ -187,6 +200,8 @@ def create_comparison_csv( parser.add_argument("--num-samples-per-cat", type=int, default=10) parser.add_argument("--mesh-metadata-path", type=str) parser.add_argument("--mesh-terms-list-path", type=str) + parser.add_argument("--active-portfolio-path", type=str) + parser.add_argument("--active-portfolio-sample", type=int, default=200) parser.add_argument("--pre-annotate-bertmesh", action="store_true") parser.add_argument( "--bertmesh-path", type=str, default="Wellcome/WellcomeBertMesh" @@ -206,6 +221,8 @@ def create_comparison_csv( num_samples_per_cat=args.num_samples_per_cat, mesh_metadata_path=args.mesh_metadata_path, mesh_terms_list_path=args.mesh_terms_list_path, + active_portfolio_path=args.active_portfolio_path, + active_portfolio_sample=args.active_portfolio_sample, pre_annotate_bertmesh=args.pre_annotate_bertmesh, bertmesh_path=args.bertmesh_path, bertmesh_thresh=args.bertmesh_thresh, From d7dde6574338ad5685364a910f909ff84b4964fb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 6 Sep 2023 07:41:35 +0000 Subject: [PATCH 300/300] Update data --- pipelines/generate_grants/dvc.lock | 8 ++++---- scripts/create_xlinear_bertmesh_comparison_csv.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pipelines/generate_grants/dvc.lock b/pipelines/generate_grants/dvc.lock index 1c9bff45..8c6a5d0c 100644 --- a/pipelines/generate_grants/dvc.lock +++ b/pipelines/generate_grants/dvc.lock @@ -10,9 +10,9 @@ stages: --xlinear-thresh 0.2 --pre-annotate-xlinear --output-path data/grants_comparison/comparison.csv deps: - path: scripts/create_xlinear_bertmesh_comparison_csv.py - md5: 3c138310387fb1d0f9fa650b3bc55842 - size: 8123 + md5: 0a91bf23be4068bdc7c4b7a32d80ff2d + size: 8214 outs: - path: data/grants_comparison/comparison.csv - md5: 72506f7ea22006c68c5707588ca7ecf2 - size: 600428 + md5: bc4fd9f4a670409dad07ffd03cf421f1 + size: 596654 diff --git a/scripts/create_xlinear_bertmesh_comparison_csv.py b/scripts/create_xlinear_bertmesh_comparison_csv.py index 01e66787..1f525a2c 100644 --- a/scripts/create_xlinear_bertmesh_comparison_csv.py +++ b/scripts/create_xlinear_bertmesh_comparison_csv.py @@ -123,6 +123,7 @@ def create_comparison_csv( grants_sample = all_grants.groupby("for_first_level_name", group_keys=False).apply( lambda x: x.sample(min(len(x), num_samples_per_cat)) ) + grants_sample["active_portfolio"] = 0 # Add active portfolio active_grants = pd.read_csv(active_portfolio_path) @@ -130,6 +131,7 @@ def create_comparison_csv( active_grants.sample(frac=1) active_grants_sample = active_grants.iloc[:active_portfolio_sample] active_grants_sample = pd.DataFrame({"abstract": active_grants_sample["Synopsis"]}) + active_grants_sample["active_portfolio"] = 1 grants_sample = pd.concat([grants_sample, active_grants_sample]) abstracts = grants_sample["abstract"].tolist()