Skip to content

Commit

Permalink
Fix embedding order in EmbeddingBagCollection (#1646)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1646

1. The bug was introduced in D37352626 where we grouped tables by (pooling, data_type), but still used the original order for _embedding_names and _length_per_key.
1. Another bug was also introduced in D37352626 where we stacked embeddings with shape (batch, dim) along the batch dimension, but we should concat along the dim dimension.

The unit tests didn't capture any of these bugs, because:
1. The numerical tolerance was too loose (1.0).  This diff tightens it to 0.1.
2. The test inputs were not sufficient to trigger the bugs, e.g. the inputs only had one batch, one grouping key, and one embedding dim.

Reviewed By: YazhiGao

Differential Revision: D52894935

fbshipit-source-id: e4c01013c9127cc1d2ad64203e2ba9fe6fdc7f35
  • Loading branch information
yingufan authored and facebook-github-bot committed Jan 23, 2024
1 parent be17fd7 commit b34db7d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
17 changes: 8 additions & 9 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,9 @@ def __init__(
if table.name in table_names:
raise ValueError(f"Duplicate table name {table.name}")
table_names.add(table.name)
self._length_per_key.extend(
[table.embedding_dim] * len(table.feature_names)
)
key = (table.pooling, table.data_type)
self._key_to_tables[key].append(table)

self._sum_length_per_key: int = sum(self._length_per_key)

location = (
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
)
Expand Down Expand Up @@ -381,9 +376,15 @@ def __init__(
emb_module.initialize_weights()
self._emb_modules.append(emb_module)

ordered_tables = list(itertools.chain(*self._key_to_tables.values()))
self._embedding_names: List[str] = list(
itertools.chain(*get_embedding_names_by_table(self._embedding_bag_configs))
itertools.chain(*get_embedding_names_by_table(ordered_tables))
)
for table in ordered_tables:
self._length_per_key.extend(
[table.embedding_dim] * len(table.feature_names)
)

# We map over the parameters from FBGEMM backed kernels to the canonical nn.EmbeddingBag
# representation. This provides consistency between this class and the EmbeddingBagCollection
# nn.Module API calls (state_dict, named_modules, etc)
Expand Down Expand Up @@ -491,11 +492,9 @@ def forward(
)
)

embeddings = torch.stack(embeddings).reshape(-1, self._sum_length_per_key)

return KeyedTensor(
keys=self._embedding_names,
values=embeddings,
values=torch.cat(embeddings, dim=1),
length_per_key=self._length_per_key,
)

Expand Down
35 changes: 20 additions & 15 deletions torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from dataclasses import replace
from typing import Dict, List, Optional, Type

import hypothesis.strategies as st
Expand All @@ -16,6 +17,7 @@
DataType,
EmbeddingBagConfig,
EmbeddingConfig,
PoolingType,
QuantConfig,
)
from torchrec.modules.embedding_modules import (
Expand All @@ -37,8 +39,9 @@ def _asserting_same_embeddings(
pooled_embeddings_2: KeyedTensor,
atol: float = 1e-08,
) -> None:

self.assertEqual(pooled_embeddings_1.keys(), pooled_embeddings_2.keys())
self.assertEqual(
set(pooled_embeddings_1.keys()), set(pooled_embeddings_2.keys())
)
for key in pooled_embeddings_1.keys():
self.assertEqual(
pooled_embeddings_1[key].shape, pooled_embeddings_2[key].shape
Expand Down Expand Up @@ -92,7 +95,7 @@ def _test_ebc(

self.assertEqual(quantized_embeddings.values().dtype, output_type)

self._asserting_same_embeddings(embeddings, quantized_embeddings, atol=1.0)
self._asserting_same_embeddings(embeddings, quantized_embeddings, atol=0.1)

# test state dict
state_dict = ebc.state_dict()
Expand Down Expand Up @@ -147,36 +150,38 @@ def test_ebc(
feature_names=["f1"],
data_type=data_type,
)
eb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=16,
num_embeddings=10,
feature_names=["f2"],
data_type=data_type,
eb1_mean_config = replace(
eb1_config,
name="t1_mean",
pooling=PoolingType.MEAN,
embedding_dim=32,
)
eb2_config = replace(eb1_config, name="t2", feature_names=["f2"])
features = (
KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.as_tensor([0, 1]),
lengths=torch.as_tensor([1, 1]),
values=torch.as_tensor([0, 2, 1, 3]),
lengths=torch.as_tensor([1, 1, 2, 0]),
)
if not permute_order
else KeyedJaggedTensor(
keys=["f2", "f1"],
values=torch.as_tensor([1, 0]),
lengths=torch.as_tensor([1, 1]),
values=torch.as_tensor([1, 3, 0, 2]),
lengths=torch.as_tensor([2, 0, 1, 1]),
)
)
# The key for grouping tables is (pooling, data_type). Test having a different
# key value in the middle.
self._test_ebc(
[eb1_config, eb2_config],
[eb1_config, eb1_mean_config, eb2_config],
features,
quant_type,
output_type,
quant_state_dict_split_scale_bias,
)

self._test_ebc(
[eb1_config, eb2_config],
[eb1_config, eb1_mean_config, eb2_config],
features,
quant_type,
output_type,
Expand Down

0 comments on commit b34db7d

Please sign in to comment.