From 10c247fcbf5908ca3f06a1d7c7e3575f012b30e8 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Mon, 1 Jul 2024 12:17:47 -0700 Subject: [PATCH] register custom_op for fpEBC (#2067) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2067 # context * convert `FeatureProcessedEmbeddingBagCollection` to custom op in IR export * add serialization and deserialization function for FPEBC * add an API for the `FeatureProcessorInterface` to export necessary paramters for create an instance * use this API (`get_init_kwargs`) in the serialize and deserialize functions to flatten and unflatten the feature processor # details 1. Added `FPEBCMetadata` schema for FP_EBC, use a `fp_json` string to store the necessary paramters 2. Added `FPEBCJsonSerializer`, converted the init_kwargs to json string and store in the `fp_json` field in the metadata 3. Added a fqn check for `serialized_fqns`, so that when a higher-level module is serialized, the lower-level module can be skipped (it's already included in the higher-level module) 4. Added an API called `get_init_kwargs` for `FeatureProcessorsCollection` and `FeatureProcessor`, and use a `FeatureProcessorNameMap` to map the classname to the feature processor class 5. Added `_non_strict_exporting_forward` function for FPEBC so that in non_strict IR export it goes to the custom_op logic Reviewed By: PaulZhang12 Differential Revision: D57829276 --- torchrec/ir/schema.py | 18 ++- torchrec/ir/serializer.py | 122 ++++++++++++++- torchrec/ir/tests/test_serializer.py | 113 +++++++++----- .../modules/tests/test_embedding_modules.py | 9 +- .../tests/test_fp_embedding_modules.py | 144 +++++++++++++----- 5 files changed, 327 insertions(+), 79 deletions(-) diff --git a/torchrec/ir/schema.py b/torchrec/ir/schema.py index f0fbf706e..0560812bd 100644 --- a/torchrec/ir/schema.py +++ b/torchrec/ir/schema.py @@ -8,7 +8,7 @@ # pyre-strict from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple from torchrec.modules.embedding_configs import DataType, PoolingType @@ -32,3 +32,19 @@ class EBCMetadata: tables: List[EmbeddingBagConfigMetadata] is_weighted: bool device: Optional[str] + + +@dataclass +class FPEBCMetadata: + is_fp_collection: bool + features: List[str] + + +@dataclass +class PositionWeightedModuleMetadata: + max_feature_length: int + + +@dataclass +class PositionWeightedModuleCollectionMetadata: + max_feature_lengths: List[Tuple[str, int]] diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py index ffc1fe69a..eb98317ca 100644 --- a/torchrec/ir/serializer.py +++ b/torchrec/ir/serializer.py @@ -14,11 +14,24 @@ import torch from torch import nn -from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata +from torchrec.ir.schema import ( + EBCMetadata, + EmbeddingBagConfigMetadata, + FPEBCMetadata, + PositionWeightedModuleCollectionMetadata, + PositionWeightedModuleMetadata, +) from torchrec.ir.types import SerializerInterface from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.feature_processor_ import ( + FeatureProcessor, + FeatureProcessorsCollection, + PositionWeightedModule, + PositionWeightedModuleCollection, +) +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection logger: logging.Logger = logging.getLogger(__name__) @@ -196,3 +209,110 @@ def deserialize_from_dict( JsonSerializer.module_to_serializer_cls["EmbeddingBagCollection"] = EBCJsonSerializer + + +class PWMJsonSerializer(JsonSerializer): + _module_cls = PositionWeightedModule + + @classmethod + def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]: + metadata = PositionWeightedModuleMetadata( + max_feature_length=module.position_weight.shape[0], + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = PositionWeightedModuleMetadata(**metadata_dict) + return PositionWeightedModule(metadata.max_feature_length, device) + + +JsonSerializer.module_to_serializer_cls["PositionWeightedModule"] = PWMJsonSerializer + + +class PWMCJsonSerializer(JsonSerializer): + _module_cls = PositionWeightedModuleCollection + + @classmethod + def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]: + metadata = PositionWeightedModuleCollectionMetadata( + max_feature_lengths=[ # convert to list of tuples to preserve the order + (feature, len) for feature, len in module.max_feature_lengths.items() + ], + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = PositionWeightedModuleCollectionMetadata(**metadata_dict) + max_feature_lengths = { + feature: len for feature, len in metadata.max_feature_lengths + } + return PositionWeightedModuleCollection(max_feature_lengths, device) + + +JsonSerializer.module_to_serializer_cls["PositionWeightedModuleCollection"] = ( + PWMCJsonSerializer +) + + +class FPEBCJsonSerializer(JsonSerializer): + _module_cls = FeatureProcessedEmbeddingBagCollection + _children = ["_feature_processors", "_embedding_bag_collection"] + + @classmethod + def serialize_to_dict( + cls, + module: nn.Module, + ) -> Dict[str, Any]: + if isinstance(module._feature_processors, FeatureProcessorsCollection): + metadata = FPEBCMetadata( + is_fp_collection=True, + features=[], + ) + else: + metadata = FPEBCMetadata( + is_fp_collection=False, + features=list(module._feature_processors.keys()), + ) + return metadata.__dict__ + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten_ep: Optional[nn.Module] = None, + ) -> nn.Module: + metadata = FPEBCMetadata(**metadata_dict) + assert unflatten_ep is not None + if metadata.is_fp_collection: + feature_processors = unflatten_ep._feature_processors + assert isinstance(feature_processors, FeatureProcessorsCollection) + else: + feature_processors: dict[str, FeatureProcessor] = {} + for feature in metadata.features: + fp = getattr(unflatten_ep._feature_processors, feature) + assert isinstance(fp, FeatureProcessor) + feature_processors[feature] = fp + ebc = unflatten_ep._embedding_bag_collection + assert isinstance(ebc, EmbeddingBagCollection) + return FeatureProcessedEmbeddingBagCollection( + ebc, + feature_processors, + ) + + +JsonSerializer.module_to_serializer_cls["FeatureProcessedEmbeddingBagCollection"] = ( + FPEBCJsonSerializer +) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 19d75b28b..e9e0865c8 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -25,7 +25,10 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection +from torchrec.modules.feature_processor_ import ( + PositionWeightedModule, + PositionWeightedModuleCollection, +) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection from torchrec.modules.utils import operator_registry_state from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor @@ -90,16 +93,18 @@ def deserialize_from_dict( class TestJsonSerializer(unittest.TestCase): + # in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict def generate_model(self) -> nn.Module: class Model(nn.Module): - def __init__(self, ebc, fpebc): + def __init__(self, ebc, fpebc1, fpebc2): super().__init__() self.ebc1 = ebc self.ebc2 = copy.deepcopy(ebc) self.ebc3 = copy.deepcopy(ebc) self.ebc4 = copy.deepcopy(ebc) self.ebc5 = copy.deepcopy(ebc) - self.fpebc = fpebc + self.fpebc1 = fpebc1 + self.fpebc2 = fpebc2 def forward( self, @@ -111,22 +116,16 @@ def forward( kt4 = self.ebc4(features) kt5 = self.ebc5(features) - fpebc_res = self.fpebc(features) - ebc_kt_vals = [kt.values() for kt in [kt1, kt2, kt3, kt4, kt5]] - sparse_arch_vals = sum(ebc_kt_vals) - sparse_arch_res = KeyedTensor( - keys=kt1.keys(), - values=sparse_arch_vals, - length_per_key=kt1.length_per_key(), - ) - - return KeyedTensor.regroup( - [sparse_arch_res, fpebc_res], [["f1"], ["f2", "f3"]] - ) + fpebc1_res = self.fpebc1(features) + fpebc2_res = self.fpebc2(features) + res: List[torch.Tensor] = [] + for kt in [kt1, kt2, kt3, kt4, kt5, fpebc1_res, fpebc2_res]: + res.extend(KeyedTensor.regroup([kt], [[key] for key in kt.keys()])) + return res tb1_config = EmbeddingBagConfig( name="t1", - embedding_dim=4, + embedding_dim=3, num_embeddings=10, feature_names=["f1"], ) @@ -138,7 +137,7 @@ def forward( ) tb3_config = EmbeddingBagConfig( name="t3", - embedding_dim=4, + embedding_dim=5, num_embeddings=10, feature_names=["f3"], ) @@ -149,7 +148,7 @@ def forward( ) max_feature_lengths = {"f1": 100, "f2": 100} - fpebc = FeatureProcessedEmbeddingBagCollection( + fpebc1 = FeatureProcessedEmbeddingBagCollection( EmbeddingBagCollection( tables=[tb1_config, tb2_config], is_weighted=True, @@ -158,8 +157,18 @@ def forward( max_feature_lengths=max_feature_lengths, ), ) + fpebc2 = FeatureProcessedEmbeddingBagCollection( + EmbeddingBagCollection( + tables=[tb1_config, tb3_config], + is_weighted=True, + ), + { + "f1": PositionWeightedModule(max_feature_length=10), + "f3": PositionWeightedModule(max_feature_length=20), + }, + ) - model = Model(ebc, fpebc) + model = Model(ebc, fpebc1, fpebc2) return model @@ -190,12 +199,16 @@ def test_serialize_deserialize_ebc(self) -> None: for i, tensor in enumerate(ep_output): self.assertEqual(eager_out[i].shape, tensor.shape) - # Only 2 custom op registered, as dimensions of ebc are same - self.assertEqual(len(operator_registry_state.op_registry_schema), 2) + # Should have 3 custom op registered, as dimensions of ebc are same, + # and two fpEBCs have different dims + self.assertEqual(len(operator_registry_state.op_registry_schema), 3) total_dim_ebc = sum(model.ebc1._lengths_per_embedding) - total_dim_fpebc = sum( - model.fpebc._embedding_bag_collection._lengths_per_embedding + total_dim_fpebc1 = sum( + model.fpebc1._embedding_bag_collection._lengths_per_embedding + ) + total_dim_fpebc2 = sum( + model.fpebc2._embedding_bag_collection._lengths_per_embedding ) # Check if custom op is registered with the correct name # EmbeddingBagCollection type and total dim @@ -204,35 +217,63 @@ def test_serialize_deserialize_ebc(self) -> None: in operator_registry_state.op_registry_schema ) self.assertTrue( - f"EmbeddingBagCollection_{total_dim_fpebc}" + f"EmbeddingBagCollection_{total_dim_fpebc1}" + in operator_registry_state.op_registry_schema + ) + self.assertTrue( + f"EmbeddingBagCollection_{total_dim_fpebc2}" in operator_registry_state.op_registry_schema ) # Deserialize EBC deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + # check EBC config for i in range(5): ebc_name = f"ebc{i + 1}" - assert isinstance( + self.assertIsInstance( getattr(deserialized_model, ebc_name), EmbeddingBagCollection ) - for deserialized_config, org_config in zip( + for deserialized, orginal in zip( getattr(deserialized_model, ebc_name).embedding_bag_configs(), getattr(model, ebc_name).embedding_bag_configs(), ): - assert deserialized_config.name == org_config.name - assert deserialized_config.embedding_dim == org_config.embedding_dim - assert deserialized_config.num_embeddings, org_config.num_embeddings - assert deserialized_config.feature_names, org_config.feature_names + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) + + # check FPEBC config + for i in range(2): + fpebc_name = f"fpebc{i + 1}" + assert isinstance( + getattr(deserialized_model, fpebc_name), + FeatureProcessedEmbeddingBagCollection, + ) + + for deserialized, orginal in zip( + getattr( + deserialized_model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + getattr( + model, fpebc_name + )._embedding_bag_collection.embedding_bag_configs(), + ): + self.assertEqual(deserialized.name, orginal.name) + self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim) + self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings) + self.assertEqual(deserialized.feature_names, orginal.feature_names) deserialized_model.load_state_dict(model.state_dict()) - # Run forward on deserialized model + + # Run forward on deserialized model and compare the output deserialized_out = deserialized_model(id_list_features) - for i, tensor in enumerate(deserialized_out): - assert eager_out[i].shape == tensor.shape - assert torch.allclose(eager_out[i], tensor) + self.assertEqual(len(deserialized_out), len(eager_out)) + for deserialized, orginal in zip(deserialized_out, eager_out): + self.assertEqual(deserialized.shape, orginal.shape) + self.assertTrue(torch.allclose(deserialized, orginal)) def test_dynamic_shape_ebc(self) -> None: model = self.generate_model() @@ -259,7 +300,7 @@ def test_dynamic_shape_ebc(self) -> None: dynamic_shapes=collection.dynamic_shapes(model, (feature1,)), strict=False, # Allows KJT to not be unflattened and run a forward on unflattened EP - preserve_module_call_signature=(tuple(sparse_fqns)), + preserve_module_call_signature=tuple(sparse_fqns), ) # Run forward on ExportedProgram @@ -271,8 +312,8 @@ def test_dynamic_shape_ebc(self) -> None: # Deserialize EBC deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) - deserialized_model.load_state_dict(model.state_dict()) + # Run forward on deserialized model deserialized_out = deserialized_model(feature2) diff --git a/torchrec/modules/tests/test_embedding_modules.py b/torchrec/modules/tests/test_embedding_modules.py index 934bdb229..6ae2b3ec8 100644 --- a/torchrec/modules/tests/test_embedding_modules.py +++ b/torchrec/modules/tests/test_embedding_modules.py @@ -227,7 +227,7 @@ def test_device(self) -> None: self.assertEqual(torch.device("cpu"), ebc.embedding_bags["t1"].weight.device) self.assertEqual(torch.device("cpu"), ebc.device) - def test_exporting(self) -> None: + def test_ir_export(self) -> None: class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -296,6 +296,13 @@ def forward( "Shoulde be exact 2 EmbeddingBagCollection nodes in the exported graph", ) + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual( + output.size(), exported.size(), "Output should match exported output" + ) + class EmbeddingCollectionTest(unittest.TestCase): def test_forward(self) -> None: diff --git a/torchrec/modules/tests/test_fp_embedding_modules.py b/torchrec/modules/tests/test_fp_embedding_modules.py index 688c56be5..ff03eb3c2 100644 --- a/torchrec/modules/tests/test_fp_embedding_modules.py +++ b/torchrec/modules/tests/test_fp_embedding_modules.py @@ -21,22 +21,12 @@ PositionWeightedModuleCollection, ) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor class PositionWeightedModuleEmbeddingBagCollectionTest(unittest.TestCase): - def test_position_weighted_module_ebc(self) -> None: - # 0 1 2 <-- batch - # 0 [0,1] None [2] - # 1 [3] [4] [5,6,7] - # ^ - # feature - features = KeyedJaggedTensor.from_offsets_sync( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) + def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: ebc = EmbeddingBagCollection( tables=[ EmbeddingBagConfig( @@ -52,8 +42,21 @@ def test_position_weighted_module_ebc(self) -> None: "f1": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=10)), "f2": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=5)), } + return FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) - fp_ebc = FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) + def test_position_weighted_module_ebc(self) -> None: + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + fp_ebc = self.generate_fp_ebc() pooled_embeddings = fp_ebc(features) self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) @@ -86,6 +89,53 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None: offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), ) + fp_ebc = self.generate_fp_ebc() + + pooled_embeddings = fp_ebc(features) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + + def test_ir_export(self) -> None: + class MyModule(torch.nn.Module): + def __init__(self, fp_ebc) -> None: + super().__init__() + self._fp_ebc = fp_ebc + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + return self._fp_ebc(features) + + m = MyModule(self.generate_fp_ebc()) + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), + ) + ep = torch.export.export( + m, + (features,), + {}, + strict=False, + ) + self.assertEqual( + sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes), + 0, + ) + self.assertEqual( + sum(n.name.startswith("embedding_bag_collection") for n in ep.graph.nodes), + 1, + "Shoulde be exact 1 EBC nodes in the exported graph", + ) + + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual(output.keys(), exported.keys()) + self.assertEqual(output.values().size(), exported.values().size()) + + +class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase): + def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: ebc = EmbeddingBagCollection( tables=[ EmbeddingBagConfig( @@ -97,20 +147,11 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None: ], is_weighted=True, ) - feature_processors = { - "f1": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=10)), - "f2": cast(FeatureProcessor, PositionWeightedModule(max_feature_length=5)), - } - - fp_ebc = FeatureProcessedEmbeddingBagCollection(ebc, feature_processors) - - pooled_embeddings = fp_ebc(features) - self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) - self.assertEqual(pooled_embeddings.values().size(), (3, 16)) - self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + return FeatureProcessedEmbeddingBagCollection( + ebc, PositionWeightedModuleCollection({"f1": 10, "f2": 10}) + ) -class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase): def test_position_weighted_collection_module_ebc(self) -> None: # 0 1 2 <-- batch # 0 [0,1] None [2] @@ -123,21 +164,7 @@ def test_position_weighted_collection_module_ebc(self) -> None: offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), ) - ebc = EmbeddingBagCollection( - tables=[ - EmbeddingBagConfig( - name="t1", embedding_dim=8, num_embeddings=16, feature_names=["f1"] - ), - EmbeddingBagConfig( - name="t2", embedding_dim=8, num_embeddings=16, feature_names=["f2"] - ), - ], - is_weighted=True, - ) - - fp_ebc = FeatureProcessedEmbeddingBagCollection( - ebc, PositionWeightedModuleCollection({"f1": 10, "f2": 10}) - ) + fp_ebc = self.generate_fp_ebc() pooled_embeddings = fp_ebc(features) self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) @@ -155,3 +182,40 @@ def test_position_weighted_collection_module_ebc(self) -> None: pooled_embeddings_gm_script.offset_per_key(), pooled_embeddings.offset_per_key(), ) + + def test_ir_export(self) -> None: + class MyModule(torch.nn.Module): + def __init__(self, fp_ebc) -> None: + super().__init__() + self._fp_ebc = fp_ebc + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + return self._fp_ebc(features) + + m = MyModule(self.generate_fp_ebc()) + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8, 9, 9, 9]), + ) + ep = torch.export.export( + m, + (features,), + {}, + strict=False, + ) + self.assertEqual( + sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes), + 0, + ) + self.assertEqual( + sum(n.name.startswith("embedding_bag_collection") for n in ep.graph.nodes), + 1, + "Shoulde be exact 1 EBC nodes in the exported graph", + ) + + # export_program's module should produce the same output shape + output = m(features) + exported = ep.module()(features) + self.assertEqual(output.keys(), exported.keys()) + self.assertEqual(output.values().size(), exported.values().size())