Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache the features order in QuantEBC #2762

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
EmbeddingCollection as QuantEmbeddingCollection,
FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection,
MODULE_ATTR_EMB_CONFIG_NAME_TO_NUM_ROWS_POST_PRUNING_DICT,
quant_prep_enable_cache_features_order,
quant_prep_enable_register_tbes,
)

Expand Down Expand Up @@ -428,6 +429,9 @@ def _quantize_fp_module(
"""

quant_prep_enable_register_tbes(model, [FeatureProcessedEmbeddingBagCollection])
quant_prep_enable_cache_features_order(
model, [FeatureProcessedEmbeddingBagCollection]
)
# pyre-fixme[16]: `FeatureProcessedEmbeddingBagCollection` has no attribute
# `qconfig`.
fp_module.qconfig = QuantConfig(
Expand Down Expand Up @@ -466,6 +470,9 @@ def _quantize_fp_module(
)

quant_prep_enable_register_tbes(model, list(additional_mapping.keys()))
quant_prep_enable_cache_features_order(
model, [EmbeddingBagCollection, EmbeddingCollection]
)
quantize_embeddings(
model,
dtype=quantization_dtype,
Expand Down
87 changes: 83 additions & 4 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from torchrec.types import ModuleNoCopyMixin

torch.fx.wrap("_get_batching_hinted_output")
torch.fx.wrap("len")

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -106,6 +107,8 @@

MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT: str = "__use_batching_hinted_output"

MODULE_ATTR_CACHE_FEATURES_ORDER: str = "__cache_features_order"

DEFAULT_ROW_ALIGNMENT = 16


Expand All @@ -120,6 +123,15 @@ def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]:
return feature.keys()


@torch.fx.wrap
def _permute_kjt(
features: KeyedJaggedTensor,
permute_order: List[int],
permute_order_tensor: Optional[Tensor] = None,
) -> KeyedJaggedTensor:
return features.permute(permute_order, permute_order_tensor)


@torch.fx.wrap
def _cat_embeddings(embeddings: List[Tensor]) -> Tensor:
return embeddings[0] if len(embeddings) == 1 else torch.cat(embeddings, dim=1)
Expand Down Expand Up @@ -177,6 +189,16 @@ def quant_prep_customize_row_alignment(
)


def quant_prep_enable_cache_features_order(
module: nn.Module, module_types: List[Type[torch.nn.Module]]
) -> None:
for_each_module_of_type_do(
module,
module_types,
lambda m: setattr(m, MODULE_ATTR_CACHE_FEATURES_ORDER, True),
)


def quantize_state_dict(
module: nn.Module,
table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]],
Expand Down Expand Up @@ -323,6 +345,7 @@ def __init__(
register_tbes: bool = False,
quant_state_dict_split_scale_bias: bool = False,
row_alignment: int = DEFAULT_ROW_ALIGNMENT,
cache_features_order: bool = False,
) -> None:
super().__init__()
self._is_weighted = is_weighted
Expand All @@ -333,6 +356,7 @@ def __init__(
self._feature_names: List[str] = []
self._feature_splits: List[int] = []
self._length_per_key: List[int] = []
self._features_order: List[int] = []
# Registering in a List instead of ModuleList because we want don't want them to be auto-registered.
# Their states will be modified via self.embedding_bags
self._emb_modules: List[nn.Module] = []
Expand Down Expand Up @@ -458,6 +482,7 @@ def __init__(
self.register_tbes = register_tbes
if register_tbes:
self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules)
setattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, cache_features_order)

def forward(
self,
Expand All @@ -473,8 +498,26 @@ def forward(

embeddings = []
kjt_keys = _get_kjt_keys(features)
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = features.permute(kjt_permute_order)
# Cache the features order since the features will always have the same order of keys in inference.
if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False):
if self._features_order == []:
for k in self._feature_names:
self._features_order.append(kjt_keys.index(k))
self.register_buffer(
"_features_order_tensor",
torch.tensor(
data=self._features_order,
device=features.device(),
dtype=torch.int32,
),
persistent=False,
)
kjt_permute = _permute_kjt(
features, self._features_order, self._features_order_tensor
)
else:
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = _permute_kjt(features, kjt_permute_order)
kjts_per_key = kjt_permute.split(self._feature_splits)

for i, (emb_op, _) in enumerate(
Expand Down Expand Up @@ -550,6 +593,9 @@ def from_float(
row_alignment=getattr(
module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT
),
cache_features_order=getattr(
module, MODULE_ATTR_CACHE_FEATURES_ORDER, False
),
)

def embedding_bag_configs(
Expand Down Expand Up @@ -584,6 +630,7 @@ def __init__(
# feature processor is Optional only for the sake of the last position in constructor
# Enforcing it to be non-None, for None case EmbeddingBagCollection must be used.
feature_processor: Optional[FeatureProcessorsCollection] = None,
cache_features_order: bool = False,
) -> None:
super().__init__(
tables,
Expand All @@ -594,6 +641,7 @@ def __init__(
register_tbes,
quant_state_dict_split_scale_bias,
row_alignment,
cache_features_order,
)
assert (
feature_processor is not None
Expand Down Expand Up @@ -661,6 +709,9 @@ def from_float(
),
# pyre-ignore
feature_processor=fp_ebc._feature_processors,
cache_features_order=getattr(
module, MODULE_ATTR_CACHE_FEATURES_ORDER, False
),
)


Expand All @@ -687,6 +738,7 @@ def __init__( # noqa C901
register_tbes: bool = False,
quant_state_dict_split_scale_bias: bool = False,
row_alignment: int = DEFAULT_ROW_ALIGNMENT,
cache_features_order: bool = False,
) -> None:
super().__init__()
self._emb_modules: List[IntNBitTableBatchedEmbeddingBagsCodegen] = []
Expand All @@ -698,6 +750,7 @@ def __init__( # noqa C901
self.row_alignment = row_alignment
self._key_to_tables: Dict[DataType, List[EmbeddingConfig]] = defaultdict(list)
self._feature_names: List[str] = []
self._features_order: List[int] = []

self._table_name_to_quantized_weights: Optional[
Dict[str, Tuple[Tensor, Tensor]]
Expand Down Expand Up @@ -808,6 +861,7 @@ def __init__( # noqa C901
self.register_tbes = register_tbes
if register_tbes:
self.tbes: torch.nn.ModuleList = torch.nn.ModuleList(self._emb_modules)
setattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, cache_features_order)

def forward(
self,
Expand All @@ -823,9 +877,28 @@ def forward(

feature_embeddings: Dict[str, JaggedTensor] = {}
kjt_keys = _get_kjt_keys(features)
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = features.permute(kjt_permute_order)
# Cache the features order since the features will always have the same order of keys in inference.
if getattr(self, MODULE_ATTR_CACHE_FEATURES_ORDER, False):
if self._features_order == []:
for k in self._feature_names:
self._features_order.append(kjt_keys.index(k))
self.register_buffer(
"_features_order_tensor",
torch.tensor(
data=self._features_order,
device=features.device(),
dtype=torch.int32,
),
persistent=False,
)
kjt_permute = _permute_kjt(
features, self._features_order, self._features_order_tensor
)
else:
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = _permute_kjt(features, kjt_permute_order)
kjts_per_key = kjt_permute.split(self._feature_splits)

for i, (emb_module, key) in enumerate(
zip(self._emb_modules, self._key_to_tables.keys())
):
Expand Down Expand Up @@ -896,6 +969,9 @@ def from_float(
row_alignment=getattr(
module, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT
),
cache_features_order=getattr(
module, MODULE_ATTR_CACHE_FEATURES_ORDER, False
),
)

def _get_name(self) -> str:
Expand Down Expand Up @@ -953,6 +1029,7 @@ def __init__(
row_alignment: int = DEFAULT_ROW_ALIGNMENT,
managed_collision_collection: Optional[ManagedCollisionCollection] = None,
return_remapped_features: bool = False,
cache_features_order: bool = False,
) -> None:
super().__init__(
tables,
Expand All @@ -963,6 +1040,7 @@ def __init__(
register_tbes,
quant_state_dict_split_scale_bias,
row_alignment,
cache_features_order,
)
assert (
managed_collision_collection
Expand Down Expand Up @@ -1062,4 +1140,5 @@ def from_float(
),
managed_collision_collection=mc_ec._managed_collision_collection,
return_remapped_features=mc_ec._return_remapped_features,
cache_features_order=getattr(ec, MODULE_ATTR_CACHE_FEATURES_ORDER, False),
)
Loading