diff --git a/prompt2model/dataset_processor/base.py b/prompt2model/dataset_processor/base.py index 113d754a2..ef615cad2 100644 --- a/prompt2model/dataset_processor/base.py +++ b/prompt2model/dataset_processor/base.py @@ -99,7 +99,7 @@ def filter_empty_strings(example: dict) -> bool: modified_dataset_dict[dataset_split] = ( dataset_dict[dataset_split] .filter(filter_empty_strings) - .map(mapping_function) + .map(mapping_function, remove_columns=["input_col", "output_col"]) ) modified_dataset_dict = datasets.DatasetDict(modified_dataset_dict) modified_dataset_dicts.append(modified_dataset_dict) diff --git a/tests/dataset_processor_test.py b/tests/dataset_processor_test.py index 369a8db94..82aa0d587 100644 --- a/tests/dataset_processor_test.py +++ b/tests/dataset_processor_test.py @@ -138,8 +138,6 @@ def test_dataset_processor_t5_style(): "convert to text2text\nExample:\nfoo\nLabel:\n", "convert to text2text\nExample:\nbar\nLabel:\n", ], - "input_col": ["foo", "bar"], - "output_col": ["baz", "qux"], "model_output": ["baz", "qux"], } ), @@ -149,8 +147,6 @@ def test_dataset_processor_t5_style(): "convert to text2text\nExample:\nfoo\nLabel:\n", "convert to text2text\nExample:\nbar\nLabel:\n", ], - "input_col": ["foo", "bar"], - "output_col": ["baz", "qux"], "model_output": ["baz", "qux"], } ), @@ -164,8 +160,6 @@ def test_dataset_processor_t5_style(): "convert to text2text\nExample:\nspam\nLabel:\n", "convert to text2text\nExample:\neggs\nLabel:\n", ], - "input_col": ["spam", "eggs"], - "output_col": ["ham", "sau"], "model_output": ["ham", "sau"], } ), @@ -175,8 +169,6 @@ def test_dataset_processor_t5_style(): "convert to text2text\nExample:\nspam\nLabel:\n", "convert to text2text\nExample:\neggs\nLabel:\n", ], - "input_col": ["spam", "eggs"], - "output_col": ["ham", "sau"], "model_output": ["ham", "sau"], } ), @@ -188,6 +180,88 @@ def test_dataset_processor_t5_style(): gc.collect() +def test_dataset_processor_with_numerical_column(): + """Test process_dataset_dict with numerical column values.""" + t5_processor = TextualizeProcessor(has_encoder=True) + raw_dataset_dicts = [ + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "input_col": ["foo", "bar"], + "output_col": ["baz", "qux"], + } + ), + "test": datasets.Dataset.from_dict( + { + "input_col": ["spam", "eggs"], + "output_col": ["ham", "sau"], + } + ), + } + ), + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "input_col": ["foo", "bar"], + "output_col": [0, 1], + } + ), + "test": datasets.Dataset.from_dict( + { + "input_col": ["spam", "eggs"], + "output_col": [1, 2], + } + ), + } + ), + ] + t5_modified_dataset_dicts = t5_processor.process_dataset_dict( + INSTRUCTION, raw_dataset_dicts + ) + expected_dataset_dict = datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nfoo\nLabel:\n", + "convert to text2text\nExample:\nbar\nLabel:\n", + "convert to text2text\nExample:\nfoo\nLabel:\n", + "convert to text2text\nExample:\nbar\nLabel:\n", + ], + "model_output": ["foo", "bar", "0", "1"], + } + ), + "test": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nspam\nLabel:\n", + "convert to text2text\nExample:\neggs\nLabel:\n", + "convert to text2text\nExample:\nspam\nLabel:\n", + "convert to text2text\nExample:\neggs\nLabel:\n", + ], + "model_output": ["ham", "sau", "1", "2"], + } + ), + } + ) + training_datasets = [] + test_datasets = [] + for modified_dataset_dict in t5_modified_dataset_dicts: + training_datasets.append(modified_dataset_dict["train"]) + test_datasets.append(modified_dataset_dict["test"]) + + concatenated_training_dataset = datasets.concatenate_datasets(training_datasets) + concatenated_test_dataset = datasets.concatenate_datasets(test_datasets) + actual_dataset_dict = datasets.DatasetDict( + {"train": concatenated_training_dataset, "test": concatenated_test_dataset} + ) + are_dataset_dicts_identical(expected_dataset_dict, actual_dataset_dict) + + gc.collect() + + def test_dataset_processor_decoder_only_style(): """Test the `process_dataset_dict` function of a GPT-type `TextualizeProcessor`.""" _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() @@ -213,8 +287,6 @@ def test_dataset_processor_decoder_only_style(): "convert to text2text\nExample:\nfoo\nLabel:\nbaz<|endoftext|>", # noqa: E501 "convert to text2text\nExample:\nbar\nLabel:\nqux<|endoftext|>", # noqa: E501 ], - "input_col": ["foo", "bar"], - "output_col": ["baz", "qux"], "model_output": ["baz<|endoftext|>", "qux<|endoftext|>"], } ), @@ -224,8 +296,6 @@ def test_dataset_processor_decoder_only_style(): "convert to text2text\nExample:\nfoo\nLabel:\n", "convert to text2text\nExample:\nbar\nLabel:\n", ], - "input_col": ["foo", "bar"], - "output_col": ["baz", "qux"], "model_output": ["baz", "qux"], } ), @@ -239,8 +309,6 @@ def test_dataset_processor_decoder_only_style(): "convert to text2text\nExample:\nspam\nLabel:\nham<|endoftext|>", # noqa: E501 "convert to text2text\nExample:\neggs\nLabel:\nsau<|endoftext|>", # noqa: E501 ], - "input_col": ["spam", "eggs"], - "output_col": ["ham", "sau"], "model_output": ["ham<|endoftext|>", "sau<|endoftext|>"], } ), @@ -250,8 +318,6 @@ def test_dataset_processor_decoder_only_style(): "convert to text2text\nExample:\nspam\nLabel:\n", "convert to text2text\nExample:\neggs\nLabel:\n", ], - "input_col": ["spam", "eggs"], - "output_col": ["ham", "sau"], "model_output": ["ham", "sau"], } ), @@ -341,8 +407,6 @@ def test_empty_filter_t5_type(): "model_input": [ "convert to text2text\nExample:\ntest\nLabel:\n", ], - "input_col": ["test"], - "output_col": ["key"], "model_output": ["key"], } ), @@ -351,12 +415,6 @@ def test_empty_filter_t5_type(): "model_input": [ "convert to text2text\nExample:\nfoo\nLabel:\n", ], - "input_col": [ - "foo", - ], - "output_col": [ - "baz", - ], "model_output": [ "baz", ], @@ -369,8 +427,6 @@ def test_empty_filter_t5_type(): "train": datasets.Dataset.from_dict( { "model_input": [], - "input_col": [], - "output_col": [], "model_output": [], } ), @@ -403,8 +459,6 @@ def test_empty_filter_decoder_only_style(): "model_input": [ "convert to text2text\nExample:\ntest\nLabel:\nkey<|endoftext|>", # noqa: E501 ], - "input_col": ["test"], - "output_col": ["key"], "model_output": ["key<|endoftext|>"], } ), @@ -413,8 +467,6 @@ def test_empty_filter_decoder_only_style(): "model_input": [ "convert to text2text\nExample:\nfoo\nLabel:\n", ], - "input_col": ["foo"], - "output_col": ["baz"], "model_output": ["baz"], } ), @@ -425,8 +477,6 @@ def test_empty_filter_decoder_only_style(): "train": datasets.Dataset.from_dict( { "model_input": [], - "input_col": [], - "output_col": [], "model_output": [], } ),