Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timer tool to Profiler #40386

Merged
merged 15 commits into from
Mar 30, 2022
7 changes: 7 additions & 0 deletions python/paddle/fluid/dataloader/dataloader_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_DatasetKind, _IterableDatasetStopIteration, _WorkerException, \
_ResumeIteration
from .flat import _flatten_batch, _restore_batch
from paddle.profiler.timer import benchmark

__all__ = ['get_worker_info']

Expand Down Expand Up @@ -256,6 +257,8 @@ def __next__(self):
event_type=profiler.TracerEventType.Dataloader)
trace_event.begin()
try:
benchmark().check_if_need_record(self)
benchmark().before_reader()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_if_need_record是否在before_reader中自动调用比较好?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为check_if_need_record要接受一个reader参数,和hook的几个基本接口的参数列表不一样。这块为了保持接口的一致性,这个函数单独抽出来

if in_dygraph_mode():
data = core.eager.read_next_tensor_list(
self._reader.read_next_list()[0])
Expand Down Expand Up @@ -283,6 +286,7 @@ def __next__(self):
data = data[0]
else:
data = self._reader.read_next()
benchmark().after_reader()

return data
except StopIteration:
Expand Down Expand Up @@ -708,6 +712,8 @@ def __next__(self):
event_type=profiler.TracerEventType.Dataloader)
trace_event.begin()
try:
benchmark().check_if_need_record(self)
benchmark().before_reader()
# _batches_outstanding here record the total batch data number
# in 'from after _try_put_indices to beforeoutput data', this
# value should be _outstanding_capacity if data is not drained,
Expand Down Expand Up @@ -750,6 +756,7 @@ def __next__(self):
else:
data = self._reader.read_next()
self._on_output_batch()
benchmark().after_reader()
return data
except StopIteration:
if not self._persistent_workers:
Expand Down
72 changes: 72 additions & 0 deletions python/paddle/fluid/tests/unittests/test_newprofiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

import paddle
import paddle.profiler as profiler
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset, DataLoader


class TestProfiler(unittest.TestCase):
Expand Down Expand Up @@ -125,5 +128,74 @@ def my_sheduler1(num_step):
result = profiler.utils.load_profiler_result('./test_profiler_pb.pb')


class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples

def __getitem__(self, idx):
image = np.random.random([100]).astype('float32')
label = np.random.randint(0, 10 - 1, (1, )).astype('int64')
return image, label

def __len__(self):
return self.num_samples


class SimpleNet(nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(100, 10)

def forward(self, image, label=None):
return self.fc(image)


class TestTimerOnly(unittest.TestCase):
def test_with_dataloader(self):
def train(step_num_samples=None):
dataset = RandomDataset(20 * 4)
simple_net = SimpleNet()
opt = paddle.optimizer.SGD(learning_rate=1e-3,
parameters=simple_net.parameters())
loader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
drop_last=True,
num_workers=2)
step_info = ''
p = profiler.Profiler(timer_only=True)
p.start()
for i, (image, label) in enumerate(loader()):
out = simple_net(image)
loss = F.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
opt.minimize(avg_loss)
simple_net.clear_gradients()
p.step(num_samples=step_num_samples)
if i % 10 == 0:
step_info = p.step_info()
print("Iter {}: {}".format(i, step_info))
p.stop()
return step_info

step_info = train(step_num_samples=None)
self.assertTrue('steps/s' in step_info)
step_info = train(step_num_samples=4)
self.assertTrue('samples/s' in step_info)

def test_without_dataloader(self):
x = paddle.to_tensor(np.random.randn(10, 10))
y = paddle.to_tensor(np.random.randn(10, 10))
p = profiler.Profiler(timer_only=True)
p.start()
step_info = ''
for i in range(20):
out = x + y
p.step()
p.stop()


if __name__ == '__main__':
unittest.main()
133 changes: 131 additions & 2 deletions python/paddle/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from .utils import RecordEvent, wrap_optimizers
from .profiler_statistic import StatisticData, _build_table, SortedKeys
from .timer import benchmark


class ProfilerState(Enum):
Expand Down Expand Up @@ -269,6 +270,8 @@ class Profiler:
which means profiling range [start_batch, end_batch).
on_trace_ready (Callable, optional): Callable object, serves as callback function, and takes the Profiler object as parameter, which provides a way for users to do post-processing.
This callable object will be called when ``scheduler`` returns ``ProfilerState.RECORD_AND_RETURN``. The default value is :ref:`export_chrome_tracing <api_paddle_profiler_export_chrome_tracing>` (./profiler_log/).
timer_only (bool, optional): If it is True, the cost of Dataloader and every step of the model will be count without profiling. Otherwise, the model will
be timed and profiled. Default: False.

Examples:
1. profiling range [2, 5).
Expand Down Expand Up @@ -316,14 +319,78 @@ class Profiler:
p.stop()
p.summary()

4. Use profiler to get throughput and cost of the model

.. code-block:: python
:name: code-example-timer1

import paddle
import paddle.profiler as profiler
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use Paddle's API create Tensor instead of numpy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


class RandomDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples

def __getitem__(self, idx):
image = np.random.random([100]).astype('float32')
label = np.random.randint(0, 10 - 1, (1, )).astype('int64')
return image, label

def __len__(self):
return self.num_samples

class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = paddle.nn.Linear(100, 10)

def forward(self, image, label=None):
return self.fc(image)

dataset = RandomDataset(20 * 4)
simple_net = SimpleNet()
opt = paddle.optimizer.SGD(learning_rate=1e-3,
parameters=simple_net.parameters())
BATCH_SIZE = 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python里面不怎么用这种命名风格吧?batch_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为这里是一个常量名,python代码风格一般是用大写的

loader = paddle.io.DataLoader(
dataset,
batch_size=BATCH_SIZE)
p = profiler.Profiler(timer_only=True)
p.start()
for i, (image, label) in enumerate(loader()):
out = simple_net(image)
loss = paddle.nn.functional.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
opt.minimize(avg_loss)
simple_net.clear_gradients()
p.step(num_samples=BATCH_SIZE)
if i % 10 == 0:
step_info = p.step_info(unit='images')
print("Iter {}: {}".format(i, step_info))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用注释的方式写下输出。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# The average statistics for 10 steps between the last and this call will be
# printed when the "step_info" is called at 10 iteration intervals.
# The values you get may be different from the following.
# Iter 0: reader_cost: 0.51946 s batch_cost: 0.66077 s ips: 6.054 images/s
# Iter 10: reader_cost: 0.00014 s batch_cost: 0.00441 s ips: 907.009 images/s
p.stop()
# The performance summary will be automatically printed when the "stop" is called.
# Reader Ratio: 2.658%
# Time Unit: s, IPS Unit: images/s
# | | avg | max | min |
# | reader_cost | 0.00011 | 0.00013 | 0.00007 |
# | batch_cost | 0.00405 | 0.00434 | 0.00326 |
# | ips | 1086.42904 | 1227.30604 | 959.92796 |
"""

def __init__(
self,
*,
targets: Optional[Iterable[ProfilerTarget]]=None,
scheduler: Union[Callable[[int], ProfilerState], tuple, None]=None,
on_trace_ready: Optional[Callable[..., Any]]=None):
on_trace_ready: Optional[Callable[..., Any]]=None,
timer_only: Optional[bool]=False):
supported_targets = _get_supported_targets()
if targets:
self.targets = set(targets)
Expand Down Expand Up @@ -371,6 +438,7 @@ def __init__(
self.current_state = self.scheduler(self.step_num)
self.record_event = None
self.profiler_result = None
self.timer_only = timer_only

def __enter__(self):
self.start()
Expand Down Expand Up @@ -399,7 +467,12 @@ def start(self):
#train()
prof.step()
prof.stop()

'''
# Timing only without profiling
benchmark().begin()
if self.timer_only:
return
# CLOSED -> self.current_state
if self.current_state == ProfilerState.READY:
self.profiler.prepare()
Expand Down Expand Up @@ -435,6 +508,9 @@ def stop(self):
prof.step()
prof.stop()
'''
benchmark().end()
if self.timer_only:
return
# self.current_state -> CLOSED
# In this situation, RECORD state is regarded as RECORD_AND_RETURN
if self.record_event:
Expand All @@ -451,11 +527,15 @@ def stop(self):
if self.on_trace_ready:
self.on_trace_ready(self)

def step(self):
def step(self, num_samples: Optional[int]=None):
r"""
Signals the profiler that the next profiling step has started.
Get the new ProfilerState and trigger corresponding action.

Args:
num_samples (int|None, optional): Specifies the batch size of every step of the model
that is used to compute throughput when timer_only is True. Default: None.

Examples:
.. code-block:: python
:name: code-example6
Expand All @@ -473,6 +553,9 @@ def step(self):
prof.step()
prof.stop()
"""
benchmark().step(num_samples)
if self.timer_only:
return
if self.record_event:
self.record_event.end()
self.record_event = None
Expand All @@ -485,6 +568,52 @@ def step(self):
event_type=TracerEventType.ProfileStep)
self.record_event.begin()

def step_info(self, unit=None):
r"""
Get statistics for current step. If the function is called at certain iteration
intervals, the result is the average of all steps between the previous call and
this call. Statistics are as follows:

1. reader_cost: the cost of loading data measured in seconds.

2. batch_cost: the cost of step measured in seconds.

3. ips(Instance Per Second): the throughput of the model measured in `samples/s`
or others depends on the `unit`. When `num_samples` of `step()` is None, it is
measured in `steps/s`.

Args:
unit (string, optional): The unit of input data is only used When `num_samples`
of `step()` is specified as a number. For example, when it is `images`, the unit
of throughput is `images/s`. Default: None, the unit of throughput is `samples/s`.

Returns:
string: A string representing the statistic.
Examples:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a black line before Examples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.. code-block:: python
:name: code-example-timer2

import paddle.profiler as profiler
prof = profiler.Profiler(timer_only=True)
prof.start()
for iter in range(20):
#train()
prof.step()
if iter % 10 == 0:
print("Iter {}: {}".format(iter, prof.step_info()))
# The example does not call the DataLoader, so there is no "reader_cost".
# Iter 0: batch_cost: 0.00001 s ips: 86216.623 steps/s
# Iter 10: batch_cost: 0.00001 s ips: 103645.034 steps/s
prof.stop()
# Time Unit: s, IPS Unit: steps/s
# | | avg | max | min |
# | batch_cost | 0.00000 | 0.00002 | 0.00000 |
# | ips | 267846.19437 | 712030.38727 | 45134.16662 |
"""
if unit is None:
unit = 'samples'
return benchmark().step_info(unit)

def _trigger_action(self):
if self.previous_state == ProfilerState.CLOSED:
if self.current_state == ProfilerState.READY: # CLOSED -> READY
Expand Down
Loading