From 55a85bd38c97d7b374a663459f99a93302d22e79 Mon Sep 17 00:00:00 2001 From: Takuya Makino Date: Fri, 23 Feb 2024 16:36:08 +0900 Subject: [PATCH 1/2] Enable loading dataset from Hugging Face Hub --- README.md | 46 +++++++++------- bunruija/data/dataset.py | 54 +++++++++---------- bunruija/dataclass.py | 20 ++++--- bunruija/evaluator.py | 5 +- bunruija/gen_yaml.py | 8 +-- bunruija/trainer.py | 34 ++++++------ example/jglue/jcola/create_jcola_data.py | 8 +-- example/jglue/jnli/create_jnli_data.py | 10 ++-- example/jglue/marc_ja/create_marc_ja_data.py | 8 +-- .../jglue/settings/classification/svm.yaml | 8 ++- .../sentence_pair_classification/lstm.yaml | 9 ++-- .../transformer.yaml | 11 ++-- example/livedoor_corpus/create_data.py | 27 ++++------ example/livedoor_corpus/settings/lgbm.yaml | 10 ++-- example/livedoor_corpus/settings/lstm.yaml | 10 ++-- example/livedoor_corpus/settings/prado.yaml | 10 ++-- example/livedoor_corpus/settings/qrnn.yaml | 10 ++-- .../livedoor_corpus/settings/stacking.yaml | 10 ++-- example/livedoor_corpus/settings/svm.yaml | 10 ++-- .../livedoor_corpus/settings/transformer.yaml | 10 ++-- tests/data/test_dataset.py | 44 +++++---------- tests/test_binary.py | 10 ++-- 22 files changed, 200 insertions(+), 172 deletions(-) diff --git a/README.md b/README.md index 8338c63..abb3db8 100755 --- a/README.md +++ b/README.md @@ -22,9 +22,10 @@ Example of `sklearn.svm.SVC` ```yaml data: - train: train.jsonl - dev: dev.jsonl - test: test.jsonl + label_column: category + text_column: title + args: + path: data/jsonl output_dir: models/svm-model @@ -53,9 +54,10 @@ Example of BERT ```yaml data: - train: train.jsonl - dev: dev.jsonl - test: test.jsonl + label_column: category + text_column: title + args: + path: data/jsonl output_dir: models/transformer-model @@ -96,36 +98,44 @@ You can set data-related settings in `data`. ```yaml data: - train: train.jsonl # training data - dev: dev.jsonl # development data - test: test.jsonl # test data - label_column: label - text_column: text + label_column: category + text_column: title + args: + # Use local data in `data/jsonl`. In this path is assumed to contain data files such as train.jsonl, validation.jsonl and test.jsonl + path: data/jsonl + + # If you want to use data on Hugging Face Hub, use the following args instead. + # Data is from https://huggingface.co/datasets/shunk031/livedoor-news-corpus + # path: shunk031/livedoor-news-corpus + # random_state: 0 + # shuffle: true + ``` -You can set local files in `train`, `dev`, and `test`. -Supported types are `csv`, `json` and `jsonl`. +data is loaded via [datasets.load_dataset](https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_dataset). +So, you can load local data as well as data on [Hugging Face Hub](https://huggingface.co/datasets). +When loading data, `args` are passed to `load_dataset`. + `label_column` and `text_column` are field names of label and text. -When you set `label_column` to `label` and `text_column` to `text`, which are the default values, actual data must be as follows: Format of `csv`: ```csv -label,text -label_name,sentence +category,sentence +sports,I like sports! … ``` Format of `json`: ```json -[{"label", "label_name", "text": "sentence"}] +[{"category", "sports", "text": "I like sports!"}] ``` Format of `jsonl`: ```json -{"label", "label_name", "text": "sentence"} +{"category", "sports", "text": "I like suports!"} ``` ### pipeline diff --git a/bunruija/data/dataset.py b/bunruija/data/dataset.py index e750cf7..cded57c 100644 --- a/bunruija/data/dataset.py +++ b/bunruija/data/dataset.py @@ -1,5 +1,6 @@ -from pathlib import Path +from collections import UserDict +import datasets from datasets import ( Dataset, DatasetDict, @@ -10,43 +11,38 @@ def load_data( - data_path: str | Path, + dataset_args: UserDict, + split: datasets.Split, label_column: str = "label", text_column: str | list[str] = "text", ) -> tuple[list[str], list[str] | list[list[str]]]: - if isinstance(data_path, str): - data_path = Path(data_path) + dataset: Dataset | DatasetDict | IterableDataset | IterableDatasetDict = ( + load_dataset(split=split, **dataset_args) + ) + assert isinstance(dataset, Dataset) labels: list[str] = [] texts: list[str] | list[list[str]] texts = [] # type: ignore - if data_path.suffix in [".csv", ".json", ".jsonl"]: - suffix: str = data_path.suffix[1:] + for idx, sample in enumerate(dataset): + label: str - # Because datasets does not support jsonl suffix, convert it to json - if suffix == "jsonl": - suffix = "json" + # If feature of label has names attribute, convert label to actual label strings + if hasattr(dataset.features[label_column], "names"): + label = dataset.features[label_column].names[sample[label_column]] + else: + label = sample[label_column] - # When data_files is only a single data_path, data split is "train" - dataset: DatasetDict | Dataset | IterableDataset | IterableDatasetDict = ( - load_dataset(suffix, data_files=str(data_path), split="train") - ) - assert isinstance(dataset, Dataset) + labels.append(label) - for idx, sample in enumerate(dataset): - labels.append(sample[label_column]) + if isinstance(text_column, str): + input_example = sample[text_column] + texts.append(input_example) + elif isinstance(text_column, list): + if len(text_column) != 2: + raise ValueError(f"{len(text_column)=}") - if isinstance(text_column, str): - input_example = sample[text_column] - texts.append(input_example) - elif isinstance(text_column, list): - if len(text_column) != 2: - raise ValueError(f"{len(text_column)=}") - - input_example = [sample[text_column[0]], sample[text_column[1]]] - texts.append(input_example) - return labels, texts - - else: - raise ValueError(data_path.suffix) + input_example = [sample[text_column[0]], sample[text_column[1]]] + texts.append(input_example) + return labels, texts diff --git a/bunruija/dataclass.py b/bunruija/dataclass.py index 742aabb..a0d5f95 100644 --- a/bunruija/dataclass.py +++ b/bunruija/dataclass.py @@ -1,3 +1,4 @@ +from collections import UserDict from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -13,23 +14,16 @@ class PipelineUnit: @dataclass class DataConfig: - train: Path = field(default_factory=Path) - dev: Path = field(default_factory=Path) - test: Path = field(default_factory=Path) label_column: str = "label" - text_column: str = "text" - - def __post_init__(self): - self.train = Path(self.train) - self.dev = Path(self.dev) - self.test = Path(self.test) + text_column: str | list[str] = "text" @dataclass class BunruijaConfig: - data: DataConfig pipeline: list[PipelineUnit] output_dir: Path + data: DataConfig | None = None + dataset_args: UserDict | None = None @classmethod def from_yaml(cls, config_file): @@ -37,8 +31,12 @@ def from_yaml(cls, config_file): yaml = ruamel.yaml.YAML() config = yaml.load(f) + label_column: str = config["data"].pop("label_column", "label") + text_column: str | list[str] = config["data"].pop("text_column", "text") + return cls( - data=DataConfig(**config["data"]), + data=DataConfig(label_column=label_column, text_column=text_column), pipeline=[PipelineUnit(**unit) for unit in config["pipeline"]], output_dir=Path(config.get("output_dir", "output")), + dataset_args=UserDict(config["data"]["args"]), ) diff --git a/bunruija/evaluator.py b/bunruija/evaluator.py index fde21cd..e1af83d 100755 --- a/bunruija/evaluator.py +++ b/bunruija/evaluator.py @@ -1,5 +1,6 @@ from argparse import Namespace +import datasets import numpy as np from sklearn.metrics import classification_report, confusion_matrix @@ -18,10 +19,12 @@ def __init__(self, args: Namespace): def evaluate(self): labels_test, X_test = load_data( - self.config.data.test, + self.config.dataset_args, + split=datasets.Split.TEST, label_column=self.config.data.label_column, text_column=self.config.data.text_column, ) + y_test: np.ndarray = self.predictor.label_encoder.transform(labels_test) y_pred: np.ndarray = self.predictor(X_test) diff --git a/bunruija/gen_yaml.py b/bunruija/gen_yaml.py index a60f1cf..4770b91 100755 --- a/bunruija/gen_yaml.py +++ b/bunruija/gen_yaml.py @@ -72,9 +72,11 @@ def main(args): setting = { "data": { - "train": "train.csv", - "dev": "dev.csv", - "test": "test.csv", + "label_column": "label", + "text_coliumn": "text", + "args": { + "path": "", + }, }, "pipeline": [ infer_vectorizer(model_cls), diff --git a/bunruija/trainer.py b/bunruija/trainer.py index 2782788..344ce91 100755 --- a/bunruija/trainer.py +++ b/bunruija/trainer.py @@ -1,3 +1,4 @@ +import datasets import numpy as np import sklearn # type: ignore from sklearn.preprocessing import LabelEncoder # type: ignore @@ -17,7 +18,8 @@ def __init__(self, config_file: str): def train(self): labels_train, X_train = load_data( - self.config.data.train, + self.config.dataset_args, + split=datasets.Split.TRAIN, label_column=self.config.data.label_column, text_column=self.config.data.text_column, ) @@ -29,21 +31,21 @@ def train(self): self.saver(self.model, label_encoder) - if self.config.data.dev.exists(): - labels_dev, X_dev = load_data( - self.config.data.dev, - label_column=self.config.data.label_column, - text_column=self.config.data.text_column, - ) + labels_dev, X_dev = load_data( + self.config.dataset_args, + split=datasets.Split.VALIDATION, + label_column=self.config.data.label_column, + text_column=self.config.data.text_column, + ) - y_dev: np.ndarray = label_encoder.transform(labels_dev) + y_dev: np.ndarray = label_encoder.transform(labels_dev) - y_pred = self.model.predict(X_dev) + y_pred = self.model.predict(X_dev) - fscore = sklearn.metrics.f1_score(y_dev, y_pred, average="micro") - print(f"F-score on dev: {fscore}") - target_names: list[str] = list(label_encoder.classes_) - report = sklearn.metrics.classification_report( - y_pred, y_dev, target_names=target_names - ) - print(report) + fscore = sklearn.metrics.f1_score(y_dev, y_pred, average="micro") + print(f"F-score on dev: {fscore}") + target_names: list[str] = list(label_encoder.classes_) + report = sklearn.metrics.classification_report( + y_pred, y_dev, target_names=target_names + ) + print(report) diff --git a/example/jglue/jcola/create_jcola_data.py b/example/jglue/jcola/create_jcola_data.py index 80a04fe..7d7c2e8 100755 --- a/example/jglue/jcola/create_jcola_data.py +++ b/example/jglue/jcola/create_jcola_data.py @@ -11,7 +11,7 @@ def write_json(ds: Dataset, name: Path): for sample in ds: category: str = ds.features["label"].names[sample["label"]] sample_ = { - "text": sample["sentence"], + "sentence": sample["sentence"], "label": category, } print(json.dumps(sample_), file=f) @@ -20,7 +20,9 @@ def write_json(ds: Dataset, name: Path): def main(): parser = ArgumentParser() - parser.add_argument("--output_dir", default="example/jglue/jcola/data", type=Path) + parser.add_argument( + "--output_dir", default="example/jglue/jcola/data/jsonl", type=Path + ) args = parser.parse_args() if not args.output_dir.exists(): @@ -29,7 +31,7 @@ def main(): dataset = load_dataset("shunk031/JGLUE", name="JCoLA") write_json(dataset["train"], args.output_dir / "train.jsonl") - write_json(dataset["validation"], args.output_dir / "dev.jsonl") + write_json(dataset["validation"], args.output_dir / "validation.jsonl") if __name__ == "__main__": diff --git a/example/jglue/jnli/create_jnli_data.py b/example/jglue/jnli/create_jnli_data.py index 555814f..b49d388 100755 --- a/example/jglue/jnli/create_jnli_data.py +++ b/example/jglue/jnli/create_jnli_data.py @@ -11,8 +11,8 @@ def write_json(ds: Dataset, name: Path): for sample in ds: category: str = ds.features["label"].names[sample["label"]] sample_ = { - "text1": sample["sentence1"], - "text2": sample["sentence2"], + "sentence1": sample["sentence1"], + "sentence2": sample["sentence2"], "label": category, } print(json.dumps(sample_), file=f) @@ -21,7 +21,9 @@ def write_json(ds: Dataset, name: Path): def main(): parser = ArgumentParser() - parser.add_argument("--output_dir", default="example/jglue/jnli/data", type=Path) + parser.add_argument( + "--output_dir", default="example/jglue/jnli/data/jsonl", type=Path + ) args = parser.parse_args() if not args.output_dir.exists(): @@ -31,7 +33,7 @@ def main(): print(dataset) write_json(dataset["train"], args.output_dir / "train.jsonl") - write_json(dataset["validation"], args.output_dir / "dev.jsonl") + write_json(dataset["validation"], args.output_dir / "validation.jsonl") if __name__ == "__main__": diff --git a/example/jglue/marc_ja/create_marc_ja_data.py b/example/jglue/marc_ja/create_marc_ja_data.py index 2046fe6..ff6d834 100755 --- a/example/jglue/marc_ja/create_marc_ja_data.py +++ b/example/jglue/marc_ja/create_marc_ja_data.py @@ -11,7 +11,7 @@ def write_json(ds: Dataset, name: Path): for sample in ds: category: str = ds.features["label"].names[sample["label"]] sample_ = { - "text": sample["sentence"], + "sentence": sample["sentence"], "label": category, } print(json.dumps(sample_), file=f) @@ -20,7 +20,9 @@ def write_json(ds: Dataset, name: Path): def main(): parser = ArgumentParser() - parser.add_argument("--output_dir", default="example/jglue/jcola/data", type=Path) + parser.add_argument( + "--output_dir", default="example/jglue/marc_ja/data/jsonl", type=Path + ) args = parser.parse_args() if not args.output_dir.exists(): @@ -29,7 +31,7 @@ def main(): dataset = load_dataset("shunk031/JGLUE", name="MARC-ja") write_json(dataset["train"], args.output_dir / "train.jsonl") - write_json(dataset["validation"], args.output_dir / "dev.jsonl") + write_json(dataset["validation"], args.output_dir / "validation.jsonl") if __name__ == "__main__": diff --git a/example/jglue/settings/classification/svm.yaml b/example/jglue/settings/classification/svm.yaml index 8eda254..155313b 100755 --- a/example/jglue/settings/classification/svm.yaml +++ b/example/jglue/settings/classification/svm.yaml @@ -1,6 +1,10 @@ data: - train: data/train.jsonl - dev: data/dev.jsonl + label_column: label + text_column: sentence + args: + path: shunk031/JGLUE + name: JCoLA + # path: data/jsonl output_dir: models/svm-model diff --git a/example/jglue/settings/sentence_pair_classification/lstm.yaml b/example/jglue/settings/sentence_pair_classification/lstm.yaml index 8fd391a..712a3c3 100755 --- a/example/jglue/settings/sentence_pair_classification/lstm.yaml +++ b/example/jglue/settings/sentence_pair_classification/lstm.yaml @@ -1,9 +1,10 @@ data: - train: data/train.jsonl - dev: data/dev.jsonl + label_column: label text_column: - - text1 - - text2 + - sentence1 + - sentence2 + args: + path: data/jsonl output_dir: models/lstm-model diff --git a/example/jglue/settings/sentence_pair_classification/transformer.yaml b/example/jglue/settings/sentence_pair_classification/transformer.yaml index b5268f6..3a3950a 100755 --- a/example/jglue/settings/sentence_pair_classification/transformer.yaml +++ b/example/jglue/settings/sentence_pair_classification/transformer.yaml @@ -1,11 +1,12 @@ data: - train: data/train.jsonl - dev: data/dev.jsonl + label_column: label text_column: - - text1 - - text2 + - sentence1 + - sentence2 + args: + path: data/jsonl -output_dir: models/lstm-model +output_dir: models/transformer-model pipeline: - type: bunruija.feature_extraction.SequencePairVectorizer diff --git a/example/livedoor_corpus/create_data.py b/example/livedoor_corpus/create_data.py index 4d4e058..a9160c1 100755 --- a/example/livedoor_corpus/create_data.py +++ b/example/livedoor_corpus/create_data.py @@ -1,4 +1,3 @@ -import csv import json from argparse import ArgumentParser from pathlib import Path @@ -7,23 +6,13 @@ from loguru import logger # type: ignore -def write_csv(ds: Dataset, name: Path): - with open(name, "w") as f: - writer = csv.writer(f) - writer.writerow(["label", "text"]) - for sample in ds: - category: str = ds.features["category"].names[sample["category"]] - writer.writerow([f"{category}", sample["title"]]) - logger.info(f"{name}") - - def write_json(ds: Dataset, name: Path): with open(name, "w") as f: for sample in ds: category: str = ds.features["category"].names[sample["category"]] sample_ = { - "text": sample["title"], - "label": category, + "title": sample["title"], + "category": category, } print(json.dumps(sample_), file=f) logger.info(f"{name}") @@ -31,7 +20,9 @@ def write_json(ds: Dataset, name: Path): def main(): parser = ArgumentParser() - parser.add_argument("--output_dir", default="example/livedoor_corpus", type=Path) + parser.add_argument( + "--output_dir", default="example/livedoor_corpus/data/jsonl", type=Path + ) args = parser.parse_args() dataset = load_dataset( @@ -39,12 +30,12 @@ def main(): random_state=0, shuffle=True, ) - write_csv(dataset["train"], args.output_dir / "train.csv") - write_csv(dataset["validation"], args.output_dir / "dev.csv") - write_csv(dataset["test"], args.output_dir / "test.csv") + + if not args.output_dir.exists(): + args.output_dir.mkdir(parents=True) write_json(dataset["train"], args.output_dir / "train.jsonl") - write_json(dataset["validation"], args.output_dir / "dev.jsonl") + write_json(dataset["validation"], args.output_dir / "validation.jsonl") write_json(dataset["test"], args.output_dir / "test.jsonl") diff --git a/example/livedoor_corpus/settings/lgbm.yaml b/example/livedoor_corpus/settings/lgbm.yaml index c0277f5..90f2f2a 100755 --- a/example/livedoor_corpus/settings/lgbm.yaml +++ b/example/livedoor_corpus/settings/lgbm.yaml @@ -1,7 +1,11 @@ data: - train: train.csv - dev: dev.csv - test: test.csv + label_column: category + text_column: title + args: + # path: shunk031/livedoor-news-corpus + # random_state: 0 + # shuffle: true + path: data/jsonl output_dir: models/lgb-model diff --git a/example/livedoor_corpus/settings/lstm.yaml b/example/livedoor_corpus/settings/lstm.yaml index c79588a..9f769f0 100755 --- a/example/livedoor_corpus/settings/lstm.yaml +++ b/example/livedoor_corpus/settings/lstm.yaml @@ -1,7 +1,11 @@ data: - train: train.csv - dev: dev.csv - test: test.csv + label_column: category + text_column: title + args: + # path: shunk031/livedoor-news-corpus + # random_state: 0 + # shuffle: true + path: data/jsonl output_dir: models/lstm-model diff --git a/example/livedoor_corpus/settings/prado.yaml b/example/livedoor_corpus/settings/prado.yaml index 7f41a55..89456b5 100755 --- a/example/livedoor_corpus/settings/prado.yaml +++ b/example/livedoor_corpus/settings/prado.yaml @@ -1,7 +1,11 @@ data: - train: train.csv - dev: dev.csv - test: test.csv + label_column: category + text_column: title + args: + # path: shunk031/livedoor-news-corpus + # random_state: 0 + # shuffle: true + path: data/jsonl output_dir: models/prado-model diff --git a/example/livedoor_corpus/settings/qrnn.yaml b/example/livedoor_corpus/settings/qrnn.yaml index 82929eb..7505938 100755 --- a/example/livedoor_corpus/settings/qrnn.yaml +++ b/example/livedoor_corpus/settings/qrnn.yaml @@ -1,7 +1,11 @@ data: - train: train.csv - dev: dev.csv - test: test.csv + label_column: category + text_column: title + args: + # path: shunk031/livedoor-news-corpus + # random_state: 0 + # shuffle: true + path: data/jsonl output_dir: models/qrnn-model diff --git a/example/livedoor_corpus/settings/stacking.yaml b/example/livedoor_corpus/settings/stacking.yaml index d31b494..bb77dec 100755 --- a/example/livedoor_corpus/settings/stacking.yaml +++ b/example/livedoor_corpus/settings/stacking.yaml @@ -1,7 +1,11 @@ data: - train: train.csv - dev: dev.csv - test: test.csv + label_column: category + text_column: title + args: + # path: shunk031/livedoor-news-corpus + # random_state: 0 + # shuffle: true + path: data/jsonl output_dir: models/stacking-model diff --git a/example/livedoor_corpus/settings/svm.yaml b/example/livedoor_corpus/settings/svm.yaml index 2abe261..62ce0af 100755 --- a/example/livedoor_corpus/settings/svm.yaml +++ b/example/livedoor_corpus/settings/svm.yaml @@ -1,7 +1,11 @@ data: - train: train.csv - dev: dev.csv - test: test.csv + label_column: category + text_column: title + args: + # path: shunk031/livedoor-news-corpus + # random_state: 0 + # shuffle: true + path: data/jsonl output_dir: models/svm-model diff --git a/example/livedoor_corpus/settings/transformer.yaml b/example/livedoor_corpus/settings/transformer.yaml index 6a03ac9..ddc5302 100755 --- a/example/livedoor_corpus/settings/transformer.yaml +++ b/example/livedoor_corpus/settings/transformer.yaml @@ -1,7 +1,11 @@ data: - train: train.csv - dev: dev.csv - test: test.csv + label_column: category + text_column: title + args: + # path: shunk031/livedoor-news-corpus + # random_state: 0 + # shuffle: true + path: data/jsonl output_dir: models/transformer-model diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index fd060c8..7589a74 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -3,6 +3,7 @@ import tempfile from pathlib import Path +import datasets import pytest from bunruija.data.dataset import load_data @@ -55,38 +56,21 @@ def test_load_data(suffix): }, ] - with tempfile.TemporaryDirectory("test_load_data}") as data_dir: - data_file = Path(data_dir) / ("sample." + suffix) - create_dummy_data(samples, data_file, "label", "text") - labels, texts = load_data(data_file) + with tempfile.TemporaryDirectory("test_load_data") as data_dir: + data_dir = Path(data_dir) / suffix + if not data_dir.exists(): + data_dir.mkdir(parents=True) - assert labels == [sample["label"] for sample in samples] - assert texts == [sample["text"] for sample in samples] + for split in ["train", "validation", "test"]: + data_file = data_dir / (f"{split}." + suffix) + create_dummy_data(samples, data_file, "label", "text") - -@pytest.mark.parametrize("suffix", ["csv", "jsonl", "json"]) -def test_load_data_2(suffix): - samples = [ - { - "category": "A", - "sample": "text 1", - }, - { - "category": "B", - "sample": "text 2", - }, - { - "category": "C", - "sample": "text 3", - }, - ] - - with tempfile.TemporaryDirectory("test_load_data}") as data_dir: - data_file = Path(data_dir) / ("sample." + suffix) - create_dummy_data(samples, data_file, "category", "sample") labels, texts = load_data( - data_file, label_column="category", text_column="sample" + dataset_args={"path": str(data_dir)}, + label_column="label", + text_column="text", + split=datasets.Split.TRAIN, ) - assert labels == [sample["category"] for sample in samples] - assert texts == [sample["sample"] for sample in samples] + assert labels == [sample["label"] for sample in samples] + assert texts == [sample["text"] for sample in samples] diff --git a/tests/test_binary.py b/tests/test_binary.py index fe40649..757720e 100755 --- a/tests/test_binary.py +++ b/tests/test_binary.py @@ -46,7 +46,7 @@ def _create_dummy_data(filename): writer.writerow([label, sample_str]) _create_dummy_data("train.csv") - _create_dummy_data("dev.csv") + _create_dummy_data("validation.csv") _create_dummy_data("test.csv") @@ -56,9 +56,7 @@ def rewrite_data_path(self, data_dir, yaml_file): with open(yaml_file, "r") as f: setting = yaml.load(f) - setting["data"]["train"] = str(Path(data_dir) / "train.csv") - setting["data"]["dev"] = str(Path(data_dir) / "dev.csv") - setting["data"]["test"] = str(Path(data_dir) / "test.csv") + setting["data"]["args"]["path"] = str(data_dir) setting["output_dir"] = str(Path(data_dir) / "output_dir") with open(yaml_file, "w") as f: @@ -66,6 +64,10 @@ def rewrite_data_path(self, data_dir, yaml_file): def execute(self, model): with tempfile.TemporaryDirectory(f"test_{model}") as data_dir: + data_dir = Path(data_dir) / "csv" + if not data_dir.exists(): + data_dir.mkdir(parents=True) + create_dummy_data(data_dir) yaml_file = str(Path(data_dir) / "test-binary.yaml") From 6c8f2bddda5ef6df52364bf595a91d73daedce8b Mon Sep 17 00:00:00 2001 From: Takuya Makino Date: Fri, 23 Feb 2024 16:36:35 +0900 Subject: [PATCH 2/2] Bump up to 0.2.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 48e11df..598ca70 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "bunruija" -version = "0.1.0" +version = "0.2.0" description = "A text classification toolkit" authors = ["Takuya Makino "] homepage = "https://github.com/tma15"