Skip to content

Commit

Permalink
Extend test_linalg from smoke_test.py
Browse files Browse the repository at this point in the history
To take device as an argument and run tests on both cpu and cuda
  • Loading branch information
malfet committed Dec 9, 2023
1 parent bb9b32c commit 4f298cb
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions test/smoke_test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def smoke_test_conv2d() -> None:
assert output is not None


def smoke_test_linalg() -> None:
print("Testing smoke_test_linalg")
A = torch.randn(5, 3)
def test_linalg(device="cpu") -> None:
print(f"Testing smoke_test_linalg on {device}")
A = torch.randn(5, 3, device=device)
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
assert U.shape == A.shape and S.shape == torch.Size([3]) and Vh.shape == torch.Size([3, 3])
torch.dist(A, U @ torch.diag(S) @ Vh)
Expand All @@ -217,15 +217,15 @@ def smoke_test_linalg() -> None:
assert U.shape == torch.Size([5, 5]) and S.shape == torch.Size([3]) and Vh.shape == torch.Size([3, 3])
torch.dist(A, U[:, :3] @ torch.diag(S) @ Vh)

A = torch.randn(7, 5, 3)
A = torch.randn(7, 5, 3, device=device)
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
torch.dist(A, U @ torch.diag_embed(S) @ Vh)

if is_cuda_system:
if device == "cuda":
supported_dtypes = [torch.float32, torch.float64]
for dtype in supported_dtypes:
print(f"Testing smoke_test_linalg with cuda for {dtype}")
A = torch.randn(20, 16, 50, 100, device="cuda").type(dtype)
A = torch.randn(20, 16, 50, 100, device=device, dtype=dtype)
torch.linalg.svd(A)


Expand Down Expand Up @@ -293,7 +293,9 @@ def main() -> None:

check_version(options.package)
smoke_test_conv2d()
smoke_test_linalg()
test_linalg()
if is_cuda_system:
test_linalg("cuda")

if options.package == "all":
smoke_test_modules()
Expand Down

0 comments on commit 4f298cb

Please sign in to comment.