Skip to content

Commit

Permalink
Fix a bug when batch type is dict and one of the values is the list (#…
Browse files Browse the repository at this point in the history
…2599)

* fix a bug when list in dict

* lint

* test

* lint

* adjsut error

---------

Co-authored-by: Michael Hu <[email protected]>
  • Loading branch information
mvpatel2000 and mhh0318 authored Oct 3, 2023
1 parent 0ee109d commit 00939a0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
11 changes: 10 additions & 1 deletion composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,16 @@ def _default_get_num_samples_in_batch(self, batch: Batch) -> int:
'`get_num_samples_in_batch(your_batch) -> int` method.')
dim0_sizes.append(t.shape[0])
elif isinstance(batch, dict):
dim0_sizes = [t.shape[0] for t in batch.values()]
for t in batch.values():
if isinstance(t, torch.Tensor):
dim0_sizes.append(t.shape[0])
elif isinstance(t, list):
dim0_sizes.append(len(t))
else:
raise ValueError('Unable to determine the batch size as batch is a dict '
f'with an element of type {type(t)} which is not Tensor '
'or list. Please use a DataSpec and provide a '
'`get_num_samples_in_batch(your_batch) -> int` method.')

if len(set(dim0_sizes)) == 1:
return dim0_sizes[0]
Expand Down
26 changes: 26 additions & 0 deletions tests/trainer/test_dataspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,29 @@ def test_small_batch_at_end_warning():

with pytest.warns(UserWarning, match='Cannot split tensor of length.*'):
trainer.fit()


@pytest.mark.parametrize(
'batch,num_samples',
[
[{
'a': torch.rand(N, 8),
'b': torch.rand(N, 64)
}, N], # dict
[[{
'a': torch.rand(N, 8)
}, {
'c': torch.rand(N, 64)
}], N], # list of dict
[{
'a': [1, 2],
'b': [3, 4]
}, 2], # dict of lists
[(torch.rand(N, 8), torch.rand(N, 64)), N], # tuple
[[torch.rand(N, 8), torch.rand(N, 64)], N], # list
[torch.rand(N, 8), N], # tensor
[torch.rand(N, 8, 4, 2), N], # 4-dim tensor
])
def test_num_samples_in_batch(batch, num_samples):
data_spec = DataSpec(dataloader=DataLoader(RandomClassificationDataset(size=17), batch_size=4))
assert data_spec.get_num_samples_in_batch(batch) == num_samples

0 comments on commit 00939a0

Please sign in to comment.