Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register custom_op for fpEBC #2067

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion torchrec/ir/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Tuple

from torchrec.modules.embedding_configs import DataType, PoolingType

Expand All @@ -32,3 +32,19 @@ class EBCMetadata:
tables: List[EmbeddingBagConfigMetadata]
is_weighted: bool
device: Optional[str]


@dataclass
class FPEBCMetadata:
is_fp_collection: bool
features: List[str]


@dataclass
class PositionWeightedModuleMetadata:
max_feature_length: int


@dataclass
class PositionWeightedModuleCollectionMetadata:
max_feature_lengths: List[Tuple[str, int]]
122 changes: 121 additions & 1 deletion torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@
import torch

from torch import nn
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata
from torchrec.ir.schema import (
EBCMetadata,
EmbeddingBagConfigMetadata,
FPEBCMetadata,
PositionWeightedModuleCollectionMetadata,
PositionWeightedModuleMetadata,
)

from torchrec.ir.types import SerializerInterface
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.feature_processor_ import (
FeatureProcessor,
FeatureProcessorsCollection,
PositionWeightedModule,
PositionWeightedModuleCollection,
)
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -196,3 +209,110 @@ def deserialize_from_dict(


JsonSerializer.module_to_serializer_cls["EmbeddingBagCollection"] = EBCJsonSerializer


class PWMJsonSerializer(JsonSerializer):
_module_cls = PositionWeightedModule

@classmethod
def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]:
metadata = PositionWeightedModuleMetadata(
max_feature_length=module.position_weight.shape[0],
)
return metadata.__dict__

@classmethod
def deserialize_from_dict(
cls,
metadata_dict: Dict[str, Any],
device: Optional[torch.device] = None,
unflatten_ep: Optional[nn.Module] = None,
) -> nn.Module:
metadata = PositionWeightedModuleMetadata(**metadata_dict)
return PositionWeightedModule(metadata.max_feature_length, device)


JsonSerializer.module_to_serializer_cls["PositionWeightedModule"] = PWMJsonSerializer


class PWMCJsonSerializer(JsonSerializer):
_module_cls = PositionWeightedModuleCollection

@classmethod
def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]:
metadata = PositionWeightedModuleCollectionMetadata(
max_feature_lengths=[ # convert to list of tuples to preserve the order
(feature, len) for feature, len in module.max_feature_lengths.items()
],
)
return metadata.__dict__

@classmethod
def deserialize_from_dict(
cls,
metadata_dict: Dict[str, Any],
device: Optional[torch.device] = None,
unflatten_ep: Optional[nn.Module] = None,
) -> nn.Module:
metadata = PositionWeightedModuleCollectionMetadata(**metadata_dict)
max_feature_lengths = {
feature: len for feature, len in metadata.max_feature_lengths
}
return PositionWeightedModuleCollection(max_feature_lengths, device)


JsonSerializer.module_to_serializer_cls["PositionWeightedModuleCollection"] = (
PWMCJsonSerializer
)


class FPEBCJsonSerializer(JsonSerializer):
_module_cls = FeatureProcessedEmbeddingBagCollection
_children = ["_feature_processors", "_embedding_bag_collection"]

@classmethod
def serialize_to_dict(
cls,
module: nn.Module,
) -> Dict[str, Any]:
if isinstance(module._feature_processors, FeatureProcessorsCollection):
metadata = FPEBCMetadata(
is_fp_collection=True,
features=[],
)
else:
metadata = FPEBCMetadata(
is_fp_collection=False,
features=list(module._feature_processors.keys()),
)
return metadata.__dict__

@classmethod
def deserialize_from_dict(
cls,
metadata_dict: Dict[str, Any],
device: Optional[torch.device] = None,
unflatten_ep: Optional[nn.Module] = None,
) -> nn.Module:
metadata = FPEBCMetadata(**metadata_dict)
assert unflatten_ep is not None
if metadata.is_fp_collection:
feature_processors = unflatten_ep._feature_processors
assert isinstance(feature_processors, FeatureProcessorsCollection)
else:
feature_processors: dict[str, FeatureProcessor] = {}
for feature in metadata.features:
fp = getattr(unflatten_ep._feature_processors, feature)
assert isinstance(fp, FeatureProcessor)
feature_processors[feature] = fp
ebc = unflatten_ep._embedding_bag_collection
assert isinstance(ebc, EmbeddingBagCollection)
return FeatureProcessedEmbeddingBagCollection(
ebc,
feature_processors,
)


JsonSerializer.module_to_serializer_cls["FeatureProcessedEmbeddingBagCollection"] = (
FPEBCJsonSerializer
)
113 changes: 77 additions & 36 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@

from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection
from torchrec.modules.feature_processor_ import (
PositionWeightedModule,
PositionWeightedModuleCollection,
)
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.utils import operator_registry_state
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
Expand Down Expand Up @@ -90,16 +93,18 @@ def deserialize_from_dict(


class TestJsonSerializer(unittest.TestCase):
# in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict
def generate_model(self) -> nn.Module:
class Model(nn.Module):
def __init__(self, ebc, fpebc):
def __init__(self, ebc, fpebc1, fpebc2):
super().__init__()
self.ebc1 = ebc
self.ebc2 = copy.deepcopy(ebc)
self.ebc3 = copy.deepcopy(ebc)
self.ebc4 = copy.deepcopy(ebc)
self.ebc5 = copy.deepcopy(ebc)
self.fpebc = fpebc
self.fpebc1 = fpebc1
self.fpebc2 = fpebc2

def forward(
self,
Expand All @@ -111,22 +116,16 @@ def forward(
kt4 = self.ebc4(features)
kt5 = self.ebc5(features)

fpebc_res = self.fpebc(features)
ebc_kt_vals = [kt.values() for kt in [kt1, kt2, kt3, kt4, kt5]]
sparse_arch_vals = sum(ebc_kt_vals)
sparse_arch_res = KeyedTensor(
keys=kt1.keys(),
values=sparse_arch_vals,
length_per_key=kt1.length_per_key(),
)

return KeyedTensor.regroup(
[sparse_arch_res, fpebc_res], [["f1"], ["f2", "f3"]]
)
fpebc1_res = self.fpebc1(features)
fpebc2_res = self.fpebc2(features)
res: List[torch.Tensor] = []
for kt in [kt1, kt2, kt3, kt4, kt5, fpebc1_res, fpebc2_res]:
res.extend(KeyedTensor.regroup([kt], [[key] for key in kt.keys()]))
return res

tb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=4,
embedding_dim=3,
num_embeddings=10,
feature_names=["f1"],
)
Expand All @@ -138,7 +137,7 @@ def forward(
)
tb3_config = EmbeddingBagConfig(
name="t3",
embedding_dim=4,
embedding_dim=5,
num_embeddings=10,
feature_names=["f3"],
)
Expand All @@ -149,7 +148,7 @@ def forward(
)
max_feature_lengths = {"f1": 100, "f2": 100}

fpebc = FeatureProcessedEmbeddingBagCollection(
fpebc1 = FeatureProcessedEmbeddingBagCollection(
EmbeddingBagCollection(
tables=[tb1_config, tb2_config],
is_weighted=True,
Expand All @@ -158,8 +157,18 @@ def forward(
max_feature_lengths=max_feature_lengths,
),
)
fpebc2 = FeatureProcessedEmbeddingBagCollection(
EmbeddingBagCollection(
tables=[tb1_config, tb3_config],
is_weighted=True,
),
{
"f1": PositionWeightedModule(max_feature_length=10),
"f3": PositionWeightedModule(max_feature_length=20),
},
)

model = Model(ebc, fpebc)
model = Model(ebc, fpebc1, fpebc2)

return model

Expand Down Expand Up @@ -190,12 +199,16 @@ def test_serialize_deserialize_ebc(self) -> None:
for i, tensor in enumerate(ep_output):
self.assertEqual(eager_out[i].shape, tensor.shape)

# Only 2 custom op registered, as dimensions of ebc are same
self.assertEqual(len(operator_registry_state.op_registry_schema), 2)
# Should have 3 custom op registered, as dimensions of ebc are same,
# and two fpEBCs have different dims
self.assertEqual(len(operator_registry_state.op_registry_schema), 3)

total_dim_ebc = sum(model.ebc1._lengths_per_embedding)
total_dim_fpebc = sum(
model.fpebc._embedding_bag_collection._lengths_per_embedding
total_dim_fpebc1 = sum(
model.fpebc1._embedding_bag_collection._lengths_per_embedding
)
total_dim_fpebc2 = sum(
model.fpebc2._embedding_bag_collection._lengths_per_embedding
)
# Check if custom op is registered with the correct name
# EmbeddingBagCollection type and total dim
Expand All @@ -204,35 +217,63 @@ def test_serialize_deserialize_ebc(self) -> None:
in operator_registry_state.op_registry_schema
)
self.assertTrue(
f"EmbeddingBagCollection_{total_dim_fpebc}"
f"EmbeddingBagCollection_{total_dim_fpebc1}"
in operator_registry_state.op_registry_schema
)
self.assertTrue(
f"EmbeddingBagCollection_{total_dim_fpebc2}"
in operator_registry_state.op_registry_schema
)

# Deserialize EBC
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)

# check EBC config
for i in range(5):
ebc_name = f"ebc{i + 1}"
assert isinstance(
self.assertIsInstance(
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
)

for deserialized_config, org_config in zip(
for deserialized, orginal in zip(
getattr(deserialized_model, ebc_name).embedding_bag_configs(),
getattr(model, ebc_name).embedding_bag_configs(),
):
assert deserialized_config.name == org_config.name
assert deserialized_config.embedding_dim == org_config.embedding_dim
assert deserialized_config.num_embeddings, org_config.num_embeddings
assert deserialized_config.feature_names, org_config.feature_names
self.assertEqual(deserialized.name, orginal.name)
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
self.assertEqual(deserialized.feature_names, orginal.feature_names)

# check FPEBC config
for i in range(2):
fpebc_name = f"fpebc{i + 1}"
assert isinstance(
getattr(deserialized_model, fpebc_name),
FeatureProcessedEmbeddingBagCollection,
)

for deserialized, orginal in zip(
getattr(
deserialized_model, fpebc_name
)._embedding_bag_collection.embedding_bag_configs(),
getattr(
model, fpebc_name
)._embedding_bag_collection.embedding_bag_configs(),
):
self.assertEqual(deserialized.name, orginal.name)
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
self.assertEqual(deserialized.feature_names, orginal.feature_names)

deserialized_model.load_state_dict(model.state_dict())
# Run forward on deserialized model

# Run forward on deserialized model and compare the output
deserialized_out = deserialized_model(id_list_features)

for i, tensor in enumerate(deserialized_out):
assert eager_out[i].shape == tensor.shape
assert torch.allclose(eager_out[i], tensor)
self.assertEqual(len(deserialized_out), len(eager_out))
for deserialized, orginal in zip(deserialized_out, eager_out):
self.assertEqual(deserialized.shape, orginal.shape)
self.assertTrue(torch.allclose(deserialized, orginal))

def test_dynamic_shape_ebc(self) -> None:
model = self.generate_model()
Expand All @@ -259,7 +300,7 @@ def test_dynamic_shape_ebc(self) -> None:
dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
preserve_module_call_signature=tuple(sparse_fqns),
)

# Run forward on ExportedProgram
Expand All @@ -271,8 +312,8 @@ def test_dynamic_shape_ebc(self) -> None:

# Deserialize EBC
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)

deserialized_model.load_state_dict(model.state_dict())

# Run forward on deserialized model
deserialized_out = deserialized_model(feature2)

Expand Down
Loading
Loading