Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VBE KJT support to EmbeddingCollection #2047

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 129 additions & 15 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import logging
import warnings
from collections import defaultdict, deque, OrderedDict
from dataclasses import dataclass, field
from itertools import accumulate
from typing import Any, cast, Dict, List, MutableMapping, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -72,10 +71,15 @@
EmbeddingCollection,
EmbeddingCollectionInterface,
)
from torchrec.modules.utils import construct_jagged_tensors
from torchrec.modules.utils import construct_jagged_tensors, SequenceVBEContext
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
from torchrec.sparse.jagged_tensor import (
_pin_and_move,
_to_offsets,
JaggedTensor,
KeyedJaggedTensor,
)

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -323,6 +327,30 @@ def create_sharding_infos_by_sharding_device_group(
return sharding_type_device_group_to_sharding_infos


def pad_vbe_kjt_lengths(features: KeyedJaggedTensor) -> KeyedJaggedTensor:
max_stride = max(features.stride_per_key())
new_lengths = torch.zeros(
max_stride * len(features.keys()),
device=features.device(),
dtype=features.lengths().dtype,
)
cum_stride = 0
for i, stride in enumerate(features.stride_per_key()):
new_lengths[i * max_stride : i * max_stride + stride] = features.lengths()[
cum_stride : cum_stride + stride
]
cum_stride += stride

return KeyedJaggedTensor(
keys=features.keys(),
values=features.values(),
lengths=new_lengths,
stride=max_stride,
length_per_key=features.length_per_key(),
offset_per_key=features.offset_per_key(),
)


class EmbeddingCollectionContext(Multistreamable):
# Torch Dynamo does not support default_factory=list:
# https://github.com/pytorch/pytorch/issues/120108
Expand All @@ -333,11 +361,13 @@ def __init__(
sharding_contexts: Optional[List[SequenceShardingContext]] = None,
input_features: Optional[List[KeyedJaggedTensor]] = None,
reverse_indices: Optional[List[torch.Tensor]] = None,
seq_vbe_ctx: Optional[List[SequenceVBEContext]] = None,
) -> None:
super().__init__()
self.sharding_contexts: List[SequenceShardingContext] = sharding_contexts or []
self.input_features: List[KeyedJaggedTensor] = input_features or []
self.reverse_indices: List[torch.Tensor] = reverse_indices or []
self.seq_vbe_ctx: List[SequenceVBEContext] = seq_vbe_ctx or []

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for ctx in self.sharding_contexts:
Expand All @@ -346,6 +376,8 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
f.record_stream(stream)
for r in self.reverse_indices:
r.record_stream(stream)
for s in self.seq_vbe_ctx:
s.record_stream(stream)


class EmbeddingCollectionAwaitable(LazyAwaitable[Dict[str, JaggedTensor]]):
Expand Down Expand Up @@ -385,6 +417,9 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
if i >= len(self._ctx.reverse_indices)
else self._ctx.reverse_indices[i]
)
seq_vbe_ctx = (
None if i >= len(self._ctx.seq_vbe_ctx) else self._ctx.seq_vbe_ctx[i]
)
jt_dict.update(
construct_jagged_tensors(
embeddings=w.wait(),
Expand All @@ -394,6 +429,7 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
features_to_permute_indices=self._features_to_permute_indices,
original_features=original_features,
reverse_indices=reverse_indices,
seq_vbe_ctx=seq_vbe_ctx,
)
)
return jt_dict
Expand Down Expand Up @@ -506,6 +542,7 @@ def __init__(
module.embedding_configs(), table_name_to_parameter_sharding
)
self._need_indices: bool = module.need_indices()
self._inverse_indices_permute_per_sharding: Optional[List[torch.Tensor]] = None

for index, (sharding, lookup) in enumerate(
zip(
Expand Down Expand Up @@ -847,11 +884,9 @@ def _create_output_dist(

def _dedup_indices(
self,
input_feature_splits: List[KeyedJaggedTensor],
ctx: EmbeddingCollectionContext,
input_feature_splits: List[KeyedJaggedTensor],
) -> List[KeyedJaggedTensor]:
if not self._use_index_dedup:
return input_feature_splits
with record_function("## dedup_ec_indices ##"):
features_by_shards = []
for i, input_feature in enumerate(input_feature_splits):
Expand Down Expand Up @@ -881,30 +916,107 @@ def _dedup_indices(

return features_by_shards

def _create_inverse_indices_permute_per_sharding(
self, inverse_indices: Tuple[List[str], torch.Tensor]
) -> None:
if (
len(self._embedding_names_per_sharding) == 1
and self._embedding_names_per_sharding[0] == inverse_indices[0]
):
return
index_per_name = {name: i for i, name in enumerate(inverse_indices[0])}
permute_per_sharding = []
for emb_names in self._embedding_names_per_sharding:
permute = _pin_and_move(
torch.tensor(
[index_per_name[name.split("@")[0]] for name in emb_names]
),
inverse_indices[1].device,
)
permute_per_sharding.append(permute)
self._inverse_indices_permute_per_sharding = permute_per_sharding

def _compute_sequence_vbe_context(
self,
ctx: EmbeddingCollectionContext,
unpadded_features: KeyedJaggedTensor,
) -> None:
assert (
unpadded_features.inverse_indices_or_none() is not None
), "inverse indices must be provided from KJT if using variable batch size per feature."

inverse_indices = unpadded_features.inverse_indices()
stride = inverse_indices[1].numel() // len(inverse_indices[0])
if self._inverse_indices_permute_per_sharding is None:
self._create_inverse_indices_permute_per_sharding(inverse_indices)

if self._features_order:
unpadded_features = unpadded_features.permute(
self._features_order,
self._features_order_tensor,
)

features_by_sharding = unpadded_features.split(self._feature_splits)
for i, feature in enumerate(features_by_sharding):
if self._inverse_indices_permute_per_sharding is not None:
permute = self._inverse_indices_permute_per_sharding[i]
permuted_indices = torch.index_select(inverse_indices[1], 0, permute)
else:
permuted_indices = inverse_indices[1]
stride_per_key = _pin_and_move(
torch.tensor(feature.stride_per_key()), feature.device()
)
offsets = _to_offsets(stride_per_key)[:-1].unsqueeze(-1)
recat = (permuted_indices + offsets).flatten().int()

if self._need_indices:
reindexed_lengths, reindexed_values, _ = (
torch.ops.fbgemm.permute_1D_sparse_data(
recat,
feature.lengths(),
feature.values(),
)
)
else:
reindexed_lengths = torch.index_select(feature.lengths(), 0, recat)
reindexed_values = None

reindexed_lengths = reindexed_lengths.view(-1, stride)
reindexed_length_per_key = torch.sum(reindexed_lengths, dim=1).tolist()

ctx.seq_vbe_ctx.append(
SequenceVBEContext(
recat=recat,
unpadded_lengths=feature.lengths(),
reindexed_lengths=reindexed_lengths,
reindexed_length_per_key=reindexed_length_per_key,
reindexed_values=reindexed_values,
)
)

# pyre-ignore [14]
def input_dist(
self,
ctx: EmbeddingCollectionContext,
features: KeyedJaggedTensor,
) -> Awaitable[Awaitable[KJTList]]:
if features.variable_stride_per_key():
raise ValueError(
"Variable batch per feature is not supported with EmbeddingCollection"
)
if self._has_uninitialized_input_dist:
self._create_input_dist(input_feature_names=features.keys())
self._has_uninitialized_input_dist = False
with torch.no_grad():
unpadded_features = None
if features.variable_stride_per_key():
unpadded_features = features
features = pad_vbe_kjt_lengths(unpadded_features)

if self._features_order:
features = features.permute(
self._features_order,
self._features_order_tensor,
)

input_feature_splits = features.split(
self._feature_splits,
)
features_by_shards = self._dedup_indices(input_feature_splits, ctx)
features_by_shards = features.split(self._feature_splits)
if self._use_index_dedup:
features_by_shards = self._dedup_indices(ctx, features_by_shards)

awaitables = []
for input_dist, features in zip(self._input_dists, features_by_shards):
Expand All @@ -919,6 +1031,8 @@ def input_dist(
),
)
)
if unpadded_features is not None:
self._compute_sequence_vbe_context(ctx, unpadded_features)
return KJTListSplitsAwaitable(awaitables, ctx)

def compute(
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/tests/test_sequence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(
kernel_type: str,
qcomms_config: Optional[QCommsConfig] = None,
fused_params: Optional[Dict[str, Any]] = None,
use_index_dedup: bool = False,
) -> None:
self._sharding_type = sharding_type
self._kernel_type = kernel_type
Expand All @@ -321,6 +322,7 @@ def __init__(
super().__init__(
fused_params=fused_params,
qcomm_codecs_registry=qcomm_codecs_registry,
use_index_dedup=use_index_dedup,
)

"""
Expand Down
27 changes: 20 additions & 7 deletions torchrec/distributed/tests/test_sequence_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ class SequenceModelParallelTest(MultiProcessTestBase):
},
]
),
variable_batch_size=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
def test_sharding_nccl_rw(
self,
sharding_type: str,
Expand All @@ -72,6 +73,7 @@ def test_sharding_nccl_rw(
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
variable_batch_size: bool,
) -> None:
assume(
apply_optimizer_in_backward_config is None
Expand All @@ -88,6 +90,7 @@ def test_sharding_nccl_rw(
backend="nccl",
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
)

@unittest.skipIf(
Expand Down Expand Up @@ -152,8 +155,9 @@ def test_sharding_nccl_dp(
},
]
),
variable_batch_size=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
def test_sharding_nccl_tw(
self,
sharding_type: str,
Expand All @@ -162,6 +166,7 @@ def test_sharding_nccl_tw(
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
variable_batch_size: bool,
) -> None:
assume(
apply_optimizer_in_backward_config is None
Expand All @@ -178,7 +183,7 @@ def test_sharding_nccl_tw(
backend="nccl",
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=False,
variable_batch_size=variable_batch_size,
)

@unittest.skipIf(
Expand All @@ -203,15 +208,17 @@ def test_sharding_nccl_tw(
},
]
),
variable_batch_size=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
def test_sharding_nccl_cw(
self,
sharding_type: str,
kernel_type: str,
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
variable_batch_size: bool,
) -> None:
assume(
apply_optimizer_in_backward_config is None
Expand All @@ -230,7 +237,7 @@ def test_sharding_nccl_cw(
for table in self.tables
},
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=False,
variable_batch_size=variable_batch_size,
)

@unittest.skipIf(
Expand All @@ -246,25 +253,28 @@ def test_sharding_nccl_cw(
ShardingType.ROW_WISE.value,
]
),
index_dedup=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
def test_sharding_variable_batch(
self,
sharding_type: str,
index_dedup: bool,
) -> None:
self._test_sharding(
sharders=[
TestEmbeddingCollectionSharder(
sharding_type=sharding_type,
kernel_type=EmbeddingComputeKernel.FUSED.value,
use_index_dedup=index_dedup,
)
],
backend="nccl",
constraints={
table.name: ParameterConstraints(min_partition=4)
for table in self.tables
},
variable_batch_size=True,
variable_batch_per_feature=True,
)

# pyre-fixme[56]
Expand Down Expand Up @@ -347,6 +357,7 @@ def _test_sharding(
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
] = None,
variable_batch_size: bool = False,
variable_batch_per_feature: bool = False,
) -> None:
self._run_multi_process_test(
callable=sharding_single_rank_test,
Expand All @@ -362,4 +373,6 @@ def _test_sharding(
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
)
Loading
Loading