-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Retiarii] Rewrite trainer with PyTorch Lightning #3359
Changes from all commits
d30dbd3
a35448d
1e3bd4b
74a691a
5bb0134
5f5f4c7
cd849b2
b15c0dd
6258bd0
4895abd
b6978aa
d44f038
374584d
36b9352
d591f44
8d1c85e
d8bebe0
fc07e48
271f57b
4d24da3
5d8eec0
304ef02
d9532e2
357fec7
f91af19
87fddd3
a1c8bbb
17c2a9d
cad574b
deba408
ca1bdc0
a9549cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,59 +3,114 @@ Customize A New Trainer | |
|
||
Trainers are necessary to evaluate the performance of new explored models. In NAS scenario, this further divides into two use cases: | ||
|
||
1. **Classic trainers**: trainers that are used to train and evaluate one single model. | ||
1. **Single-arch trainers**: trainers that are used to train and evaluate one single model. | ||
2. **One-shot trainers**: trainers that handle training and searching simultaneously, from an end-to-end perspective. | ||
|
||
Classic trainers | ||
---------------- | ||
Single-arch trainers | ||
-------------------- | ||
|
||
All classic trainers need to inherit ``nni.retiarii.trainer.BaseTrainer``, implement the ``fit`` method and decorated with ``@register_trainer`` if it is intended to be used together with Retiarii. The decorator serialize the trainer that is used and its argument to fit for the requirements of NNI. | ||
With PyTorch-Lightning | ||
^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
The init function of trainer should take model as its first argument, and the rest of the arguments should be named (``*args`` and ``**kwargs`` may not work as expected) and JSON serializable. This means, currently, passing a complex object like ``torchvision.datasets.ImageNet()`` is not supported. Trainer should use NNI standard API to communicate with tuning algorithms. This includes ``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics. | ||
It's recommended to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the `document of PyTorch-lightning <https://pytorch-lightning.readthedocs.io/>` to learn the basic concepts and components provided by PyTorch-lightning. | ||
|
||
In pratice, writing a new training module in NNI should inherit ``nni.retiarii.trainer.pytorch.lightning.LightningModule``, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Trainers should also communicate with strategies via two API calls (``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively. | ||
|
||
An example is as follows: | ||
|
||
.. code-block::python | ||
|
||
from nni.retiarii import register_trainer | ||
from nni.retiarii.trainer import BaseTrainer | ||
from nni.retiarii.trainer.pytorch.lightning import LightningModule # please import this one | ||
|
||
@register_trainer | ||
class MnistTrainer(BaseTrainer): | ||
def __init__(self, model, optimizer_class_name='SGD', learning_rate=0.1): | ||
@blackbox_module | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
class AutoEncoder(LightningModule): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we do not define model in |
||
def __init__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. set_model |
||
super().__init__() | ||
self.model = model | ||
self.criterion = nn.CrossEntropyLoss() | ||
self.train_dataset = MNIST(train=True) | ||
self.valid_dataset = MNIST(train=False) | ||
self.optimizer = getattr(torch.optim, optimizer_class_name)(lr=learning_rate) | ||
|
||
def validate(): | ||
pass | ||
|
||
def fit(self) -> None: | ||
for i in range(10): # number of epochs: | ||
for x, y in DataLoader(self.dataset): | ||
self.optimizer.zero_grad() | ||
pred = self.model(x) | ||
loss = self.criterion(pred, y) | ||
loss.backward() | ||
self.optimizer.step() | ||
acc = self.validate() # get validation accuracy | ||
nni.report_final_result(acc) | ||
self.decoder = nn.Sequential( | ||
nn.Linear(3, 64), | ||
nn.ReLU(), | ||
nn.Linear(64, 28*28) | ||
) | ||
|
||
def forward(self, x): | ||
embedding = self.model(x) # let's search for encoder | ||
return embedding | ||
|
||
def training_step(self, batch, batch_idx): | ||
# training_step defined the train loop. | ||
# It is independent of forward | ||
x, y = batch | ||
x = x.view(x.size(0), -1) | ||
z = self.model(x) # model is the one that is searched for | ||
x_hat = self.decoder(z) | ||
loss = F.mse_loss(x_hat, x) | ||
# Logging to TensorBoard by default | ||
self.log('train_loss', loss) | ||
return loss | ||
|
||
def validation_step(self, batch, batch_idx): | ||
x, y = batch | ||
x = x.view(x.size(0), -1) | ||
z = self.model(x) | ||
x_hat = self.decoder(z) | ||
loss = F.mse_loss(x_hat, x) | ||
self.log('val_loss', loss) | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | ||
return optimizer | ||
|
||
def on_validation_epoch_end(self): | ||
nni.report_intermediate_result(self.trainer.callback_metrics['val_loss'].item()) | ||
|
||
def teardown(self, stage): | ||
if stage == 'fit': | ||
nni.report_final_result(self.trainer.callback_metrics['val_loss'].item()) | ||
|
||
Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a ``Lightning`` object, and pass this object into a Retiarii experiment. | ||
|
||
.. code-block::python | ||
|
||
import nni.retiarii.trainer.pytorch.lightning as pl | ||
from nni.retiarii.experiment.pytorch import RetiariiExperiment | ||
|
||
lightning = pl.Lightning(AutoEncoder(), | ||
pl.Trainer(max_epochs=10), | ||
train_dataloader=pl.DataLoader(train_dataset, batch_size=100), | ||
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100)) | ||
experiment = RetiariiExperiment(base_model, lightning, mutators, strategy) | ||
|
||
With FunctionalTrainer | ||
^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
There is another way to customize a new trainer with functional APIs, which provides more flexibility. Users only need to write a fit function that wraps everything. This function takes one positional arguments (model) and possible keyword arguments. In this way, users get everything under their control, but exposes less information to the framework and thus fewer opportunities for possible optimization. An example is as belows: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should consistently use the word "trainer" |
||
|
||
.. code-block::python | ||
|
||
from nni.retiarii.trainer import FunctionalTrainer | ||
from nni.retiarii.experiment.pytorch import RetiariiExperiment | ||
|
||
def fit(model, dataloader): | ||
train(model, dataloader) | ||
acc = test(model, dataloader) | ||
nni.report_final_result(acc) | ||
|
||
trainer = FunctionalTrainer(fit, dataloader=DataLoader(foo, bar)) | ||
experiment = RetiariiExperiment(base_model, trainer, mutators, strategy) | ||
|
||
|
||
One-shot trainers | ||
----------------- | ||
|
||
One-shot trainers should inheirt ``nni.retiarii.trainer.BaseOneShotTrainer``, which is basically same as ``BaseTrainer``, but only with one extra method ``export()``, which is expected to return the searched best architecture. | ||
One-shot trainers should inheirt ``nni.retiarii.trainer.BaseOneShotTrainer``, and need to implement ``fit()`` (used to conduct the fitting and searching process) and ``export()`` method (used to return the searched best architecture). | ||
|
||
Writing a one-shot trainer is very different to classic trainers. First of all, there are no more restrictions on init method arguments, any Python arguments are acceptable. Secondly, the model feeded into one-shot trainers might be a model with Retiarii-specific modules, such as LayerChoice and InputChoice. Such model cannot directly forward-propagate and trainers need to decide how to handle those modules. | ||
ultmaster marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
A typical example is DartsTrainer, where learnable-parameters are used to combine multiple choices in LayerChoice. Retiarii provides ease-to-use utility functions for module-replace purposes, namely ``replace_layer_choice``, ``replace_input_choice``. A simplified example is as follows: | ||
|
||
.. code-block::python | ||
|
||
from nni.retiarii.trainer import BaseOneShotTrainer | ||
from nni.retiarii.trainer.pytorch import BaseOneShotTrainer | ||
from nni.retiarii.trainer.pytorch.utils import replace_layer_choice, replace_input_choice | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blackbox
is strange here, because i don't understand why it createstrain_dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
o...
MNIST
is a class name?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. This is limited by serialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible we also wrap the dataset class by default? when users want to define their own dataset class, they decorate this class with for example
@register_dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another option, for this pr, is renaming
blackbox
tomake_serializable