forked from wellcometrust/grants_tagger_light
-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Andrei Apostol
committed
Jul 18, 2023
1 parent
c4bd743
commit 1660ef0
Showing
5 changed files
with
222 additions
and
133 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
import typer | ||
|
||
from .preprocess_mesh import preprocess_mesh_cli | ||
|
||
preprocess_app = typer.Typer() | ||
preprocess_app.command( | ||
"mesh", | ||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, | ||
)(preprocess_mesh_cli) | ||
preprocess_app.command("bioasq-mesh")(preprocess_mesh_cli) | ||
|
||
__all__ = ["preprocess_app"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,147 +1,135 @@ | ||
""" | ||
Preprocess JSON Mesh data from BioASQ to JSONL | ||
""" | ||
import json | ||
import numpy as np | ||
import os | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import pandas as pd | ||
import typer | ||
from transformers import AutoTokenizer | ||
from datasets import Dataset, disable_caching | ||
from loguru import logger | ||
from grants_tagger_light.models.bert_mesh import BertMesh | ||
from tqdm import tqdm | ||
|
||
# TODO refactor the two load funcs into a class | ||
from grants_tagger_light.utils import ( | ||
write_jsonl, | ||
) | ||
|
||
disable_caching() | ||
preprocess_app = typer.Typer() | ||
|
||
def yield_raw_data(input_path): | ||
with open(input_path, encoding="latin-1") as f_i: | ||
f_i.readline() # skip first line ({"articles":[) which is not valid JSON | ||
for i, line in enumerate(f_i): | ||
item = json.loads(line[:-2]) | ||
yield item | ||
|
||
def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): | ||
return tokenizer( | ||
batch[x_col], | ||
padding="max_length", | ||
truncation=True, | ||
max_length=512, | ||
) | ||
|
||
def process_data(item, filter_tags=None, filter_years=None): | ||
text = item["abstractText"] | ||
tags = item["meshMajor"] | ||
journal = item["journal"] | ||
year = item["year"] | ||
if filter_tags: | ||
tags = list(set(tags).intersection(filter_tags)) | ||
if not tags: | ||
return | ||
if filter_years: | ||
min_year, max_year = filter_years.split(",") | ||
if year > int(max_year): | ||
return | ||
if year < int(min_year): | ||
return | ||
data = {"text": text, "tags": tags, "meta": {"journal": journal, "year": year}} | ||
return data | ||
|
||
|
||
def _encode_labels(sample, label2id: dict): | ||
sample["label_ids"] = [] | ||
for label in sample["meshMajor"]: | ||
try: | ||
sample["label_ids"].append(label2id[label]) | ||
except KeyError: | ||
logger.warning(f"Label {label} not found in label2id") | ||
def yield_data(input_path, filter_tags, filter_years, buffer_size=10_000): | ||
data_batch = [] | ||
for item in tqdm(yield_raw_data(input_path), total=15_000_000): # approx 15M docs | ||
processed_item = process_data(item, filter_tags, filter_years) | ||
|
||
return sample | ||
if processed_item: | ||
data_batch.append(processed_item) | ||
|
||
if len(data_batch) >= buffer_size: | ||
yield data_batch | ||
data_batch = [] | ||
|
||
def _get_label2id(dset): | ||
label_set = set() | ||
for sample in dset: | ||
for label in sample["meshMajor"]: | ||
label_set.add(label) | ||
label2id = {label: idx for idx, label in enumerate(label_set)} | ||
return label2id | ||
if data_batch: | ||
yield data_batch | ||
|
||
|
||
def preprocess_mesh( | ||
data_path: str, | ||
save_loc: str, | ||
model_key: str, | ||
test_size: float = 0.05, | ||
num_proc: int = 8, | ||
max_samples: int = np.inf, | ||
raw_data_path, | ||
processed_data_path, | ||
mesh_tags_path=None, | ||
filter_years=None, | ||
n_max=None, | ||
buffer_size=10_000, | ||
): | ||
if not model_key: | ||
label2id = None | ||
# Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" | ||
) | ||
if mesh_tags_path: | ||
filter_tags_data = pd.read_csv(mesh_tags_path) | ||
filter_tags = filter_tags_data["DescriptorName"].tolist() | ||
filter_tags = set(filter_tags) | ||
else: | ||
# Load the model to get its label2id | ||
tokenizer = AutoTokenizer.from_pretrained(model_key) | ||
model = BertMesh.from_pretrained(model_key, trust_remote_code=True) | ||
|
||
label2id = {v: k for k, v in model.id2label.items()} | ||
|
||
def _datagen(mesh_json_path: str, max_samples: int = np.inf): | ||
with open(mesh_json_path, "r", encoding="latin1") as f: | ||
for idx, line in enumerate(f): | ||
# Skip 1st line | ||
if idx == 0: | ||
continue | ||
sample = json.loads(line[:-2]) | ||
|
||
if idx > max_samples: | ||
break | ||
|
||
yield sample | ||
|
||
dset = Dataset.from_generator( | ||
_datagen, | ||
gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, | ||
) | ||
filter_tags = None | ||
|
||
# Remove unused columns to save space & time | ||
dset = dset.remove_columns(["journal", "year", "pmid", "title"]) | ||
|
||
dset = dset.map( | ||
_tokenize, | ||
batched=True, | ||
batch_size=32, | ||
num_proc=num_proc, | ||
desc="Tokenizing", | ||
fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, | ||
remove_columns=["abstractText"], | ||
) | ||
if filter_years: | ||
min_year, max_year = filter_years.split(",") | ||
filter_years = [int(min_year), int(max_year)] | ||
|
||
# Generate label2id if None | ||
if label2id is None: | ||
label2id = _get_label2id(dset) | ||
|
||
dset = dset.map( | ||
_encode_labels, | ||
batched=False, | ||
num_proc=num_proc, | ||
desc="Encoding labels", | ||
fn_kwargs={"label2id": label2id}, | ||
remove_columns=["meshMajor"], | ||
) | ||
# If only using a tiny set of data, need to reduce buffer size | ||
if n_max is not None and n_max < buffer_size: | ||
buffer_size = n_max | ||
|
||
# Split into train and test | ||
dset = dset.train_test_split(test_size=test_size) | ||
with open(processed_data_path, "w") as f: | ||
for i, data_batch in enumerate( | ||
yield_data( | ||
raw_data_path, filter_tags, filter_years, buffer_size=buffer_size | ||
) | ||
): | ||
write_jsonl(f, data_batch) | ||
if n_max and (i + 1) * buffer_size >= n_max: | ||
break | ||
|
||
# Save to disk | ||
dset.save_to_disk(os.path.join(save_loc, "dataset")) | ||
|
||
with open(os.path.join(save_loc, "label2id.json"), "w") as f: | ||
json.dump(label2id, f, indent=4) | ||
preprocess_mesh_app = typer.Typer() | ||
|
||
|
||
@preprocess_app.command() | ||
@preprocess_mesh_app.command() | ||
def preprocess_mesh_cli( | ||
data_path: str = typer.Argument(..., help="Path to mesh.json"), | ||
save_loc: str = typer.Argument(..., help="Path to save processed data"), | ||
model_key: str = typer.Argument( | ||
..., | ||
help="Key to use when loading tokenizer and label2id. Leave blank if training from scratch", # noqa | ||
input_path: Optional[str] = typer.Argument(None, help="path to BioASQ JSON data"), | ||
train_output_path: Optional[str] = typer.Argument( | ||
None, help="path to JSONL output file that will be generated for the train set" | ||
), | ||
label_binarizer_path: Optional[Path] = typer.Argument( | ||
None, help="path to pickle file that will contain the label binarizer" | ||
), | ||
test_output_path: Optional[str] = typer.Option( | ||
None, help="path to JSONL output file that will be generated for the test set" | ||
), | ||
mesh_tags_path: Optional[str] = typer.Option( | ||
None, help="path to mesh tags to filter" | ||
), | ||
test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"), | ||
num_proc: int = typer.Option( | ||
8, help="Number of processes to use for preprocessing" | ||
test_split: Optional[float] = typer.Option( | ||
0.01, help="split percentage for test data. if None no split." | ||
), | ||
max_samples: int = typer.Option( | ||
-1, | ||
help="Maximum number of samples to use for preprocessing", | ||
filter_years: Optional[str] = typer.Option( | ||
None, help="years to keep in form min_year,max_year with both inclusive" | ||
), | ||
n_max: Optional[int] = typer.Option( | ||
None, | ||
help=""" | ||
Maximum limit on the number of datapoints in the set | ||
(including training and test)""", | ||
), | ||
): | ||
if max_samples == -1: | ||
max_samples = np.inf | ||
|
||
preprocess_mesh( | ||
data_path=data_path, | ||
save_loc=save_loc, | ||
model_key=model_key, | ||
test_size=test_size, | ||
num_proc=num_proc, | ||
max_samples=max_samples, | ||
input_path, | ||
train_output_path, | ||
mesh_tags_path=mesh_tags_path, | ||
filter_years=filter_years, | ||
n_max=n_max, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(preprocess_mesh) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .mesh_json_loader import load_mesh_json | ||
from .multilabel_collator import MultilabelDataCollator | ||
|
||
__all__ = ["MultilabelDataCollator"] | ||
__all__ = ["load_mesh_json", "MultilabelDataCollator"] |
94 changes: 94 additions & 0 deletions
94
grants_tagger_light/training/dataloaders/mesh_json_loader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import json | ||
import numpy as np | ||
from transformers import AutoTokenizer | ||
from datasets import Dataset | ||
from loguru import logger | ||
|
||
# TODO refactor the two load funcs into a class | ||
|
||
|
||
def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str): | ||
return tokenizer( | ||
batch[x_col], | ||
padding="max_length", | ||
truncation=True, | ||
max_length=512, | ||
) | ||
|
||
|
||
def _encode_labels(sample, label2id: dict): | ||
sample["label_ids"] = [] | ||
for label in sample["meshMajor"]: | ||
try: | ||
sample["label_ids"].append(label2id[label]) | ||
except KeyError: | ||
logger.warning(f"Label {label} not found in label2id") | ||
|
||
return sample | ||
|
||
|
||
def _get_label2id(dset): | ||
label_set = set() | ||
for sample in dset: | ||
for label in sample["meshMajor"]: | ||
label_set.add(label) | ||
label2id = {label: idx for idx, label in enumerate(label_set)} | ||
return label2id | ||
|
||
|
||
def load_mesh_json( | ||
data_path: str, | ||
tokenizer: AutoTokenizer, | ||
label2id: dict, | ||
test_size: float = 0.05, | ||
num_proc: int = 8, | ||
max_samples: int = np.inf, | ||
): | ||
def _datagen(mesh_json_path: str, max_samples: int = np.inf): | ||
with open(mesh_json_path, "r", encoding="latin1") as f: | ||
for idx, line in enumerate(f): | ||
# Skip 1st line | ||
if idx == 0: | ||
continue | ||
sample = json.loads(line[:-2]) | ||
|
||
if idx > max_samples: | ||
break | ||
|
||
yield sample | ||
|
||
dset = Dataset.from_generator( | ||
_datagen, | ||
gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples}, | ||
) | ||
|
||
# Remove unused columns to save space & time | ||
dset = dset.remove_columns(["journal", "year", "pmid", "title"]) | ||
|
||
dset = dset.map( | ||
_tokenize, | ||
batched=True, | ||
batch_size=32, | ||
num_proc=num_proc, | ||
desc="Tokenizing", | ||
fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"}, | ||
remove_columns=["abstractText"], | ||
) | ||
|
||
# Generate label2id if None | ||
if label2id is None: | ||
label2id = _get_label2id(dset) | ||
|
||
dset = dset.map( | ||
_encode_labels, | ||
batched=False, | ||
num_proc=num_proc, | ||
desc="Encoding labels", | ||
fn_kwargs={"label2id": label2id}, | ||
remove_columns=["meshMajor"], | ||
) | ||
|
||
# Split into train and test | ||
dset = dset.train_test_split(test_size=test_size) | ||
|
||
return dset["train"], dset["test"], label2id |
Oops, something went wrong.