Skip to content

Commit

Permalink
Cache the features order in QuantEBC
Browse files Browse the repository at this point in the history
Summary: The input KJT (features) will always have the same order of keys in inference which means the order, so we can get rid of the indexing operatitons during following calls of forward() by caching the permute order.

Differential Revision: D68991644
  • Loading branch information
Joshua Su authored and facebook-github-bot committed Feb 24, 2025
1 parent 7500a0f commit a486050
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 4 deletions.
7 changes: 7 additions & 0 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
EmbeddingCollection as QuantEmbeddingCollection,
FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection,
MODULE_ATTR_EMB_CONFIG_NAME_TO_NUM_ROWS_POST_PRUNING_DICT,
quant_prep_enable_cache_features_order,
quant_prep_enable_register_tbes,
)

Expand Down Expand Up @@ -428,6 +429,9 @@ def _quantize_fp_module(
"""

quant_prep_enable_register_tbes(model, [FeatureProcessedEmbeddingBagCollection])
quant_prep_enable_cache_features_order(
model, [FeatureProcessedEmbeddingBagCollection]
)
# pyre-fixme[16]: `FeatureProcessedEmbeddingBagCollection` has no attribute
# `qconfig`.
fp_module.qconfig = QuantConfig(
Expand Down Expand Up @@ -466,6 +470,9 @@ def _quantize_fp_module(
)

quant_prep_enable_register_tbes(model, list(additional_mapping.keys()))
quant_prep_enable_cache_features_order(
model, [EmbeddingBagCollection, EmbeddingCollection]
)
quantize_embeddings(
model,
dtype=quantization_dtype,
Expand Down
87 changes: 83 additions & 4 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from torchrec.types import ModuleNoCopyMixin

torch.fx.wrap("_get_batching_hinted_output")
torch.fx.wrap("len")

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -106,6 +107,8 @@

MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT: str = "__use_batching_hinted_output"

MODULE_ATTR_CACHE_FEATURES_ORDER: str = "__cache_features_order"

DEFAULT_ROW_ALIGNMENT = 16


Expand All @@ -120,6 +123,15 @@ def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]:
return feature.keys()


@torch.fx.wrap
def _permute_kjt(
features: KeyedJaggedTensor,
permute_order: List[int],
permute_order_tensor: Optional[Tensor] = None,
) -> KeyedJaggedTensor:
return features.permute(permute_order, permute_order_tensor)


@torch.fx.wrap
def _cat_embeddings(embeddings: List[Tensor]) -> Tensor:
return embeddings[0] if len(embeddings) == 1 else torch.cat(embeddings, dim=1)
Expand Down Expand Up @@ -177,6 +189,16 @@ def quant_prep_customize_row_alignment(
)


def quant_prep_enable_cache_features_order(
module: nn.Module, module_types: List[Type[torch.nn.Module]]
) -> None:
for_each_module_of_type_do(
module,
module_types,
lambda m: setattr(m, MODULE_ATTR_CACHE_FEATURES_ORDER, True),
)


def quantize_state_dict(
module: nn.Module,
table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]],
Expand Down Expand Up @@ -323,6 +345,7 @@ def __init__(
register_tbes: bool = False,
quant_state_dict_split_scale_bias: bool = False,
row_alignment: int = DEFAULT_ROW_ALIGNMENT,
cache_features_order: bool = False,
) -> None:
super().__init__()
self._is_weighted = is_weighted
Expand All @@ -333,6 +356,7 @@ def __init__(
self._feature_names: List[str] = []
self._feature_splits: List[int] = []
self._length_per_key: List[int] = []
self._features_order: List[int] = []
# Registering in a List instead of ModuleList because we want don't want them to be auto-registered.
# Their states will be modified via self.embedding_bags
self._emb_modules: List[nn.Module] = []
Expand Down Expand Up @@ -458,6 +482,7 @@ def __init__(
self.register_tbes = register_tbes
if register_tbes:
self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules)
setattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, cache_features_order)

def forward(
self,
Expand All @@ -473,8 +498,26 @@ def forward(

embeddings = []
kjt_keys = _get_kjt_keys(features)
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = features.permute(kjt_permute_order)
# Cache the features order since the features will always have the same order of keys in inference.
if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False):
if self._features_order == []:
for k in self._feature_names:
self._features_order.append(kjt_keys.index(k))
self.register_buffer(
"_features_order_tensor",
torch.tensor(
data=self._features_order,
device=features.device(),
dtype=torch.int32,
),
persistent=False,
)
kjt_permute = _permute_kjt(
features, self._features_order, self._features_order_tensor
)
else:
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = _permute_kjt(features, kjt_permute_order)
kjts_per_key = kjt_permute.split(self._feature_splits)

for i, (emb_op, _) in enumerate(
Expand Down Expand Up @@ -550,6 +593,9 @@ def from_float(
row_alignment=getattr(
module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT
),
cache_features_order=getattr(
module, MODULE_ATTR_CACHE_FEATURES_ORDER, False
),
)

def embedding_bag_configs(
Expand Down Expand Up @@ -584,6 +630,7 @@ def __init__(
# feature processor is Optional only for the sake of the last position in constructor
# Enforcing it to be non-None, for None case EmbeddingBagCollection must be used.
feature_processor: Optional[FeatureProcessorsCollection] = None,
cache_features_order: bool = False,
) -> None:
super().__init__(
tables,
Expand All @@ -594,6 +641,7 @@ def __init__(
register_tbes,
quant_state_dict_split_scale_bias,
row_alignment,
cache_features_order,
)
assert (
feature_processor is not None
Expand Down Expand Up @@ -661,6 +709,9 @@ def from_float(
),
# pyre-ignore
feature_processor=fp_ebc._feature_processors,
cache_features_order=getattr(
module, MODULE_ATTR_CACHE_FEATURES_ORDER, False
),
)


Expand All @@ -687,6 +738,7 @@ def __init__( # noqa C901
register_tbes: bool = False,
quant_state_dict_split_scale_bias: bool = False,
row_alignment: int = DEFAULT_ROW_ALIGNMENT,
cache_features_order: bool = False,
) -> None:
super().__init__()
self._emb_modules: List[IntNBitTableBatchedEmbeddingBagsCodegen] = []
Expand All @@ -698,6 +750,7 @@ def __init__( # noqa C901
self.row_alignment = row_alignment
self._key_to_tables: Dict[DataType, List[EmbeddingConfig]] = defaultdict(list)
self._feature_names: List[str] = []
self._features_order: List[int] = []

self._table_name_to_quantized_weights: Optional[
Dict[str, Tuple[Tensor, Tensor]]
Expand Down Expand Up @@ -808,6 +861,7 @@ def __init__( # noqa C901
self.register_tbes = register_tbes
if register_tbes:
self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules)
setattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, cache_features_order)

def forward(
self,
Expand All @@ -823,9 +877,28 @@ def forward(

feature_embeddings: Dict[str, JaggedTensor] = {}
kjt_keys = _get_kjt_keys(features)
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = features.permute(kjt_permute_order)
# Cache the features order since the features will always have the same order of keys in inference.
if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False):
if self._features_order == []:
for k in self._feature_names:
self._features_order.append(kjt_keys.index(k))
self.register_buffer(
"_features_order_tensor",
torch.tensor(
data=self._features_order,
device=features.device(),
dtype=torch.int32,
),
persistent=False,
)
kjt_permute = _permute_kjt(
features, self._features_order, self._features_order_tensor
)
else:
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = _permute_kjt(features, kjt_permute_order)
kjts_per_key = kjt_permute.split(self._feature_splits)

for i, (emb_module, key) in enumerate(
zip(self._emb_modules, self._key_to_tables.keys())
):
Expand Down Expand Up @@ -896,6 +969,9 @@ def from_float(
row_alignment=getattr(
module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT
),
cache_features_order=getattr(
module, MODULE_ATTR_CACHE_FEATURES_ORDER, False
),
)

def _get_name(self) -> str:
Expand Down Expand Up @@ -953,6 +1029,7 @@ def __init__(
row_alignment: int = DEFAULT_ROW_ALIGNMENT,
managed_collision_collection: Optional[ManagedCollisionCollection] = None,
return_remapped_features: bool = False,
cache_features_order: bool = False,
) -> None:
super().__init__(
tables,
Expand All @@ -963,6 +1040,7 @@ def __init__(
register_tbes,
quant_state_dict_split_scale_bias,
row_alignment,
cache_features_order,
)
assert (
managed_collision_collection
Expand Down Expand Up @@ -1062,4 +1140,5 @@ def from_float(
),
managed_collision_collection=mc_ec._managed_collision_collection,
return_remapped_features=mc_ec._return_remapped_features,
cache_features_order=getattr(ec, MODULE_ATTR_CACHE_FEATURES_ORDER, False),
)

0 comments on commit a486050

Please sign in to comment.