Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache all_optimizer_states to speed up model sharding #2747

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def __init__( # noqa C901
create_for_table: Optional[str] = None,
param_weight_for_table: Optional[nn.Parameter] = None,
embedding_weights_by_table: Optional[List[torch.Tensor]] = None,
all_optimizer_states: Optional[List[Dict[str, torch.Tensor]]] = None,
) -> None:
"""
Implementation of a FusedOptimizer. Designed as a base class Embedding kernels
Expand Down Expand Up @@ -395,7 +396,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
embedding_weights_by_table or emb_module.split_embedding_weights()
)

all_optimizer_states = emb_module.get_optimizer_state()
all_optimizer_states = all_optimizer_states or emb_module.get_optimizer_state()
optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {}
for (
table_config,
Expand Down Expand Up @@ -678,6 +679,8 @@ def _gen_named_parameters_by_table_fused(
# TODO: move logic to FBGEMM to avoid accessing fbgemm internals
# Cache embedding_weights_by_table
embedding_weights_by_table = emb_module.split_embedding_weights()
# Cache all_optimizer_states
all_optimizer_states = emb_module.get_optimizer_state()
for t_idx, (rows, dim, location, _) in enumerate(emb_module.embedding_specs):
table_name = config.embedding_tables[t_idx].name
if table_name not in table_name_to_count:
Expand Down Expand Up @@ -714,6 +717,7 @@ def _gen_named_parameters_by_table_fused(
create_for_table=table_name,
param_weight_for_table=weight,
embedding_weights_by_table=embedding_weights_by_table,
all_optimizer_states=all_optimizer_states,
)
]
yield (table_name, weight)
Expand Down
Loading