diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 612286404041b6..eb3b28769e2063 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -1,3 +1,4 @@ +from collections import namedtuple import platform import pytest @@ -221,6 +222,13 @@ def test_single_gpu_batch_parse(): assert batch[1][0]['b'].device.index == 0 assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor' + # namedtuple of tensor + BatchType = namedtuple('BatchType', ['a', 'b']) + batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] + batch = trainer.transfer_batch_to_gpu(batch, 0) + assert batch[0].a.device.index == 0 + assert batch[0].a.type() == 'torch.cuda.FloatTensor' + def test_simple_cpu(tmpdir): """Verify continue training session on CPU."""