From 11726d387a21dc2711c83e57d4c7ddd10f9c989f Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Thu, 13 Jun 2024 10:07:30 -0700 Subject: [PATCH] support serialization of compound module Summary: # context * to support compound module serialization such as fpEBC and fpPEA, we modified the serializer interface a little * the basic idea is that: a. if a compound module A consists of child module B and C, such that module B and C have their own serializer available. b. for serialization of module A, we can just capture it relation from B and C, and used it when deserialization c. specifically, during the deserialization, A's children will be passed in as a dict of (child_fqn => child_module) # design doc [**TorchRec Composable Serializer Design**](https://docs.google.com/document/d/1WUtmzdcqZmwLd4Do8g1fQRjChnRw0ZimyUrCNhxa4nA/edit#heading=h.ezrtdguw0lwq) # note * in order to apply this approach, it requires that module A's construction **takes in it's children as objects** * **DO NOT** create A's children in the A's construction # details * serialization schedule ``` comp.ebc does NOT require further serialization of its children comp.ebc.embedding_bags is skipped for further serialization comp.ebc.embedding_bags.t1 is skipped for further serialization comp.ebc.embedding_bags.t2 is skipped for further serialization comp.ebc.embedding_bags.t3 is skipped for further serialization comp.comp is resumed for serialization comp.comp Requires further serialization of its children comp.comp.ebc does NOT require further serialization of its children comp.comp.ebc.embedding_bags is skipped for further serialization comp.comp.ebc.embedding_bags.t1 is skipped for further serialization comp.comp.ebc.embedding_bags.t2 is skipped for further serialization comp.comp.ebc.embedding_bags.t3 is skipped for further serialization comp.comp.comp is resumed for serialization comp.comp.comp Requires further serialization of its children comp.comp.comp.ebc does NOT require further serialization of its children comp.comp.comp.ebc.embedding_bags is skipped for further serialization comp.comp.comp.ebc.embedding_bags.t1 is skipped for further serialization comp.comp.comp.ebc.embedding_bags.t2 is skipped for further serialization comp.comp.comp.ebc.embedding_bags.t3 is skipped for further serialization ``` * `parent_fqn`'s children ``` comp.comp.comp {'ebc': EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) )} None None ``` ``` comp.comp {'ebc': EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) ), 'comp': CompoundModule( (ebc): EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) ) )} None None ``` ``` comp {'ebc': EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) ), 'comp': CompoundModule( (ebc): EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) ) (comp): CompoundModule( (ebc): EmbeddingBagCollection( (embedding_bags): ModuleDict( (t1): EmbeddingBag(10, 4, mode='sum') (t2): EmbeddingBag(10, 4, mode='sum') (t3): EmbeddingBag(10, 4, mode='sum') ) ) ) )} None None ``` Differential Revision: D58221182 --- torchrec/ir/serializer.py | 17 +++- torchrec/ir/tests/test_serializer.py | 53 ++++++++++- torchrec/ir/types.py | 5 ++ torchrec/ir/utils.py | 128 +++++++++++++++++++++++---- 4 files changed, 182 insertions(+), 21 deletions(-) diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py index 71b66dae8..6af6aa8f8 100644 --- a/torchrec/ir/serializer.py +++ b/torchrec/ir/serializer.py @@ -71,6 +71,11 @@ def get_deserialized_device( class JsonSerializerBase(SerializerInterface): _module_cls: Optional[Type[nn.Module]] = None + _requires_children: bool = False + + @classmethod + def requires_children(cls, typename: str) -> bool: + return cls._requires_children @classmethod def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]: @@ -81,6 +86,7 @@ def deserialize_from_dict( cls, metadata_dict: Dict[str, Any], device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, ) -> nn.Module: raise NotImplementedError() @@ -106,10 +112,11 @@ def deserialize( input: torch.Tensor, typename: str, device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, ) -> nn.Module: raw_bytes = input.numpy().tobytes() metadata_dict = json.loads(raw_bytes.decode()) - module = cls.deserialize_from_dict(metadata_dict, device) + module = cls.deserialize_from_dict(metadata_dict, device, children) if cls._module_cls is None: raise ValueError( "Must assign a nn.Module to class static variable _module_cls" @@ -148,6 +155,7 @@ def deserialize_from_dict( cls, metadata_dict: Dict[str, Any], device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, ) -> nn.Module: tables = [ EmbeddingBagConfigMetadata(**table_config) @@ -192,6 +200,7 @@ def deserialize( input: torch.Tensor, typename: str, device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, ) -> nn.Module: if typename not in cls.module_to_serializer_cls: raise ValueError( @@ -199,5 +208,9 @@ def deserialize( ) return cls.module_to_serializer_cls[typename].deserialize( - input, typename, device + input, typename, device, children ) + + @classmethod + def requires_children(cls, typename: str) -> bool: + return cls.module_to_serializer_cls[typename].requires_children(typename) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index f27ae0d94..80237054a 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -11,11 +11,12 @@ import copy import unittest -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch from torch import nn from torchrec.ir.serializer import JsonSerializer +from torchrec.ir.types import SerializerInterface from torchrec.ir.utils import ( deserialize_embedding_modules, @@ -54,6 +55,45 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: return res +class CompoundModuleSerializer(SerializerInterface): + @classmethod + def requires_children(cls, typename: str) -> bool: + return True + + @classmethod + def serialize( + cls, + module: nn.Module, + ) -> torch.Tensor: + if not isinstance(module, CompoundModule): + raise ValueError( + f"Expected module to be of type CompoundModule, got {type(module)}" + ) + return torch.empty(0) + + @classmethod + def deserialize( + cls, + input: torch.Tensor, + typename: str, + device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, + ) -> nn.Module: + if typename != "CompoundModule": + raise ValueError(f"Expected typename to be CompoundModule, got {typename}") + ebc = children["ebc"] + comp = children.get("comp") + assert isinstance(ebc, EmbeddingBagCollection) + if comp is not None: + assert isinstance(comp, CompoundModule) + i = 0 + mlist = [] + while f"list.{i}" in children: + mlist.append(children[f"list.{i}"]) + i += 1 + return CompoundModule(ebc, comp, mlist) + + class TestJsonSerializer(unittest.TestCase): def generate_model(self) -> nn.Module: class Model(nn.Module): @@ -328,6 +368,9 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: eager_out = model(id_list_features) + JsonSerializer.module_to_serializer_cls["CompoundModule"] = ( + CompoundModuleSerializer + ) # Serialize model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) ep = torch.export.export( @@ -346,6 +389,14 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: # Deserialize deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + + # Check if Compound Module is deserialized correctly + self.assertIsInstance(deserialized_model.comp, CompoundModule) + self.assertIsInstance(deserialized_model.comp.comp, CompoundModule) + self.assertIsInstance(deserialized_model.comp.comp.comp, CompoundModule) + self.assertIsInstance(deserialized_model.comp.list[1], CompoundModule) + self.assertIsInstance(deserialized_model.comp.list[1].comp, CompoundModule) + deserialized_model.load_state_dict(model.state_dict()) # Run forward on deserialized model deserialized_out = deserialized_model(id_list_features) diff --git a/torchrec/ir/types.py b/torchrec/ir/types.py index 397a4aa98..2408ce0ab 100644 --- a/torchrec/ir/types.py +++ b/torchrec/ir/types.py @@ -46,6 +46,11 @@ def deserialize( input: Any, typename: str, device: Optional[torch.device] = None, + children: Dict[str, nn.Module] = {}, ) -> nn.Module: # Take the bytes in the buffer and regenerate the eager embedding module pass + + @classmethod + def requires_children(cls, typename: str) -> bool: + return False diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index a71595134..4f268fe7f 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -10,7 +10,7 @@ #!/usr/bin/env python3 from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Iterator, List, Optional, Tuple, Type, Union import torch @@ -36,16 +36,111 @@ def serialize_embedding_modules( Returns the modified module and the list of fqns that had the buffer added. """ + # store the fqns of the modules that is serialized, so that the ir_export can preserve + # the fqns of those modules preserve_fqns = [] + + # store the fqn of the module that does not require serializing its children, + # so that we can skip based on the children's fqn + skip_children: Optional[str] = None for fqn, module in model.named_modules(): + if skip_children is not None: + if fqn.startswith(skip_children): + # {fqn} is skipped for further serialization + continue + else: + # reset the skip_children to None because fqn is no longer a child + skip_children = None if type(module).__name__ in serializer_cls.module_to_serializer_cls: serialized_module = serializer_cls.serialize(module) + # store the fqn of a serialized module that doesn't not require further serialization of its children + if not serializer_cls.requires_children(type(module).__name__): + skip_children = fqn module.register_buffer("ir_metadata", serialized_module, persistent=False) preserve_fqns.append(fqn) return model, preserve_fqns +def _next_or_none( + generator: Iterator[Tuple[str, nn.Module]] +) -> Tuple[Optional[str], Optional[nn.Module]]: + try: + return next(generator) + except StopIteration: + return None, None + + +def _deserialize_node( + parent_fqn: str, + fqn: Optional[str], + module: Optional[nn.Module], + generator: Iterator[Tuple[str, nn.Module]], + serializer_cls: Type[SerializerInterface], + module_type_dict: Dict[str, str], + fqn_to_new_module: Dict[str, nn.Module], + device: Optional[torch.device] = None, +) -> Tuple[Dict[str, nn.Module], Optional[str], Optional[nn.Module]]: + """ + returns: + 1. the children of the parent_fqn Dict[relative_fqn -> module] + 2. the next node Optional[fqn], Optional[module], which is not a child of the parent_fqn + """ + children: Dict[str, nn.Module] = {} + # we only starts the while loop when the current fqn is a child of the parent_fqn + # it stops at either the current node is not a child of the parent_fqn or + # the generator is exhausted + while fqn is not None and module is not None and fqn.startswith(parent_fqn): + # the current node is a serialized module, need to deserialize it + if "ir_metadata" in dict(module.named_buffers()): + serialized_module = module.get_buffer("ir_metadata") + if fqn not in module_type_dict: + raise RuntimeError( + f"Cannot find the type of module {fqn} in the exported program" + ) + + # current module's deserialization requires its children + if serializer_cls.requires_children(module_type_dict[fqn]): + # set current fqn as the new parent_fqn, and call deserialize_node function + # recursively to get the children of current_fqn, and the next sibling of current_fqn + next_fqn, next_module = _next_or_none(generator) + grand_children, next_fqn, next_module = _deserialize_node( + fqn, + next_fqn, + next_module, + generator, + serializer_cls, + module_type_dict, + fqn_to_new_module, + device, + ) + # deserialize the current module with its children + deserialized_module = serializer_cls.deserialize( + serialized_module, + module_type_dict[fqn], + device=device, + children=grand_children, + ) + else: + # current module's deserialization doesn't require its children + # deserialize it first then get the next sibling + deserialized_module = serializer_cls.deserialize( + serialized_module, module_type_dict[fqn], device=device + ) + next_fqn, next_module = _next_or_none(generator) + + # register the deserialized module + rel_fqn = fqn[len(parent_fqn) + 1 :] if len(parent_fqn) > 0 else fqn + children[rel_fqn] = deserialized_module + fqn_to_new_module[fqn] = deserialized_module + + else: # current node doesn't require deserialization, move on + next_fqn, next_module = _next_or_none(generator) + # move to the next node + fqn, module = next_fqn, next_module + return children, fqn, module + + def deserialize_embedding_modules( ep: ExportedProgram, serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, @@ -59,29 +154,26 @@ def deserialize_embedding_modules( Returns the unflattened ExportedProgram with the deserialized modules. """ model = torch.export.unflatten(ep) - module_type_dict = {} + module_type_dict: Dict[str, str] = {} for node in ep.graph.nodes: if "nn_module_stack" in node.meta: for fqn, type_name in node.meta["nn_module_stack"].values(): # Only get the module type name, not the full type name module_type_dict[fqn] = type_name.split(".")[-1] - fqn_to_new_module = {} - for fqn, module in model.named_modules(): - if "ir_metadata" in dict(module.named_buffers()): - serialized_module = dict(module.named_buffers())["ir_metadata"] - - if fqn not in module_type_dict: - raise RuntimeError( - f"Cannot find the type of module {fqn} in the exported program" - ) - - deserialized_module = serializer_cls.deserialize( - serialized_module, - module_type_dict[fqn], - device, - ) - fqn_to_new_module[fqn] = deserialized_module + fqn_to_new_module: Dict[str, nn.Module] = {} + generator = model.named_modules() + fqn, root = _next_or_none(generator) + _deserialize_node( + "", + fqn, + root, + generator, + serializer_cls, + module_type_dict, + fqn_to_new_module, + device, + ) for fqn, new_module in fqn_to_new_module.items(): # handle nested attribute like "x.y.z"