Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Registor custom ops for EBC and PEA when doing torch.export (#1913)
Summary: # context * when doing torch.export, embedding modules like PEA (PooledEmbeddingArch) and EBC (EmbeddingBagCollection) would be flattened into individual embedding_bags like the following example (D56282744): ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- --------------------------------- --------------------------------- ---------------------------------------------------------------------------------------------- --------------------- ... call_function getitem_23 <built-in function getitem> (split_with_sizes_2, 1) {} call_function _embedding_bag aten._embedding_bag.default (p_pea_embedding_modules_t1_weight, getitem_10, getitem_14, False, 0, False, None, True) {} call_function getitem_24 <built-in function getitem> (_embedding_bag, 0) {} call_function _embedding_bag_1 aten._embedding_bag.default (p_pea_embedding_modules_t2_weight, getitem_11, getitem_15, False, 0, False, None, True) {} call_function getitem_28 <built-in function getitem> (_embedding_bag_1, 0) {} call_function _embedding_bag_2 aten._embedding_bag.default (p_pea_embedding_modules_t3_weight, getitem_16, getitem_20, False, 0, False, getitem_22, True) {} call_function getitem_32 <built-in function getitem> (_embedding_bag_2, 0) {} call_function _embedding_bag_3 aten._embedding_bag.default (p_pea_embedding_modules_t4_weight, getitem_17, getitem_21, False, 0, False, getitem_23, True) {} call_function getitem_36 <built-in function getitem> (_embedding_bag_3, 0) {} call_function cat_2 aten.cat.default ([getitem_24, getitem_28], 1) {} call_function cat_3 aten.cat.default ([getitem_32, getitem_36], 1) {} call_function cat_4 aten.cat.default ([cat_2, cat_3], 1) {} output output output ((cat_4,),) {} ``` * this flattening is unnecessary and expansive because the deserialization of the embedding module is done by another logic without the flattened schema. * the solution is to treat the embedding module as a blackbox (custom op) in the graph when doing the torch.export ``` ... placeholder w_weights w_weights () {} call_function pooled_embedding_arch_8734585215502 custom.PooledEmbeddingArch_8734585215502.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_8734585215502, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_8734585215502, 1) {} call_function pooled_embedding_arch_8734585231976 custom.PooledEmbeddingArch_8734585231976.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_8734585231976, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_8734585231976, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # details * get the output tensor shapes (List[Tensor]) from the embedding modules in the `_non_strict_exporting_forward` function * register a custom_op with input as `List[Optional[Tensor]]` and the output (List[Tensor]) with the given shapes in `register_custom_op` * call this customo_op with original input and get the desired output, so that in the graph the custom_op can be a single node with correct shapes * in the actual forward function of the embedding module, we use `is_non_strict_exporting()` and `not torch.jit.is_scripting()` to branch to the _non_strict_exporting_forward function. Differential Revision: D56443608
- Loading branch information