Skip to content

Commit

Permalink
Unifying TBE API using List (Frontend) (pytorch#3711)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#2751


X-link: facebookresearch/FBGEMM#793

**Backend**: D68054868

---


As the number of arguments in TBE keeps growing, some of the optimizers run into number of arguments limitation (i.e., 64) during pytorch operation registration. 

**For long-term growth and maintenance, we hence redesign TBE API by packing some of the arguments into list. Note that not all arguments are packed.**

We pack the arguments as a list for each type.
For **common** arguments, we pack 
- weights and arguments of type `Momentum` into TensorList
- other tensors and optional tensors to list of optional tensors `aux_tensor`
- `int` arguments into `aux_int`
- `float` arguments into `aux_float`
- `bool` arguments into `aux_bool`.

Similarly for **optimizer-specific** arguments, we pack
- arguments of type `Momentum` that are *__not__ optional* into TensorList
- *optional* tensors to list of optional tensors `optim_tensor`
- `int` arguments into `optim_int`
- `float` arguments into `optim_float`
- `bool` arguments into `optim_bool`.

We see issues with pytorch registration across packing SymInt in python-C++, so we unroll and pass SymInt arguments individually. 

**This significantly reduces number of arguments.** For example, `split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function`, which currently has 61 arguments only have 26 arguments with this API design. 

Please refer to the design doc on which arguments are packed and signature.
Design doc:
https://docs.google.com/document/d/1dCBg7dcf7Yq9FHVrvXsAmFtBxkDi9o6u0r-Ptd4UDPE/edit?tab=t.0#heading=h.6bip5pwqq8xb

Full signature for each optimizer lookup function will be provided shortly.

Reviewed By: sryap

Differential Revision: D68055168
  • Loading branch information
spcyppt authored and facebook-github-bot committed Feb 21, 2025
1 parent b97f57f commit a9c9021
Show file tree
Hide file tree
Showing 9 changed files with 407 additions and 344 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def generate() -> None:
ssd_optimizers.append(optim)

BackwardSplitGenerator.generate_backward_split(
ssd_tensors=ssd_tensors, **optimizer
ssd_tensors=ssd_tensors, aux_args=aux_args, **optimizer
)
BackwardSplitGenerator.generate_rocm_backward_split()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ Tensor {{ embedding_cuda_op }}(

{%- if "learning_rate" in args.split_kernel_arg_names %}
// convert `learning rate` to float since `learning rate` is float in kernels
TORCH_CHECK(learning_rate_tensor.is_cpu(), "learning_rate_tensor tensor needs to be on CPU. Ensure learning_rate_tensor is on CPU or contact FBGEMM team if you get this error.")
const float learning_rate = learning_rate_tensor.item<float>();
{%- endif %}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
"""

embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]
optimizer_args: invokers.lookup_args.OptimizerArgs
optimizer_args: invokers.lookup_args.OptimizerArgsPT2
lxu_cache_locations_list: List[Tensor]
lxu_cache_locations_empty: Tensor
timesteps_prefetched: List[int]
Expand Down Expand Up @@ -1070,6 +1070,9 @@ def __init__( # noqa C901
# which should not be effective when CounterBasedRegularizationDefinition
# and CowClipDefinition are not used
counter_halflife = -1
learning_rate_tensor = torch.tensor(
learning_rate, device=torch.device("cpu"), dtype=torch.float
)

# TO DO: Enable this on the new interface
# learning_rate_tensor = torch.tensor(
Expand All @@ -1085,12 +1088,12 @@ def __init__( # noqa C901
"`use_rowwise_bias_correction` is only supported for OptimType.ADAM",
)

self.optimizer_args = invokers.lookup_args.OptimizerArgs(
self.optimizer_args = invokers.lookup_args.OptimizerArgsPT2(
stochastic_rounding=stochastic_rounding,
gradient_clipping=gradient_clipping,
max_gradient=max_gradient,
max_norm=max_norm,
learning_rate=learning_rate,
learning_rate_tensor=learning_rate_tensor,
eps=eps,
beta1=beta1,
beta2=beta2,
Expand Down Expand Up @@ -2032,7 +2035,6 @@ def forward( # noqa: C901
momentum1,
momentum2,
iter_int,
self.use_rowwise_bias_correction,
row_counter=(
row_counter if self.use_rowwise_bias_correction else None
),
Expand Down Expand Up @@ -2918,7 +2920,7 @@ def _set_learning_rate(self, lr: float) -> float:
Helper function to script `set_learning_rate`.
Note that returning None does not work.
"""
self.optimizer_args = self.optimizer_args._replace(learning_rate=lr)
self.optimizer_args.learning_rate_tensor.fill_(lr)
return 0.0

@torch.jit.ignore
Expand Down
7 changes: 5 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,12 +544,15 @@ def __init__(
)
cowclip_regularization = CowClipDefinition()

self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgs(
learning_rate_tensor = torch.tensor(
learning_rate, device=self.current_device, dtype=torch.float
)
self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgsPT2(
stochastic_rounding=stochastic_rounding,
gradient_clipping=gradient_clipping,
max_gradient=max_gradient,
max_norm=max_norm,
learning_rate=learning_rate,
learning_rate_tensor=learning_rate_tensor,
eps=eps,
beta1=beta1,
beta2=beta2,
Expand Down
84 changes: 58 additions & 26 deletions fbgemm_gpu/test/tbe/training/backward_adagrad_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
round_up,
to_device,
)
from hypothesis import assume, HealthCheck, Verbosity
from hypothesis import HealthCheck, Verbosity

from .. import common # noqa E402
from ..common import ( # noqa E402
Expand Down Expand Up @@ -73,7 +73,7 @@
"D": st.integers(min_value=2, max_value=128),
"B": st.integers(min_value=1, max_value=128),
"log_E": st.integers(min_value=3, max_value=5),
"L": st.integers(min_value=0, max_value=20),
"L": st.integers(min_value=1, max_value=20),
"D_gradcheck": st.integers(min_value=1, max_value=2),
"stochastic_rounding": st.booleans(),
"weighted": st.booleans(),
Expand All @@ -82,8 +82,10 @@
"use_cache": st.booleans(),
"cache_algorithm": st.sampled_from(CacheAlgorithm),
"use_cpu": use_cpu_strategy(),
"output_dtype": st.sampled_from(
[SparseType.FP32, SparseType.FP16, SparseType.BF16]
"output_dtype": (
st.sampled_from([SparseType.FP32, SparseType.FP16, SparseType.BF16])
if gpu_available
else st.sampled_from([SparseType.FP32])
),
}

Expand Down Expand Up @@ -118,33 +120,59 @@ def execute_backward_adagrad( # noqa C901
compile: bool = False,
) -> None:
# NOTE: cache is not applicable to CPU version.
assume(not use_cpu or not use_cache)
# assume(not use_cpu or not use_cache)
if use_cpu and use_cache:
return

# NOTE: torch.autograd.gradcheck() is too time-consuming for CPU version
# so we have to limit (T * B * L * D)!
assume(not use_cpu or T * B * L * D <= 1024)
assume(not (use_cpu and weights_precision == SparseType.FP16))
# assume(not use_cpu or T * B * L * D <= 1024)
if use_cpu and T * B * L * D > 1024:
return
# assume(not (use_cpu and weights_precision == SparseType.FP16))
if use_cpu and weights_precision == SparseType.FP16:
return
# assume(not (use_cpu and weights_precision == SparseType.BF16))
if use_cpu and weights_precision == SparseType.BF16:
return
# max_norm is only applicable to EXACT_ROWWISE_ADAGRAD GPU version
assume(max_norm == 0.0 or (not use_cpu and row_wise))

assume(
pooling_mode == PoolingMode.SUM or not weighted
) # No bag ops only work on GPUs, no mixed, no weighted
assume(not use_cpu or pooling_mode != PoolingMode.NONE)
assume(not mixed or pooling_mode != PoolingMode.NONE)
assume(not weighted or pooling_mode != PoolingMode.NONE)
# assume(max_norm == 0.0 or (not use_cpu and row_wise))
if max_norm != 0.0 and (use_cpu or not row_wise):
return
# assume(
# pooling_mode == PoolingMode.SUM or not weighted
# ) # No bag ops only work on GPUs, no mixed, no weighted
if pooling_mode != PoolingMode.SUM and weighted:
return
# assume(not use_cpu or pooling_mode != PoolingMode.NONE)
if use_cpu and pooling_mode == PoolingMode.NONE:
return
# assume(not mixed or pooling_mode != PoolingMode.NONE)
if mixed and pooling_mode == PoolingMode.NONE:
return
# assume(not weighted or pooling_mode != PoolingMode.NONE)
if weighted and pooling_mode == PoolingMode.NONE:
return
# TODO: Support these cases
assume(
not mixed_B
or (
weights_precision != SparseType.INT8
and output_dtype != SparseType.INT8
and pooling_mode != PoolingMode.NONE
)
)
# assume(
# not mixed_B
# or (
# weights_precision != SparseType.INT8
# and output_dtype != SparseType.INT8
# and pooling_mode != PoolingMode.NONE
# )
# )
if mixed_B and (
weights_precision == SparseType.INT8
or output_dtype == SparseType.INT8
or pooling_mode == PoolingMode.NONE
):
return

# Disable dynamo tests due to unknown failure
assume(not compile)
# assume(not compile)
if compile:
return

emb_op = SplitTableBatchedEmbeddingBagsCodegen

Expand All @@ -162,9 +190,13 @@ def execute_backward_adagrad( # noqa C901
raise RuntimeError("Unknown PoolingMode!")

# stochastic rounding only implemented for rowwise
assume(not stochastic_rounding or row_wise)
# assume(not stochastic_rounding or row_wise)
if stochastic_rounding and not row_wise:
return
# only row-wise supports caching
assume(row_wise or not use_cache)
# assume(row_wise or not use_cache)
if not row_wise and use_cache:
return

E = int(10**log_E)
if use_cpu:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def apply_gwd(
apply_gwd_per_table(
prev_iter_values,
weights_values,
emb.optimizer_args.learning_rate,
emb.optimizer_args.learning_rate_tensor.item(),
emb.optimizer_args.weight_decay,
step,
emb.current_device,
Expand Down
51 changes: 36 additions & 15 deletions fbgemm_gpu/test/tbe/training/forward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
round_up,
to_device,
)
from hypothesis import assume, given, HealthCheck, settings, Verbosity
from hypothesis import given, HealthCheck, settings, Verbosity

from .. import common # noqa E402
from ..common import (
Expand Down Expand Up @@ -76,6 +76,9 @@
"test_faketensor__test_forward_gpu_uvm_cache_int8": [
unittest.skip("Operator not implemented for Meta tensors"),
],
"test_faketensor__test_forward_cpu_fp32": [
unittest.skip("Operator not implemented for Meta tensors"),
],
# TODO: Make it compatible with opcheck tests
"test_faketensor__test_forward_gpu_uvm_cache_fp16": [
unittest.skip(
Expand Down Expand Up @@ -129,26 +132,44 @@ def execute_forward_( # noqa C901
use_experimental_tbe: bool,
) -> None:
# NOTE: cache is not applicable to CPU version.
assume(not use_cpu or not use_cache)
# assume(not use_cpu or not use_cache)
if use_cpu and use_cache:
return
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
assume(not use_cpu or T * B * L * D <= 2048)
# assume(not use_cpu or T * B * L * D <= 2048)
if use_cpu and T * B * L * D > 2048:
return
# NOTE: CPU does not support FP16.
assume(not (use_cpu and weights_precision == SparseType.FP16))
# assume(not (use_cpu and weights_precision == SparseType.FP16))
if use_cpu and weights_precision == SparseType.FP16:
return

# NOTE: weighted operation can be done only for SUM.
assume(pooling_mode == PoolingMode.SUM or not weighted)
# assume(pooling_mode == PoolingMode.SUM or not weighted)
if pooling_mode != PoolingMode.SUM and weighted:
return
# NOTE: No bag ops only work on GPUs, no mixed
assume(not use_cpu or pooling_mode != PoolingMode.NONE)
assume(not mixed or pooling_mode != PoolingMode.NONE)
# assume(not use_cpu or pooling_mode != PoolingMode.NONE)
if use_cpu and pooling_mode == PoolingMode.NONE:
return
# assume(not mixed or pooling_mode != PoolingMode.NONE)
if mixed and pooling_mode == PoolingMode.NONE:
return
# TODO: Support these cases
assume(
not mixed_B
or (
weights_precision != SparseType.INT8
and output_dtype != SparseType.INT8
and pooling_mode != PoolingMode.NONE
)
)
# assume(
# not mixed_B
# or (
# weights_precision != SparseType.INT8
# and output_dtype != SparseType.INT8
# and pooling_mode != PoolingMode.NONE
# )
# )
if mixed_B and (
weights_precision == SparseType.INT8
or output_dtype == SparseType.INT8
or pooling_mode == PoolingMode.NONE
):
return

emb_op = SplitTableBatchedEmbeddingBagsCodegen
if pooling_mode == PoolingMode.SUM:
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/test/tbe/utils/split_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def test_update_hyper_parameters(self) -> None:
} | {"lr": 1.0, "lower_bound": 2.0}
cc.update_hyper_parameters(updated_parameters)
self.assertAlmostEqual(
cc.optimizer_args.learning_rate, updated_parameters["lr"]
cc.optimizer_args.learning_rate_tensor.item(), updated_parameters["lr"]
)
self.assertAlmostEqual(cc.optimizer_args.eps, updated_parameters["eps"])
self.assertAlmostEqual(cc.optimizer_args.beta1, updated_parameters["beta1"])
Expand Down

0 comments on commit a9c9021

Please sign in to comment.