Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cache all_optimizer_states to speed up model sharding (#2747)
Summary: Similar to D68578829, the `get_optimizer_state()` method in the `emb_module` is invoked thousands of times when `_gen_named_parameters_by_table_fused` is called as it generates EmbeddingFusedOptimizer instances for each iteration. By extracting this operation out of the loop and passing it as a parameter to achieve a caching effect, we can save a lot of time. Specifically, ~6.6s from https://www.internalfb.com/family_of_labs/test_results/694448901 and ~20s from https://www.internalfb.com/family_of_labs/test_results/694448900 Reviewed By: dstaay-fb, lijia19, andywag Differential Revision: D69443708
- Loading branch information