diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 38b61c668..425603ae1 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -323,14 +323,9 @@ def __init__( if table.name in table_names: raise ValueError(f"Duplicate table name {table.name}") table_names.add(table.name) - self._length_per_key.extend( - [table.embedding_dim] * len(table.feature_names) - ) key = (table.pooling, table.data_type) self._key_to_tables[key].append(table) - self._sum_length_per_key: int = sum(self._length_per_key) - location = ( EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE ) @@ -381,9 +376,15 @@ def __init__( emb_module.initialize_weights() self._emb_modules.append(emb_module) + ordered_tables = list(itertools.chain(*self._key_to_tables.values())) self._embedding_names: List[str] = list( - itertools.chain(*get_embedding_names_by_table(self._embedding_bag_configs)) + itertools.chain(*get_embedding_names_by_table(ordered_tables)) ) + for table in ordered_tables: + self._length_per_key.extend( + [table.embedding_dim] * len(table.feature_names) + ) + # We map over the parameters from FBGEMM backed kernels to the canonical nn.EmbeddingBag # representation. This provides consistency between this class and the EmbeddingBagCollection # nn.Module API calls (state_dict, named_modules, etc) @@ -491,11 +492,9 @@ def forward( ) ) - embeddings = torch.stack(embeddings).reshape(-1, self._sum_length_per_key) - return KeyedTensor( keys=self._embedding_names, - values=embeddings, + values=torch.cat(embeddings, dim=1), length_per_key=self._length_per_key, ) diff --git a/torchrec/quant/tests/test_embedding_modules.py b/torchrec/quant/tests/test_embedding_modules.py index dad705b68..7f93f62a1 100644 --- a/torchrec/quant/tests/test_embedding_modules.py +++ b/torchrec/quant/tests/test_embedding_modules.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest +from dataclasses import replace from typing import Dict, List, Optional, Type import hypothesis.strategies as st @@ -16,6 +17,7 @@ DataType, EmbeddingBagConfig, EmbeddingConfig, + PoolingType, QuantConfig, ) from torchrec.modules.embedding_modules import ( @@ -37,8 +39,9 @@ def _asserting_same_embeddings( pooled_embeddings_2: KeyedTensor, atol: float = 1e-08, ) -> None: - - self.assertEqual(pooled_embeddings_1.keys(), pooled_embeddings_2.keys()) + self.assertEqual( + set(pooled_embeddings_1.keys()), set(pooled_embeddings_2.keys()) + ) for key in pooled_embeddings_1.keys(): self.assertEqual( pooled_embeddings_1[key].shape, pooled_embeddings_2[key].shape @@ -92,7 +95,7 @@ def _test_ebc( self.assertEqual(quantized_embeddings.values().dtype, output_type) - self._asserting_same_embeddings(embeddings, quantized_embeddings, atol=1.0) + self._asserting_same_embeddings(embeddings, quantized_embeddings, atol=0.1) # test state dict state_dict = ebc.state_dict() @@ -147,28 +150,30 @@ def test_ebc( feature_names=["f1"], data_type=data_type, ) - eb2_config = EmbeddingBagConfig( - name="t2", - embedding_dim=16, - num_embeddings=10, - feature_names=["f2"], - data_type=data_type, + eb1_mean_config = replace( + eb1_config, + name="t1_mean", + pooling=PoolingType.MEAN, + embedding_dim=32, ) + eb2_config = replace(eb1_config, name="t2", feature_names=["f2"]) features = ( KeyedJaggedTensor( keys=["f1", "f2"], - values=torch.as_tensor([0, 1]), - lengths=torch.as_tensor([1, 1]), + values=torch.as_tensor([0, 2, 1, 3]), + lengths=torch.as_tensor([1, 1, 2, 0]), ) if not permute_order else KeyedJaggedTensor( keys=["f2", "f1"], - values=torch.as_tensor([1, 0]), - lengths=torch.as_tensor([1, 1]), + values=torch.as_tensor([1, 3, 0, 2]), + lengths=torch.as_tensor([2, 0, 1, 1]), ) ) + # The key for grouping tables is (pooling, data_type). Test having a different + # key value in the middle. self._test_ebc( - [eb1_config, eb2_config], + [eb1_config, eb1_mean_config, eb2_config], features, quant_type, output_type, @@ -176,7 +181,7 @@ def test_ebc( ) self._test_ebc( - [eb1_config, eb2_config], + [eb1_config, eb1_mean_config, eb2_config], features, quant_type, output_type,