Skip to content

Commit

Permalink
Fix deserialization of nested fqn + serialization returns fqns for pr…
Browse files Browse the repository at this point in the history
…eserve_module_signature (#1906)

Summary:
Pull Request resolved: #1906

Return all sparse fqns for preserve_module_signature, ensuring validity in swapping out with eager modules in unflattened model. Also, fix bug in deserialization

Reviewed By: IvanKobzarev

Differential Revision: D56361242

fbshipit-source-id: a73966ef6426682dbca2ed2f5b9f014ea7c15c11
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Apr 19, 2024
1 parent f120e42 commit e7e7a34
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

#!/usr/bin/env python3

from typing import Type
from typing import List, Tuple, Type

import torch

Expand All @@ -25,19 +23,34 @@
def serialize_embedding_modules(
model: nn.Module,
serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
) -> nn.Module:
for _, module in model.named_modules():
) -> Tuple[nn.Module, List[str]]:
"""
Takes all the modules that are of type `serializer_cls` and serializes them
in the given format with a registered buffer to the module.
Returns the modified module and the list of fqns that had the buffer added.
"""
preserve_fqns = []
for fqn, module in model.named_modules():
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
serialized_module = serializer_cls.serialize(module)
module.register_buffer("ir_metadata", serialized_module, persistent=False)
preserve_fqns.append(fqn)

return model
return model, preserve_fqns


def deserialize_embedding_modules(
ep: ExportedProgram,
serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
) -> nn.Module:
"""
Takes ExportedProgram (IR) and looks for ir_metadata buffer.
If found, deserializes the buffer and replaces the module with the deserialized
module.
Returns the unflattened ExportedProgram with the deserialized modules.
"""
model = torch.export.unflatten(ep)
module_type_dict = {}
for node in ep.graph.nodes:
Expand All @@ -62,6 +75,11 @@ def deserialize_embedding_modules(
fqn_to_new_module[fqn] = deserialized_module

for fqn, new_module in fqn_to_new_module.items():
setattr(model, fqn, new_module)
# handle nested attribute like "x.y.z"
attrs = fqn.split(".")
parent = model
for a in attrs[:-1]:
parent = getattr(parent, a)
setattr(parent, attrs[-1], new_module)

return model

0 comments on commit e7e7a34

Please sign in to comment.