Skip to content

Commit

Permalink
Unifying TBE API using List (Frontend)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/FBGEMM#3711

X-link: facebookresearch/FBGEMM#793

**Backend**: D68054868

---


As the number of arguments in TBE keeps growing, some of the optimizers run into number of arguments limitation (i.e., 64) during pytorch operation registration. 

**For long-term growth and maintenance, we hence redesign TBE API by packing some of the arguments into list. Note that not all arguments are packed.**

We pack the arguments as a list for each type.
For **common** arguments, we pack 
- weights and arguments of type `Momentum` into TensorList
- other tensors and optional tensors to list of optional tensors `aux_tensor`
- `int` arguments into `aux_int`
- `float` arguments into `aux_float`
- `bool` arguments into `aux_bool`.

Similarly for **optimizer-specific** arguments, we pack
- arguments of type `Momentum` that are *__not__ optional* into TensorList
- *optional* tensors to list of optional tensors `optim_tensor`
- `int` arguments into `optim_int`
- `float` arguments into `optim_float`
- `bool` arguments into `optim_bool`.

We see issues with pytorch registration across packing SymInt in python-C++, so we unroll and pass SymInt arguments individually. 

**This significantly reduces number of arguments.** For example, `split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function`, which currently has 61 arguments only have 26 arguments with this API design. 

Please refer to the design doc on which arguments are packed and signature.
Design doc:
https://docs.google.com/document/d/1dCBg7dcf7Yq9FHVrvXsAmFtBxkDi9o6u0r-Ptd4UDPE/edit?tab=t.0#heading=h.6bip5pwqq8xb

Full signature for each optimizer lookup function will be provided shortly.

Reviewed By: sryap

Differential Revision: D68055168
  • Loading branch information
spcyppt authored and facebook-github-bot committed Feb 19, 2025
1 parent 49ac41b commit 8922c3e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def __init__(
state: Dict[Any, Any] = {}
param_group: Dict[str, Any] = {
"params": [],
"lr": emb_module.optimizer_args.learning_rate,
"lr": emb_module.optimizer_args.learning_rate_tensor,
}

params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}
Expand Down Expand Up @@ -383,7 +383,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
state: Dict[Any, Any] = {}
param_group: Dict[str, Any] = {
"params": [],
"lr": emb_module.optimizer_args.learning_rate,
"lr": emb_module.optimizer_args.learning_rate_tensor,
}

params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}
Expand Down
2 changes: 1 addition & 1 deletion torchrec/modules/fused_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__( # noqa C901
state: Dict[Any, Any] = {}
param_group: Dict[str, Any] = {
"params": [],
"lr": emb_module.optimizer_args.learning_rate,
"lr": emb_module.optimizer_args.learning_rate_tensor,
}

params: Dict[str, torch.Tensor] = {}
Expand Down

0 comments on commit 8922c3e

Please sign in to comment.