forked from microsoft/nni
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
pdarts implementation (export is not included) (microsoft#1730)
- Loading branch information
1 parent
d43fbe8
commit d1d10de
Showing
18 changed files
with
421 additions
and
166 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,7 @@ venv.bak/ | |
|
||
# VSCode | ||
.vscode | ||
.vs | ||
|
||
# In case you place source code in ~/nni/ | ||
/experiments |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
data | ||
data |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
data/* | ||
log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from torchvision import transforms | ||
from torchvision.datasets import CIFAR10 | ||
|
||
|
||
def get_dataset(cls): | ||
MEAN = [0.49139968, 0.48215827, 0.44653124] | ||
STD = [0.24703233, 0.24348505, 0.26158768] | ||
transf = [ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip() | ||
] | ||
normalize = [ | ||
transforms.ToTensor(), | ||
transforms.Normalize(MEAN, STD) | ||
] | ||
|
||
train_transform = transforms.Compose(transf + normalize) | ||
valid_transform = transforms.Compose(normalize) | ||
|
||
if cls == "cifar10": | ||
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) | ||
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) | ||
else: | ||
raise NotImplementedError | ||
return dataset_train, dataset_valid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from argparse import ArgumentParser | ||
|
||
import datasets | ||
import torch | ||
import torch.nn as nn | ||
import nni.nas.pytorch as nas | ||
from nni.nas.pytorch.pdarts import PdartsTrainer | ||
from nni.nas.pytorch.darts import CnnNetwork, CnnCell | ||
|
||
|
||
def accuracy(output, target, topk=(1,)): | ||
""" Computes the precision@k for the specified values of k """ | ||
maxk = max(topk) | ||
batch_size = target.size(0) | ||
|
||
_, pred = output.topk(maxk, 1, True, True) | ||
pred = pred.t() | ||
# one-hot case | ||
if target.ndimension() > 1: | ||
target = target.max(1)[1] | ||
|
||
correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
|
||
res = dict() | ||
for k in topk: | ||
correct_k = correct[:k].view(-1).float().sum(0) | ||
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() | ||
return res | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser("darts") | ||
parser.add_argument("--layers", default=5, type=int) | ||
parser.add_argument('--add_layers', action='append', | ||
default=[0, 6, 12], help='add layers') | ||
parser.add_argument("--nodes", default=4, type=int) | ||
parser.add_argument("--batch-size", default=128, type=int) | ||
parser.add_argument("--log-frequency", default=1, type=int) | ||
args = parser.parse_args() | ||
|
||
dataset_train, dataset_valid = datasets.get_dataset("cifar10") | ||
|
||
def model_creator(layers, n_nodes): | ||
model = CnnNetwork(3, 16, 10, layers, n_nodes=n_nodes, cell_type=CnnCell) | ||
loss = nn.CrossEntropyLoss() | ||
|
||
model_optim = torch.optim.SGD(model.parameters(), 0.025, | ||
momentum=0.9, weight_decay=3.0E-4) | ||
n_epochs = 50 | ||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, n_epochs, eta_min=0.001) | ||
return model, loss, model_optim, lr_scheduler | ||
|
||
trainer = PdartsTrainer(model_creator, | ||
metrics=lambda output, target: accuracy(output, target, topk=(1,)), | ||
num_epochs=50, | ||
pdarts_num_layers=[0, 6, 12], | ||
pdarts_num_to_drop=[3, 2, 2], | ||
dataset_train=dataset_train, | ||
dataset_valid=dataset_valid, | ||
layers=args.layers, | ||
n_nodes=args.nodes, | ||
batch_size=args.batch_size, | ||
log_frequency=args.log_frequency) | ||
trainer.train() | ||
trainer.export() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
from .mutator import DartsMutator | ||
from .trainer import DartsTrainer | ||
from .cnn_cell import CnnCell | ||
from .cnn_network import CnnNetwork |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
import nni.nas.pytorch as nas | ||
from nni.nas.pytorch.modules import RankedModule | ||
|
||
from .cnn_ops import OPS, PRIMITIVES, FactorizedReduce, StdConv | ||
|
||
|
||
class CnnCell(RankedModule): | ||
""" | ||
Cell for search. | ||
""" | ||
|
||
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): | ||
""" | ||
Initialization a search cell. | ||
Parameters | ||
---------- | ||
n_nodes: int | ||
Number of nodes in current DAG. | ||
channels_pp: int | ||
Number of output channels from previous previous cell. | ||
channels_p: int | ||
Number of output channels from previous cell. | ||
channels: int | ||
Number of channels that will be used in the current DAG. | ||
reduction_p: bool | ||
Flag for whether the previous cell is reduction cell or not. | ||
reduction: bool | ||
Flag for whether the current cell is reduction cell or not. | ||
""" | ||
super(CnnCell, self).__init__(rank=1, reduction=reduction) | ||
self.n_nodes = n_nodes | ||
|
||
# If previous cell is reduction cell, current input size does not match with | ||
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. | ||
if reduction_p: | ||
self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False) | ||
else: | ||
self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False) | ||
self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False) | ||
|
||
# generate dag | ||
self.mutable_ops = nn.ModuleList() | ||
for depth in range(self.n_nodes): | ||
self.mutable_ops.append(nn.ModuleList()) | ||
for i in range(2 + depth): # include 2 input nodes | ||
# reduction should be used only for input node | ||
stride = 2 if reduction and i < 2 else 1 | ||
m_ops = [] | ||
for primitive in PRIMITIVES: | ||
op = OPS[primitive](channels, stride, False) | ||
m_ops.append(op) | ||
op = nas.mutables.LayerChoice(m_ops, key="r{}_d{}_i{}".format(reduction, depth, i)) | ||
self.mutable_ops[depth].append(op) | ||
|
||
def forward(self, s0, s1): | ||
# s0, s1 are the outputs of previous previous cell and previous cell, respectively. | ||
tensors = [self.preproc0(s0), self.preproc1(s1)] | ||
for ops in self.mutable_ops: | ||
assert len(ops) == len(tensors) | ||
cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors)) | ||
tensors.append(cur_tensor) | ||
|
||
output = torch.cat(tensors[2:], dim=1) | ||
return output |
Oops, something went wrong.