diff --git a/torchrec/ir/schema.py b/torchrec/ir/schema.py index f0fbf706e..e24f9a89f 100644 --- a/torchrec/ir/schema.py +++ b/torchrec/ir/schema.py @@ -32,3 +32,12 @@ class EBCMetadata: tables: List[EmbeddingBagConfigMetadata] is_weighted: bool device: Optional[str] + + +@dataclass +class FPEBCMetadata: + tables: List[EmbeddingBagConfigMetadata] + is_weighted: bool + device: Optional[str] + fp_type: str + fp_json: Optional[str] diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py index 23e2d7a07..11a8a4a9a 100644 --- a/torchrec/ir/serializer.py +++ b/torchrec/ir/serializer.py @@ -14,11 +14,17 @@ import torch from torch import nn -from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata +from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata, FPEBCMetadata from torchrec.ir.types import SerializerInterface from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import ( + FeatureProcessor, + FeatureProcessorNameMap, + FeatureProcessorsCollection, +) +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection logger: logging.Logger = logging.getLogger(__name__) @@ -128,6 +134,105 @@ def deserialize( ) +class FPEBCJsonSerializer(SerializerInterface): + """ + Serializer for torch.export IR using thrift. + """ + + @classmethod + def serialize( + cls, + module: nn.Module, + ) -> torch.Tensor: + if not isinstance(module, FeatureProcessedEmbeddingBagCollection): + raise ValueError( + f"Expected module to be of type FeatureProcessedEmbeddingBagCollection, got {type(module)}" + ) + if isinstance(module._feature_processors, nn.ModuleDict): + fp_type = "dict" + param_dict = { + feature: processor.get_init_kwargs() + for feature, processor in module._feature_processors.items() + } + type_dict = { + feature: type(processor).__name__ + for feature, processor in module._feature_processors.items() + } + fp_json = json.dumps( + { + "param_dict": param_dict, + "type_dict": type_dict, + } + ) + elif isinstance(module._feature_processors, FeatureProcessorsCollection): + fp_type = type(module._feature_processors).__name__ + param_dict = module._feature_processors.get_init_kwargs() + fp_json = json.dumps(param_dict) + else: + raise ValueError( + f"Expected module._feature_processors to be of type dict or FeatureProcessorsCollection, got {type(module._feature_processors)}" + ) + ebc = module._embedding_bag_collection + ebc_metadata = FPEBCMetadata( + tables=[ + embedding_bag_config_to_metadata(table_config) + for table_config in ebc.embedding_bag_configs() + ], + is_weighted=ebc.is_weighted(), + device=str(ebc.device), + fp_type=fp_type, + fp_json=fp_json, + ) + + ebc_metadata_dict = ebc_metadata.__dict__ + ebc_metadata_dict["tables"] = [ + table_config.__dict__ for table_config in ebc_metadata_dict["tables"] + ] + + return torch.frombuffer( + json.dumps(ebc_metadata_dict).encode(), dtype=torch.uint8 + ) + + @classmethod + def deserialize( + cls, input: torch.Tensor, typename: str, device: Optional[torch.device] = None + ) -> nn.Module: + if typename != "FeatureProcessedEmbeddingBagCollection": + raise ValueError( + f"Expected typename to be EmbeddingBagCollection, got {typename}" + ) + + raw_bytes = input.numpy().tobytes() + ebc_metadata_dict = json.loads(raw_bytes.decode()) + tables = [ + EmbeddingBagConfigMetadata(**table_config) + for table_config in ebc_metadata_dict["tables"] + ] + device = get_deserialized_device(ebc_metadata_dict.get("device"), device) + ebc = EmbeddingBagCollection( + tables=[ + embedding_metadata_to_config(table_config) for table_config in tables + ], + is_weighted=ebc_metadata_dict["is_weighted"], + device=device, + ) + fp_dict = json.loads(ebc_metadata_dict["fp_json"]) + if ebc_metadata_dict["fp_type"] == "dict": + feature_processors: Dict[str, FeatureProcessor] = {} + for feature, fp_type in fp_dict["type_dict"].items(): + feature_processors[feature] = FeatureProcessorNameMap[fp_type]( + **fp_dict["param_dict"][feature] + ) + else: + feature_processors = FeatureProcessorNameMap[ebc_metadata_dict["fp_type"]]( + **fp_dict + ) + return FeatureProcessedEmbeddingBagCollection( + ebc, + feature_processors, + ) + + class JsonSerializer(SerializerInterface): """ Serializer for torch.export IR using thrift. @@ -135,6 +240,7 @@ class JsonSerializer(SerializerInterface): module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = { "EmbeddingBagCollection": EBCJsonSerializer, + "FeatureProcessedEmbeddingBagCollection": FPEBCJsonSerializer, } @classmethod diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index f332ac994..0042f12cf 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -25,23 +25,28 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection +from torchrec.modules.feature_processor_ import ( + PositionWeightedModule, + PositionWeightedModuleCollection, +) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection from torchrec.modules.utils import operator_registry_state -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor class TestJsonSerializer(unittest.TestCase): + # in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict def generate_model(self) -> nn.Module: class Model(nn.Module): - def __init__(self, ebc, fpebc): + def __init__(self, ebc, fpebc1, fpebc2): super().__init__() self.ebc1 = ebc self.ebc2 = copy.deepcopy(ebc) self.ebc3 = copy.deepcopy(ebc) self.ebc4 = copy.deepcopy(ebc) self.ebc5 = copy.deepcopy(ebc) - self.fpebc = fpebc + self.fpebc1 = fpebc1 + self.fpebc2 = fpebc2 def forward( self, @@ -53,22 +58,17 @@ def forward( kt4 = self.ebc4(features) kt5 = self.ebc5(features) - fpebc_res = self.fpebc(features) + fpebc1_res = self.fpebc1(features) + fpebc2_res = self.fpebc2(features) ebc_kt_vals = [kt.values() for kt in [kt1, kt2, kt3, kt4, kt5]] - sparse_arch_vals = sum(ebc_kt_vals) - sparse_arch_res = KeyedTensor( - keys=kt1.keys(), - values=sparse_arch_vals, - length_per_key=kt1.length_per_key(), - ) - return KeyedTensor.regroup( - [sparse_arch_res, fpebc_res], [["f1"], ["f2", "f3"]] + return ( + ebc_kt_vals + list(fpebc1_res.values()) + list(fpebc2_res.values()) ) tb1_config = EmbeddingBagConfig( name="t1", - embedding_dim=4, + embedding_dim=3, num_embeddings=10, feature_names=["f1"], ) @@ -80,7 +80,7 @@ def forward( ) tb3_config = EmbeddingBagConfig( name="t3", - embedding_dim=4, + embedding_dim=5, num_embeddings=10, feature_names=["f3"], ) @@ -91,7 +91,7 @@ def forward( ) max_feature_lengths = {"f1": 100, "f2": 100} - fpebc = FeatureProcessedEmbeddingBagCollection( + fpebc1 = FeatureProcessedEmbeddingBagCollection( EmbeddingBagCollection( tables=[tb1_config, tb2_config], is_weighted=True, @@ -100,8 +100,18 @@ def forward( max_feature_lengths=max_feature_lengths, ), ) + fpebc2 = FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=[tb1_config, tb3_config], + is_weighted=True, + ), + { + "f1": PositionWeightedModule(max_feature_length=10), + "f3": PositionWeightedModule(max_feature_length=20), + }, + ) - model = Model(ebc, fpebc) + model = Model(ebc, fpebc1, fpebc2) return model @@ -132,12 +142,16 @@ def test_serialize_deserialize_ebc(self) -> None: for i, tensor in enumerate(ep_output): self.assertEqual(eager_out[i].shape, tensor.shape) - # Only 2 custom op registered, as dimensions of ebc are same - self.assertEqual(len(operator_registry_state.op_registry_schema), 2) + # Should have 3 custom op registered, as dimensions of ebc are same, + # and two fpEBCs have different dims + self.assertEqual(len(operator_registry_state.op_registry_schema), 3) total_dim_ebc = sum(model.ebc1._lengths_per_embedding) - total_dim_fpebc = sum( - model.fpebc._embedding_bag_collection._lengths_per_embedding + total_dim_fpebc1 = sum( + model.fpebc1._embedding_bag_collection._lengths_per_embedding + ) + total_dim_fpebc2 = sum( + model.fpebc2._embedding_bag_collection._lengths_per_embedding ) # Check if custom op is registered with the correct name # EmbeddingBagCollection type and total dim @@ -146,7 +160,11 @@ def test_serialize_deserialize_ebc(self) -> None: in operator_registry_state.op_registry_schema ) self.assertTrue( - f"EmbeddingBagCollection_{total_dim_fpebc}" + f"FeatureProcessedEmbeddingBagCollection_{total_dim_fpebc1}" + in operator_registry_state.op_registry_schema + ) + self.assertTrue( + f"FeatureProcessedEmbeddingBagCollection_{total_dim_fpebc2}" in operator_registry_state.op_registry_schema ) @@ -155,28 +173,68 @@ def test_serialize_deserialize_ebc(self) -> None: # Deserialize EBC deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + # check EBC config for i in range(5): ebc_name = f"ebc{i + 1}" - assert isinstance( + self.assertIsInstance( getattr(deserialized_model, ebc_name), EmbeddingBagCollection ) - for deserialized_config, org_config in zip( + for deserialized, orginal in zip( getattr(deserialized_model, ebc_name).embedding_bag_configs(), getattr(model, ebc_name).embedding_bag_configs(), ): - assert deserialized_config.name == org_config.name - assert deserialized_config.embedding_dim == org_config.embedding_dim - assert deserialized_config.num_embeddings, org_config.num_embeddings - assert deserialized_config.feature_names, org_config.feature_names + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) + + # check FPEBC config + for i in range(2): + fpebc_name = f"fpebc{i + 1}" + assert isinstance( + getattr(deserialized_model, fpebc_name), + FeatureProcessedEmbeddingBagCollection, + ) + + deserialized_fp = getattr( + deserialized_model, fpebc_name + )._feature_processors + original_fp = getattr(model, fpebc_name)._feature_processors + if isinstance(original_fp, nn.ModuleDict): + for deserialized, orginal in zip( + deserialized_fp.values(), original_fp.values() + ): + self.assertDictEqual( + deserialized.get_init_kwargs(), orginal.get_init_kwargs() + ) + else: + self.assertDictEqual( + deserialized_fp.get_init_kwargs(), original_fp.get_init_kwargs() + ) + + for deserialized, orginal in zip( + getattr( + deserialized_model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + getattr( + model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + ): + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) deserialized_model.load_state_dict(model.state_dict()) - # Run forward on deserialized model + + # Run forward on deserialized model and compare the output deserialized_out = deserialized_model(id_list_features) - for i, tensor in enumerate(deserialized_out): - assert eager_out[i].shape == tensor.shape - assert torch.allclose(eager_out[i], tensor) + self.assertEqual(len(deserialized_out), len(eager_out)) + for deserialized, orginal in zip(deserialized_out, eager_out): + self.assertEqual(deserialized.shape, orginal.shape) + self.assertTrue(torch.allclose(deserialized, orginal)) def test_dynamic_shape_ebc(self) -> None: model = self.generate_model() diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 1676f166c..bbaaddffb 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -37,8 +37,14 @@ def serialize_embedding_modules( Returns the modified module and the list of fqns that had the buffer added. """ preserve_fqns = [] + serialized_fqns = set() for fqn, module in model.named_modules(): if type(module).__name__ in serializer_cls.module_to_serializer_cls: + # this avoid serializing the submodule within a module that is already serialized + if any(fqn.startswith(s_fqn) for s_fqn in serialized_fqns): + continue + else: + serialized_fqns.add(fqn) serialized_module = serializer_cls.serialize(module) module.register_buffer("ir_metadata", serialized_module, persistent=False) preserve_fqns.append(fqn) diff --git a/torchrec/modules/feature_processor_.py b/torchrec/modules/feature_processor_.py index c88d3b45a..386b6e205 100644 --- a/torchrec/modules/feature_processor_.py +++ b/torchrec/modules/feature_processor_.py @@ -10,7 +10,7 @@ #!/usr/bin/env python3 import abc -from typing import Dict, Optional +from typing import Any, Dict, Optional import torch @@ -50,6 +50,13 @@ def forward( """ pass + def get_init_kwargs(self) -> Dict[str, Any]: + """ + Return the kwargs for the __init__ function + This is used for serialization purposes. + """ + return {} + class PositionWeightedModule(FeatureProcessor): """ @@ -72,6 +79,11 @@ def __init__( self.reset_parameters() + def get_init_kwargs(self) -> Dict[str, int]: + return { + "max_feature_length": self.position_weight.shape[0], + } + def reset_parameters(self) -> None: with torch.no_grad(): self.position_weight.fill_(1.0) @@ -117,6 +129,13 @@ class FeatureProcessorsCollection(nn.Module): fp_kjt = grouped_fp(kjt) """ + def get_init_kwargs(self) -> Dict[str, Any]: + """ + Return the kwargs for the __init__ function + This is used for serialization purposes. + """ + return {} + @abc.abstractmethod def forward( self, @@ -173,6 +192,11 @@ def __init__( self.reset_parameters() + def get_init_kwargs(self) -> Dict[str, Dict[str, int]]: + return { + "max_feature_lengths": self.max_feature_lengths, + } + def reset_parameters(self) -> None: with torch.no_grad(): for key, _length in self.max_feature_lengths.items(): @@ -210,3 +234,9 @@ def _apply(self, *args, **kwargs) -> nn.Module: self.position_weights_dict[k] = param return self + + +FeatureProcessorNameMap: Dict[str, Any] = { + "PositionWeightedModule": PositionWeightedModule, + "PositionWeightedModuleCollection": PositionWeightedModuleCollection, +} diff --git a/torchrec/modules/fp_embedding_modules.py b/torchrec/modules/fp_embedding_modules.py index b85c9f946..cc8b0a273 100644 --- a/torchrec/modules/fp_embedding_modules.py +++ b/torchrec/modules/fp_embedding_modules.py @@ -16,6 +16,7 @@ FeatureProcessor, FeatureProcessorsCollection, ) +from torchrec.modules.utils import is_non_strict_exporting, register_custom_op from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor @@ -125,6 +126,24 @@ def __init__( feature_names_set.update(table_config.feature_names) self._feature_names: List[str] = list(feature_names_set) + def _non_strict_exporting_forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + batch_size = features.stride() + arg_list = [ + features.values(), + features.weights_or_none(), + features.lengths_or_none(), + features.offsets_or_none(), + ] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]` + ebc = self._embedding_bag_collection + dims = [sum(ebc._lengths_per_embedding)] + fp_ebc_op = register_custom_op(type(self).__name__, dims) + outputs = fp_ebc_op(arg_list, batch_size) + return KeyedTensor( + keys=ebc._embedding_names, + values=outputs[0], + length_per_key=ebc._lengths_per_embedding, + ) + def forward( self, features: KeyedJaggedTensor, @@ -136,7 +155,8 @@ def forward( Returns: KeyedTensor """ - + if is_non_strict_exporting() and not torch.jit.is_scripting(): + return self._non_strict_exporting_forward(features) if isinstance(self._feature_processors, FeatureProcessorsCollection): fp_features = self._feature_processors(features) else: diff --git a/torchrec/modules/tests/test_embedding_modules.py b/torchrec/modules/tests/test_embedding_modules.py index 934bdb229..6ae2b3ec8 100644 --- a/torchrec/modules/tests/test_embedding_modules.py +++ b/torchrec/modules/tests/test_embedding_modules.py @@ -227,7 +227,7 @@ def test_device(self) -> None: self.assertEqual(torch.device("cpu"), ebc.embedding_bags["t1"].weight.device) self.assertEqual(torch.device("cpu"), ebc.device) - def test_exporting(self) -> None: + def test_ir_export(self) -> None: class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -296,6 +296,13 @@ def forward( "Shoulde be exact 2 EmbeddingBagCollection nodes in the exported graph", ) + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual( + output.size(), exported.size(), "Output should match exported output" + ) + class EmbeddingCollectionTest(unittest.TestCase): def test_forward(self) -> None: diff --git a/torchrec/modules/tests/test_fp_embedding_modules.py b/torchrec/modules/tests/test_fp_embedding_modules.py index 688c56be5..b09c2f192 100644 --- a/torchrec/modules/tests/test_fp_embedding_modules.py +++ b/torchrec/modules/tests/test_fp_embedding_modules.py @@ -21,22 +21,12 @@ PositionWeightedModuleCollection, ) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor class PositionWeightedModuleEmbeddingBagCollectionTest(unittest.TestCase): - def test_position_weighted_module_ebc(self) -> None: - # 0 1 2 <-- batch - # 0 [0,1] None [2] - # 1 [3] [4] [5,6,7] - # ^ - # feature - features = KeyedJaggedTensor.from_offsets_sync( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) + def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: ebc = EmbeddingBagCollection( tables=[ EmbeddingBagConfig( @@ -52,8 +42,21 @@ def test_position_weighted_module_ebc(self) -> None: "f1": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=10)), "f2": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=5)), } + return FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) - fp_ebc = FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) + def test_position_weighted_module_ebc(self) -> None: + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + fp_ebc = self.generate_fp_ebc() pooled_embeddings = fp_ebc(features) self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) @@ -86,6 +89,56 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None: offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), ) + fp_ebc = self.generate_fp_ebc() + + pooled_embeddings = fp_ebc(features) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + + def test_ir_export(self) -> None: + class MyModule(torch.nn.Module): + def __init__(self, fp_ebc) -> None: + super().__init__() + self._fp_ebc = fp_ebc + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + return self._fp_ebc(features) + + m = MyModule(self.generate_fp_ebc()) + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), + ) + ep = torch.export.export( + m, + (features,), + {}, + strict=False, + ) + self.assertEqual( + sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes), + 0, + ) + self.assertEqual( + sum( + n.name.startswith("feature_processed_embedding_bag_collection") + for n in ep.graph.nodes + ), + 1, + "Shoulde be exact 1 FPEBC nodes in the exported graph", + ) + + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual(output.keys(), exported.keys()) + self.assertEqual(output.values().size(), exported.values().size()) + + +class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase): + def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: ebc = EmbeddingBagCollection( tables=[ EmbeddingBagConfig( @@ -97,20 +150,11 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None: ], is_weighted=True, ) - feature_processors = { - "f1": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=10)), - "f2": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=5)), - } - - fp_ebc = FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) - - pooled_embeddings = fp_ebc(features) - self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) - self.assertEqual(pooled_embeddings.values().size(), (3, 16)) - self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + return FeatureProcessedEmbeddingBagCollection( + ebc, PositionWeightedModuleCollection({"f1": 10, "f2": 10}) + ) -class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase): def test_position_weighted_collection_module_ebc(self) -> None: # 0 1 2 <-- batch # 0 [0,1] None [2] @@ -123,21 +167,7 @@ def test_position_weighted_collection_module_ebc(self) -> None: offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) - ebc = EmbeddingBagCollection( - tables=[ - EmbeddingBagConfig( - name="t1", embedding_dim=8, num_embeddings=16, feature_names=["f1"] - ), - EmbeddingBagConfig( - name="t2", embedding_dim=8, num_embeddings=16, feature_names=["f2"] - ), - ], - is_weighted=True, - ) - - fp_ebc = FeatureProcessedEmbeddingBagCollection( - ebc, PositionWeightedModuleCollection({"f1": 10, "f2": 10}) - ) + fp_ebc = self.generate_fp_ebc() pooled_embeddings = fp_ebc(features) self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) @@ -155,3 +185,43 @@ def test_position_weighted_collection_module_ebc(self) -> None: pooled_embeddings_gm_script.offset_per_key(), pooled_embeddings.offset_per_key(), ) + + def test_ir_export(self) -> None: + class MyModule(torch.nn.Module): + def __init__(self, fp_ebc) -> None: + super().__init__() + self._fp_ebc = fp_ebc + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + return self._fp_ebc(features) + + m = MyModule(self.generate_fp_ebc()) + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), + ) + ep = torch.export.export( + m, + (features,), + {}, + strict=False, + ) + self.assertEqual( + sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes), + 0, + ) + self.assertEqual( + sum( + n.name.startswith("feature_processed_embedding_bag_collection") + for n in ep.graph.nodes + ), + 1, + "Shoulde be exact 1 FPEBC nodes in the exported graph", + ) + + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual(output.keys(), exported.keys()) + self.assertEqual(output.values().size(), exported.values().size())