Skip to content

Commit

Permalink
Don't convert namedtuple to tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanbreitsch committed Apr 28, 2020
1 parent e40e27f commit 59fa3cc
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,15 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None):

# when tuple
if isinstance(batch, tuple):
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
return tuple(batch)
# when namedtuple
if hasattr(batch, '_fields'):
elem_type = type(batch)
return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch))
else:
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
return tuple(batch)

# when dict
if isinstance(batch, dict):
Expand Down

0 comments on commit 59fa3cc

Please sign in to comment.