From a3a74bafccbb06470bed46b008de1e9f311d38a4 Mon Sep 17 00:00:00 2001 From: Joshua Su Date: Mon, 24 Feb 2025 16:41:57 -0800 Subject: [PATCH] Cache the features order in QuantEBC (#2762) Summary: The input KJT (features) will always have the same order of keys in inference which means the order, so we can get rid of the indexing operatitons during following calls of forward() by caching the permute order. Differential Revision: D68991644 --- torchrec/inference/modules.py | 7 +++ torchrec/quant/embedding_modules.py | 87 +++++++++++++++++++++++++++-- 2 files changed, 90 insertions(+), 4 deletions(-) diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index 4d136f488..e08ba9ab3 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -66,6 +66,7 @@ EmbeddingCollection as QuantEmbeddingCollection, FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection, MODULE_ATTR_EMB_CONFIG_NAME_TO_NUM_ROWS_POST_PRUNING_DICT, + quant_prep_enable_cache_features_order, quant_prep_enable_register_tbes, ) @@ -428,6 +429,9 @@ def _quantize_fp_module( """ quant_prep_enable_register_tbes(model, [FeatureProcessedEmbeddingBagCollection]) + quant_prep_enable_cache_features_order( + model, [FeatureProcessedEmbeddingBagCollection] + ) # pyre-fixme[16]: `FeatureProcessedEmbeddingBagCollection` has no attribute # `qconfig`. fp_module.qconfig = QuantConfig( @@ -466,6 +470,9 @@ def _quantize_fp_module( ) quant_prep_enable_register_tbes(model, list(additional_mapping.keys())) + quant_prep_enable_cache_features_order( + model, [EmbeddingBagCollection, EmbeddingCollection] + ) quantize_embeddings( model, dtype=quantization_dtype, diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 06239ff7f..5eb3b54a3 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -73,6 +73,7 @@ from torchrec.types import ModuleNoCopyMixin torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("len") try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -106,6 +107,8 @@ MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT: str = "__use_batching_hinted_output" +MODULE_ATTR_CACHE_FEATURES_ORDER: str = "__cache_features_order" + DEFAULT_ROW_ALIGNMENT = 16 @@ -120,6 +123,15 @@ def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]: return feature.keys() +@torch.fx.wrap +def _permute_kjt( + features: KeyedJaggedTensor, + permute_order: List[int], + permute_order_tensor: Optional[Tensor] = None, +) -> KeyedJaggedTensor: + return features.permute(permute_order, permute_order_tensor) + + @torch.fx.wrap def _cat_embeddings(embeddings: List[Tensor]) -> Tensor: return embeddings[0] if len(embeddings) == 1 else torch.cat(embeddings, dim=1) @@ -177,6 +189,16 @@ def quant_prep_customize_row_alignment( ) +def quant_prep_enable_cache_features_order( + module: nn.Module, module_types: List[Type[torch.nn.Module]] +) -> None: + for_each_module_of_type_do( + module, + module_types, + lambda m: setattr(m, MODULE_ATTR_CACHE_FEATURES_ORDER, True), + ) + + def quantize_state_dict( module: nn.Module, table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]], @@ -323,6 +345,7 @@ def __init__( register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = DEFAULT_ROW_ALIGNMENT, + cache_features_order: bool = False, ) -> None: super().__init__() self._is_weighted = is_weighted @@ -333,6 +356,7 @@ def __init__( self._feature_names: List[str] = [] self._feature_splits: List[int] = [] self._length_per_key: List[int] = [] + self._features_order: List[int] = [] # Registering in a List instead of ModuleList because we want don't want them to be auto-registered. # Their states will be modified via self.embedding_bags self._emb_modules: List[nn.Module] = [] @@ -458,6 +482,7 @@ def __init__( self.register_tbes = register_tbes if register_tbes: self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules) + setattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, cache_features_order) def forward( self, @@ -473,8 +498,26 @@ def forward( embeddings = [] kjt_keys = _get_kjt_keys(features) - kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] - kjt_permute = features.permute(kjt_permute_order) + # Cache the features order since the features will always have the same order of keys in inference. + if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False): + if self._features_order == []: + for k in self._feature_names: + self._features_order.append(kjt_keys.index(k)) + self.register_buffer( + "_features_order_tensor", + torch.tensor( + data=self._features_order, + device=features.device(), + dtype=torch.int32, + ), + persistent=False, + ) + kjt_permute = _permute_kjt( + features, self._features_order, self._features_order_tensor + ) + else: + kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] + kjt_permute = _permute_kjt(features, kjt_permute_order) kjts_per_key = kjt_permute.split(self._feature_splits) for i, (emb_op, _) in enumerate( @@ -550,6 +593,9 @@ def from_float( row_alignment=getattr( module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT ), + cache_features_order=getattr( + module, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ), ) def embedding_bag_configs( @@ -584,6 +630,7 @@ def __init__( # feature processor is Optional only for the sake of the last position in constructor # Enforcing it to be non-None, for None case EmbeddingBagCollection must be used. feature_processor: Optional[FeatureProcessorsCollection] = None, + cache_features_order: bool = False, ) -> None: super().__init__( tables, @@ -594,6 +641,7 @@ def __init__( register_tbes, quant_state_dict_split_scale_bias, row_alignment, + cache_features_order, ) assert ( feature_processor is not None @@ -661,6 +709,9 @@ def from_float( ), # pyre-ignore feature_processor=fp_ebc._feature_processors, + cache_features_order=getattr( + module, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ), ) @@ -687,6 +738,7 @@ def __init__( # noqa C901 register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = DEFAULT_ROW_ALIGNMENT, + cache_features_order: bool = False, ) -> None: super().__init__() self._emb_modules: List[IntNBitTableBatchedEmbeddingBagsCodegen] = [] @@ -698,6 +750,7 @@ def __init__( # noqa C901 self.row_alignment = row_alignment self._key_to_tables: Dict[DataType, List[EmbeddingConfig]] = defaultdict(list) self._feature_names: List[str] = [] + self._features_order: List[int] = [] self._table_name_to_quantized_weights: Optional[ Dict[str, Tuple[Tensor, Tensor]] @@ -808,6 +861,7 @@ def __init__( # noqa C901 self.register_tbes = register_tbes if register_tbes: self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules) + setattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, cache_features_order) def forward( self, @@ -823,9 +877,28 @@ def forward( feature_embeddings: Dict[str, JaggedTensor] = {} kjt_keys = _get_kjt_keys(features) - kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] - kjt_permute = features.permute(kjt_permute_order) + # Cache the features order since the features will always have the same order of keys in inference. + if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False): + if self._features_order == []: + for k in self._feature_names: + self._features_order.append(kjt_keys.index(k)) + self.register_buffer( + "_features_order_tensor", + torch.tensor( + data=self._features_order, + device=features.device(), + dtype=torch.int32, + ), + persistent=False, + ) + kjt_permute = _permute_kjt( + features, self._features_order, self._features_order_tensor + ) + else: + kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] + kjt_permute = _permute_kjt(features, kjt_permute_order) kjts_per_key = kjt_permute.split(self._feature_splits) + for i, (emb_module, key) in enumerate( zip(self._emb_modules, self._key_to_tables.keys()) ): @@ -896,6 +969,9 @@ def from_float( row_alignment=getattr( module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT ), + cache_features_order=getattr( + module, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ), ) def _get_name(self) -> str: @@ -953,6 +1029,7 @@ def __init__( row_alignment: int = DEFAULT_ROW_ALIGNMENT, managed_collision_collection: Optional[ManagedCollisionCollection] = None, return_remapped_features: bool = False, + cache_features_order: bool = False, ) -> None: super().__init__( tables, @@ -963,6 +1040,7 @@ def __init__( register_tbes, quant_state_dict_split_scale_bias, row_alignment, + cache_features_order, ) assert ( managed_collision_collection @@ -1062,4 +1140,5 @@ def from_float( ), managed_collision_collection=mc_ec._managed_collision_collection, return_remapped_features=mc_ec._return_remapped_features, + cache_features_order=getattr(ec, MODULE_ATTR_CACHE_FEATURES_ORDER, False), )