Skip to content

Commit

Permalink
Add VBE KJT support to EmbeddingCollection (#2047)
Browse files Browse the repository at this point in the history
Summary:

- pad VBE kjt lengths to final batch size so that it's compatible with EC kernel.
- works with index dedup
- expands embeddings with vbe inverse indices
- long term solution is to fix seq TBE to not need lengths/batch size info, just length per key

Differential Revision: D51600051
  • Loading branch information
joshuadeng authored and facebook-github-bot committed May 28, 2024
1 parent 3ca8e8b commit cb47e29
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 30 deletions.
147 changes: 133 additions & 14 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import logging
import warnings
from collections import defaultdict, deque, OrderedDict
from dataclasses import dataclass, field
from dataclasses import dataclass
from itertools import accumulate
from typing import Any, cast, Dict, List, MutableMapping, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -72,10 +72,15 @@
EmbeddingCollection,
EmbeddingCollectionInterface,
)
from torchrec.modules.utils import construct_jagged_tensors
from torchrec.modules.utils import construct_jagged_tensors, SequenceVBEContext
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
from torchrec.sparse.jagged_tensor import (
_pin_and_move,
_to_offsets,
JaggedTensor,
KeyedJaggedTensor,
)

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -323,6 +328,33 @@ def create_sharding_infos_by_sharding_device_group(
return sharding_type_device_group_to_sharding_infos


def pad_vbe_kjt_lengths(features: KeyedJaggedTensor) -> KeyedJaggedTensor:
final_stride = features.inverse_indices()[1].numel() // len(
features.inverse_indices()[0]
)
new_lengths = torch.zeros(
final_stride * len(features.keys()),
device=features.device(),
dtype=features.lengths().dtype,
)
cum_stride = 0
for i, stride in enumerate(features.stride_per_key()):
new_lengths[i * final_stride : i * final_stride + stride] = features.lengths()[
cum_stride : cum_stride + stride
]
cum_stride += stride

return KeyedJaggedTensor(
keys=features.keys(),
values=features.values(),
lengths=new_lengths,
stride=final_stride,
length_per_key=features.length_per_key(),
offset_per_key=features.offset_per_key(),
)


@dataclass
class EmbeddingCollectionContext(Multistreamable):
# Torch Dynamo does not support default_factory=list:
# https://github.com/pytorch/pytorch/issues/120108
Expand All @@ -333,11 +365,13 @@ def __init__(
sharding_contexts: Optional[List[SequenceShardingContext]] = None,
input_features: Optional[List[KeyedJaggedTensor]] = None,
reverse_indices: Optional[List[torch.Tensor]] = None,
seq_vbe_ctx: Optional[List[SequenceVBEContext]] = None,
) -> None:
super().__init__()
self.sharding_contexts: List[SequenceShardingContext] = sharding_contexts or []
self.input_features: List[KeyedJaggedTensor] = input_features or []
self.reverse_indices: List[torch.Tensor] = reverse_indices or []
self.seq_vbe_ctx: List[SequenceVBEContext] = seq_vbe_ctx or []

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for ctx in self.sharding_contexts:
Expand All @@ -346,6 +380,8 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
f.record_stream(stream)
for r in self.reverse_indices:
r.record_stream(stream)
for s in self.seq_vbe_ctx:
s.record_stream(stream)


class EmbeddingCollectionAwaitable(LazyAwaitable[Dict[str, JaggedTensor]]):
Expand Down Expand Up @@ -385,6 +421,9 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
if i >= len(self._ctx.reverse_indices)
else self._ctx.reverse_indices[i]
)
seq_vbe_ctx = (
None if i >= len(self._ctx.seq_vbe_ctx) else self._ctx.seq_vbe_ctx[i]
)
jt_dict.update(
construct_jagged_tensors(
embeddings=w.wait(),
Expand All @@ -394,6 +433,7 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
features_to_permute_indices=self._features_to_permute_indices,
original_features=original_features,
reverse_indices=reverse_indices,
seq_vbe_ctx=seq_vbe_ctx,
)
)
return jt_dict
Expand Down Expand Up @@ -506,6 +546,7 @@ def __init__(
module.embedding_configs(), table_name_to_parameter_sharding
)
self._need_indices: bool = module.need_indices()
self._inverse_indices_permute_per_sharding: Optional[List[torch.Tensor]] = None

for index, (sharding, lookup) in enumerate(
zip(
Expand Down Expand Up @@ -847,8 +888,8 @@ def _create_output_dist(

def _dedup_indices(
self,
input_feature_splits: List[KeyedJaggedTensor],
ctx: EmbeddingCollectionContext,
input_feature_splits: List[KeyedJaggedTensor],
) -> List[KeyedJaggedTensor]:
if not self._use_index_dedup:
return input_feature_splits
Expand All @@ -874,37 +915,113 @@ def _dedup_indices(
offsets=offsets,
values=unique_indices,
)

ctx.input_features.append(input_feature)
ctx.reverse_indices.append(reverse_indices)
features_by_shards.append(dedup_features)

return features_by_shards

def _create_inverse_indices_permute_per_sharding(
self, inverse_indices: Tuple[List[str], torch.Tensor]
) -> None:
if (
len(self._embedding_names_per_sharding) == 1
and self._embedding_names_per_sharding[0] == inverse_indices[0]
):
return
index_per_name = {name: i for i, name in enumerate(inverse_indices[0])}
permute_per_sharding = []
for emb_names in self._embedding_names_per_sharding:
permute = _pin_and_move(
torch.tensor(
[index_per_name[name.split("@")[0]] for name in emb_names]
),
inverse_indices[1].device,
)
permute_per_sharding.append(permute)
self._inverse_indices_permute_per_sharding = permute_per_sharding

def _compute_sequence_vbe_context(
self,
ctx: EmbeddingCollectionContext,
unpadded_features: KeyedJaggedTensor,
) -> None:
assert (
unpadded_features.inverse_indices_or_none() is not None
), "inverse indices must be provided from KJT if using variable batch size per feature."

inverse_indices = unpadded_features.inverse_indices()
stride = inverse_indices[1].numel() // len(inverse_indices[0])
if self._inverse_indices_permute_per_sharding is None:
self._create_inverse_indices_permute_per_sharding(inverse_indices)

if self._features_order:
features = unpadded_features.permute(
self._features_order,
self._features_order_tensor,
)

features_by_sharding = features.split(self._feature_splits)
for i, feature in enumerate(features_by_sharding):
if self._inverse_indices_permute_per_sharding is not None:
permute = self._inverse_indices_permute_per_sharding[i]
permuted_indices = torch.index_select(inverse_indices[1], 0, permute)
else:
permuted_indices = inverse_indices[1]
stride_per_key = _pin_and_move(
torch.tensor(feature.stride_per_key()), feature.device()
)
offsets = _to_offsets(stride_per_key)[:-1].unsqueeze(-1)
recat = (permuted_indices + offsets).flatten().int()

if self._need_indices:
reindexed_lengths, reindexed_values, _ = (
torch.ops.fbgemm.permute_1D_sparse_data(
recat,
feature.lengths(),
feature.values(),
)
)
else:
reindexed_lengths = torch.index_select(feature.lengths(), 0, recat)
reindexed_values = None

reindexed_lengths = reindexed_lengths.view(-1, stride)
reindexed_length_per_key = torch.sum(reindexed_lengths, dim=1).tolist()

ctx.seq_vbe_ctx.append(
SequenceVBEContext(
recat=recat,
unpadded_lengths=feature.lengths(),
reindexed_lengths=reindexed_lengths,
reindexed_length_per_key=reindexed_length_per_key,
reindexed_values=reindexed_values,
)
)

# pyre-ignore [14]
def input_dist(
self,
ctx: EmbeddingCollectionContext,
features: KeyedJaggedTensor,
) -> Awaitable[Awaitable[KJTList]]:
if features.variable_stride_per_key():
raise ValueError(
"Variable batch per feature is not supported with EmbeddingCollection"
)
if self._has_uninitialized_input_dist:
self._create_input_dist(input_feature_names=features.keys())
self._has_uninitialized_input_dist = False
with torch.no_grad():
unpadded_features = None
if features.variable_stride_per_key():
unpadded_features = features
features = pad_vbe_kjt_lengths(unpadded_features)

if self._features_order:
features = features.permute(
self._features_order,
self._features_order_tensor,
)

input_feature_splits = features.split(
self._feature_splits,
)
features_by_shards = self._dedup_indices(input_feature_splits, ctx)
features_by_shards = features.split(self._feature_splits)
if self._use_index_dedup:
features_by_shards = self._dedup_indices(ctx, features_by_shards)

awaitables = []
for input_dist, features in zip(self._input_dists, features_by_shards):
Expand All @@ -919,6 +1036,8 @@ def input_dist(
),
)
)
if unpadded_features is not None:
self._compute_sequence_vbe_context(ctx, unpadded_features)
return KJTListSplitsAwaitable(awaitables, ctx)

def compute(
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/tests/test_sequence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(
kernel_type: str,
qcomms_config: Optional[QCommsConfig] = None,
fused_params: Optional[Dict[str, Any]] = None,
use_index_dedup: bool = False,
) -> None:
self._sharding_type = sharding_type
self._kernel_type = kernel_type
Expand All @@ -321,6 +322,7 @@ def __init__(
super().__init__(
fused_params=fused_params,
qcomm_codecs_registry=qcomm_codecs_registry,
use_index_dedup=use_index_dedup,
)

"""
Expand Down
25 changes: 19 additions & 6 deletions torchrec/distributed/tests/test_sequence_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ class SequenceModelParallelTest(MultiProcessTestBase):
},
]
),
variable_batch_size=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
def test_sharding_nccl_rw(
self,
sharding_type: str,
Expand All @@ -72,6 +73,7 @@ def test_sharding_nccl_rw(
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
variable_batch_size: bool,
) -> None:
assume(
apply_optimizer_in_backward_config is None
Expand All @@ -88,6 +90,7 @@ def test_sharding_nccl_rw(
backend="nccl",
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
)

@unittest.skipIf(
Expand Down Expand Up @@ -152,8 +155,9 @@ def test_sharding_nccl_dp(
},
]
),
variable_batch_size=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
def test_sharding_nccl_tw(
self,
sharding_type: str,
Expand All @@ -162,6 +166,7 @@ def test_sharding_nccl_tw(
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
variable_batch_size: bool,
) -> None:
assume(
apply_optimizer_in_backward_config is None
Expand All @@ -178,7 +183,7 @@ def test_sharding_nccl_tw(
backend="nccl",
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=False,
variable_batch_size=variable_batch_size,
)

@unittest.skipIf(
Expand All @@ -203,15 +208,17 @@ def test_sharding_nccl_tw(
},
]
),
variable_batch_size=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
def test_sharding_nccl_cw(
self,
sharding_type: str,
kernel_type: str,
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
variable_batch_size: bool,
) -> None:
assume(
apply_optimizer_in_backward_config is None
Expand All @@ -230,7 +237,7 @@ def test_sharding_nccl_cw(
for table in self.tables
},
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=False,
variable_batch_size=variable_batch_size,
)

@unittest.skipIf(
Expand All @@ -246,25 +253,28 @@ def test_sharding_nccl_cw(
ShardingType.ROW_WISE.value,
]
),
index_dedup=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
def test_sharding_variable_batch(
self,
sharding_type: str,
index_dedup: bool,
) -> None:
self._test_sharding(
sharders=[
TestEmbeddingCollectionSharder(
sharding_type=sharding_type,
kernel_type=EmbeddingComputeKernel.FUSED.value,
use_index_dedup=index_dedup,
)
],
backend="nccl",
constraints={
table.name: ParameterConstraints(min_partition=4)
for table in self.tables
},
variable_batch_size=True,
variable_batch_per_feature=True,
)

# pyre-fixme[56]
Expand Down Expand Up @@ -347,6 +357,7 @@ def _test_sharding(
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
] = None,
variable_batch_size: bool = False,
variable_batch_per_feature: bool = False,
) -> None:
self._run_multi_process_test(
callable=sharding_single_rank_test,
Expand All @@ -362,4 +373,6 @@ def _test_sharding(
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
)
Loading

0 comments on commit cb47e29

Please sign in to comment.