Skip to content

Commit

Permalink
Train-test shape mismatch (#1022)
Browse files Browse the repository at this point in the history
* add stratification for holdout

* remove excessive spaces

* add tests
  • Loading branch information
maypink authored Jan 22, 2023
1 parent 6a9c271 commit 03ae732
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
9 changes: 8 additions & 1 deletion fedot/core/data/data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fedot.core.data.data import InputData
from fedot.core.data.multi_modal import MultiModalData
from fedot.core.repository.dataset_types import DataTypesEnum
from fedot.core.repository.tasks import TaskTypesEnum


def _split_time_series(data: InputData, task, *args, **kwargs):
Expand Down Expand Up @@ -110,13 +111,19 @@ def _split_any(data: InputData, task, data_type, split_ratio, with_shuffle=False
input_features = data.features
input_target = data.target
idx = data.idx
if task.task_type == TaskTypesEnum.classification and with_shuffle:
stratify = input_target
else:
stratify = None

idx_train, idx_test, x_train, x_test, y_train, y_test = \
train_test_split(idx,
input_features,
input_target,
test_size=1. - split_ratio,
shuffle=with_shuffle,
random_state=random_state)
random_state=random_state,
stratify=stratify)

# Prepare data to train the operation
train_data = InputData(idx=idx_train, features=x_train, target=y_train,
Expand Down
59 changes: 58 additions & 1 deletion test/unit/data/test_data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,50 @@ def get_image_classification_data():
return input_data


def get_imbalanced_data_to_test_mismatch():
task = Task(TaskTypesEnum.classification)
x = np.array([[0, 0, 15],
[0, 1, 2],
[8, 12, 0],
[0, 1, 0],
[1, 1, 0],
[0, 11, 9],
[5, 1, 10],
[8, 16, 4],
[3, 1, 5],
[0, 1, 6],
[2, 7, 9],
[0, 1, 2],
[14, 1, 0],
[0, 4, 10]])
y = np.array([0, 0, 0, 0, 2, 0, 0, 1, 2, 1, 0, 0, 3, 3])
input_data = InputData(idx=np.arange(0, len(x)), features=x,
target=y, task=task, data_type=DataTypesEnum.table)
return input_data


def get_balanced_data_to_test_mismatch():
task = Task(TaskTypesEnum.classification)
x = np.array([[0, 0, 15],
[0, 1, 2],
[8, 12, 0],
[0, 1, 0],
[1, 1, 0],
[0, 11, 9],
[5, 1, 10],
[8, 16, 4],
[3, 1, 5],
[0, 1, 6],
[2, 7, 9],
[0, 1, 2],
[14, 1, 0],
[0, 4, 10]])
y = np.array([0, 1, 2, 3, 2, 1, 0, 1, 2, 1, 0, 0, 3, 3])
input_data = InputData(idx=np.arange(0, len(x)), features=x,
target=y, task=task, data_type=DataTypesEnum.table)
return input_data


def test_split_data():
dataframe = pd.DataFrame(data=[[1, 2, 3],
[4, 5, 6],
Expand Down Expand Up @@ -92,6 +136,20 @@ def test_advanced_time_series_splitting():
assert np.allclose(test_data.target, np.array([16, 17, 18, 19]))


@pytest.mark.parametrize('data_splitter, data',
# test StratifiedKFold
[(DataSourceSplitter(cv_folds=3, shuffle=True), get_imbalanced_data_to_test_mismatch()),
# test KFold
(DataSourceSplitter(cv_folds=3, shuffle=True), get_balanced_data_to_test_mismatch()),
# test hold-out
(DataSourceSplitter(shuffle=True), get_imbalanced_data_to_test_mismatch())])
def test_data_splitting_without_shape_mismatch(data_splitter: DataSourceSplitter, data: InputData):
""" Checks if data split correctly into train test subsets: there are no new classes in test subset """
data_source = data_splitter.build(data=data)
for fold_id, (train_data, test_data) in enumerate(data_source()):
assert set(train_data.target) >= set(test_data.target)


def test_data_splitting_perform_correctly_after_build():
"""
Check if data splitting perform correctly through Objective Builder - Objective Evaluate
Expand All @@ -105,7 +163,6 @@ def test_data_splitting_perform_correctly_after_build():

# Imitate evaluation process
for fold_id, (train_data, test_data) in enumerate(data_source()):

expected_output = output_by_fold[fold_id]
assert train_data.features.shape == expected_output['train_features_size']
assert test_data.features.shape == expected_output['test_features_size']
Expand Down

0 comments on commit 03ae732

Please sign in to comment.