This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
659 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
PBT Tuner on NNI | ||
=== | ||
|
||
## PBTTuner | ||
|
||
Population Based Training (PBT) comes from [Population Based Training of Neural Networks](https://arxiv.org/abs/1711.09846v1). It's a simple asynchronous optimization algorithm which effectively utilizes a fixed computational budget to jointly optimize a population of models and their hyperparameters to maximize performance. Importantly, PBT discovers a schedule of hyperparameter settings rather than following the generally sub-optimal strategy of trying to find a single fixed set to use for the whole course of training. | ||
|
||
PBTTuner initializes a population with several trials. Users can set a specific number of training epochs. After a certain number of epochs, the parameters and hyperparameters in the trial with bad metrics will be replaced with a better trial (exploit). Then the hyperparameters are perturbed (explore). | ||
|
||
In our implementation, training epochs in the trial code is regarded as a step of PBT, different with other tuners. At the end of each step, PBT tuner will do exploitation and exploration -- replacing some trials with new trials. This is implemented by constantly modifying the values of `load_checkpoint_dir` and `save_checkpoint_dir`. We can directly change `load_checkpoint_dir` to replace parameters and hyperparameters, and `save_checkpoint_dir` to save a checkpoint that will be loaded in the next step. To this end, we need a shared folder which is accessible to all trials. | ||
|
||
If the experiment is running in local mode, users could provide an argument `all_checkpoint_dir` which will be the base folder of `load_checkpoint_dir` and `save_checkpoint_dir` (`checkpoint_dir` is set to `all_checkpoint_dir/<population-id>/<step>`). By default, `all_checkpoint_dir` is set to be `~/nni/experiments/<exp-id>/checkpoint`. If the experiment is in non-local mode, then users should provide a path in a shared storage folder which is mounted at `all_checkpoint_dir` on worker machines (but it's not necessarily available on the machine which runs tuner). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
tmp |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
authorName: default | ||
experimentName: example_mnist_pbt_tuner_pytorch | ||
trialConcurrency: 3 | ||
maxExecDuration: 2h | ||
maxTrialNum: 100 | ||
#choice: local, remote, pai | ||
trainingServicePlatform: local | ||
searchSpacePath: search_space.json | ||
#choice: true, false | ||
useAnnotation: false | ||
tuner: | ||
# codeDir: ~/nni/src/sdk/pynni/nni/pbt_tuner | ||
# classFileName: pbt_tuner.py | ||
# className: PBTTuner | ||
builtinTunerName: PBTTuner | ||
classArgs: | ||
#choice: maximize, minimize | ||
optimize_mode: maximize | ||
trial: | ||
command: python3 mnist.py | ||
codeDir: . | ||
gpuNum: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import argparse | ||
import logging | ||
|
||
import os | ||
import nni | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torchvision import datasets, transforms | ||
|
||
|
||
logger = logging.getLogger('mnist_pbt_tuner_pytorch_AutoML') | ||
|
||
class Net(nn.Module): | ||
def __init__(self, hidden_size): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 20, 5, 1) | ||
self.conv2 = nn.Conv2d(20, 50, 5, 1) | ||
self.fc1 = nn.Linear(4*4*50, hidden_size) | ||
self.fc2 = nn.Linear(hidden_size, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.conv1(x)) | ||
x = F.max_pool2d(x, 2, 2) | ||
x = F.relu(self.conv2(x)) | ||
x = F.max_pool2d(x, 2, 2) | ||
x = x.view(-1, 4*4*50) | ||
x = F.relu(self.fc1(x)) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
def train(args, model, device, train_loader, optimizer, epoch): | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % args['log_interval'] == 0: | ||
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||
epoch, batch_idx * len(data), len(train_loader.dataset), | ||
100. * batch_idx / len(train_loader), loss.item())) | ||
|
||
|
||
def test(args, model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
# sum up batch loss | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
# get the index of the max log-probability | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
|
||
test_loss /= len(test_loader.dataset) | ||
|
||
accuracy = 100. * correct / len(test_loader.dataset) | ||
|
||
logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | ||
test_loss, correct, len(test_loader.dataset), accuracy)) | ||
|
||
return accuracy | ||
|
||
|
||
def save_checkpoint(model, checkpoint_path): | ||
torch.save(model.state_dict(), checkpoint_path) | ||
|
||
|
||
def load_checkpoint(checkpoint_path): | ||
model_state_dict = torch.load(checkpoint_path) | ||
return model_state_dict | ||
|
||
|
||
def main(args): | ||
use_cuda = not args['no_cuda'] and torch.cuda.is_available() | ||
|
||
torch.manual_seed(args['seed']) | ||
|
||
device = torch.device("cuda" if use_cuda else "cpu") | ||
|
||
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} | ||
|
||
data_dir = os.path.join(args['data_dir'], nni.get_trial_id()) | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST(data_dir, train=True, download=True, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
])), | ||
batch_size=args['batch_size'], shuffle=True, **kwargs) | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST(data_dir, train=False, transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
])), | ||
batch_size=1000, shuffle=True, **kwargs) | ||
|
||
hidden_size = args['hidden_size'] | ||
|
||
model = Net(hidden_size=hidden_size).to(device) | ||
|
||
save_checkpoint_dir = args['save_checkpoint_dir'] | ||
save_checkpoint_path = os.path.join(save_checkpoint_dir, 'model.pth') | ||
load_checkpoint_path = os.path.join(args['load_checkpoint_dir'], 'model.pth') | ||
|
||
if os.path.isfile(load_checkpoint_path): | ||
model_state_dict = load_checkpoint(load_checkpoint_path) | ||
logger.info("test : " + load_checkpoint_path) | ||
logger.info(type(model_state_dict)) | ||
model.load_state_dict(model_state_dict) | ||
|
||
optimizer = optim.SGD(model.parameters(), lr=args['lr'], | ||
momentum=args['momentum']) | ||
|
||
#epoch is perturbation interval | ||
for epoch in range(1, args['epochs'] + 1): | ||
train(args, model, device, train_loader, optimizer, epoch) | ||
test_acc = test(args, model, device, test_loader) | ||
|
||
if epoch < args['epochs']: | ||
# report intermediate result | ||
nni.report_intermediate_result(test_acc) | ||
logger.debug('test accuracy %g', test_acc) | ||
logger.debug('Pipe send intermediate result done.') | ||
else: | ||
# report final result | ||
nni.report_final_result(test_acc) | ||
logger.debug('Final result is %g', test_acc) | ||
logger.debug('Send final result done.') | ||
|
||
if not os.path.exists(save_checkpoint_dir): | ||
os.makedirs(save_checkpoint_dir) | ||
save_checkpoint(model, save_checkpoint_path) | ||
|
||
|
||
def get_params(): | ||
# Training settings | ||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | ||
parser.add_argument("--data_dir", type=str, | ||
default='./tmp/pytorch/mnist/input_data', help="data directory") | ||
parser.add_argument('--batch_size', type=int, default=64, metavar='N', | ||
help='input batch size for training (default: 64)') | ||
parser.add_argument("--hidden_size", type=int, default=512, metavar='N', | ||
help='hidden layer size (default: 512)') | ||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', | ||
help='learning rate (default: 0.01)') | ||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M', | ||
help='SGD momentum (default: 0.5)') | ||
parser.add_argument('--epochs', type=int, default=10, metavar='N', | ||
help='number of epochs to train (default: 10)') | ||
parser.add_argument('--seed', type=int, default=1, metavar='S', | ||
help='random seed (default: 1)') | ||
parser.add_argument('--no_cuda', action='store_true', default=False, | ||
help='disables CUDA training') | ||
parser.add_argument('--log_interval', type=int, default=1000, metavar='N', | ||
help='how many batches to wait before logging training status') | ||
|
||
parser.add_argument('--save_checkpoint_dir', type=str, | ||
help='where to save checkpoint of this trial') | ||
parser.add_argument('--load_checkpoint_dir', type=str, | ||
help='where to load the model') | ||
|
||
|
||
args, _ = parser.parse_known_args() | ||
return args | ||
|
||
|
||
if __name__ == '__main__': | ||
try: | ||
# get parameters form tuner | ||
tuner_params = nni.get_next_parameter() | ||
logger.debug(tuner_params) | ||
params = vars(get_params()) | ||
params.update(tuner_params) | ||
main(params) | ||
except Exception as exception: | ||
logger.exception(exception) | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]}, | ||
"hidden_size":{"_type":"choice","_value":[128, 256, 512, 1024]}, | ||
"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]}, | ||
"momentum":{"_type":"uniform","_value":[0, 1]} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.