Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Registor custom ops for EBC and PEA when doing torch.export #1913

Closed
wants to merge 1 commit into from

Conversation

TroyGarden
Copy link
Contributor

Summary:

context

  • when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
  • this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
  • the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}

details

  • get the output tensor shapes (List[Tensor]) from the embedding modules in the _meta_forward function
  • register a custom_op with input as List[Optional[Tensor]] and the output (List[Tensor]) with the given shapes in register_custom_op
  • call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
  • in the actual forward function of the embedding module, we use is_non_strict_exporting() and not torch.jit.is_scripting() to branch to the meta_forward function.

Differential Revision: D56443608

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56443608

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 22, 2024
…1913)

Summary:

# context
* when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
```
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
```
* this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
* the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
```
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}
```

# details
* get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function
* register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op`
* call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
* in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function.

Differential Revision: D56443608
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56443608

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 23, 2024
…1913)

Summary:

# context
* when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
```
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
```
* this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
* the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
```
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}
```

# details
* get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function
* register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op`
* call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
* in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function.

Differential Revision: D56443608
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56443608

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 23, 2024
…1913)

Summary:

# context
* when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
```
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
```
* this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
* the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
```
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}
```

# details
* get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function
* register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op`
* call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
* in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function.

Differential Revision: D56443608
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56443608

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 23, 2024
…1913)

Summary:

# context
* when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
```
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
```
* this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
* the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
```
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}
```

# details
* get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function
* register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op`
* call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
* in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function.

Reviewed By: PaulZhang12

Differential Revision: D56443608
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56443608

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 23, 2024
…1913)

Summary:

# context
* when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
```
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
```
* this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
* the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
```
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}
```

# details
* get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function
* register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op`
* call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
* in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function.

Reviewed By: PaulZhang12

Differential Revision: D56443608
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56443608

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 23, 2024
…1913)

Summary:

# context
* when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
```
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
```
* this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
* the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
```
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}
```

# details
* get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function
* register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op`
* call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
* in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function.

Reviewed By: PaulZhang12

Differential Revision: D56443608
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56443608

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 23, 2024
…1913)

Summary:

# context
* when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
```
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
```
* this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
* the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
```
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}
```

# details
* get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function
* register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op`
* call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
* in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function.

Reviewed By: PaulZhang12

Differential Revision: D56443608
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56443608

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 23, 2024
…1913)

Summary:

# context
* when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
```
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
```
* this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
* the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
```
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}
```

# details
* get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function
* register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op`
* call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
* in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function.

Reviewed By: PaulZhang12

Differential Revision: D56443608
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56443608

…1913)

Summary:

# context
* when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744):
```
(Pdb) ep.graph.print_tabular()
opcode         name                               target                             args                                                                                            kwargs
-------------  ---------------------------------  ---------------------------------  ----------------------------------------------------------------------------------------------  ---------------------
...
call_function  getitem_23                         <built-in function getitem>        (split_with_sizes_2, 1)                                                                         {}
call_function  _embedding_bag                     aten._embedding_bag.default        (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True)        {}
call_function  getitem_24                         <built-in function getitem>        (_embedding_bag, 0)                                                                             {}
call_function  _embedding_bag_1                   aten._embedding_bag.default        (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True)        {}
call_function  getitem_28                         <built-in function getitem>        (_embedding_bag_1, 0)                                                                           {}
call_function  _embedding_bag_2                   aten._embedding_bag.default        (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True)  {}
call_function  getitem_32                         <built-in function getitem>        (_embedding_bag_2, 0)                                                                           {}
call_function  _embedding_bag_3                   aten._embedding_bag.default        (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True)  {}
call_function  getitem_36                         <built-in function getitem>        (_embedding_bag_3, 0)                                                                           {}
call_function  cat_2                              aten.cat.default                   ([getitem_24, getitem_28], 1)                                                                   {}
call_function  cat_3                              aten.cat.default                   ([getitem_32, getitem_36], 1)                                                                   {}
call_function  cat_4                              aten.cat.default                   ([cat_2, cat_3], 1)                                                                             {}
output         output                             output                             ((cat_4,),)                                                                                     {}
```
* this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema.
* the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export
```
...
placeholder    w_weights                            w_weights                                         ()                                                                        {}
call_function  pooled_embedding_arch_8734585215502  custom.PooledEmbeddingArch_8734585215502.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_10                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 0)                                  {}
call_function  getitem_11                           <built-in function getitem>                       (pooled_embedding_arch_8734585215502, 1)                                  {}
call_function  pooled_embedding_arch_8734585231976  custom.PooledEmbeddingArch_8734585231976.default  ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2)  {}
call_function  getitem_12                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 0)                                  {}
call_function  getitem_13                           <built-in function getitem>                       (pooled_embedding_arch_8734585231976, 1)                                  {}
call_function  cat                                  aten.cat.default                                  ([getitem_10, getitem_11, getitem_12, getitem_13], 1)                     {}
output         output                               output                                            ((cat,),)                                                                 {}
```

# details
* get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function
* register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op`
* call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes
* in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function.

Reviewed By: PaulZhang12

Differential Revision: D56443608
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D56443608

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants