diff --git a/examples/nas/search_space_zoo/darts_example.py b/examples/nas/search_space_zoo/darts_example.py index 3106d7038a..dbf547fa26 100644 --- a/examples/nas/search_space_zoo/darts_example.py +++ b/examples/nas/search_space_zoo/darts_example.py @@ -14,7 +14,7 @@ from utils import accuracy from nni.nas.pytorch.search_space_zoo import DartsCell -from darts_search_space import DartsStackedCells +from darts_stack_cells import DartsStackedCells logger = logging.getLogger('nni') diff --git a/examples/nas/search_space_zoo/darts_stack_cells.py b/examples/nas/search_space_zoo/darts_stack_cells.py index 7366b5000a..04277acf43 100644 --- a/examples/nas/search_space_zoo/darts_stack_cells.py +++ b/examples/nas/search_space_zoo/darts_stack_cells.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import torch.nn as nn -import ops +from nni.nas.pytorch.search_space_zoo.darts_ops import DropPath class DartsStackedCells(nn.Module): @@ -79,5 +79,5 @@ def forward(self, x): def drop_path_prob(self, p): for module in self.modules(): - if isinstance(module, ops.DropPath): + if isinstance(module, DropPath): module.p = p diff --git a/examples/nas/search_space_zoo/enas_macro_example.py b/examples/nas/search_space_zoo/enas_macro_example.py index 3688a61a16..1dbcf2d442 100644 --- a/examples/nas/search_space_zoo/enas_macro_example.py +++ b/examples/nas/search_space_zoo/enas_macro_example.py @@ -58,7 +58,6 @@ def forward(self, x): parser = ArgumentParser("enas") parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--log-frequency", default=10, type=int) - # parser.add_argument("--search-for", choices=["macro", "micro"], default="macro") parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") parser.add_argument("--visualization", default=False, action="store_true") args = parser.parse_args() @@ -71,7 +70,6 @@ def forward(self, x): criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) - trainer = enas.EnasTrainer(model, loss=criterion, metrics=accuracy, diff --git a/examples/nas/search_space_zoo/enas_micro_example.py b/examples/nas/search_space_zoo/enas_micro_example.py index 385a19024d..07dea7eca2 100644 --- a/examples/nas/search_space_zoo/enas_micro_example.py +++ b/examples/nas/search_space_zoo/enas_micro_example.py @@ -62,7 +62,7 @@ def __init__(self, num_layers=2, num_nodes=5, out_channels=24, in_channels=3, nu reduction = False if layer_id in pool_layers: c_cur, reduction = c_p * 2, True - self.layers.append(ENASMicroLayer(self.layers, num_nodes, c_pp, c_p, c_cur, reduction)) + self.layers.append(ENASMicroLayer(num_nodes, c_pp, c_p, c_cur, reduction)) if reduction: c_pp = c_p = c_cur c_pp, c_p = c_p, c_cur @@ -98,7 +98,6 @@ def forward(self, x): parser = ArgumentParser("enas") parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--log-frequency", default=10, type=int) - # parser.add_argument("--search-for", choices=["macro", "micro"], default="macro") parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") parser.add_argument("--visualization", default=False, action="store_true") args = parser.parse_args() diff --git a/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py index ef3de84385..c50666e5f7 100644 --- a/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py +++ b/src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py @@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module): """ def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction): super().__init__() - print(in_channels_pp, in_channels_p, out_channels, reduction) self.reduction = reduction if self.reduction: self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False) @@ -160,7 +159,7 @@ def __init__(self, key, prev_labels, in_filters, out_filters): PoolBranch('avg', in_filters, out_filters, 3, 1, 1), PoolBranch('max', in_filters, out_filters, 3, 1, 1) ]) - if prev_labels > 0: + if prev_labels: self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None) else: self.skipconnect = None