Skip to content

Commit

Permalink
support serialization of compound module
Browse files Browse the repository at this point in the history
Summary:
# context
* to support compound module serialization such as fpEBC and fpPEA, we modified the serializer interface a little
* the basic idea is that:
a. if a compound module A consists of child module B and C, such that module B and C have their own serializer available. 
b. for serialization of module A, we can just capture it relation from B and C, and used it when deserialization
c. specifically, during the deserialization, A's children will be passed in as a dict of (child_fqn => child_module)

# note
* in order to apply this approach, it requires that module A's construction **takes in it's children as objects**
* **DO NOT** create A's children in the A's construction

# details
* serialization schedule
```
comp.ebc does NOT require further serialization of its children
comp.ebc.embedding_bags is skipped for further serialization
comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp is resumed for serialization
comp.comp Requires further serialization of its children
comp.comp.ebc does NOT require further serialization of its children
comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp.comp is resumed for serialization
comp.comp.comp Requires further serialization of its children
comp.comp.comp.ebc does NOT require further serialization of its children
comp.comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
```
* `parent_fqn`'s children
```
comp.comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
)} None None
```
```
comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
)} None None
```
```
comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
  (comp): CompoundModule(
    (ebc): EmbeddingBagCollection(
      (embedding_bags): ModuleDict(
        (t1): EmbeddingBag(10, 4, mode='sum')
        (t2): EmbeddingBag(10, 4, mode='sum')
        (t3): EmbeddingBag(10, 4, mode='sum')
      )
    )
  )
)} None None
```

Differential Revision: D58221182
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jun 6, 2024
1 parent fce86a4 commit 6784155
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 25 deletions.
18 changes: 15 additions & 3 deletions torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ def serialize(

@classmethod
def deserialize(
cls, input: torch.Tensor, typename: str, device: Optional[torch.device] = None
cls,
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
if typename != "EmbeddingBagCollection":
raise ValueError(
Expand Down Expand Up @@ -152,13 +156,21 @@ def serialize(

@classmethod
def deserialize(
cls, input: torch.Tensor, typename: str, device: Optional[torch.device] = None
cls,
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
if typename not in cls.module_to_serializer_cls:
raise ValueError(
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
)

return cls.module_to_serializer_cls[typename].deserialize(
input, typename, device
input, typename, device, children
)

@classmethod
def requires_children(cls, typename: str) -> bool:
return cls.module_to_serializer_cls[typename].requires_children(typename)
69 changes: 65 additions & 4 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import copy
import unittest
from typing import Callable, List, Optional
from typing import Callable, Dict, List, Optional, Union

import torch
from torch import nn
Expand All @@ -33,20 +33,64 @@

class CompoundModule(nn.Module):
def __init__(
self, ebc: EmbeddingBagCollection, comp: Optional["CompoundModule"] = None
self,
ebc: EmbeddingBagCollection,
comp: Optional["CompoundModule"] = None,
mlist: List[Union[EmbeddingBagCollection, "CompoundModule"]] = [],
) -> None:
super().__init__()
self.ebc = ebc
self.comp = comp
self.list = nn.ModuleList(mlist)

def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
res = self.comp(features) if self.comp else []
res.append(self.ebc(features).values())
for m in self.list:
if isinstance(m, CompoundModule):
res.extend(m(features))
else:
res.append(m(features).values())
return res


class CompoundModuleSerializer(SerializerInterface):
pass
@classmethod
def requires_children(cls, typename: str) -> bool:
return True

@classmethod
def serialize(
cls,
module: nn.Module,
) -> torch.Tensor:
if not isinstance(module, CompoundModule):
raise ValueError(
f"Expected module to be of type CompoundModule, got {type(module)}"
)
return torch.empty(0)

@classmethod
def deserialize(
cls,
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
if typename != "CompoundModule":
raise ValueError(f"Expected typename to be CompoundModule, got {typename}")
ebc = children["ebc"]
comp = children.get("comp")
assert isinstance(ebc, EmbeddingBagCollection)
if comp is not None:
assert isinstance(comp, CompoundModule)
i = 0
mlist = []
while f"list.{i}" in children:
mlist.append(children[f"list.{i}"])
i += 1
return CompoundModule(ebc, comp, mlist)


class TestJsonSerializer(unittest.TestCase):
Expand Down Expand Up @@ -300,7 +344,21 @@ def test_compound_module(self) -> None:
is_weighted=False,
)

model = CompoundModule(ebc(), CompoundModule(ebc(), CompoundModule(ebc())))
class MyModel(nn.Module):
def __init__(self, comp: CompoundModule) -> None:
super().__init__()
self.comp = comp

def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
return self.comp(features)

model = MyModel(
CompoundModule(
ebc=ebc(),
comp=CompoundModule(ebc(), CompoundModule(ebc(), mlist=[ebc(), ebc()])),
mlist=[ebc(), CompoundModule(ebc(), CompoundModule(ebc()))],
)
)
id_list_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
values=torch.tensor([0, 1, 2, 3, 2, 3]),
Expand All @@ -309,6 +367,9 @@ def test_compound_module(self) -> None:

eager_out = model(id_list_features)

JsonSerializer.module_to_serializer_cls["CompoundModule"] = (
CompoundModuleSerializer
)
# Serialize
model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer)
ep = torch.export.export(
Expand Down
5 changes: 5 additions & 0 deletions torchrec/ir/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def deserialize(
input: Any,
typename: str,
device: Optional[torch.device] = None,
children: Dict[str, nn.Module] = {},
) -> nn.Module:
# Take the bytes in the buffer and regenerate the eager embedding module
pass

@classmethod
def requires_children(cls, typename: str) -> bool:
return False
128 changes: 110 additions & 18 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#!/usr/bin/env python3

from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import Dict, Iterator, List, Optional, Tuple, Type, Union

import torch

Expand All @@ -36,16 +36,111 @@ def serialize_embedding_modules(
Returns the modified module and the list of fqns that had the buffer added.
"""
# store the fqns of the modules that is serialized, so that the ir_export can preserve
# the fqns of those modules
preserve_fqns = []

# store the fqn of the module that does not require serializing its children,
# so that we can skip based on the children's fqn
skip_children: Optional[str] = None
for fqn, module in model.named_modules():
if skip_children is not None:
if fqn.startswith(skip_children):
# {fqn} is skipped for further serialization
continue
else:
# reset the skip_children to None because fqn is no longer a child
skip_children = None
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
serialized_module = serializer_cls.serialize(module)
# store the fqn of a serialized module that doesn't not require further serialization of its children
if not serializer_cls.requires_children(type(module).__name__):
skip_children = fqn
module.register_buffer("ir_metadata", serialized_module, persistent=False)
preserve_fqns.append(fqn)

return model, preserve_fqns


def _next_or_none(
generator: Iterator[Tuple[str, nn.Module]]
) -> Tuple[Optional[str], Optional[nn.Module]]:
try:
return next(generator)
except StopIteration:
return None, None


def _deserialize_node(
parent_fqn: str,
fqn: Optional[str],
module: Optional[nn.Module],
generator: Iterator[Tuple[str, nn.Module]],
serializer_cls: Type[SerializerInterface],
module_type_dict: Dict[str, str],
fqn_to_new_module: Dict[str, nn.Module],
device: Optional[torch.device] = None,
) -> Tuple[Dict[str, nn.Module], Optional[str], Optional[nn.Module]]:
"""
returns:
1. the children of the parent_fqn Dict[relative_fqn -> module]
2. the next node Optional[fqn], Optional[module], which is not a child of the parent_fqn
"""
children: Dict[str, nn.Module] = {}
# we only starts the while loop when the current fqn is a child of the parent_fqn
# it stops at either the current node is not a child of the parent_fqn or
# the generator is exhausted
while fqn is not None and module is not None and fqn.startswith(parent_fqn):
# the current node is a serialized module, need to deserialize it
if "ir_metadata" in dict(module.named_buffers()):
serialized_module = module.get_buffer("ir_metadata")
if fqn not in module_type_dict:
raise RuntimeError(
f"Cannot find the type of module {fqn} in the exported program"
)

# current module's deserialization requires its children
if serializer_cls.requires_children(module_type_dict[fqn]):
# set current fqn as the new parent_fqn, and call deserialize_node function
# recursively to get the children of current_fqn, and the next sibling of current_fqn
next_fqn, next_module = _next_or_none(generator)
grand_children, next_fqn, next_module = _deserialize_node(
fqn,
next_fqn,
next_module,
generator,
serializer_cls,
module_type_dict,
fqn_to_new_module,
device,
)
# deserialize the current module with its children
deserialized_module = serializer_cls.deserialize(
serialized_module,
module_type_dict[fqn],
device=device,
children=grand_children,
)
else:
# current module's deserialization doesn't require its children
# deserialize it first then get the next sibling
deserialized_module = serializer_cls.deserialize(
serialized_module, module_type_dict[fqn], device=device
)
next_fqn, next_module = _next_or_none(generator)

# register the deserialized module
rel_fqn = fqn[len(parent_fqn) + 1 :] if len(parent_fqn) > 0 else fqn
children[rel_fqn] = deserialized_module
fqn_to_new_module[fqn] = deserialized_module

else: # current node doesn't require deserialization, move on
next_fqn, next_module = _next_or_none(generator)
# move to the next node
fqn, module = next_fqn, next_module
return children, fqn, module


def deserialize_embedding_modules(
ep: ExportedProgram,
serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
Expand All @@ -59,29 +154,26 @@ def deserialize_embedding_modules(
Returns the unflattened ExportedProgram with the deserialized modules.
"""
model = torch.export.unflatten(ep)
module_type_dict = {}
module_type_dict: Dict[str, str] = {}
for node in ep.graph.nodes:
if "nn_module_stack" in node.meta:
for fqn, type_name in node.meta["nn_module_stack"].values():
# Only get the module type name, not the full type name
module_type_dict[fqn] = type_name.split(".")[-1]

fqn_to_new_module = {}
for fqn, module in model.named_modules():
if "ir_metadata" in dict(module.named_buffers()):
serialized_module = dict(module.named_buffers())["ir_metadata"]

if fqn not in module_type_dict:
raise RuntimeError(
f"Cannot find the type of module {fqn} in the exported program"
)

deserialized_module = serializer_cls.deserialize(
serialized_module,
module_type_dict[fqn],
device,
)
fqn_to_new_module[fqn] = deserialized_module
fqn_to_new_module: Dict[str, nn.Module] = {}
generator = model.named_modules()
fqn, root = _next_or_none(generator)
_deserialize_node(
"",
fqn,
root,
generator,
serializer_cls,
module_type_dict,
fqn_to_new_module,
device,
)

for fqn, new_module in fqn_to_new_module.items():
# handle nested attribute like "x.y.z"
Expand Down

0 comments on commit 6784155

Please sign in to comment.