From f8e2ff5cd43271afd076063be7ddcde996861c96 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Thu, 6 Jun 2024 07:40:30 -0700 Subject: [PATCH] Enable Module level Inference Benchmark for CPU device Summary: As titled Differential Revision: D58143260 --- .../distributed/benchmark/benchmark_utils.py | 140 +++++++++++------- 1 file changed, 84 insertions(+), 56 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index e8f26d8f9..1236a8a13 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -17,6 +17,7 @@ import logging import os import time +import timeit from dataclasses import dataclass from enum import Enum @@ -37,6 +38,7 @@ from torch.autograd.profiler import record_function from torchrec.distributed import DistributedModelParallel from torchrec.distributed.embedding_types import ShardingType +from torchrec.distributed.global_settings import set_propogate_device from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.planner.enumerators import EmbeddingEnumerator @@ -411,6 +413,7 @@ def init_argparse_and_args() -> argparse.Namespace: parser.add_argument("--output_dir", type=str, default="/var/tmp/torchrec-bench") parser.add_argument("--num_benchmarks", type=int, default=5) parser.add_argument("--embedding_config_json", type=str, default="") + parser.add_argument("--device_type", type=str, default="cuda") args = parser.parse_args() @@ -436,7 +439,9 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: scripted_module = torch.jit.script(graph_module) return scripted_module - topology: Topology = Topology(world_size=world_size, compute_device="cuda") + set_propogate_device(True) + + topology: Topology = Topology(world_size=world_size, compute_device=device.type) planner = EmbeddingShardingPlanner( topology=topology, batch_size=batch_size, @@ -498,6 +503,7 @@ def benchmark( benchmark_func_kwargs: Optional[Dict[str, Any]], rank: int, enable_logging: bool = True, + device_type: str = "cuda", ) -> BenchmarkResult: max_mem_allocated: List[int] = [] if enable_logging: @@ -506,47 +512,64 @@ def benchmark( for _input in warmup_inputs: model(_input) - if rank == -1: - # Reset memory for measurement, no process per rank so do all - for di in range(world_size): - torch.cuda.reset_peak_memory_stats(di) - else: - torch.cuda.reset_peak_memory_stats(rank) + if device_type == "cuda": + if rank == -1: + # Reset memory for measurement, no process per rank so do all + for di in range(world_size): + torch.cuda.reset_peak_memory_stats(di) + else: + torch.cuda.reset_peak_memory_stats(rank) - # Measure time taken for batches in bench_inputs - start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] - end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] + start = [] + end = [] + if device_type == "cuda": + # Measure time taken for batches in bench_inputs + start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] + end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)] if benchmark_func_kwargs is None: # Need this to unwrap benchmark_func_kwargs = {} - for i in range(num_benchmarks): - start[i].record() - func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs) - end[i].record() + times = [] + if device_type == "cuda": + for i in range(num_benchmarks): + start[i].record() + func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs) + end[i].record() + elif device_type == "cpu": + times = timeit.repeat( + lambda: func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs), + number=1, + repeat=num_benchmarks, + ) - if rank == -1: - for di in range(world_size): - torch.cuda.synchronize(di) - else: - torch.cuda.synchronize(rank) + if device_type == "cuda": + if rank == -1: + for di in range(world_size): + torch.cuda.synchronize(di) + else: + torch.cuda.synchronize(rank) # TODO: First Benchmark Run for Eager Mode produces outlier # Start counting after first as workaround for standard deviation - elapsed_time = torch.tensor( - [si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])] - ) - - if rank == -1: - # Add up all memory allocated in inference mode - for di in range(world_size): - b = torch.cuda.max_memory_allocated(di) - max_mem_allocated.append(b // 1024 // 1024) + if device_type == "cuda": + elapsed_time = torch.tensor( + [si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])] + ) else: - # Only add up memory allocated for current rank in training mode - b = torch.cuda.max_memory_allocated(rank) - max_mem_allocated.append(b // 1024 // 1024) + elapsed_time = torch.tensor(times) + + if device_type == "cuda": + if rank == -1: + # Add up all memory allocated in inference mode + for di in range(world_size): + b = torch.cuda.max_memory_allocated(di) + max_mem_allocated.append(b // 1024 // 1024) + else: + # Only add up memory allocated for current rank in training mode + b = torch.cuda.max_memory_allocated(rank) + max_mem_allocated.append(b // 1024 // 1024) if output_dir != "": # Only do profiling if output_dir is set @@ -574,22 +597,23 @@ def trace_handler(prof) -> None: # - cd FlameGraph # - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - profile_memory=True, - with_stack=True, - with_flops=True, - with_modules=True, - on_trace_ready=trace_handler, - ) as p: - for _input in prof_inputs: - with record_function("## forward ##"): - model(_input) - p.step() + if device_type == "cuda": + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + with_modules=True, + on_trace_ready=trace_handler, + ) as p: + for _input in prof_inputs: + with record_function("## forward ##"): + model(_input) + p.step() if rank == -1: for di in range(torch.cuda.device_count()): @@ -659,26 +683,29 @@ def init_module_and_run_benchmark( if rank >= 0: warmup_inputs_cuda = [ - warmup_input.to(torch.device(f"cuda:{rank}")) + warmup_input.to(torch.device(f"{device.type}:{rank}")) for warmup_input in warmup_inputs[rank] ] bench_inputs_cuda = [ - bench_input.to(torch.device(f"cuda:{rank}")) + bench_input.to(torch.device(f"{device.type}:{rank}")) for bench_input in bench_inputs[rank] ] prof_inputs_cuda = [ - prof_input.to(torch.device(f"cuda:{rank}")) + prof_input.to(torch.device(f"{device.type}:{rank}")) for prof_input in prof_inputs[rank] ] else: warmup_inputs_cuda = [ - warmup_input.to(torch.device("cuda:0")) for warmup_input in warmup_inputs[0] + warmup_input.to(torch.device(f"{device.type}:0")) + for warmup_input in warmup_inputs[0] ] bench_inputs_cuda = [ - bench_input.to(torch.device("cuda:0")) for bench_input in bench_inputs[0] + bench_input.to(torch.device(f"{device.type}:0")) + for bench_input in bench_inputs[0] ] prof_inputs_cuda = [ - prof_input.to(torch.device("cuda:0")) for prof_input in prof_inputs[0] + prof_input.to(torch.device(f"{device.type}:0")) + for prof_input in prof_inputs[0] ] with ( @@ -713,6 +740,7 @@ def init_module_and_run_benchmark( func_to_benchmark=func_to_benchmark, benchmark_func_kwargs=benchmark_func_kwargs, rank=rank, + device_type=device.type, ) if queue is not None: @@ -801,6 +829,7 @@ def benchmark_module( benchmark_func_kwargs: Optional[Dict[str, Any]] = None, pooling_configs: Optional[List[int]] = None, variable_batch_embeddings: bool = False, + device_type: str = "cuda", ) -> List[BenchmarkResult]: """ Args: @@ -884,7 +913,7 @@ def benchmark_module( callable=init_module_and_run_benchmark, module=wrapped_module, sharder=sharder, - device=torch.device("cuda"), + device=torch.device(device_type), sharding_type=sharding_type, compile_mode=compile_mode, world_size=world_size, @@ -903,8 +932,7 @@ def benchmark_module( res = init_module_and_run_benchmark( module=wrapped_module, sharder=sharder, - # TODO: GPU hardcode for now, expand if needed for heter hardware - device=torch.device("cuda:0"), + device=torch.device(device_type), sharding_type=sharding_type, compile_mode=compile_mode, world_size=world_size,