Skip to content

Commit

Permalink
add compound module test
Browse files Browse the repository at this point in the history
Summary:
HG: context
* we want to add compound module support for the IR export function
* here is the first step, to add test about compound module
* and validate the deserialized module can produce the same output as the original module

HG: details
* graph
```
(Pdb) print(ep.graph)
graph():
    %p_ebc_embedding_bags_t1_weight : [num_users=0] = placeholder[target=p_ebc_embedding_bags_t1_weight]
    %p_ebc_embedding_bags_t2_weight : [num_users=0] = placeholder[target=p_ebc_embedding_bags_t2_weight]
    %p_ebc_embedding_bags_t3_weight : [num_users=0] = placeholder[target=p_ebc_embedding_bags_t3_weight]
    %p_comp_ebc_embedding_bags_t1_weight : [num_users=0] = placeholder[target=p_comp_ebc_embedding_bags_t1_weight]
    %p_comp_ebc_embedding_bags_t2_weight : [num_users=0] = placeholder[target=p_comp_ebc_embedding_bags_t2_weight]
    %p_comp_ebc_embedding_bags_t3_weight : [num_users=0] = placeholder[target=p_comp_ebc_embedding_bags_t3_weight]
    %p_comp_comp_ebc_embedding_bags_t1_weight : [num_users=0] = placeholder[target=p_comp_comp_ebc_embedding_bags_t1_weight]
    %p_comp_comp_ebc_embedding_bags_t2_weight : [num_users=0] = placeholder[target=p_comp_comp_ebc_embedding_bags_t2_weight]
    %p_comp_comp_ebc_embedding_bags_t3_weight : [num_users=0] = placeholder[target=p_comp_comp_ebc_embedding_bags_t3_weight]
    %b_ebc_ir_metadata : [num_users=0] = placeholder[target=b_ebc_ir_metadata]
    %b_comp_ebc_ir_metadata : [num_users=0] = placeholder[target=b_comp_ebc_ir_metadata]
    %b_comp_comp_ebc_ir_metadata : [num_users=0] = placeholder[target=b_comp_comp_ebc_ir_metadata]
    %features__values : [num_users=3] = placeholder[target=features__values]
    %features__weights : [num_users=0] = placeholder[target=features__weights]
    %features__lengths : [num_users=3] = placeholder[target=features__lengths]
    %features__offsets : [num_users=3] = placeholder[target=features__offsets]
    %embedding_bag_collection_12 : [num_users=1] = call_function[target=torch.ops.custom.EmbeddingBagCollection_12.default](args = ([%features__values, None, %features__lengths, %features__offsets], 2), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%embedding_bag_collection_12, 0), kwargs = {})
    %embedding_bag_collection_13 : [num_users=1] = call_function[target=torch.ops.custom.EmbeddingBagCollection_12.default](args = ([%features__values, None, %features__lengths, %features__offsets], 2), kwargs = {})
    %getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%embedding_bag_collection_13, 0), kwargs = {})
    %embedding_bag_collection_14 : [num_users=1] = call_function[target=torch.ops.custom.EmbeddingBagCollection_12.default](args = ([%features__values, None, %features__lengths, %features__offsets], 2), kwargs = {})
    %getitem_16 : [num_users=1] = call_function[target=operator.getitem](args = (%embedding_bag_collection_14, 0), kwargs = {})
    return (getitem_4, getitem_10, getitem_16)
```
* tabular
```
(Pdb) ep.graph.print_tabular()
opcode         name                                      target                                    args                                                                 kwargs
-------------  ----------------------------------------  ----------------------------------------  -------------------------------------------------------------------  --------
placeholder    p_ebc_embedding_bags_t1_weight            p_ebc_embedding_bags_t1_weight            ()                                                                   {}
placeholder    p_ebc_embedding_bags_t2_weight            p_ebc_embedding_bags_t2_weight            ()                                                                   {}
placeholder    p_ebc_embedding_bags_t3_weight            p_ebc_embedding_bags_t3_weight            ()                                                                   {}
placeholder    p_comp_ebc_embedding_bags_t1_weight       p_comp_ebc_embedding_bags_t1_weight       ()                                                                   {}
placeholder    p_comp_ebc_embedding_bags_t2_weight       p_comp_ebc_embedding_bags_t2_weight       ()                                                                   {}
placeholder    p_comp_ebc_embedding_bags_t3_weight       p_comp_ebc_embedding_bags_t3_weight       ()                                                                   {}
placeholder    p_comp_comp_ebc_embedding_bags_t1_weight  p_comp_comp_ebc_embedding_bags_t1_weight  ()                                                                   {}
placeholder    p_comp_comp_ebc_embedding_bags_t2_weight  p_comp_comp_ebc_embedding_bags_t2_weight  ()                                                                   {}
placeholder    p_comp_comp_ebc_embedding_bags_t3_weight  p_comp_comp_ebc_embedding_bags_t3_weight  ()                                                                   {}
placeholder    b_ebc_ir_metadata                         b_ebc_ir_metadata                         ()                                                                   {}
placeholder    b_comp_ebc_ir_metadata                    b_comp_ebc_ir_metadata                    ()                                                                   {}
placeholder    b_comp_comp_ebc_ir_metadata               b_comp_comp_ebc_ir_metadata               ()                                                                   {}
placeholder    features__values                          features__values                          ()                                                                   {}
placeholder    features__weights                         features__weights                         ()                                                                   {}
placeholder    features__lengths                         features__lengths                         ()                                                                   {}
placeholder    features__offsets                         features__offsets                         ()                                                                   {}
call_function  embedding_bag_collection_12               custom.EmbeddingBagCollection_12.default  ([features__values, None, features__lengths, features__offsets], 2)  {}
call_function  getitem_4                                 <built-in function getitem>               (embedding_bag_collection_12, 0)                                     {}
call_function  embedding_bag_collection_13               custom.EmbeddingBagCollection_12.default  ([features__values, None, features__lengths, features__offsets], 2)  {}
call_function  getitem_10                                <built-in function getitem>               (embedding_bag_collection_13, 0)                                     {}
call_function  embedding_bag_collection_14               custom.EmbeddingBagCollection_12.default  ([features__values, None, features__lengths, features__offsets], 2)  {}
call_function  getitem_16                                <built-in function getitem>               (embedding_bag_collection_14, 0)                                     {}
output         output                                    output                                    ((getitem_4, getitem_10, getitem_16),)                               {}
```

Differential Revision: D58220170
  • Loading branch information
Huanyu He authored and facebook-github-bot committed Jun 13, 2024
1 parent 8e577df commit 744ba95
Showing 1 changed file with 96 additions and 3 deletions.
99 changes: 96 additions & 3 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import copy
import unittest
from typing import List
from typing import Callable, List, Optional, Union

import torch
from torch import nn
Expand All @@ -31,6 +31,29 @@
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor


class CompoundModule(nn.Module):
def __init__(
self,
ebc: EmbeddingBagCollection,
comp: Optional["CompoundModule"] = None,
mlist: List[Union[EmbeddingBagCollection, "CompoundModule"]] = [],
) -> None:
super().__init__()
self.ebc = ebc
self.comp = comp
self.list = nn.ModuleList(mlist)

def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
res = self.comp(features) if self.comp else []
res.append(self.ebc(features).values())
for m in self.list:
if isinstance(m, CompoundModule):
res.extend(m(features))
else:
res.append(m(features).values())
return res


class TestJsonSerializer(unittest.TestCase):
def generate_model(self) -> nn.Module:
class Model(nn.Module):
Expand Down Expand Up @@ -150,8 +173,6 @@ def test_serialize_deserialize_ebc(self) -> None:
in operator_registry_state.op_registry_schema
)

# Can rerun ep forward
ep.module()(id_list_features)
# Deserialize EBC
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)

Expand Down Expand Up @@ -259,3 +280,75 @@ def test_deserialized_device(self) -> None:
if "_feature_processors" in name:
continue
assert param.device.type == device.type, f"{name} should be on {device}"

def test_compound_module(self) -> None:
tb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=4,
num_embeddings=10,
feature_names=["f1"],
)
tb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f2"],
)
tb3_config = EmbeddingBagConfig(
name="t3",
embedding_dim=4,
num_embeddings=10,
feature_names=["f3"],
)
ebc: Callable[[], EmbeddingBagCollection] = lambda: EmbeddingBagCollection(
tables=[tb1_config, tb2_config, tb3_config],
is_weighted=False,
)

class MyModel(nn.Module):
def __init__(self, comp: CompoundModule) -> None:
super().__init__()
self.comp = comp

def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
return self.comp(features)

model = MyModel(
CompoundModule(
ebc=ebc(),
comp=CompoundModule(ebc(), CompoundModule(ebc(), mlist=[ebc(), ebc()])),
mlist=[ebc(), CompoundModule(ebc(), CompoundModule(ebc()))],
)
)
id_list_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
values=torch.tensor([0, 1, 2, 3, 2, 3]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]),
)

eager_out = model(id_list_features)

# Serialize
model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(id_list_features,),
{},
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)

ep_output = ep.module()(id_list_features)
self.assertEqual(len(ep_output), len(eager_out))
for x, y in zip(ep_output, eager_out):
self.assertEqual(x.shape, y.shape)

# Deserialize
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)
deserialized_model.load_state_dict(model.state_dict())
# Run forward on deserialized model
deserialized_out = deserialized_model(id_list_features)
self.assertEqual(len(deserialized_out), len(eager_out))
for x, y in zip(deserialized_out, eager_out):
self.assertTrue(torch.allclose(x, y))

0 comments on commit 744ba95

Please sign in to comment.