Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
add PBT tuner (#2139)
Browse files Browse the repository at this point in the history
  • Loading branch information
RayMeng8 authored Mar 30, 2020
1 parent c261146 commit a82b4a3
Show file tree
Hide file tree
Showing 17 changed files with 659 additions and 96 deletions.
29 changes: 29 additions & 0 deletions docs/en_US/Tuner/BuiltinTuner.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Currently, we support the following algorithms:
|[__BOHB__](#BOHB)|BOHB is a follow-up work to Hyperband. It targets the weakness of Hyperband that new configurations are generated randomly without leveraging finished trials. For the name BOHB, HB means Hyperband, BO means Bayesian Optimization. BOHB leverages finished trials by building multiple TPE models, a proportion of new configurations are generated through these models. [Reference Paper](https://arxiv.org/abs/1807.01774)|
|[__GP Tuner__](#GPTuner)|Gaussian Process Tuner is a sequential model-based optimization (SMBO) approach with Gaussian Process as the surrogate. [Reference Paper](https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf), [Github Repo](https://github.com/fmfn/BayesianOptimization)|
|[__PPO Tuner__](#PPOTuner)|PPO Tuner is a Reinforcement Learning tuner based on PPO algorithm. [Reference Paper](https://arxiv.org/abs/1707.06347)|
|[__PBT Tuner__](#PBTTuner)|PBT Tuner is a simple asynchronous optimization algorithm which effectively utilizes a fixed computational budget to jointly optimize a population of models and their hyperparameters to maximize performance. [Reference Paper](https://arxiv.org/abs/1711.09846v1)|

## Usage of Built-in Tuners

Expand Down Expand Up @@ -453,6 +454,34 @@ tuner:
classArgs:
optimize_mode: maximize
```

<a name="PBTTuner"></a>

![](https://placehold.it/15/1589F0/000000?text=+) `PBT Tuner`

> Built-in Tuner Name: **PBTTuner**

**Suggested scenario**

Population Based Training (PBT) which bridges and extends parallel search methods and sequential optimization methods. It has a wallclock run time that is no greater than that of a single optimization process, does not require sequential runs, and is also able to use fewer computational resources than naive search methods. Therefore, it's effective when you want to save computational resources and time. Besides, PBT returns hyperparameter scheduler instead of configuration. If you don't need to get a specific configuration, but just expect good results, you can choose this tuner. It should be noted that, in our implementation, the operation of checkpoint storage location is involved. A trial is considered as several traning epochs of training, so the loading and saving of checkpoint must be specified in the trial code, which is different with other tuners. Otherwise, if the experiment is not local mode, users should provide a path in a shared storage which can be accessed by all the trials. You could try it on very simple task, such as the [mnist-pbt-tuner-pytorch](https://github.com/microsoft/nni/tree/master/examples/trials/mnist-pbt-tuner-pytorch) example. [See details](./PBTTuner.md)

**classArgs requirements:**

* **optimize_mode** (*'maximize' or 'minimize'*) - If 'maximize', the tuner will target to maximize metrics. If 'minimize', the tuner will target to minimize metrics.
* **all_checkpoint_dir** (*str, optional, default = None*) - Directory for trials to load and save checkpoint, if not specified, the directory would be "~/nni/checkpoint/<exp-id>". Note that if the experiment is not local mode, users should provide a path in a shared storage which can be accessed by all the trials.
* **population_size** (*int, optional, default = 10*) - Number of trials for each step. In our implementation, one step is running each trial by specific training epochs set by users.
* **factors** (*tuple, optional, default = (1.2, 0.8)*) - Factors for perturbation of hyperparameters.
* **fraction** (*float, optional, default = 0.2*) - Fraction for selecting bottom and top trials.

**Usage example**

```yaml
# config.yml
tuner:
builtinTunerName: PBTTuner
classArgs:
optimize_mode: maximize
```
## **Reference and Feedback**
* To [report a bug](https://github.com/microsoft/nni/issues/new?template=bug-report.md) for this feature in GitHub;
* To [file a feature or improvement request](https://github.com/microsoft/nni/issues/new?template=enhancement.md) for this feature in GitHub;
Expand Down
12 changes: 12 additions & 0 deletions docs/en_US/Tuner/PBTTuner.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
PBT Tuner on NNI
===

## PBTTuner

Population Based Training (PBT) comes from [Population Based Training of Neural Networks](https://arxiv.org/abs/1711.09846v1). It's a simple asynchronous optimization algorithm which effectively utilizes a fixed computational budget to jointly optimize a population of models and their hyperparameters to maximize performance. Importantly, PBT discovers a schedule of hyperparameter settings rather than following the generally sub-optimal strategy of trying to find a single fixed set to use for the whole course of training.

PBTTuner initializes a population with several trials. Users can set a specific number of training epochs. After a certain number of epochs, the parameters and hyperparameters in the trial with bad metrics will be replaced with a better trial (exploit). Then the hyperparameters are perturbed (explore).

In our implementation, training epochs in the trial code is regarded as a step of PBT, different with other tuners. At the end of each step, PBT tuner will do exploitation and exploration -- replacing some trials with new trials. This is implemented by constantly modifying the values of `load_checkpoint_dir` and `save_checkpoint_dir`. We can directly change `load_checkpoint_dir` to replace parameters and hyperparameters, and `save_checkpoint_dir` to save a checkpoint that will be loaded in the next step. To this end, we need a shared folder which is accessible to all trials.

If the experiment is running in local mode, users could provide an argument `all_checkpoint_dir` which will be the base folder of `load_checkpoint_dir` and `save_checkpoint_dir` (`checkpoint_dir` is set to `all_checkpoint_dir/<population-id>/<step>`). By default, `all_checkpoint_dir` is set to be `~/nni/experiments/<exp-id>/checkpoint`. If the experiment is in non-local mode, then users should provide a path in a shared storage folder which is mounted at `all_checkpoint_dir` on worker machines (but it's not necessarily available on the machine which runs tuner).
1 change: 1 addition & 0 deletions docs/en_US/builtin_tuner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ Tuner receives metrics from `Trial` to evaluate the performance of a specific pa
Hyperband <Tuner/HyperbandAdvisor>
BOHB <Tuner/BohbAdvisor>
PPO Tuner <Tuner/PPOTuner>
PBT Tuner <Tuner/PBTTuner>
1 change: 1 addition & 0 deletions examples/trials/mnist-pbt-tuner-pytorch/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tmp
Empty file.
22 changes: 22 additions & 0 deletions examples/trials/mnist-pbt-tuner-pytorch/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
authorName: default
experimentName: example_mnist_pbt_tuner_pytorch
trialConcurrency: 3
maxExecDuration: 2h
maxTrialNum: 100
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
# codeDir: ~/nni/src/sdk/pynni/nni/pbt_tuner
# classFileName: pbt_tuner.py
# className: PBTTuner
builtinTunerName: PBTTuner
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
gpuNum: 1
187 changes: 187 additions & 0 deletions examples/trials/mnist-pbt-tuner-pytorch/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import argparse
import logging

import os
import nni
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


logger = logging.getLogger('mnist_pbt_tuner_pytorch_AutoML')

class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, hidden_size)
self.fc2 = nn.Linear(hidden_size, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args['log_interval'] == 0:
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))


def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

accuracy = 100. * correct / len(test_loader.dataset)

logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset), accuracy))

return accuracy


def save_checkpoint(model, checkpoint_path):
torch.save(model.state_dict(), checkpoint_path)


def load_checkpoint(checkpoint_path):
model_state_dict = torch.load(checkpoint_path)
return model_state_dict


def main(args):
use_cuda = not args['no_cuda'] and torch.cuda.is_available()

torch.manual_seed(args['seed'])

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

data_dir = os.path.join(args['data_dir'], nni.get_trial_id())

train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args['batch_size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True, **kwargs)

hidden_size = args['hidden_size']

model = Net(hidden_size=hidden_size).to(device)

save_checkpoint_dir = args['save_checkpoint_dir']
save_checkpoint_path = os.path.join(save_checkpoint_dir, 'model.pth')
load_checkpoint_path = os.path.join(args['load_checkpoint_dir'], 'model.pth')

if os.path.isfile(load_checkpoint_path):
model_state_dict = load_checkpoint(load_checkpoint_path)
logger.info("test : " + load_checkpoint_path)
logger.info(type(model_state_dict))
model.load_state_dict(model_state_dict)

optimizer = optim.SGD(model.parameters(), lr=args['lr'],
momentum=args['momentum'])

#epoch is perturbation interval
for epoch in range(1, args['epochs'] + 1):
train(args, model, device, train_loader, optimizer, epoch)
test_acc = test(args, model, device, test_loader)

if epoch < args['epochs']:
# report intermediate result
nni.report_intermediate_result(test_acc)
logger.debug('test accuracy %g', test_acc)
logger.debug('Pipe send intermediate result done.')
else:
# report final result
nni.report_final_result(test_acc)
logger.debug('Final result is %g', test_acc)
logger.debug('Send final result done.')

if not os.path.exists(save_checkpoint_dir):
os.makedirs(save_checkpoint_dir)
save_checkpoint(model, save_checkpoint_path)


def get_params():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str,
default='./tmp/pytorch/mnist/input_data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
help='hidden layer size (default: 512)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
help='how many batches to wait before logging training status')

parser.add_argument('--save_checkpoint_dir', type=str,
help='where to save checkpoint of this trial')
parser.add_argument('--load_checkpoint_dir', type=str,
help='where to load the model')


args, _ = parser.parse_known_args()
return args


if __name__ == '__main__':
try:
# get parameters form tuner
tuner_params = nni.get_next_parameter()
logger.debug(tuner_params)
params = vars(get_params())
params.update(tuner_params)
main(params)
except Exception as exception:
logger.exception(exception)
raise
6 changes: 6 additions & 0 deletions examples/trials/mnist-pbt-tuner-pytorch/search_space.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]},
"hidden_size":{"_type":"choice","_value":[128, 256, 512, 1024]},
"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]},
"momentum":{"_type":"uniform","_value":[0, 1]}
}
2 changes: 1 addition & 1 deletion src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ export namespace ValidationSchemas {
gpuIndices: joi.string()
}),
tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner', 'PPOTuner'),
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner', 'PPOTuner', 'PBTTuner'),
codeDir: joi.string(),
classFileName: joi.string(),
className: joi.string(),
Expand Down
4 changes: 3 additions & 1 deletion src/sdk/pynni/nni/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
'Curvefitting': 'nni.curvefitting_assessor.curvefitting_assessor',
'MetisTuner': 'nni.metis_tuner.metis_tuner',
'GPTuner': 'nni.gp_tuner.gp_tuner',
'PPOTuner': 'nni.ppo_tuner.ppo_tuner'
'PPOTuner': 'nni.ppo_tuner.ppo_tuner',
'PBTTuner': 'nni.pbt_tuner.pbt_tuner'
}

ClassName = {
Expand All @@ -30,6 +31,7 @@
'MetisTuner':'MetisTuner',
'GPTuner':'GPTuner',
'PPOTuner': 'PPOTuner',
'PBTTuner': 'PBTTuner',

'Medianstop': 'MedianstopAssessor',
'Curvefitting': 'CurvefittingAssessor'
Expand Down
Loading

0 comments on commit a82b4a3

Please sign in to comment.