Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Module level Inference Benchmark for CPU device #2080

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 84 additions & 56 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
import os
import time
import timeit
from dataclasses import dataclass

from enum import Enum
Expand All @@ -37,6 +38,7 @@
from torch.autograd.profiler import record_function
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import ShardingType
from torchrec.distributed.global_settings import set_propogate_device

from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
Expand Down Expand Up @@ -411,6 +413,7 @@ def init_argparse_and_args() -> argparse.Namespace:
parser.add_argument("--output_dir", type=str, default="/var/tmp/torchrec-bench")
parser.add_argument("--num_benchmarks", type=int, default=5)
parser.add_argument("--embedding_config_json", type=str, default="")
parser.add_argument("--device_type", type=str, default="cuda")

args = parser.parse_args()

Expand All @@ -436,7 +439,9 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
scripted_module = torch.jit.script(graph_module)
return scripted_module

topology: Topology = Topology(world_size=world_size, compute_device="cuda")
set_propogate_device(True)

topology: Topology = Topology(world_size=world_size, compute_device=device.type)
planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
Expand Down Expand Up @@ -498,6 +503,7 @@ def benchmark(
benchmark_func_kwargs: Optional[Dict[str, Any]],
rank: int,
enable_logging: bool = True,
device_type: str = "cuda",
) -> BenchmarkResult:
max_mem_allocated: List[int] = []
if enable_logging:
Expand All @@ -506,47 +512,64 @@ def benchmark(
for _input in warmup_inputs:
model(_input)

if rank == -1:
# Reset memory for measurement, no process per rank so do all
for di in range(world_size):
torch.cuda.reset_peak_memory_stats(di)
else:
torch.cuda.reset_peak_memory_stats(rank)
if device_type == "cuda":
if rank == -1:
# Reset memory for measurement, no process per rank so do all
for di in range(world_size):
torch.cuda.reset_peak_memory_stats(di)
else:
torch.cuda.reset_peak_memory_stats(rank)

# Measure time taken for batches in bench_inputs
start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)]
end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)]
start = []
end = []
if device_type == "cuda":
# Measure time taken for batches in bench_inputs
start = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)]
end = [torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)]

if benchmark_func_kwargs is None:
# Need this to unwrap
benchmark_func_kwargs = {}

for i in range(num_benchmarks):
start[i].record()
func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs)
end[i].record()
times = []
if device_type == "cuda":
for i in range(num_benchmarks):
start[i].record()
func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs)
end[i].record()
elif device_type == "cpu":
times = timeit.repeat(
lambda: func_to_benchmark(model, bench_inputs, **benchmark_func_kwargs),
number=1,
repeat=num_benchmarks,
)

if rank == -1:
for di in range(world_size):
torch.cuda.synchronize(di)
else:
torch.cuda.synchronize(rank)
if device_type == "cuda":
if rank == -1:
for di in range(world_size):
torch.cuda.synchronize(di)
else:
torch.cuda.synchronize(rank)

# TODO: First Benchmark Run for Eager Mode produces outlier
# Start counting after first as workaround for standard deviation
elapsed_time = torch.tensor(
[si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])]
)

if rank == -1:
# Add up all memory allocated in inference mode
for di in range(world_size):
b = torch.cuda.max_memory_allocated(di)
max_mem_allocated.append(b // 1024 // 1024)
if device_type == "cuda":
elapsed_time = torch.tensor(
[si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])]
)
else:
# Only add up memory allocated for current rank in training mode
b = torch.cuda.max_memory_allocated(rank)
max_mem_allocated.append(b // 1024 // 1024)
elapsed_time = torch.tensor(times)

if device_type == "cuda":
if rank == -1:
# Add up all memory allocated in inference mode
for di in range(world_size):
b = torch.cuda.max_memory_allocated(di)
max_mem_allocated.append(b // 1024 // 1024)
else:
# Only add up memory allocated for current rank in training mode
b = torch.cuda.max_memory_allocated(rank)
max_mem_allocated.append(b // 1024 // 1024)

if output_dir != "":
# Only do profiling if output_dir is set
Expand Down Expand Up @@ -574,22 +597,23 @@ def trace_handler(prof) -> None:
# - cd FlameGraph
# - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg

with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
on_trace_ready=trace_handler,
) as p:
for _input in prof_inputs:
with record_function("## forward ##"):
model(_input)
p.step()
if device_type == "cuda":
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
on_trace_ready=trace_handler,
) as p:
for _input in prof_inputs:
with record_function("## forward ##"):
model(_input)
p.step()

if rank == -1:
for di in range(torch.cuda.device_count()):
Expand Down Expand Up @@ -659,26 +683,29 @@ def init_module_and_run_benchmark(

if rank >= 0:
warmup_inputs_cuda = [
warmup_input.to(torch.device(f"cuda:{rank}"))
warmup_input.to(torch.device(f"{device.type}:{rank}"))
for warmup_input in warmup_inputs[rank]
]
bench_inputs_cuda = [
bench_input.to(torch.device(f"cuda:{rank}"))
bench_input.to(torch.device(f"{device.type}:{rank}"))
for bench_input in bench_inputs[rank]
]
prof_inputs_cuda = [
prof_input.to(torch.device(f"cuda:{rank}"))
prof_input.to(torch.device(f"{device.type}:{rank}"))
for prof_input in prof_inputs[rank]
]
else:
warmup_inputs_cuda = [
warmup_input.to(torch.device("cuda:0")) for warmup_input in warmup_inputs[0]
warmup_input.to(torch.device(f"{device.type}:0"))
for warmup_input in warmup_inputs[0]
]
bench_inputs_cuda = [
bench_input.to(torch.device("cuda:0")) for bench_input in bench_inputs[0]
bench_input.to(torch.device(f"{device.type}:0"))
for bench_input in bench_inputs[0]
]
prof_inputs_cuda = [
prof_input.to(torch.device("cuda:0")) for prof_input in prof_inputs[0]
prof_input.to(torch.device(f"{device.type}:0"))
for prof_input in prof_inputs[0]
]

with (
Expand Down Expand Up @@ -713,6 +740,7 @@ def init_module_and_run_benchmark(
func_to_benchmark=func_to_benchmark,
benchmark_func_kwargs=benchmark_func_kwargs,
rank=rank,
device_type=device.type,
)

if queue is not None:
Expand Down Expand Up @@ -801,6 +829,7 @@ def benchmark_module(
benchmark_func_kwargs: Optional[Dict[str, Any]] = None,
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
device_type: str = "cuda",
) -> List[BenchmarkResult]:
"""
Args:
Expand Down Expand Up @@ -884,7 +913,7 @@ def benchmark_module(
callable=init_module_and_run_benchmark,
module=wrapped_module,
sharder=sharder,
device=torch.device("cuda"),
device=torch.device(device_type),
sharding_type=sharding_type,
compile_mode=compile_mode,
world_size=world_size,
Expand All @@ -903,8 +932,7 @@ def benchmark_module(
res = init_module_and_run_benchmark(
module=wrapped_module,
sharder=sharder,
# TODO: GPU hardcode for now, expand if needed for heter hardware
device=torch.device("cuda:0"),
device=torch.device(device_type),
sharding_type=sharding_type,
compile_mode=compile_mode,
world_size=world_size,
Expand Down
Loading