Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/load data from hf #44

Merged
merged 2 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ Example of `sklearn.svm.SVC`

```yaml
data:
train: train.jsonl
dev: dev.jsonl
test: test.jsonl
label_column: category
text_column: title
args:
path: data/jsonl

output_dir: models/svm-model

Expand Down Expand Up @@ -53,9 +54,10 @@ Example of BERT

```yaml
data:
train: train.jsonl
dev: dev.jsonl
test: test.jsonl
label_column: category
text_column: title
args:
path: data/jsonl

output_dir: models/transformer-model

Expand Down Expand Up @@ -96,36 +98,44 @@ You can set data-related settings in `data`.

```yaml
data:
train: train.jsonl # training data
dev: dev.jsonl # development data
test: test.jsonl # test data
label_column: label
text_column: text
label_column: category
text_column: title
args:
# Use local data in `data/jsonl`. In this path is assumed to contain data files such as train.jsonl, validation.jsonl and test.jsonl
path: data/jsonl

# If you want to use data on Hugging Face Hub, use the following args instead.
# Data is from https://huggingface.co/datasets/shunk031/livedoor-news-corpus
# path: shunk031/livedoor-news-corpus
# random_state: 0
# shuffle: true

```

You can set local files in `train`, `dev`, and `test`.
Supported types are `csv`, `json` and `jsonl`.
data is loaded via [datasets.load_dataset](https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_dataset).
So, you can load local data as well as data on [Hugging Face Hub](https://huggingface.co/datasets).
When loading data, `args` are passed to `load_dataset`.

`label_column` and `text_column` are field names of label and text.
When you set `label_column` to `label` and `text_column` to `text`, which are the default values, actual data must be as follows:

Format of `csv`:

```csv
label,text
label_name,sentence
category,sentence
sports,I like sports!
```

Format of `json`:

```json
[{"label", "label_name", "text": "sentence"}]
[{"category", "sports", "text": "I like sports!"}]
```

Format of `jsonl`:

```json
{"label", "label_name", "text": "sentence"}
{"category", "sports", "text": "I like suports!"}
```

### pipeline
Expand Down
54 changes: 25 additions & 29 deletions bunruija/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from collections import UserDict

import datasets
from datasets import (
Dataset,
DatasetDict,
Expand All @@ -10,43 +11,38 @@


def load_data(
data_path: str | Path,
dataset_args: UserDict,
split: datasets.Split,
label_column: str = "label",
text_column: str | list[str] = "text",
) -> tuple[list[str], list[str] | list[list[str]]]:
if isinstance(data_path, str):
data_path = Path(data_path)
dataset: Dataset | DatasetDict | IterableDataset | IterableDatasetDict = (
load_dataset(split=split, **dataset_args)
)
assert isinstance(dataset, Dataset)

labels: list[str] = []
texts: list[str] | list[list[str]]
texts = [] # type: ignore

if data_path.suffix in [".csv", ".json", ".jsonl"]:
suffix: str = data_path.suffix[1:]
for idx, sample in enumerate(dataset):
label: str

# Because datasets does not support jsonl suffix, convert it to json
if suffix == "jsonl":
suffix = "json"
# If feature of label has names attribute, convert label to actual label strings
if hasattr(dataset.features[label_column], "names"):
label = dataset.features[label_column].names[sample[label_column]]
else:
label = sample[label_column]

# When data_files is only a single data_path, data split is "train"
dataset: DatasetDict | Dataset | IterableDataset | IterableDatasetDict = (
load_dataset(suffix, data_files=str(data_path), split="train")
)
assert isinstance(dataset, Dataset)
labels.append(label)

for idx, sample in enumerate(dataset):
labels.append(sample[label_column])
if isinstance(text_column, str):
input_example = sample[text_column]
texts.append(input_example)
elif isinstance(text_column, list):
if len(text_column) != 2:
raise ValueError(f"{len(text_column)=}")

if isinstance(text_column, str):
input_example = sample[text_column]
texts.append(input_example)
elif isinstance(text_column, list):
if len(text_column) != 2:
raise ValueError(f"{len(text_column)=}")

input_example = [sample[text_column[0]], sample[text_column[1]]]
texts.append(input_example)
return labels, texts

else:
raise ValueError(data_path.suffix)
input_example = [sample[text_column[0]], sample[text_column[1]]]
texts.append(input_example)
return labels, texts
20 changes: 9 additions & 11 deletions bunruija/dataclass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import UserDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
Expand All @@ -13,32 +14,29 @@ class PipelineUnit:

@dataclass
class DataConfig:
train: Path = field(default_factory=Path)
dev: Path = field(default_factory=Path)
test: Path = field(default_factory=Path)
label_column: str = "label"
text_column: str = "text"

def __post_init__(self):
self.train = Path(self.train)
self.dev = Path(self.dev)
self.test = Path(self.test)
text_column: str | list[str] = "text"


@dataclass
class BunruijaConfig:
data: DataConfig
pipeline: list[PipelineUnit]
output_dir: Path
data: DataConfig | None = None
dataset_args: UserDict | None = None

@classmethod
def from_yaml(cls, config_file):
with open(config_file) as f:
yaml = ruamel.yaml.YAML()
config = yaml.load(f)

label_column: str = config["data"].pop("label_column", "label")
text_column: str | list[str] = config["data"].pop("text_column", "text")

return cls(
data=DataConfig(**config["data"]),
data=DataConfig(label_column=label_column, text_column=text_column),
pipeline=[PipelineUnit(**unit) for unit in config["pipeline"]],
output_dir=Path(config.get("output_dir", "output")),
dataset_args=UserDict(config["data"]["args"]),
)
5 changes: 4 additions & 1 deletion bunruija/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from argparse import Namespace

import datasets
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

Expand All @@ -18,10 +19,12 @@ def __init__(self, args: Namespace):

def evaluate(self):
labels_test, X_test = load_data(
self.config.data.test,
self.config.dataset_args,
split=datasets.Split.TEST,
label_column=self.config.data.label_column,
text_column=self.config.data.text_column,
)

y_test: np.ndarray = self.predictor.label_encoder.transform(labels_test)
y_pred: np.ndarray = self.predictor(X_test)

Expand Down
8 changes: 5 additions & 3 deletions bunruija/gen_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ def main(args):

setting = {
"data": {
"train": "train.csv",
"dev": "dev.csv",
"test": "test.csv",
"label_column": "label",
"text_coliumn": "text",
"args": {
"path": "",
},
},
"pipeline": [
infer_vectorizer(model_cls),
Expand Down
34 changes: 18 additions & 16 deletions bunruija/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datasets
import numpy as np
import sklearn # type: ignore
from sklearn.preprocessing import LabelEncoder # type: ignore
Expand All @@ -17,7 +18,8 @@ def __init__(self, config_file: str):

def train(self):
labels_train, X_train = load_data(
self.config.data.train,
self.config.dataset_args,
split=datasets.Split.TRAIN,
label_column=self.config.data.label_column,
text_column=self.config.data.text_column,
)
Expand All @@ -29,21 +31,21 @@ def train(self):

self.saver(self.model, label_encoder)

if self.config.data.dev.exists():
labels_dev, X_dev = load_data(
self.config.data.dev,
label_column=self.config.data.label_column,
text_column=self.config.data.text_column,
)
labels_dev, X_dev = load_data(
self.config.dataset_args,
split=datasets.Split.VALIDATION,
label_column=self.config.data.label_column,
text_column=self.config.data.text_column,
)

y_dev: np.ndarray = label_encoder.transform(labels_dev)
y_dev: np.ndarray = label_encoder.transform(labels_dev)

y_pred = self.model.predict(X_dev)
y_pred = self.model.predict(X_dev)

fscore = sklearn.metrics.f1_score(y_dev, y_pred, average="micro")
print(f"F-score on dev: {fscore}")
target_names: list[str] = list(label_encoder.classes_)
report = sklearn.metrics.classification_report(
y_pred, y_dev, target_names=target_names
)
print(report)
fscore = sklearn.metrics.f1_score(y_dev, y_pred, average="micro")
print(f"F-score on dev: {fscore}")
target_names: list[str] = list(label_encoder.classes_)
report = sklearn.metrics.classification_report(
y_pred, y_dev, target_names=target_names
)
print(report)
8 changes: 5 additions & 3 deletions example/jglue/jcola/create_jcola_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def write_json(ds: Dataset, name: Path):
for sample in ds:
category: str = ds.features["label"].names[sample["label"]]
sample_ = {
"text": sample["sentence"],
"sentence": sample["sentence"],
"label": category,
}
print(json.dumps(sample_), file=f)
Expand All @@ -20,7 +20,9 @@ def write_json(ds: Dataset, name: Path):

def main():
parser = ArgumentParser()
parser.add_argument("--output_dir", default="example/jglue/jcola/data", type=Path)
parser.add_argument(
"--output_dir", default="example/jglue/jcola/data/jsonl", type=Path
)
args = parser.parse_args()

if not args.output_dir.exists():
Expand All @@ -29,7 +31,7 @@ def main():
dataset = load_dataset("shunk031/JGLUE", name="JCoLA")

write_json(dataset["train"], args.output_dir / "train.jsonl")
write_json(dataset["validation"], args.output_dir / "dev.jsonl")
write_json(dataset["validation"], args.output_dir / "validation.jsonl")


if __name__ == "__main__":
Expand Down
10 changes: 6 additions & 4 deletions example/jglue/jnli/create_jnli_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def write_json(ds: Dataset, name: Path):
for sample in ds:
category: str = ds.features["label"].names[sample["label"]]
sample_ = {
"text1": sample["sentence1"],
"text2": sample["sentence2"],
"sentence1": sample["sentence1"],
"sentence2": sample["sentence2"],
"label": category,
}
print(json.dumps(sample_), file=f)
Expand All @@ -21,7 +21,9 @@ def write_json(ds: Dataset, name: Path):

def main():
parser = ArgumentParser()
parser.add_argument("--output_dir", default="example/jglue/jnli/data", type=Path)
parser.add_argument(
"--output_dir", default="example/jglue/jnli/data/jsonl", type=Path
)
args = parser.parse_args()

if not args.output_dir.exists():
Expand All @@ -31,7 +33,7 @@ def main():
print(dataset)

write_json(dataset["train"], args.output_dir / "train.jsonl")
write_json(dataset["validation"], args.output_dir / "dev.jsonl")
write_json(dataset["validation"], args.output_dir / "validation.jsonl")


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions example/jglue/marc_ja/create_marc_ja_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def write_json(ds: Dataset, name: Path):
for sample in ds:
category: str = ds.features["label"].names[sample["label"]]
sample_ = {
"text": sample["sentence"],
"sentence": sample["sentence"],
"label": category,
}
print(json.dumps(sample_), file=f)
Expand All @@ -20,7 +20,9 @@ def write_json(ds: Dataset, name: Path):

def main():
parser = ArgumentParser()
parser.add_argument("--output_dir", default="example/jglue/jcola/data", type=Path)
parser.add_argument(
"--output_dir", default="example/jglue/marc_ja/data/jsonl", type=Path
)
args = parser.parse_args()

if not args.output_dir.exists():
Expand All @@ -29,7 +31,7 @@ def main():
dataset = load_dataset("shunk031/JGLUE", name="MARC-ja")

write_json(dataset["train"], args.output_dir / "train.jsonl")
write_json(dataset["validation"], args.output_dir / "dev.jsonl")
write_json(dataset["validation"], args.output_dir / "validation.jsonl")


if __name__ == "__main__":
Expand Down
8 changes: 6 additions & 2 deletions example/jglue/settings/classification/svm.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
data:
train: data/train.jsonl
dev: data/dev.jsonl
label_column: label
text_column: sentence
args:
path: shunk031/JGLUE
name: JCoLA
# path: data/jsonl

output_dir: models/svm-model

Expand Down
Loading
Loading