Skip to content

Commit

Permalink
Benchmarking
Browse files Browse the repository at this point in the history
Summary:
Benchmark existing training benchmarks, training performance and memory on multi-gpu setups

```
buck2 run mode/dev-nosan //torchrec/distributed/train_pipeline/tests:pipeline_benchmarks
```

  TrainPipelineBase                   | Runtime (P90):  13.1 s | Memory (P90):   8.4 GB
  TrainPipelineSparseDist             | Runtime (P90):  12.7 s | Memory (P90):   8.8 GB

Reviewed By: henrylhtsang

Differential Revision: D56690925

fbshipit-source-id: d329861cf915f0223f23d35812ece98f05773950
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Apr 30, 2024
1 parent fa37d69 commit 3c785ea
Show file tree
Hide file tree
Showing 3 changed files with 413 additions and 17 deletions.
116 changes: 100 additions & 16 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import random
from dataclasses import dataclass
from typing import Any, cast, Dict, List, Optional, Tuple, Union
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
Expand All @@ -26,6 +26,7 @@
from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder
from torchrec.distributed.types import QuantizedCommCodecs
from torchrec.distributed.utils import CopyableMixin
from torchrec.modules.activation import SwishLayerNorm
from torchrec.modules.embedding_configs import (
BaseEmbeddingConfig,
EmbeddingBagConfig,
Expand Down Expand Up @@ -524,6 +525,14 @@ def forward(
return self.linear1(self.linear0(input))


@torch.fx.wrap
def _concat(
dense: torch.Tensor,
sparse_embeddings: List[torch.Tensor],
) -> torch.Tensor:
return torch.cat([dense] + sparse_embeddings, dim=1)


class TestOverArch(nn.Module):
"""
Basic nn.Module for testing
Expand Down Expand Up @@ -578,13 +587,78 @@ def forward(
dense: torch.Tensor,
sparse: KeyedTensor,
) -> torch.Tensor:
ret_list = []
ret_list.append(dense)
for embedding_name in self._embedding_names:
ret_list.append(sparse[embedding_name])
for feature_name in self._weighted_features:
ret_list.append(sparse[feature_name])
return self.dhn_arch(torch.cat(ret_list, dim=1))
sparse_regrouped: List[torch.Tensor] = KeyedTensor.regroup(
[sparse], [self._embedding_names, self._weighted_features]
)

return self.dhn_arch(_concat(dense, sparse_regrouped))


class TestOverArchLarge(nn.Module):
"""
Basic nn.Module for testing, w 5/ layers.
"""

def __init__(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
embedding_names: Optional[List[str]] = None,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
if device is None:
device = torch.device("cpu")
self._embedding_names: List[str] = (
embedding_names
if embedding_names
else [feature for table in tables for feature in table.feature_names]
)
self._weighted_features: List[str] = [
feature for table in weighted_tables for feature in table.feature_names
]
in_features = (
8
+ sum([table.embedding_dim * len(table.feature_names) for table in tables])
+ sum(
[
table.embedding_dim * len(table.feature_names)
for table in weighted_tables
]
)
)
out_features = 1000
layers = [
torch.nn.Linear(
in_features=in_features,
out_features=out_features,
),
SwishLayerNorm([out_features]),
]

for _ in range(5):
layers += [
torch.nn.Linear(
in_features=out_features,
out_features=out_features,
),
SwishLayerNorm([out_features]),
]

self.overarch = torch.nn.Sequential(*layers)

def forward(
self,
dense: torch.Tensor,
sparse: KeyedTensor,
) -> torch.Tensor:
ret_list = [dense]
ret_list.extend(
KeyedTensor.regroup(
[sparse], [self._embedding_names, self._weighted_features]
)
)
return self.overarch(torch.cat(ret_list, dim=1))


@torch.fx.wrap
Expand Down Expand Up @@ -829,6 +903,7 @@ def __init__(
sparse_device: Optional[torch.device] = None,
max_feature_lengths_list: Optional[List[Dict[str, int]]] = None,
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
over_arch_clazz: Type[nn.Module] = TestOverArch,
) -> None:
super().__init__(
tables=cast(List[BaseEmbeddingConfig], tables),
Expand All @@ -850,23 +925,26 @@ def __init__(
embedding_names = (
list(embedding_groups.values())[0] if embedding_groups else None
)
self.over = TestOverArch(tables, weighted_tables, embedding_names, dense_device)
self.over: nn.Module = over_arch_clazz(
tables, weighted_tables, embedding_names, dense_device
)
self.register_buffer(
"dummy_ones",
torch.ones(1, device=dense_device),
)

def forward(
self,
input: ModelInput,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
dense_r = self.dense(input.float_features)
sparse_r = self.sparse(
def sparse_forward(self, input: ModelInput) -> KeyedTensor:
return self.sparse(
features=input.idlist_features,
weighted_features=input.idscore_features,
batch_size=input.float_features.size(0),
)
over_r = self.over(dense_r, sparse_r)

def dense_forward(
self, input: ModelInput, sparse_output: KeyedTensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
dense_r = self.dense(input.float_features)
over_r = self.over(dense_r, sparse_output)
pred = torch.sigmoid(torch.mean(over_r, dim=1)) + self.dummy_ones
if self.training:
return (
Expand All @@ -876,6 +954,12 @@ def forward(
else:
return pred

def forward(
self,
input: ModelInput,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.dense_forward(input, self.sparse_forward(input))


class TestTowerInteraction(nn.Module):
"""
Expand Down
Loading

0 comments on commit 3c785ea

Please sign in to comment.