Skip to content

Commit

Permalink
fix test in OSS env without CUDA device (pytorch#2688)
Browse files Browse the repository at this point in the history
Summary:

# context
* to fix OSS CPU test failure due to lack of CUDA device.

Reviewed By: dstaay-fb

Differential Revision: D68340773
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jan 17, 2025
1 parent 33168a1 commit 97a0fd0
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions torchrec/sparse/tests/test_tensor_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
import unittest

import torch
from hypothesis import given, settings, strategies as st, Verbosity
from tensordict import TensorDict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
from torchrec.sparse.tests.utils import repeat_test


class TestTensorDIct(unittest.TestCase):
@repeat_test(device_str=["cpu", "cuda", "meta"])
# pyre-ignore[56]
@given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
)
def test_kjt_input(self, device_str: str) -> None:
device = torch.device(device_str)
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
Expand All @@ -30,7 +36,13 @@ def test_kjt_input(self, device_str: str) -> None:
features = maybe_td_to_kjt(kjt)
self.assertEqual(features, kjt)

@repeat_test(device_str=["cpu", "cuda", "meta"])
# pyre-ignore[56]
@given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
)
def test_td_kjt(self, device_str: str) -> None:
device = torch.device(device_str)
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
Expand Down

0 comments on commit 97a0fd0

Please sign in to comment.