Skip to content

Commit

Permalink
sequential_split does not support integer lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
andrrizzi committed Apr 11, 2023
1 parent 365f08a commit 38b00d7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions mlcvs/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def __init__(
The dataset or a list of datasets. If a list, the datasets can have
different keys but they must all have the same number of samples.
lengths : list-like, optional
Lengths of the training/validation/test datasets. This can be a list
of integers or of (float) fractions. The default is ``[0.8,0.2]``.
Lengths of the training, validation, and (optionally) test datasets.
This must be a list of (float) fractions summing to 1. The default is
``[0.8,0.2]``.
batch_size : int or list-like, optional
Batch size, by default 0 (== ``len(dataset)``).
random_split: bool, optional
Expand Down Expand Up @@ -250,6 +251,8 @@ def sequential_split(dataset, lengths: Sequence) -> list:

# LB change: do sequential rather then random splitting
return [Subset(dataset, np.arange(offset-length,offset)) for offset, length in zip(_accumulate(lengths), lengths)]
else:
raise NotImplementedError('The lengths must sum to 1.')


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion mlcvs/tests/test_utils_data_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# TESTS
# =============================================================================

@pytest.mark.parametrize('lengths', [[40, 10], [0.8, 0.2], [0.7, 0.2, 0.1]])
@pytest.mark.parametrize('lengths', [[0.8, 0.2], [0.7, 0.2, 0.1]])
@pytest.mark.parametrize('fields', [[], ['labels', 'weights']])
@pytest.mark.parametrize('random_split', [True, False])
def test_dictionary_data_module_split(lengths, fields, random_split):
Expand Down

0 comments on commit 38b00d7

Please sign in to comment.