Skip to content

Commit

Permalink
refactor the _non_strict_export_forward in EBC and PEA to include emb…
Browse files Browse the repository at this point in the history
…edding_bag.weight (#1942)

Summary:
Pull Request resolved: #1942

# context
* previously the EBC and PEA only takes the KJT as the input of the custom_op, which leaves the weights of embedding bags as dangling placeholders in the graph.
* with this diff, we add those weights into the custom_op so that in the graph those weights are connected to the custom_op

# PEA
* embedding_bag weights like `p_pea1_embedding_modules_t1_weight` is dangling before, and it's used by `pooled_embedding_arch` afterwards.
* the total count of nodes are the same.
* before
```
(Pdb) ep.graph.print_tabular()
opcode         name                                   target                                              args                                                                      kwargs
-------------  -------------------------------------  --------------------------------------------------  ------------------------------------------------------------------------  --------
placeholder    p_pea1_embedding_modules_t1_weight     p_pea1_embedding_modules_t1_weight                  ()                                                                        {}
placeholder    p_pea1_embedding_modules_t2_weight     p_pea1_embedding_modules_t2_weight                  ()                                                                        {}
placeholder    p_pea1_embedding_modules_t3_weight     p_pea1_embedding_modules_t3_weight                  ()                                                                        {}
placeholder    p_pea1_embedding_modules_t4_weight     p_pea1_embedding_modules_t4_weight                  ()                                                                        {}
placeholder    p_pea2_embedding_modules_t1_weight     p_pea2_embedding_modules_t1_weight                  ()                                                                        {}
placeholder    p_pea2_embedding_modules_t2_weight     p_pea2_embedding_modules_t2_weight                  ()                                                                        {}
placeholder    p_pea2_embedding_modules_t3_weight     p_pea2_embedding_modules_t3_weight                  ()                                                                        {}
placeholder    values                                 values                                              ()                                                                        {}
placeholder    lengths                                lengths                                             ()                                                                        {}
placeholder    w_values                               w_values                                            ()                                                                        {}
placeholder    w_lengths                              w_lengths                                           ()                                                                        {}
placeholder    w_weights                              w_weights                                           ()                                                                        {}
call_function  pooled_embedding_arch_139590115359376  custom.PooledEmbeddingArch_139590115359376.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                             <built-in function getitem>                         (pooled_embedding_arch_139590115359376, 0)                                {}
call_function  getitem_11                             <built-in function getitem>                         (pooled_embedding_arch_139590115359376, 1)                                {}
call_function  pooled_embedding_arch_139590115359377  custom.PooledEmbeddingArch_139590115359376.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                             <built-in function getitem>                         (pooled_embedding_arch_139590115359377, 0)                                {}
call_function  getitem_13                             <built-in function getitem>                         (pooled_embedding_arch_139590115359377, 1)                                {}
call_function  cat                                    aten.cat.default                                    ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                                 output                                              ((cat,),)                                                                 {}
```
* after
```
(Pdb) ep.graph.print_tabular()
opcode         name                                   target                                              args                                                                                                                                                                                                                      kwargs
-------------  -------------------------------------  --------------------------------------------------  ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  --------
placeholder    p_pea1_embedding_modules_t1_weight     p_pea1_embedding_modules_t1_weight                  ()                                                                                                                                                                                                                        {}
placeholder    p_pea1_embedding_modules_t2_weight     p_pea1_embedding_modules_t2_weight                  ()                                                                                                                                                                                                                        {}
placeholder    p_pea1_embedding_modules_t3_weight     p_pea1_embedding_modules_t3_weight                  ()                                                                                                                                                                                                                        {}
placeholder    p_pea1_embedding_modules_t4_weight     p_pea1_embedding_modules_t4_weight                  ()                                                                                                                                                                                                                        {}
placeholder    p_pea2_embedding_modules_t1_weight     p_pea2_embedding_modules_t1_weight                  ()                                                                                                                                                                                                                        {}
placeholder    p_pea2_embedding_modules_t2_weight     p_pea2_embedding_modules_t2_weight                  ()                                                                                                                                                                                                                        {}
placeholder    p_pea2_embedding_modules_t3_weight     p_pea2_embedding_modules_t3_weight                  ()                                                                                                                                                                                                                        {}
placeholder    values                                 values                                              ()                                                                                                                                                                                                                        {}
placeholder    lengths                                lengths                                             ()                                                                                                                                                                                                                        {}
placeholder    w_values                               w_values                                            ()                                                                                                                                                                                                                        {}
placeholder    w_lengths                              w_lengths                                           ()                                                                                                                                                                                                                        {}
placeholder    w_weights                              w_weights                                           ()                                                                                                                                                                                                                        {}
call_function  pooled_embedding_arch_140307162538352  custom.PooledEmbeddingArch_140307162538352.default  ([p_pea1_embedding_modules_t1_weight, p_pea1_embedding_modules_t2_weight, p_pea1_embedding_modules_t3_weight, p_pea1_embedding_modules_t4_weight, values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                             <built-in function getitem>                         (pooled_embedding_arch_140307162538352, 0)                                                                                                                                                                                {}
call_function  getitem_11                             <built-in function getitem>                         (pooled_embedding_arch_140307162538352, 1)                                                                                                                                                                                {}
call_function  pooled_embedding_arch_140307162538353  custom.PooledEmbeddingArch_140307162538352.default  ([p_pea2_embedding_modules_t1_weight, p_pea2_embedding_modules_t2_weight, p_pea2_embedding_modules_t3_weight, values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)                                      {}
call_function  getitem_12                             <built-in function getitem>                         (pooled_embedding_arch_140307162538353, 0)                                                                                                                                                                                {}
call_function  getitem_13                             <built-in function getitem>                         (pooled_embedding_arch_140307162538353, 1)                                                                                                                                                                                {}
call_function  cat                                    aten.cat.default                                    ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                                                                                                                                                                     {}
output         output                                 output                                              ((cat,),)                                                                                                                                                                                                                 {}
```

# EBC
* embedding_bag weights like `p_ebc1_embedding_bags_t1_weight` is dangling before, and it's used by `embedding_bag_collection` afterwards.
* the total count of nodes are the same.
* before
```
(Pdb) ep.graph.print_tabular()
opcode         name                                      target                                                 args                                                                 kwargs
-------------  ----------------------------------------  -----------------------------------------------------  -------------------------------------------------------------------  --------
placeholder    p_ebc1_embedding_bags_t1_weight           p_ebc1_embedding_bags_t1_weight                        ()                                                                   {}
placeholder    p_ebc1_embedding_bags_t2_weight           p_ebc1_embedding_bags_t2_weight                        ()                                                                   {}
placeholder    p_ebc2_embedding_bags_t3_weight           p_ebc2_embedding_bags_t3_weight                        ()                                                                   {}
placeholder    p_ebc2_embedding_bags_t4_weight           p_ebc2_embedding_bags_t4_weight                        ()                                                                   {}
placeholder    features__values                          features__values                                       ()                                                                   {}
placeholder    features__weights                         features__weights                                      ()                                                                   {}
placeholder    features__lengths                         features__lengths                                      ()                                                                   {}
placeholder    features__offsets                         features__offsets                                      ()                                                                   {}
call_function  embedding_bag_collection_140218446926688  custom.EmbeddingBagCollection_140218446926688.default  ([features__values, features__weights, None, features__offsets], 2)  {}
call_function  getitem                                   <built-in function getitem>                            (embedding_bag_collection_140218446926688, 0)                        {}
call_function  embedding_bag_collection_140218446926689  custom.EmbeddingBagCollection_140218446926688.default  ([features__values, features__weights, None, features__offsets], 2)  {}
call_function  getitem_1                                 <built-in function getitem>                            (embedding_bag_collection_140218446926689, 0)                        {}
call_function  cat                                       aten.cat.default                                       ([getitem, getitem_1], 1)                                            {}
output         output                                    output                                                 ((cat,),)                                                            {}
```

* after
```
(Pdb) ep.graph.print_tabular()
opcode         name                                      target                                                 args                                                                                                                                   kwargs
-------------  ----------------------------------------  -----------------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------------  --------
placeholder    p_ebc1_embedding_bags_t1_weight           p_ebc1_embedding_bags_t1_weight                        ()                                                                                                                                     {}
placeholder    p_ebc1_embedding_bags_t2_weight           p_ebc1_embedding_bags_t2_weight                        ()                                                                                                                                     {}
placeholder    p_ebc2_embedding_bags_t3_weight           p_ebc2_embedding_bags_t3_weight                        ()                                                                                                                                     {}
placeholder    p_ebc2_embedding_bags_t4_weight           p_ebc2_embedding_bags_t4_weight                        ()                                                                                                                                     {}
placeholder    features__values                          features__values                                       ()                                                                                                                                     {}
placeholder    features__weights                         features__weights                                      ()                                                                                                                                     {}
placeholder    features__lengths                         features__lengths                                      ()                                                                                                                                     {}
placeholder    features__offsets                         features__offsets                                      ()                                                                                                                                     {}
call_function  embedding_bag_collection_140614752342208  custom.EmbeddingBagCollection_140614752342208.default  ([features__values, features__weights, None, features__offsets, p_ebc1_embedding_bags_t1_weight, p_ebc1_embedding_bags_t2_weight], 2)  {}
call_function  getitem                                   <built-in function getitem>                            (embedding_bag_collection_140614752342208, 0)                                                                                          {}
call_function  embedding_bag_collection_140614752342209  custom.EmbeddingBagCollection_140614752342208.default  ([features__values, features__weights, None, features__offsets, p_ebc2_embedding_bags_t3_weight, p_ebc2_embedding_bags_t4_weight], 2)  {}
call_function  getitem_1                                 <built-in function getitem>                            (embedding_bag_collection_140614752342209, 0)                                                                                          {}
call_function  cat                                       aten.cat.default                                       ([getitem, getitem_1], 1)                                                                                                              {}
output         output                                    output                                                 ((cat,),)                                                                                                                              {}
```

Reviewed By: PaulZhang12

Differential Revision: D56904214

fbshipit-source-id: 193a95a393d8ccdad582640079da5b7879aaea36
  • Loading branch information
TroyGarden authored and facebook-github-bot committed May 3, 2024
1 parent 4b42322 commit 9d4c676
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
22 changes: 2 additions & 20 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,10 @@
EmbeddingConfig,
pooling_type_to_str,
)
from torchrec.modules.utils import register_custom_op
from torchrec.modules.utils import is_non_strict_exporting, register_custom_op
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor


try:
if torch.jit.is_scripting():
raise Exception()

from torch.compiler import (
is_compiling as is_compiler_compiling,
is_dynamo_compiling as is_torchdynamo_compiling,
)

def is_non_strict_exporting() -> bool:
return not is_torchdynamo_compiling() and is_compiler_compiling()

except Exception:

def is_non_strict_exporting() -> bool:
return False


@torch.fx.wrap
def reorder_inverse_indices(
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
Expand Down Expand Up @@ -233,7 +215,7 @@ def _non_strict_exporting_forward(
features.weights_or_none(),
features.lengths_or_none(),
features.offsets_or_none(),
]
] + [bag.weight for bag in self.embedding_bags.values()]
dims = [sum(self._lengths_per_embedding)]
ebc_op = register_custom_op(self, dims)
outputs = ebc_op(arg_list, batch_size)
Expand Down
18 changes: 18 additions & 0 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@
lib = torch.library.Library("custom", "FRAGMENT")


try:
if torch.jit.is_scripting():
raise Exception()

from torch.compiler import (
is_compiling as is_compiler_compiling,
is_dynamo_compiling as is_torchdynamo_compiling,
)

def is_non_strict_exporting() -> bool:
return not is_torchdynamo_compiling() and is_compiler_compiling()

except Exception:

def is_non_strict_exporting() -> bool:
return False


class OpRegistryState:
"""
State of operator registry.
Expand Down

0 comments on commit 9d4c676

Please sign in to comment.