Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Unifying TBE API using List (Frontend)
Summary: X-link: pytorch/FBGEMM#3711 X-link: facebookresearch/FBGEMM#793 **Backend**: D68054868 --- As the number of arguments in TBE keeps growing, some of the optimizers run into number of arguments limitation (i.e., 64) during pytorch operation registration. **For long-term growth and maintenance, we hence redesign TBE API by packing some of the arguments into list. Note that not all arguments are packed.** We pack the arguments as a list for each type. For **common** arguments, we pack - weights and arguments of type `Momentum` into TensorList - other tensors and optional tensors to list of optional tensors `aux_tensor` - `int` arguments into `aux_int` - `float` arguments into `aux_float` - `bool` arguments into `aux_bool`. Similarly for **optimizer-specific** arguments, we pack - arguments of type `Momentum` that are *__not__ optional* into TensorList - *optional* tensors to list of optional tensors `optim_tensor` - `int` arguments into `optim_int` - `float` arguments into `optim_float` - `bool` arguments into `optim_bool`. We see issues with pytorch registration across packing SymInt in python-C++, so we unroll and pass SymInt arguments individually. **This significantly reduces number of arguments.** For example, `split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function`, which currently has 61 arguments only have 26 arguments with this API design. Please refer to the design doc on which arguments are packed and signature. Design doc: https://docs.google.com/document/d/1dCBg7dcf7Yq9FHVrvXsAmFtBxkDi9o6u0r-Ptd4UDPE/edit?tab=t.0#heading=h.6bip5pwqq8xb Full signature for each optimizer lookup function will be provided shortly. Reviewed By: sryap Differential Revision: D68055168
- Loading branch information