Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deserialization of nested fqn + serialization returns fqns for preserve_module_signature #1906

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

#!/usr/bin/env python3

from typing import Type
from typing import List, Tuple, Type

import torch

Expand All @@ -25,19 +23,34 @@
def serialize_embedding_modules(
model: nn.Module,
serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
) -> nn.Module:
for _, module in model.named_modules():
) -> Tuple[nn.Module, List[str]]:
"""
Takes all the modules that are of type `serializer_cls` and serializes them
in the given format with a registered buffer to the module.

Returns the modified module and the list of fqns that had the buffer added.
"""
preserve_fqns = []
for fqn, module in model.named_modules():
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
serialized_module = serializer_cls.serialize(module)
module.register_buffer("ir_metadata", serialized_module, persistent=False)
preserve_fqns.append(fqn)

return model
return model, preserve_fqns


def deserialize_embedding_modules(
ep: ExportedProgram,
serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
) -> nn.Module:
"""
Takes ExportedProgram (IR) and looks for ir_metadata buffer.
If found, deserializes the buffer and replaces the module with the deserialized
module.

Returns the unflattened ExportedProgram with the deserialized modules.
"""
model = torch.export.unflatten(ep)
module_type_dict = {}
for node in ep.graph.nodes:
Expand All @@ -62,6 +75,11 @@ def deserialize_embedding_modules(
fqn_to_new_module[fqn] = deserialized_module

for fqn, new_module in fqn_to_new_module.items():
setattr(model, fqn, new_module)
# handle nested attribute like "x.y.z"
attrs = fqn.split(".")
parent = model
for a in attrs[:-1]:
parent = getattr(parent, a)
setattr(parent, attrs[-1], new_module)

return model
Loading