Skip to content

Commit

Permalink
revert separate preprocess command
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei Apostol committed Jul 18, 2023
1 parent c4bd743 commit 1660ef0
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 133 deletions.
8 changes: 4 additions & 4 deletions grants_tagger_light/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import typer

from .preprocess_mesh import preprocess_mesh_cli

preprocess_app = typer.Typer()
preprocess_app.command(
"mesh",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)(preprocess_mesh_cli)
preprocess_app.command("bioasq-mesh")(preprocess_mesh_cli)

__all__ = ["preprocess_app"]
222 changes: 105 additions & 117 deletions grants_tagger_light/preprocessing/preprocess_mesh.py
Original file line number Diff line number Diff line change
@@ -1,147 +1,135 @@
"""
Preprocess JSON Mesh data from BioASQ to JSONL
"""
import json
import numpy as np
import os
from pathlib import Path
from typing import Optional

import pandas as pd
import typer
from transformers import AutoTokenizer
from datasets import Dataset, disable_caching
from loguru import logger
from grants_tagger_light.models.bert_mesh import BertMesh
from tqdm import tqdm

# TODO refactor the two load funcs into a class
from grants_tagger_light.utils import (
write_jsonl,
)

disable_caching()
preprocess_app = typer.Typer()

def yield_raw_data(input_path):
with open(input_path, encoding="latin-1") as f_i:
f_i.readline() # skip first line ({"articles":[) which is not valid JSON
for i, line in enumerate(f_i):
item = json.loads(line[:-2])
yield item

def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str):
return tokenizer(
batch[x_col],
padding="max_length",
truncation=True,
max_length=512,
)

def process_data(item, filter_tags=None, filter_years=None):
text = item["abstractText"]
tags = item["meshMajor"]
journal = item["journal"]
year = item["year"]
if filter_tags:
tags = list(set(tags).intersection(filter_tags))
if not tags:
return
if filter_years:
min_year, max_year = filter_years.split(",")
if year > int(max_year):
return
if year < int(min_year):
return
data = {"text": text, "tags": tags, "meta": {"journal": journal, "year": year}}
return data


def _encode_labels(sample, label2id: dict):
sample["label_ids"] = []
for label in sample["meshMajor"]:
try:
sample["label_ids"].append(label2id[label])
except KeyError:
logger.warning(f"Label {label} not found in label2id")
def yield_data(input_path, filter_tags, filter_years, buffer_size=10_000):
data_batch = []
for item in tqdm(yield_raw_data(input_path), total=15_000_000): # approx 15M docs
processed_item = process_data(item, filter_tags, filter_years)

return sample
if processed_item:
data_batch.append(processed_item)

if len(data_batch) >= buffer_size:
yield data_batch
data_batch = []

def _get_label2id(dset):
label_set = set()
for sample in dset:
for label in sample["meshMajor"]:
label_set.add(label)
label2id = {label: idx for idx, label in enumerate(label_set)}
return label2id
if data_batch:
yield data_batch


def preprocess_mesh(
data_path: str,
save_loc: str,
model_key: str,
test_size: float = 0.05,
num_proc: int = 8,
max_samples: int = np.inf,
raw_data_path,
processed_data_path,
mesh_tags_path=None,
filter_years=None,
n_max=None,
buffer_size=10_000,
):
if not model_key:
label2id = None
# Use the same pretrained tokenizer as in Wellcome/WellcomeBertMesh
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
)
if mesh_tags_path:
filter_tags_data = pd.read_csv(mesh_tags_path)
filter_tags = filter_tags_data["DescriptorName"].tolist()
filter_tags = set(filter_tags)
else:
# Load the model to get its label2id
tokenizer = AutoTokenizer.from_pretrained(model_key)
model = BertMesh.from_pretrained(model_key, trust_remote_code=True)

label2id = {v: k for k, v in model.id2label.items()}

def _datagen(mesh_json_path: str, max_samples: int = np.inf):
with open(mesh_json_path, "r", encoding="latin1") as f:
for idx, line in enumerate(f):
# Skip 1st line
if idx == 0:
continue
sample = json.loads(line[:-2])

if idx > max_samples:
break

yield sample

dset = Dataset.from_generator(
_datagen,
gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples},
)
filter_tags = None

# Remove unused columns to save space & time
dset = dset.remove_columns(["journal", "year", "pmid", "title"])

dset = dset.map(
_tokenize,
batched=True,
batch_size=32,
num_proc=num_proc,
desc="Tokenizing",
fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"},
remove_columns=["abstractText"],
)
if filter_years:
min_year, max_year = filter_years.split(",")
filter_years = [int(min_year), int(max_year)]

# Generate label2id if None
if label2id is None:
label2id = _get_label2id(dset)

dset = dset.map(
_encode_labels,
batched=False,
num_proc=num_proc,
desc="Encoding labels",
fn_kwargs={"label2id": label2id},
remove_columns=["meshMajor"],
)
# If only using a tiny set of data, need to reduce buffer size
if n_max is not None and n_max < buffer_size:
buffer_size = n_max

# Split into train and test
dset = dset.train_test_split(test_size=test_size)
with open(processed_data_path, "w") as f:
for i, data_batch in enumerate(
yield_data(
raw_data_path, filter_tags, filter_years, buffer_size=buffer_size
)
):
write_jsonl(f, data_batch)
if n_max and (i + 1) * buffer_size >= n_max:
break

# Save to disk
dset.save_to_disk(os.path.join(save_loc, "dataset"))

with open(os.path.join(save_loc, "label2id.json"), "w") as f:
json.dump(label2id, f, indent=4)
preprocess_mesh_app = typer.Typer()


@preprocess_app.command()
@preprocess_mesh_app.command()
def preprocess_mesh_cli(
data_path: str = typer.Argument(..., help="Path to mesh.json"),
save_loc: str = typer.Argument(..., help="Path to save processed data"),
model_key: str = typer.Argument(
...,
help="Key to use when loading tokenizer and label2id. Leave blank if training from scratch", # noqa
input_path: Optional[str] = typer.Argument(None, help="path to BioASQ JSON data"),
train_output_path: Optional[str] = typer.Argument(
None, help="path to JSONL output file that will be generated for the train set"
),
label_binarizer_path: Optional[Path] = typer.Argument(
None, help="path to pickle file that will contain the label binarizer"
),
test_output_path: Optional[str] = typer.Option(
None, help="path to JSONL output file that will be generated for the test set"
),
mesh_tags_path: Optional[str] = typer.Option(
None, help="path to mesh tags to filter"
),
test_size: float = typer.Option(0.05, help="Fraction of data to use for testing"),
num_proc: int = typer.Option(
8, help="Number of processes to use for preprocessing"
test_split: Optional[float] = typer.Option(
0.01, help="split percentage for test data. if None no split."
),
max_samples: int = typer.Option(
-1,
help="Maximum number of samples to use for preprocessing",
filter_years: Optional[str] = typer.Option(
None, help="years to keep in form min_year,max_year with both inclusive"
),
n_max: Optional[int] = typer.Option(
None,
help="""
Maximum limit on the number of datapoints in the set
(including training and test)""",
),
):
if max_samples == -1:
max_samples = np.inf

preprocess_mesh(
data_path=data_path,
save_loc=save_loc,
model_key=model_key,
test_size=test_size,
num_proc=num_proc,
max_samples=max_samples,
input_path,
train_output_path,
mesh_tags_path=mesh_tags_path,
filter_years=filter_years,
n_max=n_max,
)


if __name__ == "__main__":
typer.run(preprocess_mesh)
3 changes: 2 additions & 1 deletion grants_tagger_light/training/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .mesh_json_loader import load_mesh_json
from .multilabel_collator import MultilabelDataCollator

__all__ = ["MultilabelDataCollator"]
__all__ = ["load_mesh_json", "MultilabelDataCollator"]
94 changes: 94 additions & 0 deletions grants_tagger_light/training/dataloaders/mesh_json_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
import numpy as np
from transformers import AutoTokenizer
from datasets import Dataset
from loguru import logger

# TODO refactor the two load funcs into a class


def _tokenize(batch, tokenizer: AutoTokenizer, x_col: str):
return tokenizer(
batch[x_col],
padding="max_length",
truncation=True,
max_length=512,
)


def _encode_labels(sample, label2id: dict):
sample["label_ids"] = []
for label in sample["meshMajor"]:
try:
sample["label_ids"].append(label2id[label])
except KeyError:
logger.warning(f"Label {label} not found in label2id")

return sample


def _get_label2id(dset):
label_set = set()
for sample in dset:
for label in sample["meshMajor"]:
label_set.add(label)
label2id = {label: idx for idx, label in enumerate(label_set)}
return label2id


def load_mesh_json(
data_path: str,
tokenizer: AutoTokenizer,
label2id: dict,
test_size: float = 0.05,
num_proc: int = 8,
max_samples: int = np.inf,
):
def _datagen(mesh_json_path: str, max_samples: int = np.inf):
with open(mesh_json_path, "r", encoding="latin1") as f:
for idx, line in enumerate(f):
# Skip 1st line
if idx == 0:
continue
sample = json.loads(line[:-2])

if idx > max_samples:
break

yield sample

dset = Dataset.from_generator(
_datagen,
gen_kwargs={"mesh_json_path": data_path, "max_samples": max_samples},
)

# Remove unused columns to save space & time
dset = dset.remove_columns(["journal", "year", "pmid", "title"])

dset = dset.map(
_tokenize,
batched=True,
batch_size=32,
num_proc=num_proc,
desc="Tokenizing",
fn_kwargs={"tokenizer": tokenizer, "x_col": "abstractText"},
remove_columns=["abstractText"],
)

# Generate label2id if None
if label2id is None:
label2id = _get_label2id(dset)

dset = dset.map(
_encode_labels,
batched=False,
num_proc=num_proc,
desc="Encoding labels",
fn_kwargs={"label2id": label2id},
remove_columns=["meshMajor"],
)

# Split into train and test
dset = dset.train_test_split(test_size=test_size)

return dset["train"], dset["test"], label2id
Loading

0 comments on commit 1660ef0

Please sign in to comment.