Skip to content

Commit

Permalink
Add percentiles, optional logging (#1909)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1909

Add dataclass methods to print out percentiles.  Also, useful to make logging optional when calling for other benchmarking tooling.

Reviewed By: PaulZhang12

Differential Revision: D56432475

fbshipit-source-id: 84e808d28703283fb427dd2cfe8100489f672ebb
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Apr 22, 2024
1 parent 7d6f3f4 commit cbde0b7
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
Union,
)

import numpy as np

import torch
from torch import multiprocessing as mp
from torch.autograd.profiler import record_function
Expand Down Expand Up @@ -102,10 +104,16 @@ class CompileMode(Enum):
class BenchmarkResult:
"Class for holding results of benchmark runs"
short_name: str
elapsed_time: torch.Tensor
max_mem_allocated: List[int]
elapsed_time: torch.Tensor # milliseconds
max_mem_allocated: List[int] # megabytes
rank: int = -1

def runtime_percentile(self, percentile: int = 50) -> torch.Tensor:
return np.percentile(self.elapsed_time, percentile)

def max_mem_percentile(self, percentile: int = 50) -> torch.Tensor:
return np.percentile(self.max_mem_allocated, percentile)


class ECWrapper(torch.nn.Module):
"""
Expand Down Expand Up @@ -482,9 +490,11 @@ def benchmark(
func_to_benchmark: Any,
benchmark_func_kwargs: Optional[Dict[str, Any]],
rank: int,
enable_logging: bool = True,
) -> BenchmarkResult:
max_mem_allocated: List[int] = []
logger.info(f" BENCHMARK_MODEL[{name}]:\n{model}")
if enable_logging:
logger.info(f" BENCHMARK_MODEL[{name}]:\n{model}")

for _input in warmup_inputs:
model(_input)
Expand Down

0 comments on commit cbde0b7

Please sign in to comment.