diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 412d0811a..2b8e79eae 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -12,7 +12,7 @@ import logging import warnings from collections import defaultdict, deque, OrderedDict -from dataclasses import dataclass, field +from dataclasses import dataclass from itertools import accumulate from typing import Any, cast, Dict, List, MutableMapping, Optional, Tuple, Type, Union @@ -72,10 +72,15 @@ EmbeddingCollection, EmbeddingCollectionInterface, ) -from torchrec.modules.utils import construct_jagged_tensors +from torchrec.modules.utils import construct_jagged_tensors, SequenceVBEContext from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import ( + _pin_and_move, + _to_offsets, + JaggedTensor, + KeyedJaggedTensor, +) try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -323,6 +328,33 @@ def create_sharding_infos_by_sharding_device_group( return sharding_type_device_group_to_sharding_infos +def pad_vbe_kjt_lengths(features: KeyedJaggedTensor) -> KeyedJaggedTensor: + final_stride = features.inverse_indices()[1].numel() // len( + features.inverse_indices()[0] + ) + new_lengths = torch.zeros( + final_stride * len(features.keys()), + device=features.device(), + dtype=features.lengths().dtype, + ) + cum_stride = 0 + for i, stride in enumerate(features.stride_per_key()): + new_lengths[i * final_stride : i * final_stride + stride] = features.lengths()[ + cum_stride : cum_stride + stride + ] + cum_stride += stride + + return KeyedJaggedTensor( + keys=features.keys(), + values=features.values(), + lengths=new_lengths, + stride=final_stride, + length_per_key=features.length_per_key(), + offset_per_key=features.offset_per_key(), + ) + + +@dataclass class EmbeddingCollectionContext(Multistreamable): # Torch Dynamo does not support default_factory=list: # https://github.com/pytorch/pytorch/issues/120108 @@ -333,11 +365,13 @@ def __init__( sharding_contexts: Optional[List[SequenceShardingContext]] = None, input_features: Optional[List[KeyedJaggedTensor]] = None, reverse_indices: Optional[List[torch.Tensor]] = None, + seq_vbe_ctx: Optional[List[SequenceVBEContext]] = None, ) -> None: super().__init__() self.sharding_contexts: List[SequenceShardingContext] = sharding_contexts or [] self.input_features: List[KeyedJaggedTensor] = input_features or [] self.reverse_indices: List[torch.Tensor] = reverse_indices or [] + self.seq_vbe_ctx: List[SequenceVBEContext] = seq_vbe_ctx or [] def record_stream(self, stream: torch.cuda.streams.Stream) -> None: for ctx in self.sharding_contexts: @@ -346,6 +380,8 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None: f.record_stream(stream) for r in self.reverse_indices: r.record_stream(stream) + for s in self.seq_vbe_ctx: + s.record_stream(stream) class EmbeddingCollectionAwaitable(LazyAwaitable[Dict[str, JaggedTensor]]): @@ -385,6 +421,9 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]: if i >= len(self._ctx.reverse_indices) else self._ctx.reverse_indices[i] ) + seq_vbe_ctx = ( + None if i >= len(self._ctx.seq_vbe_ctx) else self._ctx.seq_vbe_ctx[i] + ) jt_dict.update( construct_jagged_tensors( embeddings=w.wait(), @@ -394,6 +433,7 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]: features_to_permute_indices=self._features_to_permute_indices, original_features=original_features, reverse_indices=reverse_indices, + seq_vbe_ctx=seq_vbe_ctx, ) ) return jt_dict @@ -506,6 +546,7 @@ def __init__( module.embedding_configs(), table_name_to_parameter_sharding ) self._need_indices: bool = module.need_indices() + self._inverse_indices_permute_per_sharding: Optional[List[torch.Tensor]] = None for index, (sharding, lookup) in enumerate( zip( @@ -847,8 +888,8 @@ def _create_output_dist( def _dedup_indices( self, - input_feature_splits: List[KeyedJaggedTensor], ctx: EmbeddingCollectionContext, + input_feature_splits: List[KeyedJaggedTensor], ) -> List[KeyedJaggedTensor]: if not self._use_index_dedup: return input_feature_splits @@ -874,37 +915,113 @@ def _dedup_indices( offsets=offsets, values=unique_indices, ) - ctx.input_features.append(input_feature) ctx.reverse_indices.append(reverse_indices) features_by_shards.append(dedup_features) return features_by_shards + def _create_inverse_indices_permute_per_sharding( + self, inverse_indices: Tuple[List[str], torch.Tensor] + ) -> None: + if ( + len(self._embedding_names_per_sharding) == 1 + and self._embedding_names_per_sharding[0] == inverse_indices[0] + ): + return + index_per_name = {name: i for i, name in enumerate(inverse_indices[0])} + permute_per_sharding = [] + for emb_names in self._embedding_names_per_sharding: + permute = _pin_and_move( + torch.tensor( + [index_per_name[name.split("@")[0]] for name in emb_names] + ), + inverse_indices[1].device, + ) + permute_per_sharding.append(permute) + self._inverse_indices_permute_per_sharding = permute_per_sharding + + def _compute_sequence_vbe_context( + self, + ctx: EmbeddingCollectionContext, + unpadded_features: KeyedJaggedTensor, + ) -> None: + assert ( + unpadded_features.inverse_indices_or_none() is not None + ), "inverse indices must be provided from KJT if using variable batch size per feature." + + inverse_indices = unpadded_features.inverse_indices() + stride = inverse_indices[1].numel() // len(inverse_indices[0]) + if self._inverse_indices_permute_per_sharding is None: + self._create_inverse_indices_permute_per_sharding(inverse_indices) + + if self._features_order: + features = unpadded_features.permute( + self._features_order, + self._features_order_tensor, + ) + + features_by_sharding = features.split(self._feature_splits) + for i, feature in enumerate(features_by_sharding): + if self._inverse_indices_permute_per_sharding is not None: + permute = self._inverse_indices_permute_per_sharding[i] + permuted_indices = torch.index_select(inverse_indices[1], 0, permute) + else: + permuted_indices = inverse_indices[1] + stride_per_key = _pin_and_move( + torch.tensor(feature.stride_per_key()), feature.device() + ) + offsets = _to_offsets(stride_per_key)[:-1].unsqueeze(-1) + recat = (permuted_indices + offsets).flatten().int() + + if self._need_indices: + reindexed_lengths, reindexed_values, _ = ( + torch.ops.fbgemm.permute_1D_sparse_data( + recat, + feature.lengths(), + feature.values(), + ) + ) + else: + reindexed_lengths = torch.index_select(feature.lengths(), 0, recat) + reindexed_values = None + + reindexed_lengths = reindexed_lengths.view(-1, stride) + reindexed_length_per_key = torch.sum(reindexed_lengths, dim=1).tolist() + + ctx.seq_vbe_ctx.append( + SequenceVBEContext( + recat=recat, + unpadded_lengths=feature.lengths(), + reindexed_lengths=reindexed_lengths, + reindexed_length_per_key=reindexed_length_per_key, + reindexed_values=reindexed_values, + ) + ) + # pyre-ignore [14] def input_dist( self, ctx: EmbeddingCollectionContext, features: KeyedJaggedTensor, ) -> Awaitable[Awaitable[KJTList]]: - if features.variable_stride_per_key(): - raise ValueError( - "Variable batch per feature is not supported with EmbeddingCollection" - ) if self._has_uninitialized_input_dist: self._create_input_dist(input_feature_names=features.keys()) self._has_uninitialized_input_dist = False with torch.no_grad(): + unpadded_features = None + if features.variable_stride_per_key(): + unpadded_features = features + features = pad_vbe_kjt_lengths(unpadded_features) + if self._features_order: features = features.permute( self._features_order, self._features_order_tensor, ) - - input_feature_splits = features.split( - self._feature_splits, - ) - features_by_shards = self._dedup_indices(input_feature_splits, ctx) + features_by_shards = features.split(self._feature_splits) + if self._use_index_dedup: + features_by_shards = self._dedup_indices(ctx, features_by_shards) awaitables = [] for input_dist, features in zip(self._input_dists, features_by_shards): @@ -919,6 +1036,8 @@ def input_dist( ), ) ) + if unpadded_features is not None: + self._compute_sequence_vbe_context(ctx, unpadded_features) return KJTListSplitsAwaitable(awaitables, ctx) def compute( diff --git a/torchrec/distributed/tests/test_sequence_model.py b/torchrec/distributed/tests/test_sequence_model.py index 9ce5784a2..f30239903 100644 --- a/torchrec/distributed/tests/test_sequence_model.py +++ b/torchrec/distributed/tests/test_sequence_model.py @@ -305,6 +305,7 @@ def __init__( kernel_type: str, qcomms_config: Optional[QCommsConfig] = None, fused_params: Optional[Dict[str, Any]] = None, + use_index_dedup: bool = False, ) -> None: self._sharding_type = sharding_type self._kernel_type = kernel_type @@ -321,6 +322,7 @@ def __init__( super().__init__( fused_params=fused_params, qcomm_codecs_registry=qcomm_codecs_registry, + use_index_dedup=use_index_dedup, ) """ diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index 463bb7974..b2229537f 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -62,8 +62,9 @@ class SequenceModelParallelTest(MultiProcessTestBase): }, ] ), + variable_batch_size=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_nccl_rw( self, sharding_type: str, @@ -72,6 +73,7 @@ def test_sharding_nccl_rw( apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], + variable_batch_size: bool, ) -> None: assume( apply_optimizer_in_backward_config is None @@ -88,6 +90,7 @@ def test_sharding_nccl_rw( backend="nccl", qcomms_config=qcomms_config, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, ) @unittest.skipIf( @@ -152,8 +155,9 @@ def test_sharding_nccl_dp( }, ] ), + variable_batch_size=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_nccl_tw( self, sharding_type: str, @@ -162,6 +166,7 @@ def test_sharding_nccl_tw( apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], + variable_batch_size: bool, ) -> None: assume( apply_optimizer_in_backward_config is None @@ -178,7 +183,7 @@ def test_sharding_nccl_tw( backend="nccl", qcomms_config=qcomms_config, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, - variable_batch_size=False, + variable_batch_size=variable_batch_size, ) @unittest.skipIf( @@ -203,8 +208,9 @@ def test_sharding_nccl_tw( }, ] ), + variable_batch_size=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_nccl_cw( self, sharding_type: str, @@ -212,6 +218,7 @@ def test_sharding_nccl_cw( apply_optimizer_in_backward_config: Optional[ Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ], + variable_batch_size: bool, ) -> None: assume( apply_optimizer_in_backward_config is None @@ -230,7 +237,7 @@ def test_sharding_nccl_cw( for table in self.tables }, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, - variable_batch_size=False, + variable_batch_size=variable_batch_size, ) @unittest.skipIf( @@ -246,17 +253,20 @@ def test_sharding_nccl_cw( ShardingType.ROW_WISE.value, ] ), + index_dedup=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_variable_batch( self, sharding_type: str, + index_dedup: bool, ) -> None: self._test_sharding( sharders=[ TestEmbeddingCollectionSharder( sharding_type=sharding_type, kernel_type=EmbeddingComputeKernel.FUSED.value, + use_index_dedup=index_dedup, ) ], backend="nccl", @@ -264,7 +274,7 @@ def test_sharding_variable_batch( table.name: ParameterConstraints(min_partition=4) for table in self.tables }, - variable_batch_size=True, + variable_batch_per_feature=True, ) # pyre-fixme[56] @@ -347,6 +357,7 @@ def _test_sharding( Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] ] = None, variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, ) -> None: self._run_multi_process_test( callable=sharding_single_rank_test, @@ -362,4 +373,6 @@ def _test_sharding( qcomms_config=qcomms_config, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=True, ) diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index 1d0eb79ee..0125a2a1e 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -10,11 +10,17 @@ import copy import threading from collections import defaultdict +from dataclasses import dataclass from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch from torch.profiler import record_function -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import ( + _permute_tensor_by_segments, + JaggedTensor, + KeyedJaggedTensor, +) +from torchrec.streamable import Multistreamable lib = torch.library.Library("custom", "FRAGMENT") @@ -55,6 +61,22 @@ class OpRegistryState: operator_registry_state = OpRegistryState() +@dataclass +class SequenceVBEContext(Multistreamable): + recat: torch.Tensor + unpadded_lengths: torch.Tensor + reindexed_lengths: torch.Tensor + reindexed_length_per_key: List[int] + reindexed_values: Optional[torch.Tensor] = None + + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + self.recat.record_stream(stream) + self.unpadded_lengths.record_stream(stream) + self.reindexed_lengths.record_stream(stream) + if self.reindexed_values is not None: + self.reindexed_values.record_stream(stream) + + @torch.fx.wrap def _fx_to_list(tensor: torch.Tensor) -> List[int]: return tensor.long().tolist() @@ -163,6 +185,34 @@ def convert_list_of_modules_to_modulelist( ) +def _vbe_reindex( + embeddings: torch.Tensor, + seq_vbe_ctx: SequenceVBEContext, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], Optional[torch.Tensor]]: + """ + Reindexes embeddings for variable batch size per feature scenarios. + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[int], torch.Tensor]: the reindexed + embeddings, lengths, length_per_key, and values + """ + dim = embeddings.shape[1] + reindexed_embeddings, _ = _permute_tensor_by_segments( + embeddings.flatten(), + seq_vbe_ctx.unpadded_lengths * dim, + seq_vbe_ctx.recat, + ) + reindexed_embeddings = reindexed_embeddings.view(-1, dim) + # lengths must be of shape (len(keys), stride) + assert len(seq_vbe_ctx.reindexed_lengths.shape) == 2 + return ( + reindexed_embeddings, + seq_vbe_ctx.reindexed_lengths, + seq_vbe_ctx.reindexed_length_per_key, + seq_vbe_ctx.reindexed_values, + ) + + def construct_jagged_tensors( embeddings: torch.Tensor, features: KeyedJaggedTensor, @@ -171,6 +221,7 @@ def construct_jagged_tensors( features_to_permute_indices: Optional[Dict[str, List[int]]] = None, original_features: Optional[KeyedJaggedTensor] = None, reverse_indices: Optional[torch.Tensor] = None, + seq_vbe_ctx: Optional[SequenceVBEContext] = None, ) -> Dict[str, JaggedTensor]: with record_function("## construct_jagged_tensors ##"): if original_features is not None: @@ -179,16 +230,24 @@ def construct_jagged_tensors( embeddings = torch.index_select( embeddings, 0, reverse_indices.to(torch.int32) ) - ret: Dict[str, JaggedTensor] = {} - stride = features.stride() - length_per_key = features.length_per_key() - values = features.values() - lengths = features.lengths().view(-1, stride) - lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0) + if seq_vbe_ctx is not None: + embeddings, lengths, length_per_key, values = _vbe_reindex( + embeddings=embeddings, seq_vbe_ctx=seq_vbe_ctx + ) + else: + lengths = features.lengths().view(-1, features.stride()) + length_per_key = features.length_per_key() + values = features.values() + + lengths_tuple = torch.unbind(lengths, dim=0) embeddings_list = torch.split(embeddings, length_per_key, dim=0) - values_list = torch.split(values, length_per_key) if need_indices else None + values_list = ( + torch.split(values, length_per_key) + if need_indices and values is not None + else None + ) key_indices = defaultdict(list) for i, key in enumerate(embedding_names): @@ -207,8 +266,11 @@ def construct_jagged_tensors( if len(indices) == 1 else torch.cat([embeddings_list[i] for i in indices], dim=1) ), - # pyre-ignore - weights=values_list[indices[0]] if need_indices else None, + weights=( + values_list[indices[0]] + if need_indices and values_list is not None + else None + ), ) return ret