Skip to content

Commit

Permalink
Improve error handling in enumerator (#1619)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1619

Improve warning messaging handling for enumerator.

Reviewed By: joshuadeng

Differential Revision: D52545226

fbshipit-source-id: 850bb58688bc85e476df83169d5aacba71d57373
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jan 11, 2024
1 parent b581c59 commit 1fff2f7
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 15 deletions.
47 changes: 33 additions & 14 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,38 +186,51 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
for estimator in self._estimators:
estimator.estimate(sharding_options, self._sharder_map)

def _filter_sharding_types(self, name: str, sharding_types: List[str]) -> List[str]:
def _filter_sharding_types(
self, name: str, allowed_sharding_types: List[str]
) -> List[str]:
if not self._constraints or not self._constraints.get(name):
return sharding_types
return allowed_sharding_types
constraints: ParameterConstraints = self._constraints[name]
if not constraints.sharding_types:
return sharding_types
return allowed_sharding_types
constrained_sharding_types: List[str] = constraints.sharding_types

sharding_types = list(set(constrained_sharding_types) & set(sharding_types))
filtered_sharding_types = list(
set(constrained_sharding_types) & set(allowed_sharding_types)
)

if not sharding_types:
if not filtered_sharding_types:
logger.warn(
f"No available sharding types after applying user provided constraints for {name}"
"No available sharding types after applying user provided "
f"constraints for {name}. Constrained sharding types: "
f"{constrained_sharding_types}, allowed sharding types: "
f"{allowed_sharding_types}, filtered sharding types: "
f"{filtered_sharding_types}. Please check if the constrained "
"sharding types are too restrictive, if the sharder allows the "
"sharding types, or if non-strings are passed in."
)
return sharding_types
return filtered_sharding_types

def _filter_compute_kernels(
self,
name: str,
compute_kernels: List[str],
allowed_compute_kernels: List[str],
) -> List[str]:

# for the log message only
constrained_compute_kernels: List[str] = [
compute_kernel.value for compute_kernel in EmbeddingComputeKernel
]
if not self._constraints or not self._constraints.get(name):
filtered_compute_kernels = compute_kernels
filtered_compute_kernels = allowed_compute_kernels
else:
constraints: ParameterConstraints = self._constraints[name]
if not constraints.compute_kernels:
filtered_compute_kernels = compute_kernels
filtered_compute_kernels = allowed_compute_kernels
else:
constrained_compute_kernels: List[str] = constraints.compute_kernels
constrained_compute_kernels = constraints.compute_kernels
filtered_compute_kernels = list(
set(constrained_compute_kernels) & set(compute_kernels)
set(constrained_compute_kernels) & set(allowed_compute_kernels)
)

if EmbeddingComputeKernel.DENSE.value in filtered_compute_kernels:
Expand All @@ -228,7 +241,13 @@ def _filter_compute_kernels(

if not filtered_compute_kernels:
logger.warn(
f"No available compute kernels after applying user provided constraints for {name}"
"No available compute kernels after applying user provided "
f"constraints for {name}. Constrained compute kernels: "
f"{constrained_compute_kernels}, allowed compute kernels: "
f"{allowed_compute_kernels}, filtered compute kernels: "
f"{filtered_compute_kernels}. Please check if the constrained "
"compute kernels are too restrictive, if the sharder allows the "
"compute kernels, or if non-strings are passed in."
)
return filtered_compute_kernels

Expand Down
157 changes: 156 additions & 1 deletion torchrec/distributed/planner/tests/test_enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import unittest
from typing import cast, List
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import torch
from torchrec.distributed.embedding_tower_sharding import (
Expand All @@ -16,6 +16,9 @@
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.mc_embeddingbag import (
ManagedCollisionEmbeddingBagCollectionSharder,
)
from torchrec.distributed.planner.constants import BIGINT_DTYPE
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
from torchrec.distributed.planner.shard_estimators import (
Expand Down Expand Up @@ -649,6 +652,158 @@ def test_filtering(self) -> None:
self.assertIn(sharding_option.compute_kernel, expected_compute_kernels)
self.assertNotIn(sharding_option.compute_kernel, unexpected_compute_kernels)

def test_filter_sharding_types_ebc(self) -> None:
constraint = ParameterConstraints(
sharding_types=[
ShardingType.TABLE_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
],
)
constraints = {"table_0": constraint}
enumerator = EmbeddingEnumerator(
topology=MagicMock(),
batch_size=MagicMock(),
constraints=constraints,
)

sharder = EmbeddingBagCollectionSharder()
allowed_sharding_types = enumerator._filter_sharding_types(
"table_0", sharder.sharding_types("cuda")
)

self.assertEqual(
set(allowed_sharding_types),
{
ShardingType.TABLE_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
},
)

def test_filter_sharding_types_mch_ebc(self) -> None:
constraint = ParameterConstraints(
sharding_types=[
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
],
)
constraints = {"table_0": constraint}
enumerator = EmbeddingEnumerator(
topology=MagicMock(),
batch_size=MagicMock(),
constraints=constraints,
)

sharder = ManagedCollisionEmbeddingBagCollectionSharder()
allowed_sharding_types = enumerator._filter_sharding_types(
"table_0", sharder.sharding_types("cuda")
)

self.assertEqual(
set(allowed_sharding_types),
{
ShardingType.ROW_WISE.value,
},
)

def test_filter_sharding_types_mch_ebc_no_available(self) -> None:
constraint = ParameterConstraints(
sharding_types=[
ShardingType.TABLE_ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
],
)
constraints = {"table_0": constraint}
enumerator = EmbeddingEnumerator(
topology=MagicMock(),
batch_size=MagicMock(),
constraints=constraints,
)

sharder = ManagedCollisionEmbeddingBagCollectionSharder()
with self.assertWarns(Warning):
allowed_sharding_types = enumerator._filter_sharding_types(
"table_0", sharder.sharding_types("cuda")
)

self.assertEqual(allowed_sharding_types, [])

def test_filter_compute_kernels_ebc(self) -> None:
constraint = ParameterConstraints(
compute_kernels=[
EmbeddingComputeKernel.DENSE.value,
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM.value,
],
)
constraints = {"table_0": constraint}
enumerator = EmbeddingEnumerator(
topology=MagicMock(),
batch_size=MagicMock(),
constraints=constraints,
)

sharder = EmbeddingBagCollectionSharder()
allowed_compute_kernels = enumerator._filter_compute_kernels(
"table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda")
)

self.assertEqual(
set(allowed_compute_kernels),
{
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM.value,
},
)

def test_filter_compute_kernels_mch_ebc(self) -> None:
constraint = ParameterConstraints(
compute_kernels=[
EmbeddingComputeKernel.DENSE.value,
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM.value,
],
)
constraints = {"table_0": constraint}
enumerator = EmbeddingEnumerator(
topology=MagicMock(),
batch_size=MagicMock(),
constraints=constraints,
)

sharder = ManagedCollisionEmbeddingBagCollectionSharder()
allowed_compute_kernels = enumerator._filter_compute_kernels(
"table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda")
)

self.assertEqual(
set(allowed_compute_kernels),
{EmbeddingComputeKernel.FUSED.value},
)

def test_filter_compute_kernels_mch_ebc_no_available(self) -> None:
constraint = ParameterConstraints(
compute_kernels=[
EmbeddingComputeKernel.DENSE.value,
],
)
constraints = {"table_0": constraint}
enumerator = EmbeddingEnumerator(
topology=MagicMock(),
batch_size=MagicMock(),
constraints=constraints,
)

sharder = ManagedCollisionEmbeddingBagCollectionSharder()
with self.assertWarns(Warning):
allowed_compute_kernels = enumerator._filter_compute_kernels(
"table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda")
)

self.assertEqual(allowed_compute_kernels, [])

def test_tower_sharding(self) -> None:
# five tables
# tower_0: tables[2], tables[3]
Expand Down

0 comments on commit 1fff2f7

Please sign in to comment.