Skip to content

Commit

Permalink
register custom_op for fpEBC (#2067)
Browse files Browse the repository at this point in the history
Summary:

# context
* convert `FeatureProcessedEmbeddingBagCollection` to custom op in IR export
* add serialization and deserialization function for FPEBC
* add an API for the `FeatureProcessorInterface` to export necessary paramters for create an instance
* use this API (`get_init_kwargs`) in the serialize and deserialize functions to flatten and unflatten the feature processor

# details
1. Added `FPEBCMetadata` schema for FP_EBC, use a `fp_json` string to store the necessary paramters
2. Added `FPEBCJsonSerializer`, converted the init_kwargs to json string and store in the `fp_json` field in the metadata
3. Added a fqn check for `serialized_fqns`, so that when a higher-level module is serialized, the lower-level module can be skipped (it's already included in the higher-level module)
4. Added an API called `get_init_kwargs` for `FeatureProcessorsCollection` and `FeatureProcessor`, and use a `FeatureProcessorNameMap` to map the classname to the feature processor class
5. Added `_non_strict_exporting_forward` function for FPEBC so that in non_strict IR export it goes to the custom_op logic

Differential Revision: D57829276
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jun 6, 2024
1 parent 6784155 commit 048be62
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 79 deletions.
18 changes: 17 additions & 1 deletion torchrec/ir/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Tuple

from torchrec.modules.embedding_configs import DataType, PoolingType

Expand All @@ -32,3 +32,19 @@ class EBCMetadata:
tables: List[EmbeddingBagConfigMetadata]
is_weighted: bool
device: Optional[str]


@dataclass
class FPEBCMetadata:
is_fp_collection: bool
feature_list: List[str]


@dataclass
class PositionWeightedModuleMetadata:
max_feature_length: int


@dataclass
class PositionWeightedModuleCollectionMetadata:
max_feature_lengths: List[Tuple[str, int]]
161 changes: 158 additions & 3 deletions torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@
import torch

from torch import nn
from torchrec.ir.schema import EBCMetadata, EmbeddingBagConfigMetadata
from torchrec.ir.schema import (
EBCMetadata,
EmbeddingBagConfigMetadata,
FPEBCMetadata,
PositionWeightedModuleCollectionMetadata,
PositionWeightedModuleMetadata,
)

from torchrec.ir.types import SerializerInterface
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig, PoolingType
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.feature_processor_ import (
FeatureProcessor,
FeatureProcessorsCollection,
PositionWeightedModule,
PositionWeightedModuleCollection,
)
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,7 +84,7 @@ def get_deserialized_device(

class EBCJsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using thrift.
Serializer for torch.export IR using json.
"""

@classmethod
Expand Down Expand Up @@ -132,13 +145,155 @@ def deserialize(
)


class PWMJsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using json.
"""

@classmethod
def serialize(cls, module: nn.Module) -> torch.Tensor:
if not isinstance(module, PositionWeightedModule):
raise ValueError(
f"Expected module to be of type PositionWeightedModule, got {type(module)}"
)
metadata = PositionWeightedModuleMetadata(
max_feature_length=module.position_weight.shape[0],
)
return torch.frombuffer(
json.dumps(metadata.__dict__).encode(), dtype=torch.uint8
)

@classmethod
def deserialize(
cls,
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
if typename != "PositionWeightedModule":
raise ValueError(
f"Expected typename to be PositionWeightedModule, got {typename}"
)
raw_bytes = input.numpy().tobytes()
metadata = json.loads(raw_bytes)
return PositionWeightedModule(metadata["max_feature_length"], device)


class PWMCJsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using json.
"""

@classmethod
def serialize(cls, module: nn.Module) -> torch.Tensor:
if not isinstance(module, PositionWeightedModuleCollection):
raise ValueError(
f"Expected module to be of type PositionWeightedModuleCollection, got {type(module)}"
)
metadata = PositionWeightedModuleCollectionMetadata(
max_feature_lengths=[ # convert to list of tuples to preserve the order
(feature, len) for feature, len in module.max_feature_lengths.items()
],
)
return torch.frombuffer(
json.dumps(metadata.__dict__).encode(), dtype=torch.uint8
)

@classmethod
def deserialize(
cls,
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
if typename != "PositionWeightedModuleCollection":
raise ValueError(
f"Expected typename to be PositionWeightedModuleCollection, got {typename}"
)
raw_bytes = input.numpy().tobytes()
metadata = PositionWeightedModuleCollectionMetadata(**json.loads(raw_bytes))
max_feature_lengths = {
feature: len for feature, len in metadata.max_feature_lengths
}
return PositionWeightedModuleCollection(max_feature_lengths, device)


class FPEBCJsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using json.
"""

@classmethod
def requires_children(cls, typename: str) -> bool:
return True

@classmethod
def serialize(
cls,
module: nn.Module,
) -> torch.Tensor:
if not isinstance(module, FeatureProcessedEmbeddingBagCollection):
raise ValueError(
f"Expected module to be of type FeatureProcessedEmbeddingBagCollection, got {type(module)}"
)
elif isinstance(module._feature_processors, FeatureProcessorsCollection):
metadata = FPEBCMetadata(
is_fp_collection=True,
feature_list=[],
)
else:
metadata = FPEBCMetadata(
is_fp_collection=False,
feature_list=list(module._feature_processors.keys()),
)

return torch.frombuffer(
json.dumps(metadata.__dict__).encode(), dtype=torch.uint8
)

@classmethod
def deserialize(
cls,
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
if typename != "FeatureProcessedEmbeddingBagCollection":
raise ValueError(
f"Expected typename to be EmbeddingBagCollection, got {typename}"
)
raw_bytes = input.numpy().tobytes()
metadata = FPEBCMetadata(**json.loads(raw_bytes.decode()))
if metadata.is_fp_collection:
feature_processors = children["_feature_processor"]
assert isinstance(feature_processors, FeatureProcessorsCollection)
else:
feature_processors: dict[str, FeatureProcessor] = {}
for feature in metadata.feature_list:
fp = children[f"_feature_processor.{feature}"]
assert isinstance(fp, FeatureProcessor)
feature_processors[feature] = fp
ebc = children["_embedding_bag_collection"]
assert isinstance(ebc, EmbeddingBagCollection)
return FeatureProcessedEmbeddingBagCollection(
ebc,
feature_processors,
)


class JsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using thrift.
Serializer for torch.export IR using json.
"""

module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = {
"EmbeddingBagCollection": EBCJsonSerializer,
"FeatureProcessedEmbeddingBagCollection": FPEBCJsonSerializer,
"PositionWeightedModule": PWMJsonSerializer,
"PositionWeightedModuleCollection": PWMCJsonSerializer,
}

@classmethod
Expand Down
Loading

0 comments on commit 048be62

Please sign in to comment.