Skip to content

Commit

Permalink
support serialization of compound module
Browse files Browse the repository at this point in the history
Summary:
# context
* to support compound module serialization such as fpEBC and fpPEA, we modified the serializer interface a little
* the basic idea is that:
a. if a compound module A consists of child module B and C, such that module B and C have their own serializer available.
b. for serialization of module A, we can just capture it relation from B and C, and used it when deserialization
c. specifically, during the deserialization, A's children will be passed in as a dict of (child_fqn => child_module)

# design doc
[**TorchRec Composable Serializer Design**](https://docs.google.com/document/d/1WUtmzdcqZmwLd4Do8g1fQRjChnRw0ZimyUrCNhxa4nA/edit#heading=h.ezrtdguw0lwq)

# note
* in order to apply this approach, it requires that module A's construction **takes in it's children as objects**
* **DO NOT** create A's children in the A's construction

# details
* serialization schedule
```
comp.ebc does NOT require further serialization of its children
comp.ebc.embedding_bags is skipped for further serialization
comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp is resumed for serialization
comp.comp Requires further serialization of its children
comp.comp.ebc does NOT require further serialization of its children
comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp.comp is resumed for serialization
comp.comp.comp Requires further serialization of its children
comp.comp.comp.ebc does NOT require further serialization of its children
comp.comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
```
* `parent_fqn`'s children
```
comp.comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
)} None None
```
```
comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
)} None None
```
```
comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
  (comp): CompoundModule(
    (ebc): EmbeddingBagCollection(
      (embedding_bags): ModuleDict(
        (t1): EmbeddingBag(10, 4, mode='sum')
        (t2): EmbeddingBag(10, 4, mode='sum')
        (t3): EmbeddingBag(10, 4, mode='sum')
      )
    )
  )
)} None None
```

Differential Revision: D58221182
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jun 13, 2024
1 parent 744ba95 commit 11726d3
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 21 deletions.
17 changes: 15 additions & 2 deletions torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def get_deserialized_device(

class JsonSerializerBase(SerializerInterface):
_module_cls: Optional[Type[nn.Module]] = None
_requires_children: bool = False

@classmethod
def requires_children(cls, typename: str) -> bool:
return cls._requires_children

@classmethod
def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]:
Expand All @@ -81,6 +86,7 @@ def deserialize_from_dict(
cls,
metadata_dict: Dict[str, Any],
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
raise NotImplementedError()

Expand All @@ -106,10 +112,11 @@ def deserialize(
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
raw_bytes = input.numpy().tobytes()
metadata_dict = json.loads(raw_bytes.decode())
module = cls.deserialize_from_dict(metadata_dict, device)
module = cls.deserialize_from_dict(metadata_dict, device, children)
if cls._module_cls is None:
raise ValueError(
"Must assign a nn.Module to class static variable _module_cls"
Expand Down Expand Up @@ -148,6 +155,7 @@ def deserialize_from_dict(
cls,
metadata_dict: Dict[str, Any],
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
tables = [
EmbeddingBagConfigMetadata(**table_config)
Expand Down Expand Up @@ -192,12 +200,17 @@ def deserialize(
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
if typename not in cls.module_to_serializer_cls:
raise ValueError(
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
)

return cls.module_to_serializer_cls[typename].deserialize(
input, typename, device
input, typename, device, children
)

@classmethod
def requires_children(cls, typename: str) -> bool:
return cls.module_to_serializer_cls[typename].requires_children(typename)
53 changes: 52 additions & 1 deletion torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@

import copy
import unittest
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import torch
from torch import nn
from torchrec.ir.serializer import JsonSerializer
from torchrec.ir.types import SerializerInterface

from torchrec.ir.utils import (
deserialize_embedding_modules,
Expand Down Expand Up @@ -54,6 +55,45 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
return res


class CompoundModuleSerializer(SerializerInterface):
@classmethod
def requires_children(cls, typename: str) -> bool:
return True

@classmethod
def serialize(
cls,
module: nn.Module,
) -> torch.Tensor:
if not isinstance(module, CompoundModule):
raise ValueError(
f"Expected module to be of type CompoundModule, got {type(module)}"
)
return torch.empty(0)

@classmethod
def deserialize(
cls,
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
if typename != "CompoundModule":
raise ValueError(f"Expected typename to be CompoundModule, got {typename}")
ebc = children["ebc"]
comp = children.get("comp")
assert isinstance(ebc, EmbeddingBagCollection)
if comp is not None:
assert isinstance(comp, CompoundModule)
i = 0
mlist = []
while f"list.{i}" in children:
mlist.append(children[f"list.{i}"])
i += 1
return CompoundModule(ebc, comp, mlist)


class TestJsonSerializer(unittest.TestCase):
def generate_model(self) -> nn.Module:
class Model(nn.Module):
Expand Down Expand Up @@ -328,6 +368,9 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:

eager_out = model(id_list_features)

JsonSerializer.module_to_serializer_cls["CompoundModule"] = (
CompoundModuleSerializer
)
# Serialize
model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer)
ep = torch.export.export(
Expand All @@ -346,6 +389,14 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:

# Deserialize
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)

# Check if Compound Module is deserialized correctly
self.assertIsInstance(deserialized_model.comp, CompoundModule)
self.assertIsInstance(deserialized_model.comp.comp, CompoundModule)
self.assertIsInstance(deserialized_model.comp.comp.comp, CompoundModule)
self.assertIsInstance(deserialized_model.comp.list[1], CompoundModule)
self.assertIsInstance(deserialized_model.comp.list[1].comp, CompoundModule)

deserialized_model.load_state_dict(model.state_dict())
# Run forward on deserialized model
deserialized_out = deserialized_model(id_list_features)
Expand Down
5 changes: 5 additions & 0 deletions torchrec/ir/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def deserialize(
input: Any,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
# Take the bytes in the buffer and regenerate the eager embedding module
pass

@classmethod
def requires_children(cls, typename: str) -> bool:
return False
128 changes: 110 additions & 18 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#!/usr/bin/env python3

from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import Dict, Iterator, List, Optional, Tuple, Type, Union

import torch

Expand All @@ -36,16 +36,111 @@ def serialize_embedding_modules(
Returns the modified module and the list of fqns that had the buffer added.
"""
# store the fqns of the modules that is serialized, so that the ir_export can preserve
# the fqns of those modules
preserve_fqns = []

# store the fqn of the module that does not require serializing its children,
# so that we can skip based on the children's fqn
skip_children: Optional[str] = None
for fqn, module in model.named_modules():
if skip_children is not None:
if fqn.startswith(skip_children):
# {fqn} is skipped for further serialization
continue
else:
# reset the skip_children to None because fqn is no longer a child
skip_children = None
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
serialized_module = serializer_cls.serialize(module)
# store the fqn of a serialized module that doesn't not require further serialization of its children
if not serializer_cls.requires_children(type(module).__name__):
skip_children = fqn
module.register_buffer("ir_metadata", serialized_module, persistent=False)
preserve_fqns.append(fqn)

return model, preserve_fqns


def _next_or_none(
generator: Iterator[Tuple[str, nn.Module]]
) -> Tuple[Optional[str], Optional[nn.Module]]:
try:
return next(generator)
except StopIteration:
return None, None


def _deserialize_node(
parent_fqn: str,
fqn: Optional[str],
module: Optional[nn.Module],
generator: Iterator[Tuple[str, nn.Module]],
serializer_cls: Type[SerializerInterface],
module_type_dict: Dict[str, str],
fqn_to_new_module: Dict[str, nn.Module],
device: Optional[torch.device] = None,
) -> Tuple[Dict[str, nn.Module], Optional[str], Optional[nn.Module]]:
"""
returns:
1. the children of the parent_fqn Dict[relative_fqn -> module]
2. the next node Optional[fqn], Optional[module], which is not a child of the parent_fqn
"""
children: Dict[str, nn.Module] = {}
# we only starts the while loop when the current fqn is a child of the parent_fqn
# it stops at either the current node is not a child of the parent_fqn or
# the generator is exhausted
while fqn is not None and module is not None and fqn.startswith(parent_fqn):
# the current node is a serialized module, need to deserialize it
if "ir_metadata" in dict(module.named_buffers()):
serialized_module = module.get_buffer("ir_metadata")
if fqn not in module_type_dict:
raise RuntimeError(
f"Cannot find the type of module {fqn} in the exported program"
)

# current module's deserialization requires its children
if serializer_cls.requires_children(module_type_dict[fqn]):
# set current fqn as the new parent_fqn, and call deserialize_node function
# recursively to get the children of current_fqn, and the next sibling of current_fqn
next_fqn, next_module = _next_or_none(generator)
grand_children, next_fqn, next_module = _deserialize_node(
fqn,
next_fqn,
next_module,
generator,
serializer_cls,
module_type_dict,
fqn_to_new_module,
device,
)
# deserialize the current module with its children
deserialized_module = serializer_cls.deserialize(
serialized_module,
module_type_dict[fqn],
device=device,
children=grand_children,
)
else:
# current module's deserialization doesn't require its children
# deserialize it first then get the next sibling
deserialized_module = serializer_cls.deserialize(
serialized_module, module_type_dict[fqn], device=device
)
next_fqn, next_module = _next_or_none(generator)

# register the deserialized module
rel_fqn = fqn[len(parent_fqn) + 1 :] if len(parent_fqn) > 0 else fqn
children[rel_fqn] = deserialized_module
fqn_to_new_module[fqn] = deserialized_module

else: # current node doesn't require deserialization, move on
next_fqn, next_module = _next_or_none(generator)
# move to the next node
fqn, module = next_fqn, next_module
return children, fqn, module


def deserialize_embedding_modules(
ep: ExportedProgram,
serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
Expand All @@ -59,29 +154,26 @@ def deserialize_embedding_modules(
Returns the unflattened ExportedProgram with the deserialized modules.
"""
model = torch.export.unflatten(ep)
module_type_dict = {}
module_type_dict: Dict[str, str] = {}
for node in ep.graph.nodes:
if "nn_module_stack" in node.meta:
for fqn, type_name in node.meta["nn_module_stack"].values():
# Only get the module type name, not the full type name
module_type_dict[fqn] = type_name.split(".")[-1]

fqn_to_new_module = {}
for fqn, module in model.named_modules():
if "ir_metadata" in dict(module.named_buffers()):
serialized_module = dict(module.named_buffers())["ir_metadata"]

if fqn not in module_type_dict:
raise RuntimeError(
f"Cannot find the type of module {fqn} in the exported program"
)

deserialized_module = serializer_cls.deserialize(
serialized_module,
module_type_dict[fqn],
device,
)
fqn_to_new_module[fqn] = deserialized_module
fqn_to_new_module: Dict[str, nn.Module] = {}
generator = model.named_modules()
fqn, root = _next_or_none(generator)
_deserialize_node(
"",
fqn,
root,
generator,
serializer_cls,
module_type_dict,
fqn_to_new_module,
device,
)

for fqn, new_module in fqn_to_new_module.items():
# handle nested attribute like "x.y.z"
Expand Down

0 comments on commit 11726d3

Please sign in to comment.