From 92604217eb4b573f933fc540f12e44e35a6a1606 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Mon, 22 Apr 2024 15:55:00 -0700 Subject: [PATCH] Registor custom ops for EBC and PEA when doing torch.export Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_meta_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the meta_forward function. Differential Revision: D56443608 --- torchrec/ir/utils.py | 82 +++++++++++++- torchrec/modules/embedding_modules.py | 31 +++++- .../modules/tests/test_embedding_modules.py | 101 ++++++++++++++++++ 3 files changed, 212 insertions(+), 2 deletions(-) diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 5809b9dff..039fdc111 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -9,15 +9,35 @@ #!/usr/bin/env python3 -from typing import List, Tuple, Type +import threading +from typing import Callable, Dict, List, Optional, Tuple, Type import torch from torch import nn from torch.export.exported_program import ExportedProgram +from torch.library import Library from torchrec.ir.types import SerializerInterface +lib = Library("custom", "FRAGMENT") + + +class OpRegistryState: + """ + State of operator registry. + + We can only register the op schema once. So if we're registering multiple + times we need a lock and check if they're the same schema + """ + + op_registry_lock = threading.Lock() + # operator schema: op_name: schema + op_registry_schema: Dict[str, str] = {} + + +operator_registry_state = OpRegistryState() + # TODO: Replace the default interface with the python dataclass interface DEFAULT_SERIALIZER_CLS = SerializerInterface @@ -85,3 +105,63 @@ def deserialize_embedding_modules( setattr(parent, attrs[-1], new_module) return model + + +def register_custom_op( + module: nn.Module, dims: List[int] +) -> Callable[[List[Optional[torch.Tensor]], int], List[torch.Tensor]]: + """ + Register a customized operator. + + Args: + module: customized module instance + dims: output dimensions + """ + + global operator_registry_state + + op_name = f"{type(module).__name__}_{hash(module)}" + with operator_registry_state.op_registry_lock: + if op_name in operator_registry_state.op_registry_schema: + return getattr(torch.ops.custom, op_name) + + def pea_op( + values: List[Optional[torch.Tensor]], + batch_size: int, + ) -> List[torch.Tensor]: + device = None + for v in values: + if v is not None: + device = v.device + break + else: + raise AssertionError( + f"Custom op {type(module).__name__} expects at least one " + "input tensor" + ) + + return [ + torch.empty( + batch_size, + dim, + device=device, + ) + for dim in dims + ] + + schema_string = f"{op_name}(Tensor?[] values, int batch_size) -> Tensor[]" + with operator_registry_state.op_registry_lock: + if op_name in operator_registry_state.op_registry_schema: + return getattr(torch.ops.custom, op_name) + operator_registry_state.op_registry_schema[op_name] = schema_string + # Register schema + lib.define(schema_string) + + # Register implementation + lib.impl(op_name, pea_op, "CPU") + lib.impl(op_name, pea_op, "CUDA") + + # Register meta formula + lib.impl(op_name, pea_op, "Meta") + + return getattr(torch.ops.custom, op_name) diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 1655d8f4c..a87ad4b9d 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -12,13 +12,40 @@ import torch import torch.nn as nn +from torchrec.ir.utils import register_custom_op from torchrec.modules.embedding_configs import ( DataType, EmbeddingBagConfig, EmbeddingConfig, pooling_type_to_str, ) -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import ( + is_non_strict_exporting, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) + + +def _forward_meta( + ebc: "EmbeddingBagCollectionInterface", + features: KeyedJaggedTensor, +) -> KeyedTensor: + batch_size = features.stride() + arg_list = [ + features.values(), + features.weights_or_none(), + features.lengths_or_none(), + features.offsets_or_none(), + ] + dims = [sum(ebc._lengths_per_embedding)] + ebc_op = register_custom_op(ebc, dims) + outputs = ebc_op(arg_list, batch_size) + return KeyedTensor( + keys=ebc._embedding_names, + values=outputs[0], + length_per_key=ebc._lengths_per_embedding, + ) @torch.fx.wrap @@ -212,6 +239,8 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: Returns: KeyedTensor """ + if is_non_strict_exporting() and not torch.jit.is_scripting(): + return _forward_meta(self, features) flat_feature_names: List[str] = [] for names in self._feature_names: flat_feature_names.extend(names) diff --git a/torchrec/modules/tests/test_embedding_modules.py b/torchrec/modules/tests/test_embedding_modules.py index 62338bc10..934bdb229 100644 --- a/torchrec/modules/tests/test_embedding_modules.py +++ b/torchrec/modules/tests/test_embedding_modules.py @@ -128,6 +128,38 @@ def test_weighted(self) -> None: self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"]) self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10]) + def test_forward_with_meta_device(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + ebc = EmbeddingBagCollection( + tables=[eb1_config, eb2_config], + is_weighted=True, + device=torch.device("meta"), + ) + + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7], device="meta"), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12], device="meta"), + weights=torch.tensor( + [0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7], + device="meta", + ), + ) + + pooled_embeddings = ebc(features) + self.assertEqual(pooled_embeddings.values().size(), (2, 10)) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"]) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10]) + self.assertEqual(pooled_embeddings.values().device, torch.device("meta")) + def test_fx(self) -> None: eb1_config = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"] @@ -195,6 +227,75 @@ def test_device(self) -> None: self.assertEqual(torch.device("cpu"), ebc.embedding_bags["t1"].weight.device) self.assertEqual(torch.device("cpu"), ebc.device) + def test_exporting(self) -> None: + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + eb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1", "f3"], + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + eb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1", "f2"], + ) + eb4_config = EmbeddingBagConfig( + name="t4", + embedding_dim=5, + num_embeddings=10, + feature_names=["f3"], + ) + self.ebc1 = EmbeddingBagCollection( + tables=[eb1_config, eb2_config], is_weighted=True + ) + self.ebc2 = EmbeddingBagCollection( + tables=[eb3_config, eb4_config], is_weighted=True + ) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> torch.Tensor: + embeddings1 = self.ebc1(features) + embeddings2 = self.ebc2(features) + return torch.concat([embeddings1.values(), embeddings2.values()], dim=1) + + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12]), + weights=torch.tensor( + [0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7] + ), + ) + + m = MyModule() + ep = torch.export.export( + m, + (features,), + {}, + strict=False, + ) + self.assertEqual( + sum(n.name.startswith("_embedding_bag") for n in ep.graph.nodes), + 0, + ) + self.assertEqual( + sum(n.name.startswith("embedding_bag_collection") for n in ep.graph.nodes), + 2, + "Shoulde be exact 2 EmbeddingBagCollection nodes in the exported graph", + ) + class EmbeddingCollectionTest(unittest.TestCase): def test_forward(self) -> None: