diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index 4d136f488..e08ba9ab3 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -66,6 +66,7 @@ EmbeddingCollection as QuantEmbeddingCollection, FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection, MODULE_ATTR_EMB_CONFIG_NAME_TO_NUM_ROWS_POST_PRUNING_DICT, + quant_prep_enable_cache_features_order, quant_prep_enable_register_tbes, ) @@ -428,6 +429,9 @@ def _quantize_fp_module( """ quant_prep_enable_register_tbes(model, [FeatureProcessedEmbeddingBagCollection]) + quant_prep_enable_cache_features_order( + model, [FeatureProcessedEmbeddingBagCollection] + ) # pyre-fixme[16]: `FeatureProcessedEmbeddingBagCollection` has no attribute # `qconfig`. fp_module.qconfig = QuantConfig( @@ -466,6 +470,9 @@ def _quantize_fp_module( ) quant_prep_enable_register_tbes(model, list(additional_mapping.keys())) + quant_prep_enable_cache_features_order( + model, [EmbeddingBagCollection, EmbeddingCollection] + ) quantize_embeddings( model, dtype=quantization_dtype, diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 06239ff7f..5eb3b54a3 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -73,6 +73,7 @@ from torchrec.types import ModuleNoCopyMixin torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("len") try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -106,6 +107,8 @@ MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT: str = "__use_batching_hinted_output" +MODULE_ATTR_CACHE_FEATURES_ORDER: str = "__cache_features_order" + DEFAULT_ROW_ALIGNMENT = 16 @@ -120,6 +123,15 @@ def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]: return feature.keys() +@torch.fx.wrap +def _permute_kjt( + features: KeyedJaggedTensor, + permute_order: List[int], + permute_order_tensor: Optional[Tensor] = None, +) -> KeyedJaggedTensor: + return features.permute(permute_order, permute_order_tensor) + + @torch.fx.wrap def _cat_embeddings(embeddings: List[Tensor]) -> Tensor: return embeddings[0] if len(embeddings) == 1 else torch.cat(embeddings, dim=1) @@ -177,6 +189,16 @@ def quant_prep_customize_row_alignment( ) +def quant_prep_enable_cache_features_order( + module: nn.Module, module_types: List[Type[torch.nn.Module]] +) -> None: + for_each_module_of_type_do( + module, + module_types, + lambda m: setattr(m, MODULE_ATTR_CACHE_FEATURES_ORDER, True), + ) + + def quantize_state_dict( module: nn.Module, table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]], @@ -323,6 +345,7 @@ def __init__( register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = DEFAULT_ROW_ALIGNMENT, + cache_features_order: bool = False, ) -> None: super().__init__() self._is_weighted = is_weighted @@ -333,6 +356,7 @@ def __init__( self._feature_names: List[str] = [] self._feature_splits: List[int] = [] self._length_per_key: List[int] = [] + self._features_order: List[int] = [] # Registering in a List instead of ModuleList because we want don't want them to be auto-registered. # Their states will be modified via self.embedding_bags self._emb_modules: List[nn.Module] = [] @@ -458,6 +482,7 @@ def __init__( self.register_tbes = register_tbes if register_tbes: self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules) + setattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, cache_features_order) def forward( self, @@ -473,8 +498,26 @@ def forward( embeddings = [] kjt_keys = _get_kjt_keys(features) - kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] - kjt_permute = features.permute(kjt_permute_order) + # Cache the features order since the features will always have the same order of keys in inference. + if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False): + if self._features_order == []: + for k in self._feature_names: + self._features_order.append(kjt_keys.index(k)) + self.register_buffer( + "_features_order_tensor", + torch.tensor( + data=self._features_order, + device=features.device(), + dtype=torch.int32, + ), + persistent=False, + ) + kjt_permute = _permute_kjt( + features, self._features_order, self._features_order_tensor + ) + else: + kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] + kjt_permute = _permute_kjt(features, kjt_permute_order) kjts_per_key = kjt_permute.split(self._feature_splits) for i, (emb_op, _) in enumerate( @@ -550,6 +593,9 @@ def from_float( row_alignment=getattr( module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT ), + cache_features_order=getattr( + module, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ), ) def embedding_bag_configs( @@ -584,6 +630,7 @@ def __init__( # feature processor is Optional only for the sake of the last position in constructor # Enforcing it to be non-None, for None case EmbeddingBagCollection must be used. feature_processor: Optional[FeatureProcessorsCollection] = None, + cache_features_order: bool = False, ) -> None: super().__init__( tables, @@ -594,6 +641,7 @@ def __init__( register_tbes, quant_state_dict_split_scale_bias, row_alignment, + cache_features_order, ) assert ( feature_processor is not None @@ -661,6 +709,9 @@ def from_float( ), # pyre-ignore feature_processor=fp_ebc._feature_processors, + cache_features_order=getattr( + module, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ), ) @@ -687,6 +738,7 @@ def __init__( # noqa C901 register_tbes: bool = False, quant_state_dict_split_scale_bias: bool = False, row_alignment: int = DEFAULT_ROW_ALIGNMENT, + cache_features_order: bool = False, ) -> None: super().__init__() self._emb_modules: List[IntNBitTableBatchedEmbeddingBagsCodegen] = [] @@ -698,6 +750,7 @@ def __init__( # noqa C901 self.row_alignment = row_alignment self._key_to_tables: Dict[DataType, List[EmbeddingConfig]] = defaultdict(list) self._feature_names: List[str] = [] + self._features_order: List[int] = [] self._table_name_to_quantized_weights: Optional[ Dict[str, Tuple[Tensor, Tensor]] @@ -808,6 +861,7 @@ def __init__( # noqa C901 self.register_tbes = register_tbes if register_tbes: self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules) + setattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, cache_features_order) def forward( self, @@ -823,9 +877,28 @@ def forward( feature_embeddings: Dict[str, JaggedTensor] = {} kjt_keys = _get_kjt_keys(features) - kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] - kjt_permute = features.permute(kjt_permute_order) + # Cache the features order since the features will always have the same order of keys in inference. + if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False): + if self._features_order == []: + for k in self._feature_names: + self._features_order.append(kjt_keys.index(k)) + self.register_buffer( + "_features_order_tensor", + torch.tensor( + data=self._features_order, + device=features.device(), + dtype=torch.int32, + ), + persistent=False, + ) + kjt_permute = _permute_kjt( + features, self._features_order, self._features_order_tensor + ) + else: + kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names] + kjt_permute = _permute_kjt(features, kjt_permute_order) kjts_per_key = kjt_permute.split(self._feature_splits) + for i, (emb_module, key) in enumerate( zip(self._emb_modules, self._key_to_tables.keys()) ): @@ -896,6 +969,9 @@ def from_float( row_alignment=getattr( module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT ), + cache_features_order=getattr( + module, MODULE_ATTR_CACHE_FEATURES_ORDER, False + ), ) def _get_name(self) -> str: @@ -953,6 +1029,7 @@ def __init__( row_alignment: int = DEFAULT_ROW_ALIGNMENT, managed_collision_collection: Optional[ManagedCollisionCollection] = None, return_remapped_features: bool = False, + cache_features_order: bool = False, ) -> None: super().__init__( tables, @@ -963,6 +1040,7 @@ def __init__( register_tbes, quant_state_dict_split_scale_bias, row_alignment, + cache_features_order, ) assert ( managed_collision_collection @@ -1062,4 +1140,5 @@ def from_float( ), managed_collision_collection=mc_ec._managed_collision_collection, return_remapped_features=mc_ec._return_remapped_features, + cache_features_order=getattr(ec, MODULE_ATTR_CACHE_FEATURES_ORDER, False), )