Skip to content

Commit

Permalink
new method to load data ✨
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanbrison committed Dec 4, 2024
1 parent 76366c3 commit ebfcded
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions carte_ai/data/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,79 @@ def set_split(data, config_data, num_train, random_state=42):
return X_train, X_test, y_train, y_test


def set_split_hf(data, target_name, entity_name, num_train, random_state=42):
"""
Helper function to split a dataset into train and test sets using grouped splitting.
Parameters
----------
data : pandas.DataFrame
The input dataset containing features and the target variable.
target_name : str
The name of the target column in `data`.
entity_name : str
The column name in `data` that defines the grouping variable.
Samples with the same group value will not be split across train and test sets.
num_train : int
The desired number of training groups.
random_state : int, optional
Random seed for reproducibility of the split. Default is 42.
Returns
-------
X_train : pandas.DataFrame
The training set features.
X_test : pandas.DataFrame
The test set features.
y_train : numpy.ndarray
The target values for the training set.
y_test : numpy.ndarray
The target values for the test set.
Examples
--------
>>> import pandas as pd
>>> from sklearn.model_selection import GroupShuffleSplit
>>> import numpy as np
>>> data = pd.DataFrame({
... 'feature1': [1, 2, 3, 4, 5],
... 'feature2': [10, 20, 30, 40, 50],
... 'target': [0, 1, 0, 1, 0],
... 'group': [1, 1, 2, 2, 3]
... })
>>> X_train, X_test, y_train, y_test = set_split_hf(data, 'target', 'group', num_train=2)
>>> X_train
feature1 feature2
0 1 10
1 2 20
>>> y_train
array([0, 1])
Notes
-----
This function uses `GroupShuffleSplit` from scikit-learn to ensure that groups
specified by `entity_name` are preserved either in the train or test split but not both.
"""
# Extract target and features
X = data.drop(columns=target_name)
y = data[target_name].to_numpy()

groups = np.array(data.groupby(entity_name).ngroup())
num_groups = len(np.unique(groups))

gss = GroupShuffleSplit(
n_splits=1, test_size=int(num_groups - num_train), random_state=random_state
)

idx_train, idx_test = next(iter(gss.split(X, y, groups=groups)))
X_train, X_test = X.iloc[idx_train], X.iloc[idx_test]
y_train, y_test = y[idx_train], y[idx_test]

return X_train, X_test, y_train, y_test




# Define individual methods for each dataset


Expand Down

0 comments on commit ebfcded

Please sign in to comment.