From 1e25182c9a913ca643c778baf17eccd37af485d3 Mon Sep 17 00:00:00 2001 From: Shicong Huang Date: Tue, 18 Feb 2025 12:44:13 -0800 Subject: [PATCH] Cache all_optimizer_states to speed up model sharding (#2747) Summary: Similar to D68578829, the `get_optimizer_state()` method in the `emb_module` is invoked thousands of times when `_gen_named_parameters_by_table_fused` is called as it generates EmbeddingFusedOptimizer instances for each iteration. By extracting this operation out of the loop and passing it as a parameter to achieve a caching effect, we can save a lot of time. Specifically, ~6.6s from https://www.internalfb.com/family_of_labs/test_results/694448901 and ~20s from https://www.internalfb.com/family_of_labs/test_results/694448900 Reviewed By: dstaay-fb, lijia19, andywag Differential Revision: D69443708 --- torchrec/distributed/batched_embedding_kernel.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index b4db21da4..d7b945bf6 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -216,6 +216,7 @@ def __init__( # noqa C901 create_for_table: Optional[str] = None, param_weight_for_table: Optional[nn.Parameter] = None, embedding_weights_by_table: Optional[List[torch.Tensor]] = None, + all_optimizer_states: Optional[List[Dict[str, torch.Tensor]]] = None, ) -> None: """ Implementation of a FusedOptimizer. Designed as a base class Embedding kernels @@ -395,7 +396,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( embedding_weights_by_table or emb_module.split_embedding_weights() ) - all_optimizer_states = emb_module.get_optimizer_state() + all_optimizer_states = all_optimizer_states or emb_module.get_optimizer_state() optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {} for ( table_config, @@ -678,6 +679,8 @@ def _gen_named_parameters_by_table_fused( # TODO: move logic to FBGEMM to avoid accessing fbgemm internals # Cache embedding_weights_by_table embedding_weights_by_table = emb_module.split_embedding_weights() + # Cache all_optimizer_states + all_optimizer_states = emb_module.get_optimizer_state() for t_idx, (rows, dim, location, _) in enumerate(emb_module.embedding_specs): table_name = config.embedding_tables[t_idx].name if table_name not in table_name_to_count: @@ -714,6 +717,7 @@ def _gen_named_parameters_by_table_fused( create_for_table=table_name, param_weight_for_table=weight, embedding_weights_by_table=embedding_weights_by_table, + all_optimizer_states=all_optimizer_states, ) ] yield (table_name, weight)