-
Notifications
You must be signed in to change notification settings - Fork 479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Registor custom ops for EBC and PEA when doing torch.export #1913
Conversation
This pull request was exported from Phabricator. Differential Revision: D56443608 |
…1913) Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 <built-in function getitem> (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 <built-in function getitem> (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 <built-in function getitem> (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 <built-in function getitem> (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 <built-in function getitem> (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function. Differential Revision: D56443608
9260421
to
ce6cd55
Compare
This pull request was exported from Phabricator. Differential Revision: D56443608 |
…1913) Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 <built-in function getitem> (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 <built-in function getitem> (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 <built-in function getitem> (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 <built-in function getitem> (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 <built-in function getitem> (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function. Differential Revision: D56443608
ce6cd55
to
24ef7b6
Compare
This pull request was exported from Phabricator. Differential Revision: D56443608 |
…1913) Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 <built-in function getitem> (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 <built-in function getitem> (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 <built-in function getitem> (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 <built-in function getitem> (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 <built-in function getitem> (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function. Differential Revision: D56443608
24ef7b6
to
41a5b42
Compare
This pull request was exported from Phabricator. Differential Revision: D56443608 |
…1913) Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 <built-in function getitem> (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 <built-in function getitem> (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 <built-in function getitem> (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 <built-in function getitem> (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 <built-in function getitem> (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function. Reviewed By: PaulZhang12 Differential Revision: D56443608
41a5b42
to
00e0e85
Compare
This pull request was exported from Phabricator. Differential Revision: D56443608 |
…1913) Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 <built-in function getitem> (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 <built-in function getitem> (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 <built-in function getitem> (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 <built-in function getitem> (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 <built-in function getitem> (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function. Reviewed By: PaulZhang12 Differential Revision: D56443608
00e0e85
to
6f639b3
Compare
This pull request was exported from Phabricator. Differential Revision: D56443608 |
…1913) Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 <built-in function getitem> (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 <built-in function getitem> (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 <built-in function getitem> (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 <built-in function getitem> (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 <built-in function getitem> (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function. Reviewed By: PaulZhang12 Differential Revision: D56443608
6f639b3
to
ea23208
Compare
This pull request was exported from Phabricator. Differential Revision: D56443608 |
…1913) Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 <built-in function getitem> (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 <built-in function getitem> (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 <built-in function getitem> (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 <built-in function getitem> (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 <built-in function getitem> (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function. Reviewed By: PaulZhang12 Differential Revision: D56443608
ea23208
to
09abc36
Compare
This pull request was exported from Phabricator. Differential Revision: D56443608 |
…1913) Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 <built-in function getitem> (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 <built-in function getitem> (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 <built-in function getitem> (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 <built-in function getitem> (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 <built-in function getitem> (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function. Reviewed By: PaulZhang12 Differential Revision: D56443608
09abc36
to
d12f207
Compare
This pull request was exported from Phabricator. Differential Revision: D56443608 |
…1913) Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 <built-in function getitem> (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 <built-in function getitem> (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 <built-in function getitem> (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 <built-in function getitem> (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 <built-in function getitem> (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function. Reviewed By: PaulZhang12 Differential Revision: D56443608
d12f207
to
cf3170c
Compare
This pull request was exported from Phabricator. Differential Revision: D56443608 |
Summary:
context
details
_meta_forward
functionList[Optional[Tensor]]
and the output (List[Tensor]) with the given shapes inregister_custom_op
is_non_strict_exporting()
andnot torch.jit.is_scripting()
to branch to the meta_forward function.Differential Revision: D56443608