Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train-test shape mismatch #1022

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion fedot/core/data/data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fedot.core.data.data import InputData
from fedot.core.data.multi_modal import MultiModalData
from fedot.core.repository.dataset_types import DataTypesEnum
from fedot.core.repository.tasks import TaskTypesEnum


def _split_time_series(data: InputData, task, *args, **kwargs):
Expand Down Expand Up @@ -110,13 +111,19 @@ def _split_any(data: InputData, task, data_type, split_ratio, with_shuffle=False
input_features = data.features
input_target = data.target
idx = data.idx
if task.task_type == TaskTypesEnum.classification and with_shuffle:
stratify = input_target
else:
stratify = None

idx_train, idx_test, x_train, x_test, y_train, y_test = \
train_test_split(idx,
input_features,
input_target,
test_size=1. - split_ratio,
shuffle=with_shuffle,
random_state=random_state)
random_state=random_state,
stratify=stratify)

# Prepare data to train the operation
train_data = InputData(idx=idx_train, features=x_train, target=y_train,
Expand Down
59 changes: 58 additions & 1 deletion test/unit/data/test_data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,50 @@ def get_image_classification_data():
return input_data


def get_imbalanced_data_to_test_mismatch():
task = Task(TaskTypesEnum.classification)
x = np.array([[0, 0, 15],
[0, 1, 2],
[8, 12, 0],
[0, 1, 0],
[1, 1, 0],
[0, 11, 9],
[5, 1, 10],
[8, 16, 4],
[3, 1, 5],
[0, 1, 6],
[2, 7, 9],
[0, 1, 2],
[14, 1, 0],
[0, 4, 10]])
y = np.array([0, 0, 0, 0, 2, 0, 0, 1, 2, 1, 0, 0, 3, 3])
input_data = InputData(idx=np.arange(0, len(x)), features=x,
target=y, task=task, data_type=DataTypesEnum.table)
return input_data


def get_balanced_data_to_test_mismatch():
task = Task(TaskTypesEnum.classification)
x = np.array([[0, 0, 15],
[0, 1, 2],
[8, 12, 0],
[0, 1, 0],
[1, 1, 0],
[0, 11, 9],
[5, 1, 10],
[8, 16, 4],
[3, 1, 5],
[0, 1, 6],
[2, 7, 9],
[0, 1, 2],
[14, 1, 0],
[0, 4, 10]])
y = np.array([0, 1, 2, 3, 2, 1, 0, 1, 2, 1, 0, 0, 3, 3])
input_data = InputData(idx=np.arange(0, len(x)), features=x,
target=y, task=task, data_type=DataTypesEnum.table)
return input_data


def test_split_data():
dataframe = pd.DataFrame(data=[[1, 2, 3],
[4, 5, 6],
Expand Down Expand Up @@ -92,6 +136,20 @@ def test_advanced_time_series_splitting():
assert np.allclose(test_data.target, np.array([16, 17, 18, 19]))


@pytest.mark.parametrize('data_splitter, data',
# test StratifiedKFold
[(DataSourceSplitter(cv_folds=3, shuffle=True), get_imbalanced_data_to_test_mismatch()),
# test KFold
(DataSourceSplitter(cv_folds=3, shuffle=True), get_balanced_data_to_test_mismatch()),
# test hold-out
(DataSourceSplitter(shuffle=True), get_imbalanced_data_to_test_mismatch())])
def test_data_splitting_without_shape_mismatch(data_splitter: DataSourceSplitter, data: InputData):
""" Checks if data split correctly into train test subsets: there are no new classes in test subset """
data_source = data_splitter.build(data=data)
for fold_id, (train_data, test_data) in enumerate(data_source()):
assert set(train_data.target) >= set(test_data.target)


def test_data_splitting_perform_correctly_after_build():
"""
Check if data splitting perform correctly through Objective Builder - Objective Evaluate
Expand All @@ -105,7 +163,6 @@ def test_data_splitting_perform_correctly_after_build():

# Imitate evaluation process
for fold_id, (train_data, test_data) in enumerate(data_source()):

expected_output = output_by_fold[fold_id]
assert train_data.features.shape == expected_output['train_features_size']
assert test_data.features.shape == expected_output['test_features_size']
Expand Down