Skip to content

Commit

Permalink
Cache all_optimizer_states to speed up model sharding (#2747)
Browse files Browse the repository at this point in the history
Summary:

Similar to D68578829, the `get_optimizer_state()` method in the `emb_module` is invoked thousands of times when `_gen_named_parameters_by_table_fused` is called as it generates EmbeddingFusedOptimizer instances for each iteration.

By extracting this operation out of the loop and passing it as a parameter to achieve a caching effect, we can save a lot of time. Specifically, ~6.6s from https://www.internalfb.com/family_of_labs/test_results/694448901
 and ~20s from https://www.internalfb.com/family_of_labs/test_results/694448900

Reviewed By: dstaay-fb, lijia19, andywag

Differential Revision: D69443708
  • Loading branch information
hstonec authored and facebook-github-bot committed Feb 18, 2025
1 parent ea1cc27 commit 1e25182
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def __init__( # noqa C901
create_for_table: Optional[str] = None,
param_weight_for_table: Optional[nn.Parameter] = None,
embedding_weights_by_table: Optional[List[torch.Tensor]] = None,
all_optimizer_states: Optional[List[Dict[str, torch.Tensor]]] = None,
) -> None:
"""
Implementation of a FusedOptimizer. Designed as a base class Embedding kernels
Expand Down Expand Up @@ -395,7 +396,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
embedding_weights_by_table or emb_module.split_embedding_weights()
)

all_optimizer_states = emb_module.get_optimizer_state()
all_optimizer_states = all_optimizer_states or emb_module.get_optimizer_state()
optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {}
for (
table_config,
Expand Down Expand Up @@ -678,6 +679,8 @@ def _gen_named_parameters_by_table_fused(
# TODO: move logic to FBGEMM to avoid accessing fbgemm internals
# Cache embedding_weights_by_table
embedding_weights_by_table = emb_module.split_embedding_weights()
# Cache all_optimizer_states
all_optimizer_states = emb_module.get_optimizer_state()
for t_idx, (rows, dim, location, _) in enumerate(emb_module.embedding_specs):
table_name = config.embedding_tables[t_idx].name
if table_name not in table_name_to_count:
Expand Down Expand Up @@ -714,6 +717,7 @@ def _gen_named_parameters_by_table_fused(
create_for_table=table_name,
param_weight_for_table=weight,
embedding_weights_by_table=embedding_weights_by_table,
all_optimizer_states=all_optimizer_states,
)
]
yield (table_name, weight)
Expand Down

0 comments on commit 1e25182

Please sign in to comment.