Skip to content

Commit

Permalink
Remove unnecessary columns in the dataset processor (#279)
Browse files Browse the repository at this point in the history
* Remove unnecessary columns in the dataset processor

* Modify dataset.map to achieve this

* Add numerical column test

* Add test case to test concatenating columns of different types

* Improve docstring and test name

---------

Co-authored-by: Eren Chenyang Zhao <[email protected]>
  • Loading branch information
viswavi and Eren Chenyang Zhao authored Aug 24, 2023
1 parent 6006d1e commit ce4626f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 33 deletions.
2 changes: 1 addition & 1 deletion prompt2model/dataset_processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def filter_empty_strings(example: dict) -> bool:
modified_dataset_dict[dataset_split] = (
dataset_dict[dataset_split]
.filter(filter_empty_strings)
.map(mapping_function)
.map(mapping_function, remove_columns=["input_col", "output_col"])
)
modified_dataset_dict = datasets.DatasetDict(modified_dataset_dict)
modified_dataset_dicts.append(modified_dataset_dict)
Expand Down
114 changes: 82 additions & 32 deletions tests/dataset_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ def test_dataset_processor_t5_style():
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
"<task 0>convert to text2text\nExample:\nbar\nLabel:\n",
],
"input_col": ["foo", "bar"],
"output_col": ["baz", "qux"],
"model_output": ["baz", "qux"],
}
),
Expand All @@ -149,8 +147,6 @@ def test_dataset_processor_t5_style():
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
"<task 0>convert to text2text\nExample:\nbar\nLabel:\n",
],
"input_col": ["foo", "bar"],
"output_col": ["baz", "qux"],
"model_output": ["baz", "qux"],
}
),
Expand All @@ -164,8 +160,6 @@ def test_dataset_processor_t5_style():
"<task 1>convert to text2text\nExample:\nspam\nLabel:\n",
"<task 1>convert to text2text\nExample:\neggs\nLabel:\n",
],
"input_col": ["spam", "eggs"],
"output_col": ["ham", "sau"],
"model_output": ["ham", "sau"],
}
),
Expand All @@ -175,8 +169,6 @@ def test_dataset_processor_t5_style():
"<task 1>convert to text2text\nExample:\nspam\nLabel:\n",
"<task 1>convert to text2text\nExample:\neggs\nLabel:\n",
],
"input_col": ["spam", "eggs"],
"output_col": ["ham", "sau"],
"model_output": ["ham", "sau"],
}
),
Expand All @@ -188,6 +180,88 @@ def test_dataset_processor_t5_style():
gc.collect()


def test_dataset_processor_with_numerical_column():
"""Test process_dataset_dict with numerical column values."""
t5_processor = TextualizeProcessor(has_encoder=True)
raw_dataset_dicts = [
datasets.DatasetDict(
{
"train": datasets.Dataset.from_dict(
{
"input_col": ["foo", "bar"],
"output_col": ["baz", "qux"],
}
),
"test": datasets.Dataset.from_dict(
{
"input_col": ["spam", "eggs"],
"output_col": ["ham", "sau"],
}
),
}
),
datasets.DatasetDict(
{
"train": datasets.Dataset.from_dict(
{
"input_col": ["foo", "bar"],
"output_col": [0, 1],
}
),
"test": datasets.Dataset.from_dict(
{
"input_col": ["spam", "eggs"],
"output_col": [1, 2],
}
),
}
),
]
t5_modified_dataset_dicts = t5_processor.process_dataset_dict(
INSTRUCTION, raw_dataset_dicts
)
expected_dataset_dict = datasets.DatasetDict(
{
"train": datasets.Dataset.from_dict(
{
"model_input": [
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
"<task 0>convert to text2text\nExample:\nbar\nLabel:\n",
"<task 1>convert to text2text\nExample:\nfoo\nLabel:\n",
"<task 1>convert to text2text\nExample:\nbar\nLabel:\n",
],
"model_output": ["foo", "bar", "0", "1"],
}
),
"test": datasets.Dataset.from_dict(
{
"model_input": [
"<task 0>convert to text2text\nExample:\nspam\nLabel:\n",
"<task 0>convert to text2text\nExample:\neggs\nLabel:\n",
"<task 1>convert to text2text\nExample:\nspam\nLabel:\n",
"<task 1>convert to text2text\nExample:\neggs\nLabel:\n",
],
"model_output": ["ham", "sau", "1", "2"],
}
),
}
)
training_datasets = []
test_datasets = []
for modified_dataset_dict in t5_modified_dataset_dicts:
training_datasets.append(modified_dataset_dict["train"])
test_datasets.append(modified_dataset_dict["test"])

concatenated_training_dataset = datasets.concatenate_datasets(training_datasets)
concatenated_test_dataset = datasets.concatenate_datasets(test_datasets)
actual_dataset_dict = datasets.DatasetDict(
{"train": concatenated_training_dataset, "test": concatenated_test_dataset}
)
are_dataset_dicts_identical(expected_dataset_dict, actual_dataset_dict)

gc.collect()


def test_dataset_processor_decoder_only_style():
"""Test the `process_dataset_dict` function of a GPT-type `TextualizeProcessor`."""
_, gpt2_tokenizer = create_gpt2_model_and_tokenizer()
Expand All @@ -213,8 +287,6 @@ def test_dataset_processor_decoder_only_style():
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\nbaz<|endoftext|>", # noqa: E501
"<task 0>convert to text2text\nExample:\nbar\nLabel:\nqux<|endoftext|>", # noqa: E501
],
"input_col": ["foo", "bar"],
"output_col": ["baz", "qux"],
"model_output": ["baz<|endoftext|>", "qux<|endoftext|>"],
}
),
Expand All @@ -224,8 +296,6 @@ def test_dataset_processor_decoder_only_style():
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
"<task 0>convert to text2text\nExample:\nbar\nLabel:\n",
],
"input_col": ["foo", "bar"],
"output_col": ["baz", "qux"],
"model_output": ["baz", "qux"],
}
),
Expand All @@ -239,8 +309,6 @@ def test_dataset_processor_decoder_only_style():
"<task 1>convert to text2text\nExample:\nspam\nLabel:\nham<|endoftext|>", # noqa: E501
"<task 1>convert to text2text\nExample:\neggs\nLabel:\nsau<|endoftext|>", # noqa: E501
],
"input_col": ["spam", "eggs"],
"output_col": ["ham", "sau"],
"model_output": ["ham<|endoftext|>", "sau<|endoftext|>"],
}
),
Expand All @@ -250,8 +318,6 @@ def test_dataset_processor_decoder_only_style():
"<task 1>convert to text2text\nExample:\nspam\nLabel:\n",
"<task 1>convert to text2text\nExample:\neggs\nLabel:\n",
],
"input_col": ["spam", "eggs"],
"output_col": ["ham", "sau"],
"model_output": ["ham", "sau"],
}
),
Expand Down Expand Up @@ -341,8 +407,6 @@ def test_empty_filter_t5_type():
"model_input": [
"<task 0>convert to text2text\nExample:\ntest\nLabel:\n",
],
"input_col": ["test"],
"output_col": ["key"],
"model_output": ["key"],
}
),
Expand All @@ -351,12 +415,6 @@ def test_empty_filter_t5_type():
"model_input": [
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
],
"input_col": [
"foo",
],
"output_col": [
"baz",
],
"model_output": [
"baz",
],
Expand All @@ -369,8 +427,6 @@ def test_empty_filter_t5_type():
"train": datasets.Dataset.from_dict(
{
"model_input": [],
"input_col": [],
"output_col": [],
"model_output": [],
}
),
Expand Down Expand Up @@ -403,8 +459,6 @@ def test_empty_filter_decoder_only_style():
"model_input": [
"<task 0>convert to text2text\nExample:\ntest\nLabel:\nkey<|endoftext|>", # noqa: E501
],
"input_col": ["test"],
"output_col": ["key"],
"model_output": ["key<|endoftext|>"],
}
),
Expand All @@ -413,8 +467,6 @@ def test_empty_filter_decoder_only_style():
"model_input": [
"<task 0>convert to text2text\nExample:\nfoo\nLabel:\n",
],
"input_col": ["foo"],
"output_col": ["baz"],
"model_output": ["baz"],
}
),
Expand All @@ -425,8 +477,6 @@ def test_empty_filter_decoder_only_style():
"train": datasets.Dataset.from_dict(
{
"model_input": [],
"input_col": [],
"output_col": [],
"model_output": [],
}
),
Expand Down

0 comments on commit ce4626f

Please sign in to comment.