From 048be62b33c315782dfb450f9198a0ae9b3badb2 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Thu, 6 Jun 2024 13:29:16 -0700 Subject: [PATCH] register custom_op for fpEBC (#2067) Summary: # context * convert `FeatureProcessedEmbeddingBagCollection` to custom op in IR export * add serialization and deserialization function for FPEBC * add an API for the `FeatureProcessorInterface` to export necessary paramters for create an instance * use this API (`get_init_kwargs`) in the serialize and deserialize functions to flatten and unflatten the feature processor # details 1. Added `FPEBCMetadata` schema for FP_EBC, use a `fp_json` string to store the necessary paramters 2. Added `FPEBCJsonSerializer`, converted the init_kwargs to json string and store in the `fp_json` field in the metadata 3. Added a fqn check for `serialized_fqns`, so that when a higher-level module is serialized, the lower-level module can be skipped (it's already included in the higher-level module) 4. Added an API called `get_init_kwargs` for `FeatureProcessorsCollection` and `FeatureProcessor`, and use a `FeatureProcessorNameMap` to map the classname to the feature processor class 5. Added `_non_strict_exporting_forward` function for FPEBC so that in non_strict IR export it goes to the custom_op logic Differential Revision: D57829276 --- torchrec/ir/schema.py | 18 +- torchrec/ir/serializer.py | 161 +++++++++++++++++- torchrec/ir/tests/test_serializer.py | 122 +++++++++---- torchrec/ir/utils.py | 4 +- .../modules/tests/test_embedding_modules.py | 9 +- .../tests/test_fp_embedding_modules.py | 147 +++++++++++----- 6 files changed, 382 insertions(+), 79 deletions(-) diff --git a/torchrec/ir/schema.py b/torchrec/ir/schema.py index f0fbf706e..87f14cabc 100644 --- a/torchrec/ir/schema.py +++ b/torchrec/ir/schema.py @@ -8,7 +8,7 @@ # pyre-strict from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple from torchrec.modules.embedding_configs import DataType, PoolingType @@ -32,3 +32,19 @@ class EBCMetadata: tables: List[EmbeddingBagConfigMetadata] is_weighted: bool device: Optional[str] + + +@dataclass +class FPEBCMetadata: + is_fp_collection: bool + feature_list: List[str] + + +@dataclass +class PositionWeightedModuleMetadata: + max_feature_length: int + + +@dataclass +class PositionWeightedModuleCollectionMetadata: + max_feature_lengths: List[Tuple[str, int]] diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py index 0eb232d8c..45ce5953c 100644 --- a/torchrec/ir/serializer.py +++ b/torchrec/ir/serializer.py @@ -14,11 +14,24 @@ import torch from torch import nn -from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata +from torchrec.ir.schema import ( + EBCMetadata, + EmbeddingBagConfigMetadata, + FPEBCMetadata, + PositionWeightedModuleCollectionMetadata, + PositionWeightedModuleMetadata, +) from torchrec.ir.types import SerializerInterface from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import ( + FeatureProcessor, + FeatureProcessorsCollection, + PositionWeightedModule, + PositionWeightedModuleCollection, +) +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection logger: logging.Logger = logging.getLogger(__name__) @@ -71,7 +84,7 @@ def get_deserialized_device( class EBCJsonSerializer(SerializerInterface): """ - Serializer for torch.export IR using thrift. + Serializer for torch.export IR using json. """ @classmethod @@ -132,13 +145,155 @@ def deserialize( ) +class PWMJsonSerializer(SerializerInterface): + """ + Serializer for torch.export IR using json. + """ + + @classmethod + def serialize(cls, module: nn.Module) -> torch.Tensor: + if not isinstance(module, PositionWeightedModule): + raise ValueError( + f"Expected module to be of type PositionWeightedModule, got {type(module)}" + ) + metadata = PositionWeightedModuleMetadata( + max_feature_length=module.position_weight.shape[0], + ) + return torch.frombuffer( + json.dumps(metadata.__dict__).encode(), dtype=torch.uint8 + ) + + @classmethod + def deserialize( + cls, + input: torch.Tensor, + typename: str, + device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, + ) -> nn.Module: + if typename != "PositionWeightedModule": + raise ValueError( + f"Expected typename to be PositionWeightedModule, got {typename}" + ) + raw_bytes = input.numpy().tobytes() + metadata = json.loads(raw_bytes) + return PositionWeightedModule(metadata["max_feature_length"], device) + + +class PWMCJsonSerializer(SerializerInterface): + """ + Serializer for torch.export IR using json. + """ + + @classmethod + def serialize(cls, module: nn.Module) -> torch.Tensor: + if not isinstance(module, PositionWeightedModuleCollection): + raise ValueError( + f"Expected module to be of type PositionWeightedModuleCollection, got {type(module)}" + ) + metadata = PositionWeightedModuleCollectionMetadata( + max_feature_lengths=[ # convert to list of tuples to preserve the order + (feature, len) for feature, len in module.max_feature_lengths.items() + ], + ) + return torch.frombuffer( + json.dumps(metadata.__dict__).encode(), dtype=torch.uint8 + ) + + @classmethod + def deserialize( + cls, + input: torch.Tensor, + typename: str, + device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, + ) -> nn.Module: + if typename != "PositionWeightedModuleCollection": + raise ValueError( + f"Expected typename to be PositionWeightedModuleCollection, got {typename}" + ) + raw_bytes = input.numpy().tobytes() + metadata = PositionWeightedModuleCollectionMetadata(**json.loads(raw_bytes)) + max_feature_lengths = { + feature: len for feature, len in metadata.max_feature_lengths + } + return PositionWeightedModuleCollection(max_feature_lengths, device) + + +class FPEBCJsonSerializer(SerializerInterface): + """ + Serializer for torch.export IR using json. + """ + + @classmethod + def requires_children(cls, typename: str) -> bool: + return True + + @classmethod + def serialize( + cls, + module: nn.Module, + ) -> torch.Tensor: + if not isinstance(module, FeatureProcessedEmbeddingBagCollection): + raise ValueError( + f"Expected module to be of type FeatureProcessedEmbeddingBagCollection, got {type(module)}" + ) + elif isinstance(module._feature_processors, FeatureProcessorsCollection): + metadata = FPEBCMetadata( + is_fp_collection=True, + feature_list=[], + ) + else: + metadata = FPEBCMetadata( + is_fp_collection=False, + feature_list=list(module._feature_processors.keys()), + ) + + return torch.frombuffer( + json.dumps(metadata.__dict__).encode(), dtype=torch.uint8 + ) + + @classmethod + def deserialize( + cls, + input: torch.Tensor, + typename: str, + device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, + ) -> nn.Module: + if typename != "FeatureProcessedEmbeddingBagCollection": + raise ValueError( + f"Expected typename to be EmbeddingBagCollection, got {typename}" + ) + raw_bytes = input.numpy().tobytes() + metadata = FPEBCMetadata(**json.loads(raw_bytes.decode())) + if metadata.is_fp_collection: + feature_processors = children["_feature_processor"] + assert isinstance(feature_processors, FeatureProcessorsCollection) + else: + feature_processors: dict[str, FeatureProcessor] = {} + for feature in metadata.feature_list: + fp = children[f"_feature_processor.{feature}"] + assert isinstance(fp, FeatureProcessor) + feature_processors[feature] = fp + ebc = children["_embedding_bag_collection"] + assert isinstance(ebc, EmbeddingBagCollection) + return FeatureProcessedEmbeddingBagCollection( + ebc, + feature_processors, + ) + + class JsonSerializer(SerializerInterface): """ - Serializer for torch.export IR using thrift. + Serializer for torch.export IR using json. """ module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = { "EmbeddingBagCollection": EBCJsonSerializer, + "FeatureProcessedEmbeddingBagCollection": FPEBCJsonSerializer, + "PositionWeightedModule": PWMJsonSerializer, + "PositionWeightedModuleCollection": PWMCJsonSerializer, } @classmethod diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 9b06c7c1c..e2ad1a49a 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -25,10 +25,13 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection +from torchrec.modules.feature_processor_ import ( + PositionWeightedModule, + PositionWeightedModuleCollection, +) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection from torchrec.modules.utils import operator_registry_state -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor class CompoundModule(nn.Module): @@ -94,16 +97,18 @@ def deserialize( class TestJsonSerializer(unittest.TestCase): + # in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict def generate_model(self) -> nn.Module: class Model(nn.Module): - def __init__(self, ebc, fpebc): + def __init__(self, ebc, fpebc1, fpebc2): super().__init__() self.ebc1 = ebc self.ebc2 = copy.deepcopy(ebc) self.ebc3 = copy.deepcopy(ebc) self.ebc4 = copy.deepcopy(ebc) self.ebc5 = copy.deepcopy(ebc) - self.fpebc = fpebc + self.fpebc1 = fpebc1 + self.fpebc2 = fpebc2 def forward( self, @@ -115,22 +120,17 @@ def forward( kt4 = self.ebc4(features) kt5 = self.ebc5(features) - fpebc_res = self.fpebc(features) + fpebc1_res = self.fpebc1(features) + fpebc2_res = self.fpebc2(features) ebc_kt_vals = [kt.values() for kt in [kt1, kt2, kt3, kt4, kt5]] - sparse_arch_vals = sum(ebc_kt_vals) - sparse_arch_res = KeyedTensor( - keys=kt1.keys(), - values=sparse_arch_vals, - length_per_key=kt1.length_per_key(), - ) - return KeyedTensor.regroup( - [sparse_arch_res, fpebc_res], [["f1"], ["f2", "f3"]] + return ( + ebc_kt_vals + list(fpebc1_res.values()) + list(fpebc2_res.values()) ) tb1_config = EmbeddingBagConfig( name="t1", - embedding_dim=4, + embedding_dim=3, num_embeddings=10, feature_names=["f1"], ) @@ -142,7 +142,7 @@ def forward( ) tb3_config = EmbeddingBagConfig( name="t3", - embedding_dim=4, + embedding_dim=5, num_embeddings=10, feature_names=["f3"], ) @@ -153,7 +153,7 @@ def forward( ) max_feature_lengths = {"f1": 100, "f2": 100} - fpebc = FeatureProcessedEmbeddingBagCollection( + fpebc1 = FeatureProcessedEmbeddingBagCollection( EmbeddingBagCollection( tables=[tb1_config, tb2_config], is_weighted=True, @@ -162,8 +162,18 @@ def forward( max_feature_lengths=max_feature_lengths, ), ) + fpebc2 = FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=[tb1_config, tb3_config], + is_weighted=True, + ), + { + "f1": PositionWeightedModule(max_feature_length=10), + "f3": PositionWeightedModule(max_feature_length=20), + }, + ) - model = Model(ebc, fpebc) + model = Model(ebc, fpebc1, fpebc2) return model @@ -194,12 +204,16 @@ def test_serialize_deserialize_ebc(self) -> None: for i, tensor in enumerate(ep_output): self.assertEqual(eager_out[i].shape, tensor.shape) - # Only 2 custom op registered, as dimensions of ebc are same - self.assertEqual(len(operator_registry_state.op_registry_schema), 2) + # Should have 3 custom op registered, as dimensions of ebc are same, + # and two fpEBCs have different dims + self.assertEqual(len(operator_registry_state.op_registry_schema), 3) total_dim_ebc = sum(model.ebc1._lengths_per_embedding) - total_dim_fpebc = sum( - model.fpebc._embedding_bag_collection._lengths_per_embedding + total_dim_fpebc1 = sum( + model.fpebc1._embedding_bag_collection._lengths_per_embedding + ) + total_dim_fpebc2 = sum( + model.fpebc2._embedding_bag_collection._lengths_per_embedding ) # Check if custom op is registered with the correct name # EmbeddingBagCollection type and total dim @@ -208,35 +222,79 @@ def test_serialize_deserialize_ebc(self) -> None: in operator_registry_state.op_registry_schema ) self.assertTrue( - f"EmbeddingBagCollection_{total_dim_fpebc}" + f"EmbeddingBagCollection_{total_dim_fpebc1}" + in operator_registry_state.op_registry_schema + ) + self.assertTrue( + f"EmbeddingBagCollection_{total_dim_fpebc2}" in operator_registry_state.op_registry_schema ) # Deserialize EBC deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + # check EBC config for i in range(5): ebc_name = f"ebc{i + 1}" - assert isinstance( + self.assertIsInstance( getattr(deserialized_model, ebc_name), EmbeddingBagCollection ) - for deserialized_config, org_config in zip( + for deserialized, orginal in zip( getattr(deserialized_model, ebc_name).embedding_bag_configs(), getattr(model, ebc_name).embedding_bag_configs(), ): - assert deserialized_config.name == org_config.name - assert deserialized_config.embedding_dim == org_config.embedding_dim - assert deserialized_config.num_embeddings, org_config.num_embeddings - assert deserialized_config.feature_names, org_config.feature_names + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) + + # check FPEBC config + for i in range(2): + fpebc_name = f"fpebc{i + 1}" + assert isinstance( + getattr(deserialized_model, fpebc_name), + FeatureProcessedEmbeddingBagCollection, + ) + + deserialized_fp = getattr( + deserialized_model, fpebc_name + )._feature_processors + original_fp = getattr(model, fpebc_name)._feature_processors + if isinstance(original_fp, nn.ModuleDict): + for deserialized, orginal in zip( + deserialized_fp.values(), original_fp.values() + ): + self.assertDictEqual( + deserialized.get_init_kwargs(), orginal.get_init_kwargs() + ) + else: + self.assertDictEqual( + deserialized_fp.get_init_kwargs(), original_fp.get_init_kwargs() + ) + + for deserialized, orginal in zip( + getattr( + deserialized_model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + getattr( + model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + ): + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) deserialized_model.load_state_dict(model.state_dict()) - # Run forward on deserialized model + + # Run forward on deserialized model and compare the output deserialized_out = deserialized_model(id_list_features) - for i, tensor in enumerate(deserialized_out): - assert eager_out[i].shape == tensor.shape - assert torch.allclose(eager_out[i], tensor) + self.assertEqual(len(deserialized_out), len(eager_out)) + for deserialized, orginal in zip(deserialized_out, eager_out): + self.assertEqual(deserialized.shape, orginal.shape) + self.assertTrue(torch.allclose(deserialized, orginal)) def test_dynamic_shape_ebc(self) -> None: model = self.generate_model() diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index eb72daf99..4f68682f5 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -153,14 +153,14 @@ def deserialize_embedding_modules( Returns the unflattened ExportedProgram with the deserialized modules. """ - model = torch.export.unflatten(ep) module_type_dict: Dict[str, str] = {} for node in ep.graph.nodes: if "nn_module_stack" in node.meta: for fqn, type_name in node.meta["nn_module_stack"].values(): # Only get the module type name, not the full type name module_type_dict[fqn] = type_name.split(".")[-1] - + breakpoint() + model = torch.export.unflatten(ep) fqn_to_new_module: Dict[str, nn.Module] = {} generator = model.named_modules() fqn, root = _next_or_none(generator) diff --git a/torchrec/modules/tests/test_embedding_modules.py b/torchrec/modules/tests/test_embedding_modules.py index 934bdb229..6ae2b3ec8 100644 --- a/torchrec/modules/tests/test_embedding_modules.py +++ b/torchrec/modules/tests/test_embedding_modules.py @@ -227,7 +227,7 @@ def test_device(self) -> None: self.assertEqual(torch.device("cpu"), ebc.embedding_bags["t1"].weight.device) self.assertEqual(torch.device("cpu"), ebc.device) - def test_exporting(self) -> None: + def test_ir_export(self) -> None: class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -296,6 +296,13 @@ def forward( "Shoulde be exact 2 EmbeddingBagCollection nodes in the exported graph", ) + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual( + output.size(), exported.size(), "Output should match exported output" + ) + class EmbeddingCollectionTest(unittest.TestCase): def test_forward(self) -> None: diff --git a/torchrec/modules/tests/test_fp_embedding_modules.py b/torchrec/modules/tests/test_fp_embedding_modules.py index 688c56be5..ce973ddd8 100644 --- a/torchrec/modules/tests/test_fp_embedding_modules.py +++ b/torchrec/modules/tests/test_fp_embedding_modules.py @@ -21,22 +21,12 @@ PositionWeightedModuleCollection, ) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor class PositionWeightedModuleEmbeddingBagCollectionTest(unittest.TestCase): - def test_position_weighted_module_ebc(self) -> None: - # 0 1 2 <-- batch - # 0 [0,1] None [2] - # 1 [3] [4] [5,6,7] - # ^ - # feature - features = KeyedJaggedTensor.from_offsets_sync( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) + def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: ebc = EmbeddingBagCollection( tables=[ EmbeddingBagConfig( @@ -52,8 +42,21 @@ def test_position_weighted_module_ebc(self) -> None: "f1": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=10)), "f2": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=5)), } + return FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) - fp_ebc = FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) + def test_position_weighted_module_ebc(self) -> None: + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + fp_ebc = self.generate_fp_ebc() pooled_embeddings = fp_ebc(features) self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) @@ -86,6 +89,53 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None: offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), ) + fp_ebc = self.generate_fp_ebc() + + pooled_embeddings = fp_ebc(features) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + + def test_ir_export(self) -> None: + class MyModule(torch.nn.Module): + def __init__(self, fp_ebc) -> None: + super().__init__() + self._fp_ebc = fp_ebc + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + return self._fp_ebc(features) + + m = MyModule(self.generate_fp_ebc()) + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), + ) + ep = torch.export.export( + m, + (features,), + {}, + strict=False, + ) + self.assertEqual( + sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes), + 0, + ) + self.assertEqual( + sum(n.name.startswith("embedding_bag_collection") for n in ep.graph.nodes), + 1, + "Shoulde be exact 1 EBC nodes in the exported graph", + ) + + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual(output.keys(), exported.keys()) + self.assertEqual(output.values().size(), exported.values().size()) + + +class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase): + def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: ebc = EmbeddingBagCollection( tables=[ EmbeddingBagConfig( @@ -97,20 +147,11 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None: ], is_weighted=True, ) - feature_processors = { - "f1": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=10)), - "f2": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=5)), - } - - fp_ebc = FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) - - pooled_embeddings = fp_ebc(features) - self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) - self.assertEqual(pooled_embeddings.values().size(), (3, 16)) - self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + return FeatureProcessedEmbeddingBagCollection( + ebc, PositionWeightedModuleCollection({"f1": 10, "f2": 10}) + ) -class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase): def test_position_weighted_collection_module_ebc(self) -> None: # 0 1 2 <-- batch # 0 [0,1] None [2] @@ -123,21 +164,7 @@ def test_position_weighted_collection_module_ebc(self) -> None: offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) - ebc = EmbeddingBagCollection( - tables=[ - EmbeddingBagConfig( - name="t1", embedding_dim=8, num_embeddings=16, feature_names=["f1"] - ), - EmbeddingBagConfig( - name="t2", embedding_dim=8, num_embeddings=16, feature_names=["f2"] - ), - ], - is_weighted=True, - ) - - fp_ebc = FeatureProcessedEmbeddingBagCollection( - ebc, PositionWeightedModuleCollection({"f1": 10, "f2": 10}) - ) + fp_ebc = self.generate_fp_ebc() pooled_embeddings = fp_ebc(features) self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) @@ -155,3 +182,43 @@ def test_position_weighted_collection_module_ebc(self) -> None: pooled_embeddings_gm_script.offset_per_key(), pooled_embeddings.offset_per_key(), ) + + def test_ir_export(self) -> None: + class MyModule(torch.nn.Module): + def __init__(self, fp_ebc) -> None: + super().__init__() + self._fp_ebc = fp_ebc + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + return self._fp_ebc(features) + + m = MyModule(self.generate_fp_ebc()) + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), + ) + ep = torch.export.export( + m, + (features,), + {}, + strict=False, + ) + self.assertEqual( + sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes), + 0, + ) + self.assertEqual( + sum( + n.name.startswith("feature_processed_embedding_bag_collection") + for n in ep.graph.nodes + ), + 1, + "Shoulde be exact 1 FPEBC nodes in the exported graph", + ) + + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual(output.keys(), exported.keys()) + self.assertEqual(output.values().size(), exported.values().size())