From 3c785ea39bd0e785f64c70eaa484455ef66e2b6e Mon Sep 17 00:00:00 2001 From: Dennis van der Staay Date: Tue, 30 Apr 2024 12:29:41 -0700 Subject: [PATCH] Benchmarking Summary: Benchmark existing training benchmarks, training performance and memory on multi-gpu setups ``` buck2 run mode/dev-nosan //torchrec/distributed/train_pipeline/tests:pipeline_benchmarks ``` TrainPipelineBase | Runtime (P90): 13.1 s | Memory (P90): 8.4 GB TrainPipelineSparseDist | Runtime (P90): 12.7 s | Memory (P90): 8.8 GB Reviewed By: henrylhtsang Differential Revision: D56690925 fbshipit-source-id: d329861cf915f0223f23d35812ece98f05773950 --- torchrec/distributed/test_utils/test_model.py | 116 ++++++- .../tests/pipeline_benchmarks.py | 301 ++++++++++++++++++ .../tests/test_train_pipelines_base.py | 13 +- 3 files changed, 413 insertions(+), 17 deletions(-) create mode 100644 torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 4bc70d180..c5ba0147f 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -9,7 +9,7 @@ import random from dataclasses import dataclass -from typing import Any, cast, Dict, List, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -26,6 +26,7 @@ from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder from torchrec.distributed.types import QuantizedCommCodecs from torchrec.distributed.utils import CopyableMixin +from torchrec.modules.activation import SwishLayerNorm from torchrec.modules.embedding_configs import ( BaseEmbeddingConfig, EmbeddingBagConfig, @@ -524,6 +525,14 @@ def forward( return self.linear1(self.linear0(input)) +@torch.fx.wrap +def _concat( + dense: torch.Tensor, + sparse_embeddings: List[torch.Tensor], +) -> torch.Tensor: + return torch.cat([dense] + sparse_embeddings, dim=1) + + class TestOverArch(nn.Module): """ Basic nn.Module for testing @@ -578,13 +587,78 @@ def forward( dense: torch.Tensor, sparse: KeyedTensor, ) -> torch.Tensor: - ret_list = [] - ret_list.append(dense) - for embedding_name in self._embedding_names: - ret_list.append(sparse[embedding_name]) - for feature_name in self._weighted_features: - ret_list.append(sparse[feature_name]) - return self.dhn_arch(torch.cat(ret_list, dim=1)) + sparse_regrouped: List[torch.Tensor] = KeyedTensor.regroup( + [sparse], [self._embedding_names, self._weighted_features] + ) + + return self.dhn_arch(_concat(dense, sparse_regrouped)) + + +class TestOverArchLarge(nn.Module): + """ + Basic nn.Module for testing, w 5/ layers. + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + embedding_names: Optional[List[str]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + if device is None: + device = torch.device("cpu") + self._embedding_names: List[str] = ( + embedding_names + if embedding_names + else [feature for table in tables for feature in table.feature_names] + ) + self._weighted_features: List[str] = [ + feature for table in weighted_tables for feature in table.feature_names + ] + in_features = ( + 8 + + sum([table.embedding_dim * len(table.feature_names) for table in tables]) + + sum( + [ + table.embedding_dim * len(table.feature_names) + for table in weighted_tables + ] + ) + ) + out_features = 1000 + layers = [ + torch.nn.Linear( + in_features=in_features, + out_features=out_features, + ), + SwishLayerNorm([out_features]), + ] + + for _ in range(5): + layers += [ + torch.nn.Linear( + in_features=out_features, + out_features=out_features, + ), + SwishLayerNorm([out_features]), + ] + + self.overarch = torch.nn.Sequential(*layers) + + def forward( + self, + dense: torch.Tensor, + sparse: KeyedTensor, + ) -> torch.Tensor: + ret_list = [dense] + ret_list.extend( + KeyedTensor.regroup( + [sparse], [self._embedding_names, self._weighted_features] + ) + ) + return self.overarch(torch.cat(ret_list, dim=1)) @torch.fx.wrap @@ -829,6 +903,7 @@ def __init__( sparse_device: Optional[torch.device] = None, max_feature_lengths_list: Optional[List[Dict[str, int]]] = None, feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, + over_arch_clazz: Type[nn.Module] = TestOverArch, ) -> None: super().__init__( tables=cast(List[BaseEmbeddingConfig], tables), @@ -850,23 +925,26 @@ def __init__( embedding_names = ( list(embedding_groups.values())[0] if embedding_groups else None ) - self.over = TestOverArch(tables, weighted_tables, embedding_names, dense_device) + self.over: nn.Module = over_arch_clazz( + tables, weighted_tables, embedding_names, dense_device + ) self.register_buffer( "dummy_ones", torch.ones(1, device=dense_device), ) - def forward( - self, - input: ModelInput, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - dense_r = self.dense(input.float_features) - sparse_r = self.sparse( + def sparse_forward(self, input: ModelInput) -> KeyedTensor: + return self.sparse( features=input.idlist_features, weighted_features=input.idscore_features, batch_size=input.float_features.size(0), ) - over_r = self.over(dense_r, sparse_r) + + def dense_forward( + self, input: ModelInput, sparse_output: KeyedTensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + dense_r = self.dense(input.float_features) + over_r = self.over(dense_r, sparse_output) pred = torch.sigmoid(torch.mean(over_r, dim=1)) + self.dummy_ones if self.training: return ( @@ -876,6 +954,12 @@ def forward( else: return pred + def forward( + self, + input: ModelInput, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + return self.dense_forward(input, self.sparse_forward(input)) + class TestTowerInteraction(nn.Module): """ diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py new file mode 100644 index 000000000..62d52ab06 --- /dev/null +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +import copy +import multiprocessing +import os +from typing import Any, Callable, cast, Dict, List, Optional, Tuple + +import click + +import torch +import torch.distributed as dist +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from torch import nn, optim +from torch.optim import Optimizer +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.benchmark.benchmark_utils import benchmark +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.test_utils.multi_process import MultiProcessContext +from torchrec.distributed.test_utils.test_model import ( + ModelInput, + TestEBCSharder, + TestOverArchLarge, + TestSparseNN, +) +from torchrec.distributed.train_pipeline import ( + TrainPipeline, + TrainPipelineBase, + TrainPipelineSparseDist, +) +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig + +from torchrec.test_utils import get_free_port + + +@click.command() +@click.option( + "--world_size", + type=int, + default=4, + help="Num of GPUs to run", +) +@click.option( + "--n_features", + default=100, + help="Total number of sparse embeddings to be used.", +) +@click.option( + "--dim_emb", + type=int, + default=512, + help="Dim embeddings embedding.", +) +@click.option( + "--n_batches", + type=int, + default=20, + help="Num of batchs to run.", +) +@click.option( + "--batch_size", + type=int, + default=8192, + help="Batch size.", +) +def main( + world_size: int, + n_features: int, + dim_emb: int, + n_batches: int, + batch_size: int, +) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + + num_features = n_features // 2 + num_weighted_features = n_features // 2 + tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 1000, + embedding_dim=dim_emb, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 1000, + embedding_dim=dim_emb, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + batches = _generate_data( + tables=tables, + weighted_tables=weighted_tables, + num_float_features=10, + num_batches=n_batches, + batch_size=batch_size, + world_size=world_size, + ) + + _run_multi_process_test( + callable=runner, + tables=tables, + weighted_tables=weighted_tables, + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.FUSED.value, + batches=batches, + fused_params={}, + world_size=world_size, + ) + + +def _run_multi_process_test( + *, + callable: Callable[ + ..., + None, + ], + world_size: int, + # pyre-ignore + **kwargs, +) -> None: + ctx = multiprocessing.get_context("spawn") + processes = [] + if world_size == 1: + kwargs["world_size"] = 1 + kwargs["rank"] = 0 + callable(**kwargs) + return + + for rank in range(world_size): + kwargs["rank"] = rank + kwargs["world_size"] = world_size + p = ctx.Process( + target=callable, + kwargs=kwargs, + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + + +def _generate_data( + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + num_float_features: int = 10, + num_batches: int = 100, + batch_size: int = 4096, + world_size: int = 1, +) -> List[List[ModelInput]]: + return [ + ModelInput.generate( + tables=tables, + weighted_tables=weighted_tables, + batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + )[1] + for i in range(num_batches) + ] + + +def _generate_sharded_model_and_optimizer( + model: nn.Module, + sharding_type: str, + kernel_type: str, + pg: dist.ProcessGroup, + device: torch.device, + fused_params: Optional[Dict[str, Any]] = None, +) -> Tuple[nn.Module, Optimizer]: + sharder = TestEBCSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ) + sharded_model = DistributedModelParallel( + module=copy.deepcopy(model), + env=ShardingEnv.from_process_group(pg), + init_data_parallel=True, + device=device, + sharders=[ + cast( + ModuleSharder[nn.Module], + sharder, + ) + ], + ).to(device) + optimizer = optim.SGD( + [ + param + for name, param in sharded_model.named_parameters() + if "sparse" not in name + ], + lr=0.1, + ) + return sharded_model, optimizer + + +def runner( + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + rank: int, + sharding_type: str, + kernel_type: str, + fused_params: Dict[str, Any], + world_size: int, + batches: List[List[ModelInput]], +) -> None: + + torch.autograd.set_detect_anomaly(True) + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl", + use_deterministic_algorithms=False, + ) as ctx: + + unsharded_model = TestSparseNN( + tables=tables, + weighted_tables=weighted_tables, + dense_device=ctx.device, + sparse_device=torch.device("meta"), + over_arch_clazz=TestOverArchLarge, + ) + + sharded_model, optimizer = _generate_sharded_model_and_optimizer( + model=unsharded_model, + sharding_type=sharding_type, + kernel_type=kernel_type, + # pyre-ignore + pg=ctx.pg, + device=ctx.device, + fused_params={ + "optimizer": EmbOptimType.EXACT_ADAGRAD, + "learning_rate": 0.1, + }, + ) + bench_inputs = [batch[rank] for batch in batches] + for pipeline_clazz in [ + TrainPipelineBase, + TrainPipelineSparseDist, + ]: + pipeline = pipeline_clazz( + model=sharded_model, + optimizer=optimizer, + device=ctx.device, + ) + pipeline.progress(iter(bench_inputs)) + + def _func_to_benchmark( + model: nn.Module, + bench_inputs: List[ModelInput], + pipeline: TrainPipeline, + ) -> None: + dataloader = iter(bench_inputs) + while True: + try: + pipeline.progress(dataloader) + except StopIteration: + break + + result = benchmark( + name=pipeline_clazz.__name__, + model=sharded_model, + num_benchmarks=5, + output_dir="", + warmup_inputs=[], + # pyre-ignore + bench_inputs=bench_inputs, + prof_inputs=[], + world_size=world_size, + func_to_benchmark=_func_to_benchmark, + benchmark_func_kwargs={"pipeline": pipeline}, + rank=rank, + enable_logging=False, + ) + if rank == 0: + print( + f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.1f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.1f} GB" + ) + + +if __name__ == "__main__": + main() diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py index 067a679ce..283b08114 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -132,5 +132,16 @@ def _generate_sharded_model_and_optimizer( ) ], ) - optimizer = optim.SGD(sharded_model.parameters(), lr=0.1) + # default fused optimizer is SGD w/ lr=0.1; we need to drop params + fused_named_parameters: List[str] = [ + x for x in DistributedModelParallel._sharded_parameter_names(sharded_model) + ] + optimizer = optim.SGD( + [ + y + for x, y in sharded_model.named_parameters() + if x not in fused_named_parameters + ], + lr=0.1, + ) return sharded_model, optimizer