Skip to content

Commit

Permalink
Add optimizers and schedules to RTD and updated the corresponding par…
Browse files Browse the repository at this point in the history
…t in the website (deepspeedai#799)

* add optimizers and schedules to rtd

* update ds website and fix links

* add optimizers and schedules to rtd

* update ds website and fix links

* add flops profiler to rtd

* fix

Co-authored-by: Shaden Smith <[email protected]>
  • Loading branch information
cli99 and Shaden Smith authored Mar 11, 2021
1 parent f79b7b0 commit 46ee7da
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
6 changes: 0 additions & 6 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@

class DeepSpeedFlopsProfilerConfig(object):
def __init__(self, param_dict):
"""
docstring
"""
super(DeepSpeedFlopsProfilerConfig, self).__init__()

self.enabled = None
Expand All @@ -27,9 +24,6 @@ def __init__(self, param_dict):
self._initialize(flops_profiler_dict)

def _initialize(self, flops_profiler_dict):
"""
docstring
"""
self.enabled = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_ENABLED,
FLOPS_PROFILER_ENABLED_DEFAULT)
Expand Down
56 changes: 55 additions & 1 deletion flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@ class FlopsProfiler(object):
"""Measures the latency, number of estimated floating point operations and parameters of each module in a PyTorch model.
The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how latency, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input.
The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package.
When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config file and no user code change is required.
If using the profiler as a standalone package, one imports the flops_profiler package and use the APIs.
Here is an example for usage in a typical training workflow:
.. code-block:: python
model = Model()
prof = FlopsProfiler(model)
for step, batch in enumerate(data_loader):
if step == profile_step:
prof.start_profile()
loss = model(batch)
if step == profile_step:
flops = prof.get_total_flops(as_string=True)
params = prof.get_total_params(as_string=True)
prof.print_model_profile(profile_step=profile_step)
prof.end_profile()
loss.backward()
optimizer.step()
To profile a trained model in inference, use the `get_model_profile` API.
Args:
object (torch.nn.Module): The PyTorch model to profile.
Expand Down Expand Up @@ -118,6 +146,9 @@ def get_total_flops(self, as_string=False):
Args:
as_string (bool, optional): whether to output the flops as string. Defaults to False.
Returns:
The number of multiply-accumulate operations of the model forward pass.
"""
total_flops = get_module_flops(self.model)
return macs_to_string(total_flops) if as_string else total_flops
Expand All @@ -127,6 +158,9 @@ def get_total_duration(self, as_string=False):
Args:
as_string (bool, optional): whether to output the duration as string. Defaults to False.
Returns:
The latency of the model forward pass.
"""
total_duration = self.model.__duration__
return duration_to_string(total_duration) if as_string else total_duration
Expand All @@ -136,6 +170,9 @@ def get_total_params(self, as_string=False):
Args:
as_string (bool, optional): whether to output the parameters as string. Defaults to False.
Returns:
The number of parameters in the model.
"""
return params_to_string(
self.model.__params__) if as_string else self.model.__params__
Expand All @@ -146,6 +183,12 @@ def print_model_profile(self,
top_modules=3,
detailed=True):
"""Prints the model graph with the measured profile attached to each module.
Args:
profile_step (int, optional): The global training step at which to profile. Note that warm up steps are needed for accurate time measurement.
module_depth (int, optional): The depth of the model at which to print the aggregated module information. When set to -1, it prints information on the innermost modules (with the maximum depth).
top_modules (int, optional): Limits the aggregated profile output to the number of top modules specified.
detailed (bool, optional): Whether to print the detailed model profile.
"""

total_flops = self.get_total_flops()
Expand Down Expand Up @@ -219,7 +262,7 @@ def del_extra_repr(module):
"\n------------------------------ Detailed Profile ------------------------------"
)
print(
"Each module profile is listed after its name in the follwing order: \nnumber of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency)."
"Each module profile is listed after its name in the following order: \nnumber of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency)."
)
print(
"Note: \n1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.\n2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.\n"
Expand Down Expand Up @@ -749,6 +792,14 @@ def get_model_profile(
):
"""Returns the total MACs and parameters of a model.
Example:
.. code-block:: python
model = torchvision.models.alexnet()
batch_size = 256
macs, params = get_model_profile(model=model, input_res= (batch_size, 3, 224, 224)))
Args:
model ([torch.nn.Module]): the PyTorch model to be profiled.
input_res (list): input shape or input to the input_constructor
Expand All @@ -760,6 +811,9 @@ def get_model_profile(
warm_up (int, optional): the number of warm-up steps before measuring the latency of each module. Defaults to 1.
as_string (bool, optional): whether to print the output as string. Defaults to True.
ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None.
Returns:
The number of multiply-accumulate operations (MACs) and parameters in the model.
"""
assert type(input_res) is tuple
assert len(input_res) >= 1
Expand Down

0 comments on commit 46ee7da

Please sign in to comment.