From 7ad55ab0f4236fee3e7d12f1cafb1d03ab3bf703 Mon Sep 17 00:00:00 2001 From: nathanbreitsch Date: Mon, 27 Apr 2020 21:40:05 -0400 Subject: [PATCH] Test namedtuples sent to device correctly --- tests/models/test_cpu.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 612286404041b6..eb3b28769e2063 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -1,3 +1,4 @@ +from collections import namedtuple import platform import pytest @@ -221,6 +222,13 @@ def test_single_gpu_batch_parse(): assert batch[1][0]['b'].device.index == 0 assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor' + # namedtuple of tensor + BatchType = namedtuple('BatchType', ['a', 'b']) + batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] + batch = trainer.transfer_batch_to_gpu(batch, 0) + assert batch[0].a.device.index == 0 + assert batch[0].a.type() == 'torch.cuda.FloatTensor' + def test_simple_cpu(tmpdir): """Verify continue training session on CPU."""