Skip to content

Commit

Permalink
Replace to_dict to permute in QEBC (#1876)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1876

Replace to_dict + concat in QEBC with permute + split in QEBC

Reviewed By: IvanKobzarev, seanx92

Differential Revision: D56069966

fbshipit-source-id: 6df57958e21f6fe12b2f3fe64a03ad599af94977
  • Loading branch information
gnahzg authored and facebook-github-bot committed Apr 22, 2024
1 parent 205f5ba commit 7d6f3f4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 31 deletions.
51 changes: 24 additions & 27 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor:
return feature.lengths()


@torch.fx.wrap
def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]:
# this is a fx rule to help with batching hinting jagged sequence tensor coalescing.
return feature.keys()


def for_each_module_of_type_do(
module: nn.Module,
module_types: List[Type[torch.nn.Module]],
Expand Down Expand Up @@ -320,6 +326,8 @@ def __init__(
self._key_to_tables: Dict[
Tuple[PoolingType, DataType], List[EmbeddingBagConfig]
] = defaultdict(list)
self._feature_names: List[str] = []
self._feature_splits: List[int] = []
self._length_per_key: List[int] = []
# Registering in a List instead of ModuleList because we want don't want them to be auto-registered.
# Their states will be modified via self.embedding_bags
Expand Down Expand Up @@ -389,6 +397,11 @@ def __init__(
if weight_lists is None:
emb_module.initialize_weights()
self._emb_modules.append(emb_module)
for table in emb_configs:
self._feature_names.extend(table.feature_names)
self._feature_splits.append(
sum(table.num_features() for table in emb_configs)
)

ordered_tables = list(itertools.chain(*self._key_to_tables.values()))
self._embedding_names: List[str] = list(
Expand Down Expand Up @@ -462,47 +475,31 @@ def forward(
KeyedTensor
"""

feature_dict = self._kjt_to_jt_dict(features)
embeddings = []
kjt_keys = _get_kjt_keys(features)
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = features.permute(kjt_permute_order)
kjts_per_key = kjt_permute.split(self._feature_splits)

# TODO ideally we can accept KJTs with any feature order. However, this will require an order check + permute, which will break torch.script.
# Once torchsccript is no longer a requirement, we should revisit this.

for emb_op, (_key, tables) in zip(
self._emb_modules, self._key_to_tables.items()
for i, (emb_op, _) in enumerate(
zip(self._emb_modules, self._key_to_tables.keys())
):
indices = []
lengths = []
offsets = []
weights = []

for table in tables:
for feature in table.feature_names:
f = feature_dict[feature]
indices.append(f.values())
lengths.append(f.lengths())
if self._is_weighted:
weights.append(f.weights())

indices = torch.cat(indices)
lengths = torch.cat(lengths)

offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
if self._is_weighted:
weights = torch.cat(weights)
f = kjts_per_key[i]
indices = f.values()
offsets = f.offsets()

embeddings.append(
# Syntax for FX to generate call_module instead of call_function to keep TBE copied unchanged to fx.GraphModule, can be done only for registered module
emb_op(
indices=indices,
offsets=offsets,
per_sample_weights=weights if self._is_weighted else None,
per_sample_weights=f.weights() if self._is_weighted else None,
)
if self.register_tbes
else emb_op.forward(
indices=indices,
offsets=offsets,
per_sample_weights=weights if self._is_weighted else None,
per_sample_weights=f.weights() if self._is_weighted else None,
)
)

Expand Down
8 changes: 4 additions & 4 deletions torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,13 +459,13 @@ def test_trace_and_script(self) -> None:
)
self.assertEqual(
non_placeholder_nodes[0].op,
"call_module",
f"First non-placeholder node must be call_method, got {non_placeholder_nodes[0].op} instead",
"call_function",
f"First non-placeholder node must be call_function, got {non_placeholder_nodes[0].op} instead",
)
self.assertEqual(
non_placeholder_nodes[0].name,
"_kjt_to_jt_dict",
f"First non-placeholder node must be _kjt_to_jt_dict, got {non_placeholder_nodes[0].name} instead",
"_get_kjt_keys",
f"First non-placeholder node must be '_get_kjt_keys', got {non_placeholder_nodes[0].name} instead",
)

features = KeyedJaggedTensor(
Expand Down

0 comments on commit 7d6f3f4

Please sign in to comment.