Skip to content

Commit

Permalink
Enable GPU tests for mc-ebc, mc-ec, fp-ebc (#1625)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1625

GPUs tests are currently not run.

Reviewed By: PaulZhang12

Differential Revision: D52702754

fbshipit-source-id: 690d781958805d2bdab812584a1e15bdb5853009
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jan 12, 2024
1 parent 1fff2f7 commit 7488452
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 30 deletions.
14 changes: 9 additions & 5 deletions torchrec/distributed/tests/test_fp_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,22 @@ def test_sharding_ebc(
use_fp_collection=use_fp_collection,
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
# pyre-ignore
@given(use_fp_collection=st.booleans())
def test_sharding_fp_ebc_from_meta(self, use_fp_collection: bool) -> None:
@given(use_fp_collection=st.booleans(), backend=st.sampled_from(["nccl", "gloo"]))
def test_sharding_fp_ebc_from_meta(
self, use_fp_collection: bool, backend: str
) -> None:
embedding_bag_config, kjt_input_per_rank = get_configs_and_kjt_inputs()
self._run_multi_process_test(
callable=_test_sharding_from_meta,
world_size=2,
tables=embedding_bag_config,
sharder=FeatureProcessedEmbeddingBagCollectionSharder(),
backend="nccl"
if (torch.cuda.is_available() and torch.cuda.device_count() >= 2)
else "gloo",
backend=backend,
use_fp_collection=use_fp_collection,
)
34 changes: 21 additions & 13 deletions torchrec/distributed/tests/test_mc_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
import torch.nn as nn
from hypothesis import given, settings, strategies as st
from torchrec.distributed.embedding import ShardedEmbeddingCollection
from torchrec.distributed.mc_embedding import (
ManagedCollisionEmbeddingCollectionSharder,
Expand Down Expand Up @@ -256,13 +257,14 @@ def _test_sharding_and_remapping( # noqa C901

@skip_if_asan_class
class ShardedMCEmbeddingCollectionParallelTest(MultiProcessTestBase):

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_uneven_sharding(self) -> None:
# pyre-ignore
@given(backend=st.sampled_from(["nccl"]))
@settings(deadline=20000)
def test_uneven_sharding(self, backend: str) -> None:
WORLD_SIZE = 2

embedding_config = [
Expand All @@ -285,15 +287,17 @@ def test_uneven_sharding(self) -> None:
world_size=WORLD_SIZE,
tables=embedding_config,
sharder=ManagedCollisionEmbeddingCollectionSharder(),
backend="nccl",
backend=backend,
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_even_sharding(self) -> None:
# pyre-ignore
@given(backend=st.sampled_from(["nccl"]))
@settings(deadline=20000)
def test_even_sharding(self, backend: str) -> None:
WORLD_SIZE = 2

embedding_config = [
Expand All @@ -316,15 +320,17 @@ def test_even_sharding(self) -> None:
world_size=WORLD_SIZE,
tables=embedding_config,
sharder=ManagedCollisionEmbeddingCollectionSharder(),
backend="nccl",
backend=backend,
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_sharding_zch_mc_ec(self) -> None:
# pyre-ignore
@given(backend=st.sampled_from(["nccl"]))
@settings(deadline=20000)
def test_sharding_zch_mc_ec(self, backend: str) -> None:

WORLD_SIZE = 2

Expand Down Expand Up @@ -420,15 +426,17 @@ def test_sharding_zch_mc_ec(self) -> None:
kjt_input_per_rank=kjt_input_per_rank,
kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank,
sharder=ManagedCollisionEmbeddingCollectionSharder(),
backend="nccl",
backend=backend,
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_sharding_zch_mch_mc_ec(self) -> None:
# pyre-ignore
@given(backend=st.sampled_from(["nccl"]))
@settings(deadline=20000)
def test_sharding_zch_mch_mc_ec(self, backend: str) -> None:

WORLD_SIZE = 2

Expand Down Expand Up @@ -553,5 +561,5 @@ def test_sharding_zch_mch_mc_ec(self) -> None:
kjt_input_per_rank=kjt_input_per_rank,
kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank,
sharder=ManagedCollisionEmbeddingCollectionSharder(),
backend="nccl",
backend=backend,
)
33 changes: 21 additions & 12 deletions torchrec/distributed/tests/test_mc_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
import torch.nn as nn
from hypothesis import given, settings, strategies as st
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
from torchrec.distributed.mc_embeddingbag import (
ManagedCollisionEmbeddingBagCollectionSharder,
Expand Down Expand Up @@ -264,12 +265,14 @@ def _test_sharding_and_remapping( # noqa C901

@skip_if_asan_class
class ShardedMCEmbeddingBagCollectionParallelTest(MultiProcessTestBase):
# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_uneven_sharding(self) -> None:
# pyre-ignore
@given(backend=st.sampled_from(["nccl"]))
@settings(deadline=20000)
def test_uneven_sharding(self, backend: str) -> None:
WORLD_SIZE = 2

embedding_bag_config = [
Expand All @@ -292,15 +295,17 @@ def test_uneven_sharding(self) -> None:
world_size=WORLD_SIZE,
tables=embedding_bag_config,
sharder=ManagedCollisionEmbeddingBagCollectionSharder(),
backend="nccl",
backend=backend,
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_even_sharding(self) -> None:
# pyre-ignore
@given(backend=st.sampled_from(["nccl"]))
@settings(deadline=20000)
def test_even_sharding(self, backend: str) -> None:
WORLD_SIZE = 2

embedding_bag_config = [
Expand All @@ -323,15 +328,17 @@ def test_even_sharding(self) -> None:
world_size=WORLD_SIZE,
tables=embedding_bag_config,
sharder=ManagedCollisionEmbeddingBagCollectionSharder(),
backend="nccl",
backend=backend,
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_sharding_zch_mc_ebc(self) -> None:
# pyre-ignore
@given(backend=st.sampled_from(["nccl"]))
@settings(deadline=20000)
def test_sharding_zch_mc_ebc(self, backend: str) -> None:

WORLD_SIZE = 2

Expand Down Expand Up @@ -427,15 +434,17 @@ def test_sharding_zch_mc_ebc(self) -> None:
kjt_input_per_rank=kjt_input_per_rank,
kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank,
sharder=ManagedCollisionEmbeddingBagCollectionSharder(),
backend="nccl",
backend=backend,
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_sharding_zch_mch_mc_ebc(self) -> None:
# pyre-ignore
@given(backend=st.sampled_from(["nccl"]))
@settings(deadline=20000)
def test_sharding_zch_mch_mc_ebc(self, backend: str) -> None:

WORLD_SIZE = 2

Expand Down Expand Up @@ -560,5 +569,5 @@ def test_sharding_zch_mch_mc_ebc(self) -> None:
kjt_input_per_rank=kjt_input_per_rank,
kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank,
sharder=ManagedCollisionEmbeddingBagCollectionSharder(),
backend="nccl",
backend=backend,
)

0 comments on commit 7488452

Please sign in to comment.