diff --git a/.gitignore b/.gitignore index e96b14efc6..83049a476e 100644 --- a/.gitignore +++ b/.gitignore @@ -80,6 +80,7 @@ venv.bak/ # VSCode .vscode +.vs # In case you place source code in ~/nni/ /experiments diff --git a/examples/nas/.gitignore b/examples/nas/.gitignore index 1269488f7f..8705cba4d6 100644 --- a/examples/nas/.gitignore +++ b/examples/nas/.gitignore @@ -1 +1 @@ -data +data diff --git a/examples/nas/darts/model.py b/examples/nas/darts/model.py deleted file mode 100644 index 629831e0b7..0000000000 --- a/examples/nas/darts/model.py +++ /dev/null @@ -1,141 +0,0 @@ -import torch -import torch.nn as nn - -import ops -from nni.nas import pytorch as nas - - -class SearchCell(nn.Module): - """ - Cell for search. - """ - - def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): - """ - Initialization a search cell. - - Parameters - ---------- - n_nodes: int - Number of nodes in current DAG. - channels_pp: int - Number of output channels from previous previous cell. - channels_p: int - Number of output channels from previous cell. - channels: int - Number of channels that will be used in the current DAG. - reduction_p: bool - Flag for whether the previous cell is reduction cell or not. - reduction: bool - Flag for whether the current cell is reduction cell or not. - """ - super().__init__() - self.reduction = reduction - self.n_nodes = n_nodes - - # If previous cell is reduction cell, current input size does not match with - # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. - if reduction_p: - self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) - else: - self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) - self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) - - # generate dag - self.mutable_ops = nn.ModuleList() - for depth in range(self.n_nodes): - self.mutable_ops.append(nn.ModuleList()) - for i in range(2 + depth): # include 2 input nodes - # reduction should be used only for input node - stride = 2 if reduction and i < 2 else 1 - op = nas.mutables.LayerChoice([ops.PoolBN('max', channels, 3, stride, 1, affine=False), - ops.PoolBN('avg', channels, 3, stride, 1, affine=False), - ops.Identity() if stride == 1 else - ops.FactorizedReduce(channels, channels, affine=False), - ops.SepConv(channels, channels, 3, stride, 1, affine=False), - ops.SepConv(channels, channels, 5, stride, 2, affine=False), - ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), - ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False), - ops.Zero(stride)], - key="r{}_d{}_i{}".format(reduction, depth, i)) - self.mutable_ops[depth].append(op) - - def forward(self, s0, s1): - # s0, s1 are the outputs of previous previous cell and previous cell, respectively. - tensors = [self.preproc0(s0), self.preproc1(s1)] - for ops in self.mutable_ops: - assert len(ops) == len(tensors) - cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors)) - tensors.append(cur_tensor) - - output = torch.cat(tensors[2:], dim=1) - return output - - -class SearchCNN(nn.Module): - """ - Search CNN model - """ - - def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3): - """ - Initializing a search channelsNN. - - Parameters - ---------- - in_channels: int - Number of channels in images. - channels: int - Number of channels used in the network. - n_classes: int - Number of classes. - n_layers: int - Number of cells in the whole network. - n_nodes: int - Number of nodes in a cell. - stem_multiplier: int - Multiplier of channels in STEM. - """ - super().__init__() - self.in_channels = in_channels - self.channels = channels - self.n_classes = n_classes - self.n_layers = n_layers - - c_cur = stem_multiplier * self.channels - self.stem = nn.Sequential( - nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), - nn.BatchNorm2d(c_cur) - ) - - # for the first cell, stem is used for both s0 and s1 - # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. - channels_pp, channels_p, c_cur = c_cur, c_cur, channels - - self.cells = nn.ModuleList() - reduction_p, reduction = False, False - for i in range(n_layers): - reduction_p, reduction = reduction, False - # Reduce featuremap size and double channels in 1/3 and 2/3 layer. - if i in [n_layers // 3, 2 * n_layers // 3]: - c_cur *= 2 - reduction = True - - cell = SearchCell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) - self.cells.append(cell) - c_cur_out = c_cur * n_nodes - channels_pp, channels_p = channels_p, c_cur_out - - self.gap = nn.AdaptiveAvgPool2d(1) - self.linear = nn.Linear(channels_p, n_classes) - - def forward(self, x): - s0 = s1 = self.stem(x) - - for cell in self.cells: - s0, s1 = s1, cell(s0, s1) - - out = self.gap(s1) - out = out.view(out.size(0), -1) # flatten - logits = self.linear(out) - return logits diff --git a/examples/nas/darts/search.py b/examples/nas/darts/search.py index ad0650d156..0d7f995769 100644 --- a/examples/nas/darts/search.py +++ b/examples/nas/darts/search.py @@ -1,25 +1,23 @@ from argparse import ArgumentParser -import datasets import torch import torch.nn as nn -from model import SearchCNN -from nni.nas.pytorch.darts import DartsTrainer +import datasets +from nni.nas.pytorch.darts import CnnNetwork, DartsTrainer from utils import accuracy - if __name__ == "__main__": parser = ArgumentParser("darts") - parser.add_argument("--layers", default=4, type=int) - parser.add_argument("--nodes", default=2, type=int) + parser.add_argument("--layers", default=5, type=int) + parser.add_argument("--nodes", default=4, type=int) parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--log-frequency", default=1, type=int) args = parser.parse_args() dataset_train, dataset_valid = datasets.get_dataset("cifar10") - model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes) + model = CnnNetwork(3, 16, 10, args.layers, n_nodes=args.nodes) criterion = nn.CrossEntropyLoss() optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) diff --git a/examples/nas/pdarts/.gitignore b/examples/nas/pdarts/.gitignore new file mode 100644 index 0000000000..054c274eeb --- /dev/null +++ b/examples/nas/pdarts/.gitignore @@ -0,0 +1,2 @@ +data/* +log diff --git a/examples/nas/pdarts/datasets.py b/examples/nas/pdarts/datasets.py new file mode 100644 index 0000000000..8fe0ab0fbf --- /dev/null +++ b/examples/nas/pdarts/datasets.py @@ -0,0 +1,25 @@ +from torchvision import transforms +from torchvision.datasets import CIFAR10 + + +def get_dataset(cls): + MEAN = [0.49139968, 0.48215827, 0.44653124] + STD = [0.24703233, 0.24348505, 0.26158768] + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize(MEAN, STD) + ] + + train_transform = transforms.Compose(transf + normalize) + valid_transform = transforms.Compose(normalize) + + if cls == "cifar10": + dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) + dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) + else: + raise NotImplementedError + return dataset_train, dataset_valid diff --git a/examples/nas/pdarts/main.py b/examples/nas/pdarts/main.py new file mode 100644 index 0000000000..68a59c8856 --- /dev/null +++ b/examples/nas/pdarts/main.py @@ -0,0 +1,65 @@ +from argparse import ArgumentParser + +import datasets +import torch +import torch.nn as nn +import nni.nas.pytorch as nas +from nni.nas.pytorch.pdarts import PdartsTrainer +from nni.nas.pytorch.darts import CnnNetwork, CnnCell + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = dict() + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() + return res + + +if __name__ == "__main__": + parser = ArgumentParser("darts") + parser.add_argument("--layers", default=5, type=int) + parser.add_argument('--add_layers', action='append', + default=[0, 6, 12], help='add layers') + parser.add_argument("--nodes", default=4, type=int) + parser.add_argument("--batch-size", default=128, type=int) + parser.add_argument("--log-frequency", default=1, type=int) + args = parser.parse_args() + + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + + def model_creator(layers, n_nodes): + model = CnnNetwork(3, 16, 10, layers, n_nodes=n_nodes, cell_type=CnnCell) + loss = nn.CrossEntropyLoss() + + model_optim = torch.optim.SGD(model.parameters(), 0.025, + momentum=0.9, weight_decay=3.0E-4) + n_epochs = 50 + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, n_epochs, eta_min=0.001) + return model, loss, model_optim, lr_scheduler + + trainer = PdartsTrainer(model_creator, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + num_epochs=50, + pdarts_num_layers=[0, 6, 12], + pdarts_num_to_drop=[3, 2, 2], + dataset_train=dataset_train, + dataset_valid=dataset_valid, + layers=args.layers, + n_nodes=args.nodes, + batch_size=args.batch_size, + log_frequency=args.log_frequency) + trainer.train() + trainer.export() diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py index 34e5f6e81c..f28d2cd73c 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/__init__.py @@ -1,2 +1,4 @@ from .mutator import DartsMutator from .trainer import DartsTrainer +from .cnn_cell import CnnCell +from .cnn_network import CnnNetwork diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py new file mode 100644 index 0000000000..69dc28e8f0 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py @@ -0,0 +1,69 @@ + +import torch +import torch.nn as nn + +import nni.nas.pytorch as nas +from nni.nas.pytorch.modules import RankedModule + +from .cnn_ops import OPS, PRIMITIVES, FactorizedReduce, StdConv + + +class CnnCell(RankedModule): + """ + Cell for search. + """ + + def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): + """ + Initialization a search cell. + + Parameters + ---------- + n_nodes: int + Number of nodes in current DAG. + channels_pp: int + Number of output channels from previous previous cell. + channels_p: int + Number of output channels from previous cell. + channels: int + Number of channels that will be used in the current DAG. + reduction_p: bool + Flag for whether the previous cell is reduction cell or not. + reduction: bool + Flag for whether the current cell is reduction cell or not. + """ + super(CnnCell, self).__init__(rank=1, reduction=reduction) + self.n_nodes = n_nodes + + # If previous cell is reduction cell, current input size does not match with + # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. + if reduction_p: + self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False) + else: + self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False) + self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False) + + # generate dag + self.mutable_ops = nn.ModuleList() + for depth in range(self.n_nodes): + self.mutable_ops.append(nn.ModuleList()) + for i in range(2 + depth): # include 2 input nodes + # reduction should be used only for input node + stride = 2 if reduction and i < 2 else 1 + m_ops = [] + for primitive in PRIMITIVES: + op = OPS[primitive](channels, stride, False) + m_ops.append(op) + op = nas.mutables.LayerChoice(m_ops, key="r{}_d{}_i{}".format(reduction, depth, i)) + self.mutable_ops[depth].append(op) + + def forward(self, s0, s1): + # s0, s1 are the outputs of previous previous cell and previous cell, respectively. + tensors = [self.preproc0(s0), self.preproc1(s1)] + for ops in self.mutable_ops: + assert len(ops) == len(tensors) + cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors)) + tensors.append(cur_tensor) + + output = torch.cat(tensors[2:], dim=1) + return output diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py new file mode 100644 index 0000000000..d126e3353e --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py @@ -0,0 +1,73 @@ + +import torch.nn as nn + +from .cnn_cell import CnnCell + + +class CnnNetwork(nn.Module): + """ + Search CNN model + """ + + def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3, cell_type=CnnCell): + """ + Initializing a search channelsNN. + + Parameters + ---------- + in_channels: int + Number of channels in images. + channels: int + Number of channels used in the network. + n_classes: int + Number of classes. + n_layers: int + Number of cells in the whole network. + n_nodes: int + Number of nodes in a cell. + stem_multiplier: int + Multiplier of channels in STEM. + """ + super().__init__() + self.in_channels = in_channels + self.channels = channels + self.n_classes = n_classes + self.n_layers = n_layers + + c_cur = stem_multiplier * self.channels + self.stem = nn.Sequential( + nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), + nn.BatchNorm2d(c_cur) + ) + + # for the first cell, stem is used for both s0 and s1 + # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. + channels_pp, channels_p, c_cur = c_cur, c_cur, channels + + self.cells = nn.ModuleList() + reduction_p, reduction = False, False + for i in range(n_layers): + reduction_p, reduction = reduction, False + # Reduce featuremap size and double channels in 1/3 and 2/3 layer. + if i in [n_layers // 3, 2 * n_layers // 3]: + c_cur *= 2 + reduction = True + + cell = cell_type(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) + self.cells.append(cell) + c_cur_out = c_cur * n_nodes + channels_pp, channels_p = channels_p, c_cur_out + + self.gap = nn.AdaptiveAvgPool2d(1) + self.linear = nn.Linear(channels_p, n_classes) + + def forward(self, x): + s0 = s1 = self.stem(x) + + for cell in self.cells: + s0, s1 = s1, cell(s0, s1) + + out = self.gap(s1) + out = out.view(out.size(0), -1) # flatten + logits = self.linear(out) + return logits diff --git a/examples/nas/darts/ops.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py similarity index 93% rename from examples/nas/darts/ops.py rename to src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py index ef25a6e830..02b4a3a94c 100644 --- a/examples/nas/darts/ops.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py @@ -1,29 +1,27 @@ import torch import torch.nn as nn - PRIMITIVES = [ + 'none', 'max_pool_3x3', 'avg_pool_3x3', - 'skip_connect', # identity + 'skip_connect', # identity 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5', - 'none' ] OPS = { 'none': lambda C, stride, affine: Zero(stride), 'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine), 'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine), - 'skip_connect': lambda C, stride, affine: \ - Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), + 'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), - 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5 - 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9 + 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5 + 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9 'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine) } @@ -60,6 +58,7 @@ class PoolBN(nn.Module): """ AvgPool or MaxPool - BN """ + def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): """ Args: @@ -85,6 +84,7 @@ class StdConv(nn.Module): """ Standard conv ReLU - Conv - BN """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super().__init__() self.net = nn.Sequential( @@ -101,6 +101,7 @@ class FacConv(nn.Module): """ Factorized conv ReLU - Conv(Kx1) - Conv(1xK) - BN """ + def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): super().__init__() self.net = nn.Sequential( @@ -118,14 +119,14 @@ class DilConv(nn.Module): """ (Dilated) depthwise separable conv ReLU - (Dilated) depthwise separable - Pointwise - BN If dilation == 2, 3x3 conv => 5x5 receptive field - 5x5 conv => 9x9 receptive field + 5x5 conv => 9x9 receptive field """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): super().__init__() self.net = nn.Sequential( nn.ReLU(), - nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, - bias=False), + nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, bias=False), nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), nn.BatchNorm2d(C_out, affine=affine) ) @@ -138,6 +139,7 @@ class SepConv(nn.Module): """ Depthwise separable conv DilConv(dilation=1) * 2 """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super().__init__() self.net = nn.Sequential( @@ -172,6 +174,7 @@ class FactorizedReduce(nn.Module): """ Reduce feature map size by factorized pointwise(stride=2). """ + def __init__(self, C_in, C_out, affine=True): super().__init__() self.relu = nn.ReLU() diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py index 72ac427c11..75463ff23f 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -94,7 +94,8 @@ def validate_epoch(self, epoch): with torch.no_grad(): for step, (X, y) in enumerate(self.valid_loader): X, y = X.to(self.device), y.to(self.device) - logits = self.model(X) + with self.mutator.forward_pass(): + logits = self.model(X) metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py index 93dad9c77c..a158886233 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/mutator.py @@ -40,7 +40,7 @@ def before_build(self, model): self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) - self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) + self.skip_targets = nn.Parameter(torch.Tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) self.cross_entropy_loss = nn.CrossEntropyLoss() def after_build(self, model): @@ -79,7 +79,7 @@ def on_calc_layer_choice_mask(self, mutable): self._lstm_next_step() logit = self.soft(self._h[-1]) if self.tanh_constant is not None: - logit = self.tanh_constant * torch.tanh(logit) + logit = self.tanh_constant * torch.tanh(logit) branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) log_prob = self.cross_entropy_loss(logit, branch_id) self.sample_log_prob += log_prob diff --git a/src/sdk/pynni/nni/nas/pytorch/modules.py b/src/sdk/pynni/nni/nas/pytorch/modules.py new file mode 100644 index 0000000000..6570220e13 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/modules.py @@ -0,0 +1,9 @@ + +from torch import nn as nn + + +class RankedModule(nn.Module): + def __init__(self, rank=None, reduction=False): + super(RankedModule, self).__init__() + self.rank = rank + self.reduction = reduction diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py index 456fe7a498..e28af84037 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutables.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -56,9 +56,6 @@ def _check_built(self): "Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__" "so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) - def __repr__(self): - return "{} ({})".format(self.name, self.key) - class MutableScope(PyTorchMutable): """ @@ -85,6 +82,9 @@ def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None) self.reduction = reduction self.return_mask = return_mask + def __len__(self): + return self.length + def forward(self, *inputs): out, mask = self.mutator.on_forward(self, *inputs) if self.return_mask: @@ -116,4 +116,4 @@ def forward(self, optional_inputs, semantic_labels=None): def similar(self, other): return type(self) == type(other) and \ - self.n_candidates == other.n_candidates and self.n_selected and other.n_selected + self.n_candidates == other.n_candidates and self.n_selected and other.n_selected diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py new file mode 100644 index 0000000000..27dd912ab3 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py @@ -0,0 +1 @@ +from .trainer import PdartsTrainer diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py new file mode 100644 index 0000000000..6e385b1170 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py @@ -0,0 +1,93 @@ +import copy + +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from nni.nas.pytorch.darts import DartsMutator +from nni.nas.pytorch.mutables import LayerChoice + + +class PdartsMutator(DartsMutator): + + def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches=None): + self.pdarts_epoch_index = pdarts_epoch_index + self.pdarts_num_to_drop = pdarts_num_to_drop + self.switches = switches + + super(PdartsMutator, self).__init__(model) + + def before_build(self, model): + self.choices = nn.ParameterDict() + if self.switches is None: + self.switches = {} + + def named_mutables(self, model): + key2module = dict() + for name, module in model.named_modules(): + if isinstance(module, LayerChoice): + key2module[module.key] = module + yield name, module, True + + def drop_paths(self): + for key in self.switches: + prob = F.softmax(self.choices[key], dim=-1).data.cpu().numpy() + + switches = self.switches[key] + idxs = [] + for j in range(len(switches)): + if switches[j]: + idxs.append(j) + if self.pdarts_epoch_index == len(self.pdarts_num_to_drop) - 1: + # for the last stage, drop all Zero operations + drop = self.get_min_k_no_zero(prob, idxs, self.pdarts_num_to_drop[self.pdarts_epoch_index]) + else: + drop = self.get_min_k(prob, self.pdarts_num_to_drop[self.pdarts_epoch_index]) + + for idx in drop: + switches[idxs[idx]] = False + return self.switches + + def on_init_layer_choice(self, mutable: LayerChoice): + switches = self.switches.get( + mutable.key, [True for j in range(mutable.length)]) + + for index in range(len(switches)-1, -1, -1): + if switches[index] == False: + del(mutable.choices[index]) + mutable.length -= 1 + + self.switches[mutable.key] = switches + + self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length)) + + def on_calc_layer_choice_mask(self, mutable: LayerChoice): + return F.softmax(self.choices[mutable.key], dim=-1) + + def get_min_k(self, input_in, k): + index = [] + for _ in range(k): + idx = np.argmin(input) + index.append(idx) + + return index + + def get_min_k_no_zero(self, w_in, idxs, k): + w = copy.deepcopy(w_in) + index = [] + if 0 in idxs: + zf = True + else: + zf = False + if zf: + w = w[1:] + index.append(0) + k = k - 1 + for _ in range(k): + idx = np.argmin(w) + w[idx] = 1 + if zf: + idx = idx + 1 + index.append(idx) + return index diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py new file mode 100644 index 0000000000..6425e234d8 --- /dev/null +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py @@ -0,0 +1,54 @@ +from nni.nas.pytorch.darts import DartsTrainer +from nni.nas.pytorch.trainer import Trainer + +from .mutator import PdartsMutator + + +class PdartsTrainer(Trainer): + + def __init__(self, model_creator, metrics, num_epochs, dataset_train, dataset_valid, + layers=5, n_nodes=4, pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2], + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None): + self.model_creator = model_creator + self.layers = layers + self.n_nodes = n_nodes + self.pdarts_num_layers = pdarts_num_layers + self.pdarts_num_to_drop = pdarts_num_to_drop + self.pdarts_epoch = len(pdarts_num_to_drop) + self.darts_parameters = { + "metrics": metrics, + "num_epochs": num_epochs, + "dataset_train": dataset_train, + "dataset_valid": dataset_valid, + "batch_size": batch_size, + "workers": workers, + "device": device, + "log_frequency": log_frequency + } + + def train(self): + layers = self.layers + n_nodes = self.n_nodes + switches = None + for epoch in range(self.pdarts_epoch): + + layers = self.layers+self.pdarts_num_layers[epoch] + model, loss, model_optim, lr_scheduler = self.model_creator( + layers, n_nodes) + mutator = PdartsMutator( + model, epoch, self.pdarts_num_to_drop, switches) + + self.trainer = DartsTrainer(model, loss=loss, model_optim=model_optim, + lr_scheduler=lr_scheduler, mutator=mutator, **self.darts_parameters) + print("start pdrats training %s..." % epoch) + + self.trainer.train() + + # with open('log/parameters_%d.txt' % epoch, "w") as f: + # f.write(str(model.parameters)) + + switches = mutator.drop_paths() + + def export(self): + if (self.trainer is not None) and hasattr(self.trainer, "export"): + self.trainer.export()