Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
pdarts implementation (export is not included) (#1730)
Browse files Browse the repository at this point in the history
  • Loading branch information
squirrelsc authored Nov 14, 2019
1 parent d43fbe8 commit d1d10de
Show file tree
Hide file tree
Showing 18 changed files with 421 additions and 166 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ venv.bak/

# VSCode
.vscode
.vs

# In case you place source code in ~/nni/
/experiments
2 changes: 1 addition & 1 deletion examples/nas/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
data
data
141 changes: 0 additions & 141 deletions examples/nas/darts/model.py

This file was deleted.

12 changes: 5 additions & 7 deletions examples/nas/darts/search.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
from argparse import ArgumentParser

import datasets
import torch
import torch.nn as nn

from model import SearchCNN
from nni.nas.pytorch.darts import DartsTrainer
import datasets
from nni.nas.pytorch.darts import CnnNetwork, DartsTrainer
from utils import accuracy


if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=4, type=int)
parser.add_argument("--nodes", default=2, type=int)
parser.add_argument("--layers", default=5, type=int)
parser.add_argument("--nodes", default=4, type=int)
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()

dataset_train, dataset_valid = datasets.get_dataset("cifar10")

model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes)
model = CnnNetwork(3, 16, 10, args.layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss()

optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
Expand Down
2 changes: 2 additions & 0 deletions examples/nas/pdarts/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
data/*
log
25 changes: 25 additions & 0 deletions examples/nas/pdarts/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torchvision import transforms
from torchvision.datasets import CIFAR10


def get_dataset(cls):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]

train_transform = transforms.Compose(transf + normalize)
valid_transform = transforms.Compose(normalize)

if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
65 changes: 65 additions & 0 deletions examples/nas/pdarts/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from argparse import ArgumentParser

import datasets
import torch
import torch.nn as nn
import nni.nas.pytorch as nas
from nni.nas.pytorch.pdarts import PdartsTrainer
from nni.nas.pytorch.darts import CnnNetwork, CnnCell


def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]

correct = pred.eq(target.view(1, -1).expand_as(pred))

res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res


if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=5, type=int)
parser.add_argument('--add_layers', action='append',
default=[0, 6, 12], help='add layers')
parser.add_argument("--nodes", default=4, type=int)
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()

dataset_train, dataset_valid = datasets.get_dataset("cifar10")

def model_creator(layers, n_nodes):
model = CnnNetwork(3, 16, 10, layers, n_nodes=n_nodes, cell_type=CnnCell)
loss = nn.CrossEntropyLoss()

model_optim = torch.optim.SGD(model.parameters(), 0.025,
momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, n_epochs, eta_min=0.001)
return model, loss, model_optim, lr_scheduler

trainer = PdartsTrainer(model_creator,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
num_epochs=50,
pdarts_num_layers=[0, 6, 12],
pdarts_num_to_drop=[3, 2, 2],
dataset_train=dataset_train,
dataset_valid=dataset_valid,
layers=args.layers,
n_nodes=args.nodes,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
trainer.train()
trainer.export()
2 changes: 2 additions & 0 deletions src/sdk/pynni/nni/nas/pytorch/darts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .mutator import DartsMutator
from .trainer import DartsTrainer
from .cnn_cell import CnnCell
from .cnn_network import CnnNetwork
69 changes: 69 additions & 0 deletions src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

import torch
import torch.nn as nn

import nni.nas.pytorch as nas
from nni.nas.pytorch.modules import RankedModule

from .cnn_ops import OPS, PRIMITIVES, FactorizedReduce, StdConv


class CnnCell(RankedModule):
"""
Cell for search.
"""

def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super(CnnCell, self).__init__(rank=1, reduction=reduction)
self.n_nodes = n_nodes

# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False)

# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(self.n_nodes):
self.mutable_ops.append(nn.ModuleList())
for i in range(2 + depth): # include 2 input nodes
# reduction should be used only for input node
stride = 2 if reduction and i < 2 else 1
m_ops = []
for primitive in PRIMITIVES:
op = OPS[primitive](channels, stride, False)
m_ops.append(op)
op = nas.mutables.LayerChoice(m_ops, key="r{}_d{}_i{}".format(reduction, depth, i))
self.mutable_ops[depth].append(op)

def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
for ops in self.mutable_ops:
assert len(ops) == len(tensors)
cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors))
tensors.append(cur_tensor)

output = torch.cat(tensors[2:], dim=1)
return output
Loading

0 comments on commit d1d10de

Please sign in to comment.