Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor the _non_strict_export_forward in EBC and PEA to include emb…
…edding_bag.weight (#1942) Summary: Pull Request resolved: #1942 # context * previously the EBC and PEA only takes the KJT as the input of the custom_op, which leaves the weights of embedding bags as dangling placeholders in the graph. * with this diff, we add those weights into the custom_op so that in the graph those weights are connected to the custom_op # PEA * embedding_bag weights like `p_pea1_embedding_modules_t1_weight` is dangling before, and it's used by `pooled_embedding_arch` afterwards. * the total count of nodes are the same. * before ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- ------------------------------------- -------------------------------------------------- ------------------------------------------------------------------------ -------- placeholder p_pea1_embedding_modules_t1_weight p_pea1_embedding_modules_t1_weight () {} placeholder p_pea1_embedding_modules_t2_weight p_pea1_embedding_modules_t2_weight () {} placeholder p_pea1_embedding_modules_t3_weight p_pea1_embedding_modules_t3_weight () {} placeholder p_pea1_embedding_modules_t4_weight p_pea1_embedding_modules_t4_weight () {} placeholder p_pea2_embedding_modules_t1_weight p_pea2_embedding_modules_t1_weight () {} placeholder p_pea2_embedding_modules_t2_weight p_pea2_embedding_modules_t2_weight () {} placeholder p_pea2_embedding_modules_t3_weight p_pea2_embedding_modules_t3_weight () {} placeholder values values () {} placeholder lengths lengths () {} placeholder w_values w_values () {} placeholder w_lengths w_lengths () {} placeholder w_weights w_weights () {} call_function pooled_embedding_arch_139590115359376 custom.PooledEmbeddingArch_139590115359376.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_139590115359376, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_139590115359376, 1) {} call_function pooled_embedding_arch_139590115359377 custom.PooledEmbeddingArch_139590115359376.default ([values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_139590115359377, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_139590115359377, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` * after ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- ------------------------------------- -------------------------------------------------- ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ -------- placeholder p_pea1_embedding_modules_t1_weight p_pea1_embedding_modules_t1_weight () {} placeholder p_pea1_embedding_modules_t2_weight p_pea1_embedding_modules_t2_weight () {} placeholder p_pea1_embedding_modules_t3_weight p_pea1_embedding_modules_t3_weight () {} placeholder p_pea1_embedding_modules_t4_weight p_pea1_embedding_modules_t4_weight () {} placeholder p_pea2_embedding_modules_t1_weight p_pea2_embedding_modules_t1_weight () {} placeholder p_pea2_embedding_modules_t2_weight p_pea2_embedding_modules_t2_weight () {} placeholder p_pea2_embedding_modules_t3_weight p_pea2_embedding_modules_t3_weight () {} placeholder values values () {} placeholder lengths lengths () {} placeholder w_values w_values () {} placeholder w_lengths w_lengths () {} placeholder w_weights w_weights () {} call_function pooled_embedding_arch_140307162538352 custom.PooledEmbeddingArch_140307162538352.default ([p_pea1_embedding_modules_t1_weight, p_pea1_embedding_modules_t2_weight, p_pea1_embedding_modules_t3_weight, p_pea1_embedding_modules_t4_weight, values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_10 <built-in function getitem> (pooled_embedding_arch_140307162538352, 0) {} call_function getitem_11 <built-in function getitem> (pooled_embedding_arch_140307162538352, 1) {} call_function pooled_embedding_arch_140307162538353 custom.PooledEmbeddingArch_140307162538352.default ([p_pea2_embedding_modules_t1_weight, p_pea2_embedding_modules_t2_weight, p_pea2_embedding_modules_t3_weight, values, None, lengths, None, w_values, w_weights, w_lengths, None], 2) {} call_function getitem_12 <built-in function getitem> (pooled_embedding_arch_140307162538353, 0) {} call_function getitem_13 <built-in function getitem> (pooled_embedding_arch_140307162538353, 1) {} call_function cat aten.cat.default ([getitem_10, getitem_11, getitem_12, getitem_13], 1) {} output output output ((cat,),) {} ``` # EBC * embedding_bag weights like `p_ebc1_embedding_bags_t1_weight` is dangling before, and it's used by `embedding_bag_collection` afterwards. * the total count of nodes are the same. * before ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- ---------------------------------------- ----------------------------------------------------- ------------------------------------------------------------------- -------- placeholder p_ebc1_embedding_bags_t1_weight p_ebc1_embedding_bags_t1_weight () {} placeholder p_ebc1_embedding_bags_t2_weight p_ebc1_embedding_bags_t2_weight () {} placeholder p_ebc2_embedding_bags_t3_weight p_ebc2_embedding_bags_t3_weight () {} placeholder p_ebc2_embedding_bags_t4_weight p_ebc2_embedding_bags_t4_weight () {} placeholder features__values features__values () {} placeholder features__weights features__weights () {} placeholder features__lengths features__lengths () {} placeholder features__offsets features__offsets () {} call_function embedding_bag_collection_140218446926688 custom.EmbeddingBagCollection_140218446926688.default ([features__values, features__weights, None, features__offsets], 2) {} call_function getitem <built-in function getitem> (embedding_bag_collection_140218446926688, 0) {} call_function embedding_bag_collection_140218446926689 custom.EmbeddingBagCollection_140218446926688.default ([features__values, features__weights, None, features__offsets], 2) {} call_function getitem_1 <built-in function getitem> (embedding_bag_collection_140218446926689, 0) {} call_function cat aten.cat.default ([getitem, getitem_1], 1) {} output output output ((cat,),) {} ``` * after ``` (Pdb) ep.graph.print_tabular() opcode name target args kwargs ------------- ---------------------------------------- ----------------------------------------------------- ------------------------------------------------------------------------------------------------------------------------------------- -------- placeholder p_ebc1_embedding_bags_t1_weight p_ebc1_embedding_bags_t1_weight () {} placeholder p_ebc1_embedding_bags_t2_weight p_ebc1_embedding_bags_t2_weight () {} placeholder p_ebc2_embedding_bags_t3_weight p_ebc2_embedding_bags_t3_weight () {} placeholder p_ebc2_embedding_bags_t4_weight p_ebc2_embedding_bags_t4_weight () {} placeholder features__values features__values () {} placeholder features__weights features__weights () {} placeholder features__lengths features__lengths () {} placeholder features__offsets features__offsets () {} call_function embedding_bag_collection_140614752342208 custom.EmbeddingBagCollection_140614752342208.default ([features__values, features__weights, None, features__offsets, p_ebc1_embedding_bags_t1_weight, p_ebc1_embedding_bags_t2_weight], 2) {} call_function getitem <built-in function getitem> (embedding_bag_collection_140614752342208, 0) {} call_function embedding_bag_collection_140614752342209 custom.EmbeddingBagCollection_140614752342208.default ([features__values, features__weights, None, features__offsets, p_ebc2_embedding_bags_t3_weight, p_ebc2_embedding_bags_t4_weight], 2) {} call_function getitem_1 <built-in function getitem> (embedding_bag_collection_140614752342209, 0) {} call_function cat aten.cat.default ([getitem, getitem_1], 1) {} output output output ((cat,),) {} ``` Reviewed By: PaulZhang12 Differential Revision: D56904214 fbshipit-source-id: 193a95a393d8ccdad582640079da5b7879aaea36
- Loading branch information