Skip to content

Commit

Permalink
register custom_op for fpEBC (#2067)
Browse files Browse the repository at this point in the history
Summary:

# context
* convert `FeatureProcessedEmbeddingBagCollection` to custom op in IR export
* add serialization and deserialization function for FPEBC
* add an API for the `FeatureProcessorInterface` to export necessary paramters for create an instance
* use this API (`get_init_kwargs`) in the serialize and deserialize functions to flatten and unflatten the feature processor

# details
1. Added `FPEBCMetadata` schema for FP_EBC, use a `fp_json` string to store the necessary paramters
2. Added `FPEBCJsonSerializer`, converted the init_kwargs to json string and store in the `fp_json` field in the metadata
3. Added a fqn check for `serialized_fqns`, so that when a higher-level module is serialized, the lower-level module can be skipped (it's already included in the higher-level module)
4. Added an API called `get_init_kwargs` for `FeatureProcessorsCollection` and `FeatureProcessor`, and use a `FeatureProcessorNameMap` to map the classname to the feature processor class
5. Added `_non_strict_exporting_forward` function for FPEBC so that in non_strict IR export it goes to the custom_op logic

Differential Revision: D57829276
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jun 4, 2024
1 parent 3226b16 commit e6ff80c
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 76 deletions.
9 changes: 9 additions & 0 deletions torchrec/ir/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ class EBCMetadata:
tables: List[EmbeddingBagConfigMetadata]
is_weighted: bool
device: Optional[str]


@dataclass
class FPEBCMetadata:
tables: List[EmbeddingBagConfigMetadata]
is_weighted: bool
device: Optional[str]
fp_type: str
fp_json: Optional[str]
108 changes: 107 additions & 1 deletion torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@
import torch

from torch import nn
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata, FPEBCMetadata

from torchrec.ir.types import SerializerInterface
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.feature_processor_ import (
FeatureProcessor,
FeatureProcessorNameMap,
FeatureProcessorsCollection,
)
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,13 +134,113 @@ def deserialize(
)


class FPEBCJsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using thrift.
"""

@classmethod
def serialize(
cls,
module: nn.Module,
) -> torch.Tensor:
if not isinstance(module, FeatureProcessedEmbeddingBagCollection):
raise ValueError(
f"Expected module to be of type FeatureProcessedEmbeddingBagCollection, got {type(module)}"
)
if isinstance(module._feature_processors, nn.ModuleDict):
fp_type = "dict"
param_dict = {
feature: processor.get_init_kwargs()
for feature, processor in module._feature_processors.items()
}
type_dict = {
feature: type(processor).__name__
for feature, processor in module._feature_processors.items()
}
fp_json = json.dumps(
{
"param_dict": param_dict,
"type_dict": type_dict,
}
)
elif isinstance(module._feature_processors, FeatureProcessorsCollection):
fp_type = type(module._feature_processors).__name__
param_dict = module._feature_processors.get_init_kwargs()
fp_json = json.dumps(param_dict)
else:
raise ValueError(
f"Expected module._feature_processors to be of type dict or FeatureProcessorsCollection, got {type(module._feature_processors)}"
)
ebc = module._embedding_bag_collection
ebc_metadata = FPEBCMetadata(
tables=[
embedding_bag_config_to_metadata(table_config)
for table_config in ebc.embedding_bag_configs()
],
is_weighted=ebc.is_weighted(),
device=str(ebc.device),
fp_type=fp_type,
fp_json=fp_json,
)

ebc_metadata_dict = ebc_metadata.__dict__
ebc_metadata_dict["tables"] = [
table_config.__dict__ for table_config in ebc_metadata_dict["tables"]
]

return torch.frombuffer(
json.dumps(ebc_metadata_dict).encode(), dtype=torch.uint8
)

@classmethod
def deserialize(
cls, input: torch.Tensor, typename: str, device: Optional[torch.device] = None
) -> nn.Module:
if typename != "FeatureProcessedEmbeddingBagCollection":
raise ValueError(
f"Expected typename to be EmbeddingBagCollection, got {typename}"
)

raw_bytes = input.numpy().tobytes()
ebc_metadata_dict = json.loads(raw_bytes.decode())
tables = [
EmbeddingBagConfigMetadata(**table_config)
for table_config in ebc_metadata_dict["tables"]
]
device = get_deserialized_device(ebc_metadata_dict.get("device"), device)
ebc = EmbeddingBagCollection(
tables=[
embedding_metadata_to_config(table_config) for table_config in tables
],
is_weighted=ebc_metadata_dict["is_weighted"],
device=device,
)
fp_dict = json.loads(ebc_metadata_dict["fp_json"])
if ebc_metadata_dict["fp_type"] == "dict":
feature_processors: Dict[str, FeatureProcessor] = {}
for feature, fp_type in fp_dict["type_dict"].items():
feature_processors[feature] = FeatureProcessorNameMap[fp_type](
**fp_dict["param_dict"][feature]
)
else:
feature_processors = FeatureProcessorNameMap[ebc_metadata_dict["fp_type"]](
**fp_dict
)
return FeatureProcessedEmbeddingBagCollection(
ebc,
feature_processors,
)


class JsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using thrift.
"""

module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = {
"EmbeddingBagCollection": EBCJsonSerializer,
"FeatureProcessedEmbeddingBagCollection": FPEBCJsonSerializer,
}

@classmethod
Expand Down
122 changes: 90 additions & 32 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,28 @@

from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection
from torchrec.modules.feature_processor_ import (
PositionWeightedModule,
PositionWeightedModuleCollection,
)
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.utils import operator_registry_state
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


class TestJsonSerializer(unittest.TestCase):
# in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict
def generate_model(self) -> nn.Module:
class Model(nn.Module):
def __init__(self, ebc, fpebc):
def __init__(self, ebc, fpebc1, fpebc2):
super().__init__()
self.ebc1 = ebc
self.ebc2 = copy.deepcopy(ebc)
self.ebc3 = copy.deepcopy(ebc)
self.ebc4 = copy.deepcopy(ebc)
self.ebc5 = copy.deepcopy(ebc)
self.fpebc = fpebc
self.fpebc1 = fpebc1
self.fpebc2 = fpebc2

def forward(
self,
Expand All @@ -53,22 +58,17 @@ def forward(
kt4 = self.ebc4(features)
kt5 = self.ebc5(features)

fpebc_res = self.fpebc(features)
fpebc1_res = self.fpebc1(features)
fpebc2_res = self.fpebc2(features)
ebc_kt_vals = [kt.values() for kt in [kt1, kt2, kt3, kt4, kt5]]
sparse_arch_vals = sum(ebc_kt_vals)
sparse_arch_res = KeyedTensor(
keys=kt1.keys(),
values=sparse_arch_vals,
length_per_key=kt1.length_per_key(),
)

return KeyedTensor.regroup(
[sparse_arch_res, fpebc_res], [["f1"], ["f2", "f3"]]
return (
ebc_kt_vals + list(fpebc1_res.values()) + list(fpebc2_res.values())
)

tb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=4,
embedding_dim=3,
num_embeddings=10,
feature_names=["f1"],
)
Expand All @@ -80,7 +80,7 @@ def forward(
)
tb3_config = EmbeddingBagConfig(
name="t3",
embedding_dim=4,
embedding_dim=5,
num_embeddings=10,
feature_names=["f3"],
)
Expand All @@ -91,7 +91,7 @@ def forward(
)
max_feature_lengths = {"f1": 100, "f2": 100}

fpebc = FeatureProcessedEmbeddingBagCollection(
fpebc1 = FeatureProcessedEmbeddingBagCollection(
EmbeddingBagCollection(
tables=[tb1_config, tb2_config],
is_weighted=True,
Expand All @@ -100,8 +100,18 @@ def forward(
max_feature_lengths=max_feature_lengths,
),
)
fpebc2 = FeatureProcessedEmbeddingBagCollection(
EmbeddingBagCollection(
tables=[tb1_config, tb3_config],
is_weighted=True,
),
{
"f1": PositionWeightedModule(max_feature_length=10),
"f3": PositionWeightedModule(max_feature_length=20),
},
)

model = Model(ebc, fpebc)
model = Model(ebc, fpebc1, fpebc2)

return model

Expand Down Expand Up @@ -132,12 +142,16 @@ def test_serialize_deserialize_ebc(self) -> None:
for i, tensor in enumerate(ep_output):
self.assertEqual(eager_out[i].shape, tensor.shape)

# Only 2 custom op registered, as dimensions of ebc are same
self.assertEqual(len(operator_registry_state.op_registry_schema), 2)
# Should have 3 custom op registered, as dimensions of ebc are same,
# and two fpEBCs have different dims
self.assertEqual(len(operator_registry_state.op_registry_schema), 3)

total_dim_ebc = sum(model.ebc1._lengths_per_embedding)
total_dim_fpebc = sum(
model.fpebc._embedding_bag_collection._lengths_per_embedding
total_dim_fpebc1 = sum(
model.fpebc1._embedding_bag_collection._lengths_per_embedding
)
total_dim_fpebc2 = sum(
model.fpebc2._embedding_bag_collection._lengths_per_embedding
)
# Check if custom op is registered with the correct name
# EmbeddingBagCollection type and total dim
Expand All @@ -146,7 +160,11 @@ def test_serialize_deserialize_ebc(self) -> None:
in operator_registry_state.op_registry_schema
)
self.assertTrue(
f"EmbeddingBagCollection_{total_dim_fpebc}"
f"FeatureProcessedEmbeddingBagCollection_{total_dim_fpebc1}"
in operator_registry_state.op_registry_schema
)
self.assertTrue(
f"FeatureProcessedEmbeddingBagCollection_{total_dim_fpebc2}"
in operator_registry_state.op_registry_schema
)

Expand All @@ -155,28 +173,68 @@ def test_serialize_deserialize_ebc(self) -> None:
# Deserialize EBC
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)

# check EBC config
for i in range(5):
ebc_name = f"ebc{i + 1}"
assert isinstance(
self.assertIsInstance(
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
)

for deserialized_config, org_config in zip(
for deserialized, orginal in zip(
getattr(deserialized_model, ebc_name).embedding_bag_configs(),
getattr(model, ebc_name).embedding_bag_configs(),
):
assert deserialized_config.name == org_config.name
assert deserialized_config.embedding_dim == org_config.embedding_dim
assert deserialized_config.num_embeddings, org_config.num_embeddings
assert deserialized_config.feature_names, org_config.feature_names
self.assertEqual(deserialized.name, orginal.name)
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
self.assertEqual(deserialized.feature_names, orginal.feature_names)

# check FPEBC config
for i in range(2):
fpebc_name = f"fpebc{i + 1}"
assert isinstance(
getattr(deserialized_model, fpebc_name),
FeatureProcessedEmbeddingBagCollection,
)

deserialized_fp = getattr(
deserialized_model, fpebc_name
)._feature_processors
original_fp = getattr(model, fpebc_name)._feature_processors
if isinstance(original_fp, nn.ModuleDict):
for deserialized, orginal in zip(
deserialized_fp.values(), original_fp.values()
):
self.assertDictEqual(
deserialized.get_init_kwargs(), orginal.get_init_kwargs()
)
else:
self.assertDictEqual(
deserialized_fp.get_init_kwargs(), original_fp.get_init_kwargs()
)

for deserialized, orginal in zip(
getattr(
deserialized_model, fpebc_name
)._embedding_bag_collection.embedding_bag_configs(),
getattr(
model, fpebc_name
)._embedding_bag_collection.embedding_bag_configs(),
):
self.assertEqual(deserialized.name, orginal.name)
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
self.assertEqual(deserialized.feature_names, orginal.feature_names)

deserialized_model.load_state_dict(model.state_dict())
# Run forward on deserialized model

# Run forward on deserialized model and compare the output
deserialized_out = deserialized_model(id_list_features)

for i, tensor in enumerate(deserialized_out):
assert eager_out[i].shape == tensor.shape
assert torch.allclose(eager_out[i], tensor)
self.assertEqual(len(deserialized_out), len(eager_out))
for deserialized, orginal in zip(deserialized_out, eager_out):
self.assertEqual(deserialized.shape, orginal.shape)
self.assertTrue(torch.allclose(deserialized, orginal))

def test_dynamic_shape_ebc(self) -> None:
model = self.generate_model()
Expand Down
6 changes: 6 additions & 0 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,14 @@ def serialize_embedding_modules(
Returns the modified module and the list of fqns that had the buffer added.
"""
preserve_fqns = []
serialized_fqns = set()
for fqn, module in model.named_modules():
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
# this avoid serializing the submodule within a module that is already serialized
if any(fqn.startswith(s_fqn) for s_fqn in serialized_fqns):
continue
else:
serialized_fqns.add(fqn)
serialized_module = serializer_cls.serialize(module)
module.register_buffer("ir_metadata", serialized_module, persistent=False)
preserve_fqns.append(fqn)
Expand Down
Loading

0 comments on commit e6ff80c

Please sign in to comment.