Skip to content

Commit

Permalink
Don't convert namedtuple to tuple (#1589)
Browse files Browse the repository at this point in the history
* Don't convert namedtuple to tuple

* Test namedtuples sent to device correctly
  • Loading branch information
nathanbreitsch authored Apr 30, 2020
1 parent d40425d commit 3eac6cf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
13 changes: 9 additions & 4 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,15 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None):

# when tuple
if isinstance(batch, tuple):
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
return tuple(batch)
# when namedtuple
if hasattr(batch, '_fields'):
elem_type = type(batch)
return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch))
else:
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
return tuple(batch)

# when dict
if isinstance(batch, dict):
Expand Down
8 changes: 8 additions & 0 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import namedtuple
import platform

import pytest
Expand Down Expand Up @@ -221,6 +222,13 @@ def test_single_gpu_batch_parse():
assert batch[1][0]['b'].device.index == 0
assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor'

# namedtuple of tensor
BatchType = namedtuple('BatchType', ['a', 'b'])
batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)]
batch = trainer.transfer_batch_to_gpu(batch, 0)
assert batch[0].a.device.index == 0
assert batch[0].a.type() == 'torch.cuda.FloatTensor'


def test_simple_cpu(tmpdir):
"""Verify continue training session on CPU."""
Expand Down

0 comments on commit 3eac6cf

Please sign in to comment.