From a649b4e7034e4922fe6f6ea64e1490ac81301f3b Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Tue, 30 Apr 2024 23:30:01 -0700 Subject: [PATCH] Create BoundsCheckMode fused_param for inference Summary: Introduce BoundsCheckMode fused_param for TBE BoundsCheckMode. There is no reason really to run bounds_check_indices during inference use case (AIMP has it off by default: https://fburl.com/code/q8zhundg), and it causes issues with the PT2 IR (bounds_check_indices is a mutating op) Reviewed By: ZhengkaiZ Differential Revision: D56743992 fbshipit-source-id: 1a474170bd5a16f6ef9ddc8d047eca6739d213fc --- torchrec/distributed/fused_params.py | 13 +++++++++++++ torchrec/distributed/quant_embedding_kernel.py | 10 +++++++++- torchrec/distributed/tests/test_pt2.py | 18 +++++++++++++++++- torchrec/inference/modules.py | 9 ++++++++- 4 files changed, 47 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/fused_params.py b/torchrec/distributed/fused_params.py index 491262fb0..7fb3985c1 100644 --- a/torchrec/distributed/fused_params.py +++ b/torchrec/distributed/fused_params.py @@ -15,6 +15,7 @@ IntNBitTableBatchedEmbeddingBagsCodegen, ) from torchrec.distributed.embedding_types import GroupedEmbeddingConfig +from torchrec.distributed.types import BoundsCheckMode FUSED_PARAM_REGISTER_TBE_BOOL: str = "__register_tbes_in_named_modules" FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: str = ( @@ -22,6 +23,7 @@ ) FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment" FUSED_PARAM_IS_WEIGHTED: str = "__register_tbe_is_weighted" +FUSED_PARAM_BOUNDS_CHECK_MODE: str = "__register_tbe_bounds_check_mode" class TBEToRegisterMixIn: @@ -65,6 +67,15 @@ def is_fused_param_weighted(fused_params: Optional[Dict[str, Any]]) -> Optional[ return fused_params[FUSED_PARAM_IS_WEIGHTED] +def fused_param_bounds_check_mode( + fused_params: Optional[Dict[str, Any]] +) -> Optional[BoundsCheckMode]: + if fused_params is None or FUSED_PARAM_BOUNDS_CHECK_MODE not in fused_params: + return None + else: + return fused_params[FUSED_PARAM_BOUNDS_CHECK_MODE] + + def is_fused_param_quant_state_dict_split_scale_bias( fused_params: Optional[Dict[str, Any]] ) -> bool: @@ -90,5 +101,7 @@ def tbe_fused_params( fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT) if FUSED_PARAM_IS_WEIGHTED in fused_params_for_tbe: fused_params_for_tbe.pop(FUSED_PARAM_IS_WEIGHTED) + if FUSED_PARAM_BOUNDS_CHECK_MODE in fused_params_for_tbe: + fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE) return fused_params_for_tbe diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 01ba10a58..13a5d4651 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -32,13 +32,14 @@ GroupedEmbeddingConfig, ) from torchrec.distributed.fused_params import ( - get_fused_param_tbe_row_alignment, + fused_param_bounds_check_mode, is_fused_param_quant_state_dict_split_scale_bias, is_fused_param_register_tbe, is_fused_param_weighted, tbe_fused_params, TBEToRegisterMixIn, ) +from torchrec.distributed.types import BoundsCheckMode from torchrec.distributed.utils import append_prefix from torchrec.modules.embedding_configs import ( DATA_TYPE_NUM_BITS, @@ -197,6 +198,10 @@ def __init__( self._quant_state_dict_split_scale_bias: bool = ( is_fused_param_quant_state_dict_split_scale_bias(fused_params) ) + bounds_check_mode: Optional[BoundsCheckMode] = fused_param_bounds_check_mode( + fused_params + ) + index_remapping = [ table.pruning_indices_remapping for table in config.embedding_tables ] @@ -233,6 +238,9 @@ def __init__( feature_table_map=self._feature_table_map, row_alignment=self._tbe_row_alignment, uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + bounds_check_mode=( + bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING + ), **(tbe_fused_params(fused_params) or {}), ) ) diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 987c43352..8bf36ecaf 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -24,6 +24,7 @@ from fbgemm_gpu import sparse_ops # noqa: F401, E402 from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.fused_params import FUSED_PARAM_BOUNDS_CHECK_MODE from torchrec.distributed.shard import _shard_modules from torchrec.distributed.test_utils.infer_utils import ( assert_close, @@ -35,7 +36,7 @@ replace_sharded_quant_modules_tbes_with_mock_tbes, TestQuantEBCSharder, ) -from torchrec.distributed.types import ShardingEnv, ShardingType +from torchrec.distributed.types import BoundsCheckMode, ShardingEnv, ShardingType from torchrec.sparse.jagged_tensor import ComputeKJTToJTDict, KeyedJaggedTensor @@ -82,17 +83,22 @@ def _sharded_quant_ebc_model( sharding_type: ShardingType = ShardingType.TABLE_WISE + fused_params = { + FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE, + } if feature_processor: sharder = TestQuantFPEBCSharder( sharding_type=sharding_type.value, kernel_type=EmbeddingComputeKernel.QUANT.value, shardable_params=[table.name for table in mi.tables], + fused_params=fused_params, ) else: sharder = TestQuantEBCSharder( sharding_type=sharding_type.value, kernel_type=EmbeddingComputeKernel.QUANT.value, shardable_params=[table.name for table in mi.tables], + fused_params=fused_params, ) # pyre-ignore plan = mi.planner.plan( @@ -308,6 +314,11 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None: ep.module()(kjt.values(), kjt.lengths()) + # PT2 IR autofunctionalizes mutation funcs (bounds_check_indices) + # ensure such node isn't present, as it causes issues with IR + for n in ep.graph_module.graph.nodes: + self.assertFalse("auto_functionalized" in str(n.name)) + # TODO: Fix Unflatten # torch.export.unflatten(ep) @@ -338,6 +349,11 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None: ) ep.module()(kjt.values(), kjt.lengths()) + # PT2 IR autofunctionalizes mutation funcs (bounds_check_indices) + # ensure such node isn't present, as it causes issues with IR + for n in ep.graph_module.graph.nodes: + self.assertFalse("auto_functionalized" in str(n.name)) + # TODO: Fix Unflatten # torch.export.unflatten(ep) diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index 97f8b165f..27db09280 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -21,6 +21,7 @@ from torchrec import distributed as trec_dist, inference as trec_infer from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.fused_params import ( + FUSED_PARAM_BOUNDS_CHECK_MODE, FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, FUSED_PARAM_REGISTER_TBE_BOOL, ) @@ -36,7 +37,12 @@ QuantFeatureProcessedEmbeddingBagCollectionSharder, ) from torchrec.distributed.shard import _shard_modules -from torchrec.distributed.types import ModuleSharder, ShardingPlan, ShardingType +from torchrec.distributed.types import ( + BoundsCheckMode, + ModuleSharder, + ShardingPlan, + ShardingType, +) from torchrec.modules.embedding_configs import QuantConfig from torchrec.modules.embedding_modules import ( @@ -375,6 +381,7 @@ def shard_quant_model( _fused_param: Dict[str, Any] = { FUSED_PARAM_REGISTER_TBE_BOOL: True, FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: True, + FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE, } _sharders: List[ModuleSharder[torch.nn.Module]] = [