From 6784155d73c70968e787628bb7e1407eb2d5f71b Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Thu, 6 Jun 2024 13:29:16 -0700 Subject: [PATCH] support serialization of compound module Summary: # context * to support compound module serialization such as fpEBC and fpPEA, we modified the serializer interface a little * the basic idea is that: a. if a compound module A consists of child module B and C, such that module B and C have their own serializer available. b. for serialization of module A, we can just capture it relation from B and C, and used it when deserialization c. specifically, during the deserialization, A's children will be passed in as a dict of (child_fqn => child_module) # note * in order to apply this approach, it requires that module A's construction **takes in it's children as objects** * **DO NOT** create A's children in the A's construction # details * serialization schedule ``` comp.ebc does NOT require further serialization of its children comp.ebc.embedding_bags is skipped for further serialization comp.ebc.embedding_bags.t1 is skipped for further serialization comp.ebc.embedding_bags.t2 is skipped for further serialization comp.ebc.embedding_bags.t3 is skipped for further serialization comp.comp is resumed for serialization comp.comp Requires further serialization of its children comp.comp.ebc does NOT require further serialization of its children comp.comp.ebc.embedding_bags is skipped for further serialization comp.comp.ebc.embedding_bags.t1 is skipped for further serialization comp.comp.ebc.embedding_bags.t2 is skipped for further serialization comp.comp.ebc.embedding_bags.t3 is skipped for further serialization comp.comp.comp is resumed for serialization comp.comp.comp Requires further serialization of its children comp.comp.comp.ebc does NOT require further serialization of its children comp.comp.comp.ebc.embedding_bags is skipped for further serialization comp.comp.comp.ebc.embedding_bags.t1 is skipped for further serialization comp.comp.comp.ebc.embedding_bags.t2 is skipped for further serialization comp.comp.comp.ebc.embedding_bags.t3 is skipped for further serialization ``` * `parent_fqn`'s children ``` comp.comp.comp {'ebc': EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) )} None None ``` ``` comp.comp {'ebc': EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) ), 'comp': CompoundModule( (ebc): EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) ) )} None None ``` ``` comp {'ebc': EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) ), 'comp': CompoundModule( (ebc): EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) ) (comp): CompoundModule( (ebc): EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) ) ) )} None None ``` Differential Revision: D58221182 --- torchrec/ir/serializer.py | 18 +++- torchrec/ir/tests/test_serializer.py | 69 ++++++++++++++- torchrec/ir/types.py | 5 ++ torchrec/ir/utils.py | 128 +++++++++++++++++++++++---- 4 files changed, 195 insertions(+), 25 deletions(-) diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py index 23e2d7a07..0eb232d8c 100644 --- a/torchrec/ir/serializer.py +++ b/torchrec/ir/serializer.py @@ -104,7 +104,11 @@ def serialize( @classmethod def deserialize( - cls, input: torch.Tensor, typename: str, device: Optional[torch.device] = None + cls, + input: torch.Tensor, + typename: str, + device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, ) -> nn.Module: if typename != "EmbeddingBagCollection": raise ValueError( @@ -152,7 +156,11 @@ def serialize( @classmethod def deserialize( - cls, input: torch.Tensor, typename: str, device: Optional[torch.device] = None + cls, + input: torch.Tensor, + typename: str, + device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, ) -> nn.Module: if typename not in cls.module_to_serializer_cls: raise ValueError( @@ -160,5 +168,9 @@ def deserialize( ) return cls.module_to_serializer_cls[typename].deserialize( - input, typename, device + input, typename, device, children ) + + @classmethod + def requires_children(cls, typename: str) -> bool: + return cls.module_to_serializer_cls[typename].requires_children(typename) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 31357cf90..9b06c7c1c 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -11,7 +11,7 @@ import copy import unittest -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional, Union import torch from torch import nn @@ -33,20 +33,64 @@ class CompoundModule(nn.Module): def __init__( - self, ebc: EmbeddingBagCollection, comp: Optional["CompoundModule"] = None + self, + ebc: EmbeddingBagCollection, + comp: Optional["CompoundModule"] = None, + mlist: List[Union[EmbeddingBagCollection, "CompoundModule"]] = [], ) -> None: super().__init__() self.ebc = ebc self.comp = comp + self.list = nn.ModuleList(mlist) def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: res = self.comp(features) if self.comp else [] res.append(self.ebc(features).values()) + for m in self.list: + if isinstance(m, CompoundModule): + res.extend(m(features)) + else: + res.append(m(features).values()) return res class CompoundModuleSerializer(SerializerInterface): - pass + @classmethod + def requires_children(cls, typename: str) -> bool: + return True + + @classmethod + def serialize( + cls, + module: nn.Module, + ) -> torch.Tensor: + if not isinstance(module, CompoundModule): + raise ValueError( + f"Expected module to be of type CompoundModule, got {type(module)}" + ) + return torch.empty(0) + + @classmethod + def deserialize( + cls, + input: torch.Tensor, + typename: str, + device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, + ) -> nn.Module: + if typename != "CompoundModule": + raise ValueError(f"Expected typename to be CompoundModule, got {typename}") + ebc = children["ebc"] + comp = children.get("comp") + assert isinstance(ebc, EmbeddingBagCollection) + if comp is not None: + assert isinstance(comp, CompoundModule) + i = 0 + mlist = [] + while f"list.{i}" in children: + mlist.append(children[f"list.{i}"]) + i += 1 + return CompoundModule(ebc, comp, mlist) class TestJsonSerializer(unittest.TestCase): @@ -300,7 +344,21 @@ def test_compound_module(self) -> None: is_weighted=False, ) - model = CompoundModule(ebc(), CompoundModule(ebc(), CompoundModule(ebc()))) + class MyModel(nn.Module): + def __init__(self, comp: CompoundModule) -> None: + super().__init__() + self.comp = comp + + def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: + return self.comp(features) + + model = MyModel( + CompoundModule( + ebc=ebc(), + comp=CompoundModule(ebc(), CompoundModule(ebc(), mlist=[ebc(), ebc()])), + mlist=[ebc(), CompoundModule(ebc(), CompoundModule(ebc()))], + ) + ) id_list_features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2", "f3"], values=torch.tensor([0, 1, 2, 3, 2, 3]), @@ -309,6 +367,9 @@ def test_compound_module(self) -> None: eager_out = model(id_list_features) + JsonSerializer.module_to_serializer_cls["CompoundModule"] = ( + CompoundModuleSerializer + ) # Serialize model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) ep = torch.export.export( diff --git a/torchrec/ir/types.py b/torchrec/ir/types.py index 397a4aa98..2408ce0ab 100644 --- a/torchrec/ir/types.py +++ b/torchrec/ir/types.py @@ -46,6 +46,11 @@ def deserialize( input: Any, typename: str, device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, ) -> nn.Module: # Take the bytes in the buffer and regenerate the eager embedding module pass + + @classmethod + def requires_children(cls, typename: str) -> bool: + return False diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 1676f166c..eb72daf99 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -10,7 +10,7 @@ #!/usr/bin/env python3 from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Iterator, List, Optional, Tuple, Type, Union import torch @@ -36,16 +36,111 @@ def serialize_embedding_modules( Returns the modified module and the list of fqns that had the buffer added. """ + # store the fqns of the modules that is serialized, so that the ir_export can preserve + # the fqns of those modules preserve_fqns = [] + + # store the fqn of the module that does not require serializing its children, + # so that we can skip based on the children's fqn + skip_children: Optional[str] = None for fqn, module in model.named_modules(): + if skip_children is not None: + if fqn.startswith(skip_children): + # {fqn} is skipped for further serialization + continue + else: + # reset the skip_children to None because fqn is no longer a child + skip_children = None if type(module).__name__ in serializer_cls.module_to_serializer_cls: serialized_module = serializer_cls.serialize(module) + # store the fqn of a serialized module that doesn't not require further serialization of its children + if not serializer_cls.requires_children(type(module).__name__): + skip_children = fqn module.register_buffer("ir_metadata", serialized_module, persistent=False) preserve_fqns.append(fqn) return model, preserve_fqns +def _next_or_none( + generator: Iterator[Tuple[str, nn.Module]] +) -> Tuple[Optional[str], Optional[nn.Module]]: + try: + return next(generator) + except StopIteration: + return None, None + + +def _deserialize_node( + parent_fqn: str, + fqn: Optional[str], + module: Optional[nn.Module], + generator: Iterator[Tuple[str, nn.Module]], + serializer_cls: Type[SerializerInterface], + module_type_dict: Dict[str, str], + fqn_to_new_module: Dict[str, nn.Module], + device: Optional[torch.device] = None, +) -> Tuple[Dict[str, nn.Module], Optional[str], Optional[nn.Module]]: + """ + returns: + 1. the children of the parent_fqn Dict[relative_fqn -> module] + 2. the next node Optional[fqn], Optional[module], which is not a child of the parent_fqn + """ + children: Dict[str, nn.Module] = {} + # we only starts the while loop when the current fqn is a child of the parent_fqn + # it stops at either the current node is not a child of the parent_fqn or + # the generator is exhausted + while fqn is not None and module is not None and fqn.startswith(parent_fqn): + # the current node is a serialized module, need to deserialize it + if "ir_metadata" in dict(module.named_buffers()): + serialized_module = module.get_buffer("ir_metadata") + if fqn not in module_type_dict: + raise RuntimeError( + f"Cannot find the type of module {fqn} in the exported program" + ) + + # current module's deserialization requires its children + if serializer_cls.requires_children(module_type_dict[fqn]): + # set current fqn as the new parent_fqn, and call deserialize_node function + # recursively to get the children of current_fqn, and the next sibling of current_fqn + next_fqn, next_module = _next_or_none(generator) + grand_children, next_fqn, next_module = _deserialize_node( + fqn, + next_fqn, + next_module, + generator, + serializer_cls, + module_type_dict, + fqn_to_new_module, + device, + ) + # deserialize the current module with its children + deserialized_module = serializer_cls.deserialize( + serialized_module, + module_type_dict[fqn], + device=device, + children=grand_children, + ) + else: + # current module's deserialization doesn't require its children + # deserialize it first then get the next sibling + deserialized_module = serializer_cls.deserialize( + serialized_module, module_type_dict[fqn], device=device + ) + next_fqn, next_module = _next_or_none(generator) + + # register the deserialized module + rel_fqn = fqn[len(parent_fqn) + 1 :] if len(parent_fqn) > 0 else fqn + children[rel_fqn] = deserialized_module + fqn_to_new_module[fqn] = deserialized_module + + else: # current node doesn't require deserialization, move on + next_fqn, next_module = _next_or_none(generator) + # move to the next node + fqn, module = next_fqn, next_module + return children, fqn, module + + def deserialize_embedding_modules( ep: ExportedProgram, serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, @@ -59,29 +154,26 @@ def deserialize_embedding_modules( Returns the unflattened ExportedProgram with the deserialized modules. """ model = torch.export.unflatten(ep) - module_type_dict = {} + module_type_dict: Dict[str, str] = {} for node in ep.graph.nodes: if "nn_module_stack" in node.meta: for fqn, type_name in node.meta["nn_module_stack"].values(): # Only get the module type name, not the full type name module_type_dict[fqn] = type_name.split(".")[-1] - fqn_to_new_module = {} - for fqn, module in model.named_modules(): - if "ir_metadata" in dict(module.named_buffers()): - serialized_module = dict(module.named_buffers())["ir_metadata"] - - if fqn not in module_type_dict: - raise RuntimeError( - f"Cannot find the type of module {fqn} in the exported program" - ) - - deserialized_module = serializer_cls.deserialize( - serialized_module, - module_type_dict[fqn], - device, - ) - fqn_to_new_module[fqn] = deserialized_module + fqn_to_new_module: Dict[str, nn.Module] = {} + generator = model.named_modules() + fqn, root = _next_or_none(generator) + _deserialize_node( + "", + fqn, + root, + generator, + serializer_cls, + module_type_dict, + fqn_to_new_module, + device, + ) for fqn, new_module in fqn_to_new_module.items(): # handle nested attribute like "x.y.z"