Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
lRomul committed Apr 25, 2024
2 parents a62c30d + 5a1924f commit c2efec9
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 37 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ pip install -U git+https://github.com/lRomul/argus.git@dev
Simple image classification example with `create_model` from [pytorch-image-models](https://github.com/rwightman/pytorch-image-models):

```python
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

import timm
Expand Down Expand Up @@ -100,6 +100,7 @@ if __name__ == "__main__":
```

More examples you can find [here](https://pytorch-argus.readthedocs.io/en/latest/examples.html).
Additional guides on how to customize and use argus component can be found in [Guides](https://pytorch-argus.readthedocs.io/en/latest/guides.html) section.


## Why this name, Argus?
Expand Down
4 changes: 2 additions & 2 deletions argus/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@


class EventEnum(Enum):
"""Base class for engine events. User defined custom events should also
inherit this class. Example of creating custom events you can find
"""Base class for engine events. User-defined custom events should also
inherit this class. An example of creating custom events is available
`here <https://github.com/lRomul/argus/blob/master/examples/custom_events.py>`_.
"""

Expand Down
9 changes: 9 additions & 0 deletions docs/source/api_reference/callbacks/lr_schedulers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,12 @@ OneCycleLR
:members:

PyTorch docs on :class:`torch.optim.lr_scheduler.OneCycleLR`.

LRScheduler
-----------

Base learning rate scheduler callback. It can be used as a wrapper to adapt PyTorch or other custom
learning rate schedulers.

.. autoclass:: LRScheduler
:members:
2 changes: 1 addition & 1 deletion docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Basic examples
Advanced examples
-----------------

* `CIFAR with DPP, mixed precision and gradient accumulation. <https://github.com/lRomul/argus/blob/master/examples/cifar_advanced.py>`_
* `CIFAR with DDP, mixed precision and gradient accumulation. <https://github.com/lRomul/argus/blob/master/examples/cifar_advanced.py>`_

Single GPU training:

Expand Down
184 changes: 154 additions & 30 deletions docs/source/guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ The simplest user case is allows to load a model with saved parameters and compo
However, the model loading process may require customizations; some cases are provided below.

1. Load the model to a specified device.
1. Load the model to a specific device.
Just provide the desired device name or a list of devices.

.. code:: python
Expand All @@ -156,9 +156,10 @@ However, the model loading process may require customizations; some cases are pr
model = load_model('/path/to/model/file', device='cuda:0')
The feature is helpful if one wants to load the model to a specific device for training or inference
and also to load the model on a machine that does not have the device, which was specified before the
model file was saved. For example, if the model was saved with ``device=='cuda:1'``,
while the target machine is equipped with the only GPU, so, ``device=='cuda:0'`` is the only valid option.
and also to load the model on a machine that does not have the device, which was used before the
model file was saved. For example, if a model was saved with ``device='cuda:1'`` but the target machine
only has one GPU, one would need to load the model on that GPU. In this case, the device
should be specified as ``device='cuda:0'``, as it is the only valid GPU option.

.. note::

Expand Down Expand Up @@ -192,13 +193,86 @@ However, the model loading process may require customizations; some cases are pr
model = load_model('/path/to/model/file', prediction_transform={'device': my_device},
device=my_device)
4. Partial weights loading and manipulation.
Sometimes it is necessary to load only some of the model's weights, for example,
to reuse a pretrained backbone while utilising new heads, or load a subset of weights
from a saved model. It also applies to cases when the pretrained model was trained outside of
argus and it is required to utilise some of the pretrained weights. In that situation,
it is possible to perform any operations on the model or optimizer state dict during the loading process.

To do this, it is necessary to define a function which takes the original state dicts
and updates them as needed; then, the function should be passed to :meth:`argus.model.load_model`
as an argument ``change_state_dict_func``.

.. code:: python
from argus import load_model
def update_state_dict(nn_state_dict: dict,
optimizer_state_dict: Optional[dict] = None):
# TODO custom operations on the state dict
return nn_state_dict, optimizer_state_dict
model = load_model('/path/to/model/file',
change_state_dict_func=update_state_dict)
In order to change some weights in an already created model, you can manipulate
the model's state dict directly and then load it using :meth:`torch.nn.Module.load_state_dict`:

.. code:: python
from argus import Model
model: Model = ... # The model to be manipulated
nn_state_dict = model.get_nn_module().state_dict()
nn_state_dict = ... # Perform required operations on the state dict
model.get_nn_module().load_state_dict(nn_state_dict)
5. Model import.
In cases where it is required to load a model that is not a typical PyTorch argus model,
which cannot be loaded with :func:`torch.load`, for example, when the model was trained
using another framework or saved in a different format, one can implement a converter
loading function that takes the path to the model file as input, reads the file and converts
it to an appropriate state dictionary. The function should then be passed to
:meth:`argus.model.load_model` as an argument ``state_load_func``.

.. seealso::
* For more information see the :func:`argus.model.load_model` documentation.
* More real-world examples of how to use `load_model` are available
`here <https://github.com/lRomul/argus/blob/master/examples/load_model.py>`_.

.. _model_export:

Model export
------------

.. _custom metrics:
:meth:`argus.model.Model.get_nn_module` allows to get raw PyTorch ``nn.Module`` from an argus model.
It can be beneficial, for instance, to convert a model into another format for optimised inference.

The example below shows how to get ``nn.Module`` and convert it to ONNX format with
dynamic batch size by using :func:`torch.onnx.export`.

.. code:: python
import torch
from argus import load_model
# Assuming the model has one input and one output.
model = load_model('/path/to/model/file', device='cpu', loss=None,
optimizer=None, prediction_transform=None)
nn_module = model.get_nn_module()
sample_input = torch.ones((1, 3, 224, 224)) # Model input tensor for batch_size=1
torch.onnx.export(nn_module, sample_input, '/path/to/save/onnx/file',
input_names=['input_0'], output_names=['output_0'],
dynamic_axes={'input_0': {0: 'batch_size'},
'output_0': {0: 'batch_size'}})
.. _custom_metrics:

Custom metrics
--------------
Expand Down Expand Up @@ -233,10 +307,12 @@ The first attribute specifies the name of the evaluation metric, while the secon
(`max`) or a lower value (`min`) means improvement for this metric.

The code below demonstrates a top-K accuracy metric, which implements the required methods.
:class:`argus.utils.AverageMeter` used to compute the average metric value over the predictions.

.. code-block:: python
from argus.metrics import Metric
from argus.utils import AverageMeter
class TopKAccuracy(Metric):
Expand All @@ -251,26 +327,24 @@ The code below demonstrates a top-K accuracy metric, which implements the requir
def __init__(self, k: int = 5):
self.k = k
self.correct = 0
self.count = 0
self.accuracy_meter = AverageMeter()
# Parametrized name allows having several instances of the metric with different k values
self.name = f'top_{self.k}_accuracy'
def reset(self):
self.correct = 0
self.count = 0
self.accuracy_meter.reset()
def update(self, step_output: dict):
indices = torch.topk(step_output['prediction'], k=self.k, dim=1)[1]
target = step_output['target'].unsqueeze(1)
correct = torch.any(indices == target, dim=1)
self.correct += torch.sum(correct).item()
self.count += correct.shape[0]
n_correct = torch.sum(torch.any(indices == target, dim=1)).item()
n_items = target.shape[0]
self.accuracy_meter.update(n_correct, n=n_items)
def compute(self) -> float:
if self.count == 0:
if self.accuracy_meter.count == 0:
raise RuntimeError('Must be at least one example for computation')
return self.correct / self.count
return self.accuracy_meter.average
In some more advanced use cases, it may be required to create a custom metric to
Expand All @@ -289,9 +363,10 @@ correct answer was present among the top-K predictions.
from argus.engine import State
from argus.metrics import Metric
from argus.utils import AverageMeter
class TopKAccuracy(Metric):
class TopKAccuracyRank(Metric):
"""Calculate the top-K accuracy for multiclass classification.
It also reports the average rank of the correct top-K predictions.
Expand All @@ -304,34 +379,83 @@ correct answer was present among the top-K predictions.
def __init__(self, k: int = 5):
self.k = k
self.correct = 0
self.rank = 0
self.count = 0
self.accuracy_meter = AverageMeter()
self.rank_meter = AverageMeter()
self.name = f'top_{self.k}_accuracy'
def reset(self):
self.correct = 0
self.rank = 0
self.count = 0
self.accuracy_meter.reset()
self.rank_meter.reset()
def update(self, step_output: dict):
indices = torch.topk(step_output['prediction'], k=self.k, dim=1)[1]
target = step_output['target'].unsqueeze(1)
correct = torch.any(indices == target, dim=1)
rank = torch.nonzero(indices == target)[:, 1]
self.correct += torch.sum(correct).item()
self.rank += torch.sum(rank).item()
self.count += correct.shape[0]
n_correct = torch.sum(torch.any(indices == target, dim=1)).item()
rank_sum = torch.sum(torch.nonzero(indices == target)[:, 1]).item()
n_items = target.shape[0]
self.accuracy_meter.update(n_correct, n=n_items)
self.rank_meter.update(rank_sum, n=n_items)
def compute(self) -> float:
if self.count == 0:
if self.accuracy_meter.count == 0:
raise RuntimeError('Must be at least one example for computation')
return self.correct / self.count
return self.accuracy_meter.average
def epoch_complete(self, state: State):
with torch.no_grad():
accuracy = self.compute()
rank = self.rank / self.count + 1.0 # +1.0 because ranks are 1-indexed
rank = self.rank_meter.average + 1.0 # +1.0 because ranks are 1-indexed
name_prefix = f"{state.phase}_" if state.phase else ''
state.metrics[f'{name_prefix}{self.name}'] = accuracy
state.metrics[f'{name_prefix}rank_{self.k}'] = rank
state.metrics[f'{name_prefix}rank_{self.k}'] = rank
.. _custom_callbacks:

Custom callbacks
----------------

Custom callbacks can be implemented in a similar way as custom metrics. The custom
callback class should inherit :class:`argus.callbacks.Callback` and redefine the methods,
triggered by the required callback actions, such as ``epoch_complete`` or ``iteration_start``.
See details and an example in :class:`argus.callbacks.Callback` documentation.

It is also possible to define custom events to trigger a custom callback action in any specific
moment of the training or validation loop. It requires registering the necessary custom events
in :class:`argus.engine.EventEnum` and then raising the events with :meth:`argus.engine.State.engine.raise_event`.
This will trigger all the custom callbacks, which implement the method for the custom event handling.
See details in an `example <https://github.com/lRomul/argus/blob/master/examples/custom_events.py>`_ code.

.. _LR_schedulers:

Learning rate schedulers
------------------------

Argus learning rate schedulers can be used to adjust the learning rate during the training process.
There are many types provided with *argus*; for details, see :doc:`api_reference/callbacks/lr_schedulers`.
Once created, a scheduler should be added to the list of callbacks provided to :meth:`argus.model.Model.fit` as the
``callbacks`` argument.

The schedulers are implemented as special callbacks, inheriting from :class:`argus.callbacks.LRScheduler`.
The class can be used to create custom schedulers or adapt a PyTorch :class:`torch.optim.lr_scheduler.LRScheduler` scheduler.

The following shows an example of how to use :class:`argus.callbacks.LRScheduler`:

.. code-block:: python
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import ConstantLR
from argus.callbacks.lr_schedulers import LRScheduler
def get_lr_scheduler(optimizer: Optimizer):
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1, total_iters=2)
return scheduler
lr_scheduler = LRScheduler(get_lr_scheduler)
model.fit(...,
callbacks=[lr_scheduler])
Similar approach can be used to combine several schedulers with :class:`torch.optim.lr_scheduler.SequentialLR`
or :class:`torch.optim.lr_scheduler.ChainedScheduler`. See an example in the
`sequential_lr_scheduler.py <https://github.com/lRomul/argus/blob/master/examples/sequential_lr_scheduler.py>`_ code.
4 changes: 2 additions & 2 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,5 @@ Argus allows managing different parts combinations of a pipeline.

* Override methods of :class:`argus.model.Model`. For example, :meth:`argus.model.Model.train_step`
and :meth:`argus.model.Model.val_step`. See :ref:`train_and_val_steps` guide for details.
* Create a custom :class:`argus.callbacks.Callback`.
* Use a custom :class:`argus.metrics.Metric`.
* Create a custom :class:`argus.callbacks.Callback`. See :ref:`custom_callbacks` guide.
* Implement a custom :class:`argus.metrics.Metric`. See :ref:`custom_metrics` guide.
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

## Advanced examples

* CIFAR with DPP, mixed precision and gradient accumulation [cifar_advanced.py](cifar_advanced.py).
* CIFAR with DDP, mixed precision and gradient accumulation [cifar_advanced.py](cifar_advanced.py).

Single GPU training:

Expand Down

0 comments on commit c2efec9

Please sign in to comment.