Skip to content

Commit

Permalink
Fix FPEBC train pipeline test (#2090)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2090

3 issues that needed fixing:
1) Move batch to GPU
2) Set compute kernel to fused instead of dense to work w/ TW sharding
3) Ensure that input batch idlist_features KJT has max length equal to the max lengths specified for feature processors (otherwise it would fail on `torch.gather()` in feature processor due to  shape mismatch between KJT input lengths and indices

Reviewed By: henrylhtsang

Differential Revision: D56950454
  • Loading branch information
sarckk authored and facebook-github-bot committed Jun 11, 2024
1 parent 2267dc8 commit 8489857
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 16 deletions.
14 changes: 14 additions & 0 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def generate(
weighted_tables_pooling: Optional[List[int]] = None,
randomize_indices: bool = True,
device: Optional[torch.device] = None,
max_feature_lengths: Optional[List[int]] = None,
) -> Tuple["ModelInput", List["ModelInput"]]:
"""
Returns a global (single-rank training) batch
Expand Down Expand Up @@ -102,11 +103,17 @@ def _validate_pooling_factor(

idlist_features_to_num_embeddings = {}
idlist_features_to_pooling_factor = {}
idlist_features_to_max_length = {}
feature_idx = 0
for idx in range(len(tables)):
for feature in tables[idx].feature_names:
idlist_features_to_num_embeddings[feature] = tables[idx].num_embeddings
idlist_features_to_max_length[feature] = (
max_feature_lengths[feature_idx] if max_feature_lengths else None
)
if tables_pooling is not None:
idlist_features_to_pooling_factor[feature] = tables_pooling[idx]
feature_idx += 1

idlist_features = list(idlist_features_to_num_embeddings.keys())
idscore_features = [
Expand All @@ -119,6 +126,8 @@ def _validate_pooling_factor(
idlist_pooling_factor = list(idlist_features_to_pooling_factor.values())
idscore_pooling_factor = weighted_tables_pooling

idlist_max_lengths = list(idlist_features_to_max_length.values())

# Generate global batch.
global_idlist_lengths = []
global_idlist_indices = []
Expand All @@ -142,6 +151,10 @@ def _validate_pooling_factor(
lengths_ = torch.abs(
torch.randn(batch_size * world_size, device=device) + pooling_avg,
).int()

if idlist_max_lengths[idx]:
lengths_ = torch.clamp(lengths_, max=idlist_max_lengths[idx])

if variable_batch_size:
lengths = torch.zeros(batch_size * world_size, device=device).int()
for r in range(world_size):
Expand All @@ -152,6 +165,7 @@ def _validate_pooling_factor(
)
else:
lengths = lengths_

num_indices = cast(int, torch.sum(lengths).item())
if randomize_indices:
indices = torch.randint(
Expand Down
36 changes: 25 additions & 11 deletions torchrec/distributed/tests/test_fp_embeddingbag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,30 @@
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

DEFAULT_MAX_FEATURE_LENGTH = 12


class SparseArch(nn.Module):
def __init__(
self,
tables: List[EmbeddingBagConfig],
use_fp_collection: bool,
device: torch.device,
max_feature_lengths: Optional[List[int]] = None,
) -> None:
super().__init__()

feature_names = [
feature_name for table in tables for feature_name in table.feature_names
]

if max_feature_lengths is None:
max_feature_lengths = [DEFAULT_MAX_FEATURE_LENGTH] * len(feature_names)

assert len(max_feature_lengths) == len(
feature_names
), "Expect max_feature_lengths to have the same number of items as feature_names"

self._fp_ebc: FeatureProcessedEmbeddingBagCollection = (
FeatureProcessedEmbeddingBagCollection(
EmbeddingBagCollection(
Expand All @@ -49,20 +63,19 @@ def __init__(
cast(
Dict[str, FeatureProcessor],
{
"feature_0": PositionWeightedModule(max_feature_length=10),
"feature_1": PositionWeightedModule(max_feature_length=10),
"feature_2": PositionWeightedModule(max_feature_length=12),
"feature_3": PositionWeightedModule(max_feature_length=12),
feature_name: PositionWeightedModule(
max_feature_length=max_feature_length
)
for feature_name, max_feature_length in zip(
feature_names, max_feature_lengths
)
},
)
if not use_fp_collection
else PositionWeightedModuleCollection(
max_feature_lengths={
"feature_0": 10,
"feature_1": 10,
"feature_2": 12,
"feature_3": 12,
}
max_feature_lengths=dict(
zip(feature_names, max_feature_lengths)
),
)
),
).to(device)
Expand All @@ -85,9 +98,10 @@ def create_module_and_freeze(
tables: List[EmbeddingBagConfig],
use_fp_collection: bool,
device: torch.device,
max_feature_lengths: Optional[List[int]] = None,
) -> SparseArch:

sparse_arch = SparseArch(tables, use_fp_collection, device)
sparse_arch = SparseArch(tables, use_fp_collection, device, max_feature_lengths)

torch.manual_seed(0)
for param in sparse_arch.parameters():
Expand Down
16 changes: 11 additions & 5 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,23 @@ def __init__(self, sparse_arch):
def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
return self.m(model_input.idlist_features)

max_feature_lengths = [10, 10, 12, 12]
sparse_arch = DummyWrapper(
create_module_and_freeze(
tables=self.tables,
device=self.device,
use_fp_collection=False,
max_feature_lengths=max_feature_lengths,
)
)
compute_kernel = EmbeddingComputeKernel.FUSED.value
module_sharding_plan = construct_module_sharding_plan(
sparse_arch.m._fp_ebc,
per_param_sharding={
"table_0": table_wise(rank=0),
"table_1": table_wise(rank=0),
"table_2": table_wise(rank=0),
"table_3": table_wise(rank=0),
"table_0": table_wise(rank=0, compute_kernel=compute_kernel),
"table_1": table_wise(rank=0, compute_kernel=compute_kernel),
"table_2": table_wise(rank=0, compute_kernel=compute_kernel),
"table_3": table_wise(rank=0, compute_kernel=compute_kernel),
},
local_size=1,
world_size=1,
Expand Down Expand Up @@ -219,7 +222,9 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
sharded_sparse_arch_pipeline.state_dict(),
)

data = self._generate_data(num_batches=5, batch_size=1)
data = self._generate_data(
num_batches=5, batch_size=1, max_feature_lengths=max_feature_lengths
)
dataloader = iter(data)

optimizer_no_pipeline = optim.SGD(
Expand All @@ -236,6 +241,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
)

for batch in data[:-2]:
batch = batch.to(self.device)
optimizer_no_pipeline.zero_grad()
loss, pred = sharded_sparse_arch_no_pipeline(batch)
loss.backward()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _generate_data(
self,
num_batches: int = 5,
batch_size: int = 1,
max_feature_lengths: Optional[List[int]] = None,
) -> List[ModelInput]:
return [
ModelInput.generate(
Expand All @@ -76,6 +77,7 @@ def _generate_data(
batch_size=batch_size,
world_size=1,
num_float_features=10,
max_feature_lengths=max_feature_lengths,
)[0]
for i in range(num_batches)
]
Expand Down

0 comments on commit 8489857

Please sign in to comment.