diff --git a/composer/core/data_spec.py b/composer/core/data_spec.py index e1971d0c7e..feb38a4687 100644 --- a/composer/core/data_spec.py +++ b/composer/core/data_spec.py @@ -244,7 +244,16 @@ def _default_get_num_samples_in_batch(self, batch: Batch) -> int: '`get_num_samples_in_batch(your_batch) -> int` method.') dim0_sizes.append(t.shape[0]) elif isinstance(batch, dict): - dim0_sizes = [t.shape[0] for t in batch.values()] + for t in batch.values(): + if isinstance(t, torch.Tensor): + dim0_sizes.append(t.shape[0]) + elif isinstance(t, list): + dim0_sizes.append(len(t)) + else: + raise ValueError('Unable to determine the batch size as batch is a dict ' + f'with an element of type {type(t)} which is not Tensor ' + 'or list. Please use a DataSpec and provide a ' + '`get_num_samples_in_batch(your_batch) -> int` method.') if len(set(dim0_sizes)) == 1: return dim0_sizes[0] diff --git a/tests/trainer/test_dataspec.py b/tests/trainer/test_dataspec.py index 94afca1972..068eb6ad3a 100644 --- a/tests/trainer/test_dataspec.py +++ b/tests/trainer/test_dataspec.py @@ -86,3 +86,29 @@ def test_small_batch_at_end_warning(): with pytest.warns(UserWarning, match='Cannot split tensor of length.*'): trainer.fit() + + +@pytest.mark.parametrize( + 'batch,num_samples', + [ + [{ + 'a': torch.rand(N, 8), + 'b': torch.rand(N, 64) + }, N], # dict + [[{ + 'a': torch.rand(N, 8) + }, { + 'c': torch.rand(N, 64) + }], N], # list of dict + [{ + 'a': [1, 2], + 'b': [3, 4] + }, 2], # dict of lists + [(torch.rand(N, 8), torch.rand(N, 64)), N], # tuple + [[torch.rand(N, 8), torch.rand(N, 64)], N], # list + [torch.rand(N, 8), N], # tensor + [torch.rand(N, 8, 4, 2), N], # 4-dim tensor + ]) +def test_num_samples_in_batch(batch, num_samples): + data_spec = DataSpec(dataloader=DataLoader(RandomClassificationDataset(size=17), batch_size=4)) + assert data_spec.get_num_samples_in_batch(batch) == num_samples