diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 73efaf67c486b..db4e132c0b445 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -461,10 +461,15 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None): # when tuple if isinstance(batch, tuple): - batch = list(batch) - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device, gpu_id) - return tuple(batch) + # when namedtuple + if hasattr(batch, '_fields'): + elem_type = type(batch) + return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch)) + else: + batch = list(batch) + for i, x in enumerate(batch): + batch[i] = self.__transfer_data_to_device(x, device, gpu_id) + return tuple(batch) # when dict if isinstance(batch, dict): diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 612286404041b..eb3b28769e206 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -1,3 +1,4 @@ +from collections import namedtuple import platform import pytest @@ -221,6 +222,13 @@ def test_single_gpu_batch_parse(): assert batch[1][0]['b'].device.index == 0 assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor' + # namedtuple of tensor + BatchType = namedtuple('BatchType', ['a', 'b']) + batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] + batch = trainer.transfer_batch_to_gpu(batch, 0) + assert batch[0].a.device.index == 0 + assert batch[0].a.type() == 'torch.cuda.FloatTensor' + def test_simple_cpu(tmpdir): """Verify continue training session on CPU."""