Skip to content

Commit

Permalink
Enable Module level Inference Benchmark for CPU device
Browse files Browse the repository at this point in the history
Summary: As titled

Differential Revision: D58143260
  • Loading branch information
gnahzg authored and facebook-github-bot committed Jun 6, 2024
1 parent 8393202 commit f8e2ff5
Showing 1 changed file with 84 additions and 56 deletions.
140 changes: 84 additions & 56 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
import os
import time
import timeit
from dataclasses import dataclass

from enum import Enum
Expand All @@ -37,6 +38,7 @@
from torch.autograd.profiler import record_function
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import ShardingType
from torchrec.distributed.global_settings import set_propogate_device

from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
Expand Down Expand Up @@ -411,6 +413,7 @@ def init_argparse_and_args() -> argparse.Namespace:
parser.add_argument("--output_dir", type=str, default="/var/tmp/torchrec-bench")
parser.add_argument("--num_benchmarks", type=int, default=5)
parser.add_argument("--embedding_config_json", type=str, default="")
parser.add_argument("--device_type", type=str, default="cuda")

args = parser.parse_args()

Expand All @@ -436,7 +439,9 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
scripted_module = torch.jit.script(graph_module)
return scripted_module

topology: Topology = Topology(world_size=world_size, compute_device="cuda")
set_propogate_device(True)

topology: Topology = Topology(world_size=world_size, compute_device=device.type)
planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
Expand Down Expand Up @@ -498,6 +503,7 @@ def benchmark(
benchmark_func_kwargs: Optional[Dict[str, Any]],
rank: int,
enable_logging: bool = True,
device_type: str = "cuda",
) -> BenchmarkResult:
max_mem_allocated: List[int] = []
if enable_logging:
Expand All @@ -506,47 +512,64 @@ def benchmark(
for _input in warmup_inputs:
model(_input)

if rank == -1:
# Reset memory for measurement, no process per rank so do all
for di in range(world_size):
torch.cuda.reset_peak_memory_stats(di)
else:
torch.cuda.reset_peak_memory_stats(rank)
if device_type == "cuda":
if rank == -1:
# Reset memory for measurement, no process per rank so do all
for di in range(world_size):
torch.cuda.reset_peak_memory_stats(di)
else:
torch.cuda.reset_peak_memory_stats(rank)

# Measure time taken for batches in bench_inputs
start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)]
end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)]
start = []
end = []
if device_type == "cuda":
# Measure time taken for batches in bench_inputs
start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)]
end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)]

if benchmark_func_kwargs is None:
# Need this to unwrap
benchmark_func_kwargs = {}

for i in range(num_benchmarks):
start[i].record()
func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs)
end[i].record()
times = []
if device_type == "cuda":
for i in range(num_benchmarks):
start[i].record()
func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs)
end[i].record()
elif device_type == "cpu":
times = timeit.repeat(
lambda: func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs),
number=1,
repeat=num_benchmarks,
)

if rank == -1:
for di in range(world_size):
torch.cuda.synchronize(di)
else:
torch.cuda.synchronize(rank)
if device_type == "cuda":
if rank == -1:
for di in range(world_size):
torch.cuda.synchronize(di)
else:
torch.cuda.synchronize(rank)

# TODO: First Benchmark Run for Eager Mode produces outlier
# Start counting after first as workaround for standard deviation
elapsed_time = torch.tensor(
[si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])]
)

if rank == -1:
# Add up all memory allocated in inference mode
for di in range(world_size):
b = torch.cuda.max_memory_allocated(di)
max_mem_allocated.append(b // 1024 // 1024)
if device_type == "cuda":
elapsed_time = torch.tensor(
[si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])]
)
else:
# Only add up memory allocated for current rank in training mode
b = torch.cuda.max_memory_allocated(rank)
max_mem_allocated.append(b // 1024 // 1024)
elapsed_time = torch.tensor(times)

if device_type == "cuda":
if rank == -1:
# Add up all memory allocated in inference mode
for di in range(world_size):
b = torch.cuda.max_memory_allocated(di)
max_mem_allocated.append(b // 1024 // 1024)
else:
# Only add up memory allocated for current rank in training mode
b = torch.cuda.max_memory_allocated(rank)
max_mem_allocated.append(b // 1024 // 1024)

if output_dir != "":
# Only do profiling if output_dir is set
Expand Down Expand Up @@ -574,22 +597,23 @@ def trace_handler(prof) -> None:
# - cd FlameGraph
# - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg

with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
on_trace_ready=trace_handler,
) as p:
for _input in prof_inputs:
with record_function("## forward ##"):
model(_input)
p.step()
if device_type == "cuda":
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
on_trace_ready=trace_handler,
) as p:
for _input in prof_inputs:
with record_function("## forward ##"):
model(_input)
p.step()

if rank == -1:
for di in range(torch.cuda.device_count()):
Expand Down Expand Up @@ -659,26 +683,29 @@ def init_module_and_run_benchmark(

if rank >= 0:
warmup_inputs_cuda = [
warmup_input.to(torch.device(f"cuda:{rank}"))
warmup_input.to(torch.device(f"{device.type}:{rank}"))
for warmup_input in warmup_inputs[rank]
]
bench_inputs_cuda = [
bench_input.to(torch.device(f"cuda:{rank}"))
bench_input.to(torch.device(f"{device.type}:{rank}"))
for bench_input in bench_inputs[rank]
]
prof_inputs_cuda = [
prof_input.to(torch.device(f"cuda:{rank}"))
prof_input.to(torch.device(f"{device.type}:{rank}"))
for prof_input in prof_inputs[rank]
]
else:
warmup_inputs_cuda = [
warmup_input.to(torch.device("cuda:0")) for warmup_input in warmup_inputs[0]
warmup_input.to(torch.device(f"{device.type}:0"))
for warmup_input in warmup_inputs[0]
]
bench_inputs_cuda = [
bench_input.to(torch.device("cuda:0")) for bench_input in bench_inputs[0]
bench_input.to(torch.device(f"{device.type}:0"))
for bench_input in bench_inputs[0]
]
prof_inputs_cuda = [
prof_input.to(torch.device("cuda:0")) for prof_input in prof_inputs[0]
prof_input.to(torch.device(f"{device.type}:0"))
for prof_input in prof_inputs[0]
]

with (
Expand Down Expand Up @@ -713,6 +740,7 @@ def init_module_and_run_benchmark(
func_to_benchmark=func_to_benchmark,
benchmark_func_kwargs=benchmark_func_kwargs,
rank=rank,
device_type=device.type,
)

if queue is not None:
Expand Down Expand Up @@ -801,6 +829,7 @@ def benchmark_module(
benchmark_func_kwargs: Optional[Dict[str, Any]] = None,
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
device_type: str = "cuda",
) -> List[BenchmarkResult]:
"""
Args:
Expand Down Expand Up @@ -884,7 +913,7 @@ def benchmark_module(
callable=init_module_and_run_benchmark,
module=wrapped_module,
sharder=sharder,
device=torch.device("cuda"),
device=torch.device(device_type),
sharding_type=sharding_type,
compile_mode=compile_mode,
world_size=world_size,
Expand All @@ -903,8 +932,7 @@ def benchmark_module(
res = init_module_and_run_benchmark(
module=wrapped_module,
sharder=sharder,
# TODO: GPU hardcode for now, expand if needed for heter hardware
device=torch.device("cuda:0"),
device=torch.device(device_type),
sharding_type=sharding_type,
compile_mode=compile_mode,
world_size=world_size,
Expand Down

0 comments on commit f8e2ff5

Please sign in to comment.