diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index 3b5b18c35..f4c06150e 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -206,9 +206,9 @@ def smoke_test_conv2d() -> None: assert output is not None -def smoke_test_linalg() -> None: - print("Testing smoke_test_linalg") - A = torch.randn(5, 3) +def test_linalg(device="cpu") -> None: + print(f"Testing smoke_test_linalg on {device}") + A = torch.randn(5, 3, device=device) U, S, Vh = torch.linalg.svd(A, full_matrices=False) assert U.shape == A.shape and S.shape == torch.Size([3]) and Vh.shape == torch.Size([3, 3]) torch.dist(A, U @ torch.diag(S) @ Vh) @@ -217,15 +217,15 @@ def smoke_test_linalg() -> None: assert U.shape == torch.Size([5, 5]) and S.shape == torch.Size([3]) and Vh.shape == torch.Size([3, 3]) torch.dist(A, U[:, :3] @ torch.diag(S) @ Vh) - A = torch.randn(7, 5, 3) + A = torch.randn(7, 5, 3, device=device) U, S, Vh = torch.linalg.svd(A, full_matrices=False) torch.dist(A, U @ torch.diag_embed(S) @ Vh) - if is_cuda_system: + if device == "cuda": supported_dtypes = [torch.float32, torch.float64] for dtype in supported_dtypes: print(f"Testing smoke_test_linalg with cuda for {dtype}") - A = torch.randn(20, 16, 50, 100, device="cuda").type(dtype) + A = torch.randn(20, 16, 50, 100, device=device, dtype=dtype) torch.linalg.svd(A) @@ -293,7 +293,9 @@ def main() -> None: check_version(options.package) smoke_test_conv2d() - smoke_test_linalg() + test_linalg() + if is_cuda_system: + test_linalg("cuda") if options.package == "all": smoke_test_modules()