Skip to content

Commit

Permalink
Registor custom ops for EBC and PEA when doing torch.export
Browse files Browse the repository at this point in the history
Summary:
# context
* when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
```
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
```
* this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
* the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
```
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}
```

# details
* get the output tensor shapes (List[Tensor]) from the embedding modules in the `_meta_forward` function
* register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op`
* call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
* in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the meta_forward function.

Differential Revision: D56443608
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Apr 22, 2024
1 parent 3fbd547 commit 9260421
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 2 deletions.
82 changes: 81 additions & 1 deletion torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,35 @@

#!/usr/bin/env python3

from typing import List, Tuple, Type
import threading
from typing import Callable, Dict, List, Optional, Tuple, Type

import torch

from torch import nn
from torch.export.exported_program import ExportedProgram
from torch.library import Library
from torchrec.ir.types import SerializerInterface


lib = Library("custom", "FRAGMENT")


class OpRegistryState:
"""
State of operator registry.
We can only register the op schema once. So if we're registering multiple
times we need a lock and check if they're the same schema
"""

op_registry_lock = threading.Lock()
# operator schema: op_name: schema
op_registry_schema: Dict[str, str] = {}


operator_registry_state = OpRegistryState()

# TODO: Replace the default interface with the python dataclass interface
DEFAULT_SERIALIZER_CLS = SerializerInterface

Expand Down Expand Up @@ -85,3 +105,63 @@ def deserialize_embedding_modules(
setattr(parent, attrs[-1], new_module)

return model


def register_custom_op(
module: nn.Module, dims: List[int]
) -> Callable[[List[Optional[torch.Tensor]], int], List[torch.Tensor]]:
"""
Register a customized operator.
Args:
module: customized module instance
dims: output dimensions
"""

global operator_registry_state

op_name = f"{type(module).__name__}_{hash(module)}"
with operator_registry_state.op_registry_lock:
if op_name in operator_registry_state.op_registry_schema:
return getattr(torch.ops.custom, op_name)

def pea_op(
values: List[Optional[torch.Tensor]],
batch_size: int,
) -> List[torch.Tensor]:
device = None
for v in values:
if v is not None:
device = v.device
break
else:
raise AssertionError(
f"Custom op {type(module).__name__} expects at least one "
"input tensor"
)

return [
torch.empty(
batch_size,
dim,
device=device,
)
for dim in dims
]

schema_string = f"{op_name}(Tensor?[] values, int batch_size) -> Tensor[]"
with operator_registry_state.op_registry_lock:
if op_name in operator_registry_state.op_registry_schema:
return getattr(torch.ops.custom, op_name)
operator_registry_state.op_registry_schema[op_name] = schema_string
# Register schema
lib.define(schema_string)

# Register implementation
lib.impl(op_name, pea_op, "CPU")
lib.impl(op_name, pea_op, "CUDA")

# Register meta formula
lib.impl(op_name, pea_op, "Meta")

return getattr(torch.ops.custom, op_name)
31 changes: 30 additions & 1 deletion torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,40 @@

import torch
import torch.nn as nn
from torchrec.ir.utils import register_custom_op
from torchrec.modules.embedding_configs import (
DataType,
EmbeddingBagConfig,
EmbeddingConfig,
pooling_type_to_str,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import (
is_non_strict_exporting,
JaggedTensor,
KeyedJaggedTensor,
KeyedTensor,
)


def _forward_meta(
ebc: "EmbeddingBagCollectionInterface",
features: KeyedJaggedTensor,
) -> KeyedTensor:
batch_size = features.stride()
arg_list = [
features.values(),
features.weights_or_none(),
features.lengths_or_none(),
features.offsets_or_none(),
]
dims = [sum(ebc._lengths_per_embedding)]
ebc_op = register_custom_op(ebc, dims)
outputs = ebc_op(arg_list, batch_size)
return KeyedTensor(
keys=ebc._embedding_names,
values=outputs[0],
length_per_key=ebc._lengths_per_embedding,
)


@torch.fx.wrap
Expand Down Expand Up @@ -212,6 +239,8 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
Returns:
KeyedTensor
"""
if is_non_strict_exporting() and not torch.jit.is_scripting():
return _forward_meta(self, features)
flat_feature_names: List[str] = []
for names in self._feature_names:
flat_feature_names.extend(names)
Expand Down
101 changes: 101 additions & 0 deletions torchrec/modules/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,38 @@ def test_weighted(self) -> None:
self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"])
self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10])

def test_forward_with_meta_device(self) -> None:
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"]
)
eb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f2"],
)
ebc = EmbeddingBagCollection(
tables=[eb1_config, eb2_config],
is_weighted=True,
device=torch.device("meta"),
)

features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f3", "f2"],
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7], device="meta"),
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12], device="meta"),
weights=torch.tensor(
[0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7],
device="meta",
),
)

pooled_embeddings = ebc(features)
self.assertEqual(pooled_embeddings.values().size(), (2, 10))
self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"])
self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10])
self.assertEqual(pooled_embeddings.values().device, torch.device("meta"))

def test_fx(self) -> None:
eb1_config = EmbeddingBagConfig(
name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"]
Expand Down Expand Up @@ -195,6 +227,75 @@ def test_device(self) -> None:
self.assertEqual(torch.device("cpu"), ebc.embedding_bags["t1"].weight.device)
self.assertEqual(torch.device("cpu"), ebc.device)

def test_exporting(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
eb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=3,
num_embeddings=10,
feature_names=["f1", "f3"],
)
eb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=4,
num_embeddings=10,
feature_names=["f2"],
)
eb3_config = EmbeddingBagConfig(
name="t3",
embedding_dim=3,
num_embeddings=10,
feature_names=["f1", "f2"],
)
eb4_config = EmbeddingBagConfig(
name="t4",
embedding_dim=5,
num_embeddings=10,
feature_names=["f3"],
)
self.ebc1 = EmbeddingBagCollection(
tables=[eb1_config, eb2_config], is_weighted=True
)
self.ebc2 = EmbeddingBagCollection(
tables=[eb3_config, eb4_config], is_weighted=True
)

def forward(
self,
features: KeyedJaggedTensor,
) -> torch.Tensor:
embeddings1 = self.ebc1(features)
embeddings2 = self.ebc2(features)
return torch.concat([embeddings1.values(), embeddings2.values()], dim=1)

features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f3", "f2"],
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7]),
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12]),
weights=torch.tensor(
[0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7]
),
)

m = MyModule()
ep = torch.export.export(
m,
(features,),
{},
strict=False,
)
self.assertEqual(
sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes),
0,
)
self.assertEqual(
sum(n.name.startswith("embedding_bag_collection") for n in ep.graph.nodes),
2,
"Shoulde be exact 2 EmbeddingBagCollection nodes in the exported graph",
)


class EmbeddingCollectionTest(unittest.TestCase):
def test_forward(self) -> None:
Expand Down

0 comments on commit 9260421

Please sign in to comment.