Skip to content

Commit

Permalink
Handle KJT with zero batch size for Column-Wise sharded EmbeddingBagC…
Browse files Browse the repository at this point in the history
…ollection

Summary:
Support new use case where some ranks have no embedding ids to look up i.e. `kjt.values() == torch.tensor([])`. In such cases, the expectation is for the returned embedding to be of shape `[0,emb_dim]`, as 0 is a valid tensor dimension. 

This diff adds support for this use case, now only for Column-Wise (CW) sharding + EmbeddingBagCollection use case.

Changes in this diff:

1) CW sharding and VLE uses FBGEMM kernel to permute pooled embeddings, which doesn't work for 0-dim tensors. If tensor has no elements (i.e. 0-dim), permute doesn't do anything so we can return early. We could also support this via an `if` statement in TorchRec codebase, but we run into FX tracing issues in this case.
2) In comm_ops.py, `[output.view(B_local, -1) for output in outputs_by_rank]` isn't supported if `output` tensor has 0 dim as it will error out with `RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous`. Instead, we can explicitly create a view with the bsz and emb_dim dimensions which will hold even if `output` is 0-dim (outputs will have shape `[0,emb_dim_for_rank]`
3) Added new unit test cases

Reviewed By: dstaay-fb

Differential Revision: D69156551
  • Loading branch information
sarckk authored and facebook-github-bot committed Feb 18, 2025
1 parent ea1cc27 commit 1cb8152
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 27 deletions.
9 changes: 8 additions & 1 deletion torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,7 @@ def forward(
outputs_by_rank = sharded_output_embeddings.split(
[B_local * D_rank_sum for D_rank_sum in padded_dim_sum_per_rank]
)
final_dim_sum_per_rank = padded_dim_sum_per_rank
if (
myreq.qcomm_ctx is not None
and myreq.qcomm_ctx.padded_dim_sum_per_rank is not None
Expand All @@ -1298,8 +1299,14 @@ def forward(
output.view(B_local, -1)[:, :dim_sum]
for output, dim_sum in zip(outputs_by_rank, dim_sum_per_rank)
]
final_dim_sum_per_rank = dim_sum_per_rank

result = torch.cat(
[output.view(B_local, -1) for output in outputs_by_rank], dim=1
[
output.view(B_local, dim)
for output, dim in zip(outputs_by_rank, final_dim_sum_per_rank)
],
dim=1,
)
return result

Expand Down
7 changes: 6 additions & 1 deletion torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _test_sharding(
pooling: PoolingType = PoolingType.SUM,
data_type: DataType = DataType.FP32,
use_inter_host_allreduce: bool = False,
allow_zero_batch_size: bool = False,
) -> None:
self._build_tables_and_groups(data_type=data_type)
self._run_multi_process_test(
Expand All @@ -172,6 +173,7 @@ def _test_sharding(
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=global_constant_batch,
use_inter_host_allreduce=use_inter_host_allreduce,
allow_zero_batch_size=allow_zero_batch_size,
)


Expand Down Expand Up @@ -333,8 +335,9 @@ def test_sharding_dp(
),
variable_batch_size=st.booleans(),
data_type=st.sampled_from([DataType.FP32, DataType.FP16]),
allow_zero_batch_size=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
@settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None)
def test_sharding_cw(
self,
sharder_type: str,
Expand All @@ -345,6 +348,7 @@ def test_sharding_cw(
],
variable_batch_size: bool,
data_type: DataType,
allow_zero_batch_size: bool,
) -> None:
if (
self.device == torch.device("cpu")
Expand Down Expand Up @@ -377,6 +381,7 @@ def test_sharding_cw(
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
data_type=data_type,
allow_zero_batch_size=allow_zero_batch_size,
)

# pyre-fixme[56]
Expand Down
57 changes: 32 additions & 25 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import random
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Protocol, Tuple, Type, Union

Expand Down Expand Up @@ -317,8 +318,12 @@ def sharding_single_rank_test(
node_group_size: Optional[int] = None,
use_inter_host_allreduce: bool = False,
input_type: str = "kjt", # "kjt" or "td"
allow_zero_batch_size: bool = False,
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
batch_size = (
random.randint(0, batch_size) if allow_zero_batch_size else batch_size
)
# Generate model & inputs.
(global_model, inputs) = gen_model_and_input(
model_class=model_class,
Expand Down Expand Up @@ -464,35 +469,37 @@ def sharding_single_rank_test(
local_input,
)

all_local_pred = []
for _ in range(world_size):
all_local_pred.append(torch.empty_like(local_pred))
dist.all_gather(all_local_pred, local_pred, group=ctx.pg)

# Run second training step of the unsharded model.
assert optim == EmbOptimType.EXACT_SGD
global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1)
# TODO: support non-sharded forward with zero batch size KJT
if not allow_zero_batch_size:
all_local_pred = []
for _ in range(world_size):
all_local_pred.append(torch.empty_like(local_pred))
dist.all_gather(all_local_pred, local_pred, group=ctx.pg)

global_pred = gen_full_pred_after_one_step(
global_model, global_opt, global_input
)
# Run second training step of the unsharded model.
assert optim == EmbOptimType.EXACT_SGD
global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1)

# Compare predictions of sharded vs unsharded models.
if qcomms_config is None:
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
else:
# With quantized comms, we can relax constraints a bit
rtol = 0.003
if CommType.FP8 in [
qcomms_config.forward_precision,
qcomms_config.backward_precision,
]:
rtol = 0.05
atol = global_pred.max().item() * rtol
torch.testing.assert_close(
global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol
global_pred = gen_full_pred_after_one_step(
global_model, global_opt, global_input
)

# Compare predictions of sharded vs unsharded models.
if qcomms_config is None:
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
else:
# With quantized comms, we can relax constraints a bit
rtol = 0.003
if CommType.FP8 in [
qcomms_config.forward_precision,
qcomms_config.backward_precision,
]:
rtol = 0.05
atol = global_pred.max().item() * rtol
torch.testing.assert_close(
global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol
)


def gen_full_pred_after_one_step(
model: nn.Module,
Expand Down

0 comments on commit 1cb8152

Please sign in to comment.