diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 15834a6c1..6f6c2f61e 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -186,38 +186,51 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None: for estimator in self._estimators: estimator.estimate(sharding_options, self._sharder_map) - def _filter_sharding_types(self, name: str, sharding_types: List[str]) -> List[str]: + def _filter_sharding_types( + self, name: str, allowed_sharding_types: List[str] + ) -> List[str]: if not self._constraints or not self._constraints.get(name): - return sharding_types + return allowed_sharding_types constraints: ParameterConstraints = self._constraints[name] if not constraints.sharding_types: - return sharding_types + return allowed_sharding_types constrained_sharding_types: List[str] = constraints.sharding_types - sharding_types = list(set(constrained_sharding_types) & set(sharding_types)) + filtered_sharding_types = list( + set(constrained_sharding_types) & set(allowed_sharding_types) + ) - if not sharding_types: + if not filtered_sharding_types: logger.warn( - f"No available sharding types after applying user provided constraints for {name}" + "No available sharding types after applying user provided " + f"constraints for {name}. Constrained sharding types: " + f"{constrained_sharding_types}, allowed sharding types: " + f"{allowed_sharding_types}, filtered sharding types: " + f"{filtered_sharding_types}. Please check if the constrained " + "sharding types are too restrictive, if the sharder allows the " + "sharding types, or if non-strings are passed in." ) - return sharding_types + return filtered_sharding_types def _filter_compute_kernels( self, name: str, - compute_kernels: List[str], + allowed_compute_kernels: List[str], ) -> List[str]: - + # for the log message only + constrained_compute_kernels: List[str] = [ + compute_kernel.value for compute_kernel in EmbeddingComputeKernel + ] if not self._constraints or not self._constraints.get(name): - filtered_compute_kernels = compute_kernels + filtered_compute_kernels = allowed_compute_kernels else: constraints: ParameterConstraints = self._constraints[name] if not constraints.compute_kernels: - filtered_compute_kernels = compute_kernels + filtered_compute_kernels = allowed_compute_kernels else: - constrained_compute_kernels: List[str] = constraints.compute_kernels + constrained_compute_kernels = constraints.compute_kernels filtered_compute_kernels = list( - set(constrained_compute_kernels) & set(compute_kernels) + set(constrained_compute_kernels) & set(allowed_compute_kernels) ) if EmbeddingComputeKernel.DENSE.value in filtered_compute_kernels: @@ -228,7 +241,13 @@ def _filter_compute_kernels( if not filtered_compute_kernels: logger.warn( - f"No available compute kernels after applying user provided constraints for {name}" + "No available compute kernels after applying user provided " + f"constraints for {name}. Constrained compute kernels: " + f"{constrained_compute_kernels}, allowed compute kernels: " + f"{allowed_compute_kernels}, filtered compute kernels: " + f"{filtered_compute_kernels}. Please check if the constrained " + "compute kernels are too restrictive, if the sharder allows the " + "compute kernels, or if non-strings are passed in." ) return filtered_compute_kernels diff --git a/torchrec/distributed/planner/tests/test_enumerators.py b/torchrec/distributed/planner/tests/test_enumerators.py index 8bd463670..9291aed3e 100644 --- a/torchrec/distributed/planner/tests/test_enumerators.py +++ b/torchrec/distributed/planner/tests/test_enumerators.py @@ -7,7 +7,7 @@ import unittest from typing import cast, List -from unittest.mock import patch +from unittest.mock import MagicMock, patch import torch from torchrec.distributed.embedding_tower_sharding import ( @@ -16,6 +16,9 @@ ) from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.mc_embeddingbag import ( + ManagedCollisionEmbeddingBagCollectionSharder, +) from torchrec.distributed.planner.constants import BIGINT_DTYPE from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.shard_estimators import ( @@ -649,6 +652,158 @@ def test_filtering(self) -> None: self.assertIn(sharding_option.compute_kernel, expected_compute_kernels) self.assertNotIn(sharding_option.compute_kernel, unexpected_compute_kernels) + def test_filter_sharding_types_ebc(self) -> None: + constraint = ParameterConstraints( + sharding_types=[ + ShardingType.TABLE_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = EmbeddingBagCollectionSharder() + allowed_sharding_types = enumerator._filter_sharding_types( + "table_0", sharder.sharding_types("cuda") + ) + + self.assertEqual( + set(allowed_sharding_types), + { + ShardingType.TABLE_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + }, + ) + + def test_filter_sharding_types_mch_ebc(self) -> None: + constraint = ParameterConstraints( + sharding_types=[ + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = ManagedCollisionEmbeddingBagCollectionSharder() + allowed_sharding_types = enumerator._filter_sharding_types( + "table_0", sharder.sharding_types("cuda") + ) + + self.assertEqual( + set(allowed_sharding_types), + { + ShardingType.ROW_WISE.value, + }, + ) + + def test_filter_sharding_types_mch_ebc_no_available(self) -> None: + constraint = ParameterConstraints( + sharding_types=[ + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = ManagedCollisionEmbeddingBagCollectionSharder() + with self.assertWarns(Warning): + allowed_sharding_types = enumerator._filter_sharding_types( + "table_0", sharder.sharding_types("cuda") + ) + + self.assertEqual(allowed_sharding_types, []) + + def test_filter_compute_kernels_ebc(self) -> None: + constraint = ParameterConstraints( + compute_kernels=[ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = EmbeddingBagCollectionSharder() + allowed_compute_kernels = enumerator._filter_compute_kernels( + "table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda") + ) + + self.assertEqual( + set(allowed_compute_kernels), + { + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM.value, + }, + ) + + def test_filter_compute_kernels_mch_ebc(self) -> None: + constraint = ParameterConstraints( + compute_kernels=[ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = ManagedCollisionEmbeddingBagCollectionSharder() + allowed_compute_kernels = enumerator._filter_compute_kernels( + "table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda") + ) + + self.assertEqual( + set(allowed_compute_kernels), + {EmbeddingComputeKernel.FUSED.value}, + ) + + def test_filter_compute_kernels_mch_ebc_no_available(self) -> None: + constraint = ParameterConstraints( + compute_kernels=[ + EmbeddingComputeKernel.DENSE.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + sharder = ManagedCollisionEmbeddingBagCollectionSharder() + with self.assertWarns(Warning): + allowed_compute_kernels = enumerator._filter_compute_kernels( + "table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda") + ) + + self.assertEqual(allowed_compute_kernels, []) + def test_tower_sharding(self) -> None: # five tables # tower_0: tables[2], tables[3]