From a5746a8ac86261339568826e4641b10a839d63bd Mon Sep 17 00:00:00 2001 From: Kaustubh Vartak Date: Thu, 13 Feb 2025 13:19:59 -0800 Subject: [PATCH] Proportional Uneven RW Inference Sharding (#2734) Summary: Support bucketization aware inference sharding in TGIF for ZCH bucket boundaries from training. A "best effort" sharding is performed across bucket boundaries proportional to memory list. * Added bucketization awareness to RW sharding, * TGIF sharding now ensures at most 1 bucket difference across equal memory uneven shards as opposed to previous logic of remainder rows to last shard * InferRWSparseDist checks for customized embedding_shard_metadata for uneven shards before dividing evenly Differential Revision: D69057627 --- torchrec/distributed/quant_state.py | 1 + torchrec/distributed/sharding/rw_sharding.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py index 60572b929..1de388e1b 100644 --- a/torchrec/distributed/quant_state.py +++ b/torchrec/distributed/quant_state.py @@ -441,6 +441,7 @@ def sharded_tbes_weights_spec( shard_sizes: List[int] = [table.local_rows, table.local_cols] shard_offsets: List[int] = table_metadata.shard_offsets s: str = "embedding_bags" if is_sqebc else "embeddings" + s = ("_embedding_module." if is_sqmcec else "") + s unsharded_fqn_weight: str = f"{module_fqn}.{s}.{table_name}.weight" sharded_fqn_weight: str = ( diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index f61ea0bd8..53b459e1d 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -661,14 +661,18 @@ def __init__( self._feature_total_num_buckets: Optional[List[int]] = feature_total_num_buckets self.feature_block_sizes: List[int] = [] - for i, hash_size in enumerate(feature_hash_sizes): - block_divisor = self._world_size - if feature_total_num_buckets is not None: - assert feature_total_num_buckets[i] % self._world_size == 0 - block_divisor = feature_total_num_buckets[i] - self.feature_block_sizes.append( - (hash_size + block_divisor - 1) // block_divisor - ) + if embedding_shard_metadata is not None: + assert len(embedding_shard_metadata) == len(feature_hash_sizes) + self.feature_block_sizes = [0] * len(feature_hash_sizes) + else: + for i, hash_size in enumerate(feature_hash_sizes): + block_divisor = self._world_size + if feature_total_num_buckets is not None: + assert feature_total_num_buckets[i] % self._world_size == 0 + block_divisor = feature_total_num_buckets[i] + self.feature_block_sizes.append( + (hash_size + block_divisor - 1) // block_divisor + ) self.tensor_cache: Dict[ str, Tuple[torch.Tensor, Optional[List[torch.Tensor]]] ] = {}