Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GRAPH EXECUTOR,VM] Add benchmarking function to graph executor and vm #8807

Merged
merged 6 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions python/tvm/contrib/graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,62 @@ def __getitem__(self, key):
The key to the module.
"""
return self.module[key]

def benchmark(self, device, func_name="run", repeat=5, number=5, min_repeat_ms=None, **kwargs):
"""Calculate runtime of a function by repeatedly calling it.

Use this function to get an accurate measurement of the runtime of a function. The function
is run multiple times in order to account for variability in measurements, processor speed
or other external factors. Mean, median, standard deviation, min and max runtime are all
reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that
synchonization and data transfer operations are not counted towards the runtime. This allows
for fair comparison of runtimes across different functions and models.

The benchmarking loop looks approximately like so:

.. code-block:: python

for r in range(repeat):
time_start = now()
for n in range(number):
func_name()
time_end = now()
total_times.append((time_end - time_start)/number)


Parameters
----------
func_name : str
The function to benchmark

repeat : int
Number of times to run the outer loop of the timing code (see above). The output will
contain `repeat` number of datapoints.

number : int
Number of times to run the inner loop of the timing code. This inner loop is run in
between the timer starting and stopping. In order to amortize any timing overhead,
`number` should be increased when the runtime of the function is small (less than a 1/10
of a millisecond).

min_repeat_ms : Optional[float]
If set, the inner loop will be run until it takes longer than `min_repeat_ms`
milliseconds. This can be used to ensure that the function is run enough to get an
accurate measurement.

kwargs : Dict[str, Object]
Named arguments to the function. These are cached before running timing code, so that
data transfer costs are not counted in the runtime.

Returns
-------
timing_results : ProfileResult
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be BenchmarkResult.

Runtimes of the function. Use `.mean` to access the mean runtime, use `.results` to
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
access the individual runtimes.
"""
min_repeat_ms = 0 if min_repeat_ms is None else min_repeat_ms
if kwargs:
self.set_input(**kwargs)
return self.module.time_evaluator(
func_name, device, repeat=repeat, number=number, min_repeat_ms=min_repeat_ms
)()
27 changes: 7 additions & 20 deletions python/tvm/driver/tvmc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@
import os
import tarfile
import json
from typing import Optional, Union, List, Dict, Callable, TextIO
from typing import Optional, Union, Dict, Callable, TextIO
import numpy as np

import tvm
import tvm.contrib.cc
from tvm import relay
from tvm.contrib import utils
from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule
from tvm.runtime.module import BenchmarkResult

try:
from tvm.micro import export_model_library_format
Expand Down Expand Up @@ -371,14 +372,14 @@ def import_package(self, package_path: str):
class TVMCResult(object):
"""A class that stores the results of tvmc.run and provides helper utilities."""

def __init__(self, outputs: Dict[str, np.ndarray], times: List[float]):
def __init__(self, outputs: Dict[str, np.ndarray], times: BenchmarkResult):
"""Create a convenience wrapper around the output of tvmc.run

Parameters
----------
outputs : dict
Outputs dictionary mapping the name of the output to its numpy value.
times : list of float
times : BenchmarkResult
The execution times measured by the time evaluator in seconds to produce outputs.
"""
self.outputs = outputs
Expand All @@ -390,29 +391,15 @@ def format_times(self):
This has the effect of producing a small table that looks like:
.. code-block::
Execution time summary:
mean (ms) max (ms) min (ms) std (ms)
0.14310 0.16161 0.12933 0.01004
mean (ms) median (ms) max (ms) min (ms) std (ms)
0.14310 0.14310 0.16161 0.12933 0.01004

Returns
-------
str
A formatted string containing the statistics.
"""

# timestamps
mean_ts = np.mean(self.times) * 1000
std_ts = np.std(self.times) * 1000
max_ts = np.max(self.times) * 1000
min_ts = np.min(self.times) * 1000

header = "Execution time summary:\n{0:^10} {1:^10} {2:^10} {3:^10}".format(
"mean (ms)", "max (ms)", "min (ms)", "std (ms)"
)
stats = "{0:^10.2f} {1:^10.2f} {2:^10.2f} {3:^10.2f}".format(
mean_ts, max_ts, min_ts, std_ts
)

return "%s\n%s\n" % (header, stats)
return str(times)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's probably a bad idea to change the semantics of time_evaluator, which is no doubt used in many places where it's expected to return the raw list of times. I would recommend moving the BenchmarkResult to the benchmark() functions and not change time_evaluator itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just the printing code in TVMC, it is not modifying time_evaluator at all. The only changes to the printing code is that it now outputs median.


def get_output(self, name: str):
"""A helper function to grab one of the outputs by name.
Expand Down
8 changes: 2 additions & 6 deletions python/tvm/driver/tvmc/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,8 @@ def run_module(
# This print is intentional
print(report)

# create the module time evaluator (returns a function)
timer = module.module.time_evaluator("run", dev, number=number, repeat=repeat)
# call the evaluator function to invoke the module and save execution times
prof_result = timer()
# collect a list of execution times from the profiling results
times = prof_result.results
# call the benchmarking function of the executor
times = module.benchmark(dev, number=number, repeat=repeat)

logger.debug("Collecting the output tensors.")
num_outputs = module.get_num_outputs()
Expand Down
69 changes: 63 additions & 6 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import os
import ctypes
import struct
from collections import namedtuple
from typing import Sequence
import numpy as np

import tvm._ffi
from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY
Expand All @@ -30,8 +31,65 @@
from . import _ffi_api


# profile result of time evaluator
ProfileResult = namedtuple("ProfileResult", ["mean", "results"])
class BenchmarkResult:
"""Runtimes from benchmarking"""

def __init__(self, results: Sequence[float]):
"""Construct a new BenchmarkResult from a sequence of runtimes.

Parameters
----------
results : Sequence[float]
Raw times from benchmarking

Attributes
----------
min : float
Minimum runtime in seconds of all results.
mean : float
Mean runtime in seconds of all results. Note that this mean is not
necessarily statistically correct as it is the mean of mean
runtimes.
median : float
Median runtime in seconds of all results. Note that this is not necessarily
statistically correct as it is the median of mean runtimes.
max : float
Maximum runtime in seconds of all results.
std : float
Standard deviation in seconds of runtimes. Note that this is not necessarily
correct as it is the std of mean runtimes.
results : Sequence[float]
The collected runtimes (in seconds). This may be a series of mean runtimes if
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs more explanation. Currently a BenchmarkResult object contains no information on the benchmark parameters that were used, and it would be best (IMHO) to avoid there being surprises in terms of the interpretation of the results based on how the object was created. My recommendation would be to either fully document the behavior of what it means to benchmark with 'number > 1' or ensure that the BenchmarkResult object itself contains the benchmark parameters used.

the benchmark was run with `number` > 1.
"""
self.results = results
self.mean = np.mean(self.results)
self.std = np.std(self.results)
self.median = np.median(self.results)
self.min = np.min(self.results)
self.max = np.max(self.results)

def __repr__(self):
return "BenchmarkResult(min={}, mean={}, median={}, max={}, std={}, results={})".format(
self.min, self.mean, self.median, self.max, self.std, self.results
)

def __str__(self):
return """Execution time summary:
{:^12} {:^12} {:^12} {:^12} {:^12}
{:^12.2f} {:^12.2f} {:^12.2f} {:^12.2f} {:^12.2f}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 significant digits is not really enough for models that run quite fast. Would recommend using .4 instead.

""".format(
"mean (ms)",
"median (ms)",
"max (ms)",
"min (ms)",
"std (ms)",
self.mean * 1000,
self.median * 1000,
self.max * 1000,
self.min * 1000,
self.std * 1000,
)


class Module(object):
Expand Down Expand Up @@ -209,7 +267,7 @@ def time_evaluator(self, func_name, dev, number=10, repeat=1, min_repeat_ms=0, f
Returns
-------
ftimer : function
The function that takes same argument as func and returns a ProfileResult.
The function that takes same argument as func and returns a BenchmarkResult.
The ProfileResult reports `repeat` time costs in seconds.
"""
try:
Expand All @@ -230,8 +288,7 @@ def evaluator(*args):
blob = feval(*args)
fmt = "@" + ("d" * repeat)
results = struct.unpack(fmt, blob)
mean = sum(results) / float(repeat)
return ProfileResult(mean=mean, results=results)
return BenchmarkResult(results)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does time_evaluator itself have a unit test that needs to be updated? If not, would it make sense to add a quick one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

time_evaluator was not changed, so it makes no sense to create/modify a unit test for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um, unless I'm reading this diff incorrectly (sorry if I'm getting confused by github diffs!) this is a change to time_evaluator. Am I confused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, this is the python binding for time_evaluator. Output is essentially the same, so there is no need to modify the tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, OK, but the existing tests for time_evaluator seem kind of weak. Here we're adding a new class (BenchmarkResult) which calculates a number of new statistics, but there's no tests for it ... meaning ... a regression in this behavior (either a change to time_evaluator or BenchmarkResult) would not be caught by CI.

At minimum I'd recommend a unit test for BenchmarkResult() itself, which is easy to write. Not trying to be too nitpicky but I do feel like we should be constantly improving our test coverage and adding a new class without corresponding tests tends to be a red flag for me :-)

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify BenchmarkResult is just a renaming + adding a constructor to ProfileResult (which already existed). It is tested in an ad hoc manner in a variety of places. I've added some tests for it constructor.


return evaluator
except NameError:
Expand Down
64 changes: 64 additions & 0 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,67 @@ def get_input_index(self, input_name, func_name="main"):
The input index. -1 will be returned if the given input name is not found.
"""
return self._get_input_index(input_name, func_name)

def benchmark(
self, device, *args, func_name="main", repeat=5, number=5, min_repeat_ms=None, **kwargs
):
"""Calculate runtime of a function by repeatedly calling it.

Use this function to get an accurate measurement of the runtime of a function. The function
is run multiple times in order to account for variability in measurements, processor speed
or other external factors. Mean, median, standard deviation, min and max runtime are all
reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that
synchonization and data transfer operations are not counted towards the runtime. This allows
for fair comparison of runtimes across different functions and models.

The benchmarking loop looks approximately like so:

.. code-block:: python

for r in range(repeat):
time_start = now()
for n in range(number):
func_name()
time_end = now()
total_times.append((time_end - time_start)/number)


Parameters
----------
func_name : str
The function to benchmark

repeat : int
Number of times to run the outer loop of the timing code (see above). The output will
contain `repeat` number of datapoints.

number : int
Number of times to run the inner loop of the timing code. This inner loop is run in
between the timer starting and stopping. In order to amortize any timing overhead,
`number` should be increased when the runtime of the function is small (less than a 1/10
of a millisecond).

min_repeat_ms : Optional[float]
If set, the inner loop will be run until it takes longer than `min_repeat_ms`
milliseconds. This can be used to ensure that the function is run enough to get an
accurate measurement.

args : Sequence[Object]
Arguments to the function. These are cached before running timing code, so that data
transfer costs are not counted in the runtime.

kwargs : Dict[str, Object]
Named arguments to the function. These are cached like `args`.

Returns
-------
timing_results : ProfileResult
Runtimes of the function. Use `.mean` to access the mean runtime, use `.results` to
access the individual runtimes.
"""
min_repeat_ms = 0 if min_repeat_ms is None else min_repeat_ms
if args or kwargs:
self.set_input(func_name, *args, **kwargs)
return self.module.time_evaluator(
"invoke", device, repeat=repeat, number=number, min_repeat_ms=min_repeat_ms
)(func_name)
5 changes: 3 additions & 2 deletions src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,9 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator")
<< "Cannot find " << f_preproc_name << " in the global function";
f_preproc = *pf_preproc;
}
return WrapTimeEvaluator(m.GetFunction(name, false), dev, number, repeat, min_repeat_ms,
f_preproc);
PackedFunc pf = m.GetFunction(name, false);
CHECK(pf != nullptr) << "Cannot find " << name << " in the global registry";
return WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, f_preproc);
}
} else {
auto* pf = runtime::Registry::Get(name);
Expand Down
13 changes: 13 additions & 0 deletions tests/python/relay/test_backend_graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm.contrib import graph_executor
from tvm.relay.op import add
import tvm.testing
from tvm.relay.testing import mlp

# @tq, @jr should we put this in testing ns?
def check_rts(expr, args, expected_result, mod=None):
Expand Down Expand Up @@ -322,5 +323,17 @@ def test_graph_executor_api():
assert mod.get_input_index("Invalid") == -1


@tvm.testing.requires_llvm
def test_benchmark():
mod, params = mlp.get_workload(1)
lib = relay.build(mod, target="llvm", params=params)
exe = graph_executor.create(lib.get_graph_json(), lib.lib, tvm.cpu())
data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"))
result = exe.benchmark(tvm.cpu(), data=data, func_name="run", repeat=2, number=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better test here would be to use unittest.patch() on the time_evaluator function itself so it always returns deterministic times, and hence you can fully test the behavior of the benchmark() function by itself independent of the actual module or time_evaluator behavior.

assert result.mean == result.median
assert result.mean > 0
assert len(result.results) == 2


if __name__ == "__main__":
pytest.main([__file__])
13 changes: 13 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tvm import rpc
import tvm.testing
from tvm.relay.transform import InferType
from tvm.relay.testing import mlp


def check_result(args, expected_result, mod=None):
Expand Down Expand Up @@ -955,5 +956,17 @@ def test_get_input_index():
assert vm_factory.get_input_index("invalid") == -1


@tvm.testing.requires_llvm
def test_benchmark():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above on using unittest.patch() so you can test the benchmark() function directly.

mod, params = mlp.get_workload(1)
lib = vm.compile(mod, target="llvm", params=params)
exe = runtime.vm.VirtualMachine(lib, tvm.cpu())
data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"))
result = exe.benchmark(tvm.cpu(), data, func_name="main", repeat=2, number=1)
assert result.mean == result.median
assert result.mean > 0
assert len(result.results) == 2


if __name__ == "__main__":
pytest.main([__file__])
6 changes: 1 addition & 5 deletions tutorials/auto_scheduler/tune_network_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,7 @@ def tune_and_evaluate():

# Evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500)
prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond
print(
"Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))
)
print(module.benchmark(dev, repeat=3, min_repeat_ms=500))


# We do not run the tuning in our webpage server since the server doesn't have a Raspberry Pi,
Expand Down
4 changes: 1 addition & 3 deletions tutorials/auto_scheduler/tune_network_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,7 @@ def run_tuning():

# Evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500)
prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))
print(module.benchmark(dev, repeat=3, min_repeat_ms=500))


#################################################################
Expand Down
Loading