diff --git a/examples/notebooks/Retiarii_example_multi-trial_NAS.ipynb b/examples/notebooks/Retiarii_example_multi-trial_NAS.ipynb index 2143b842b4..9394e54251 100644 --- a/examples/notebooks/Retiarii_example_multi-trial_NAS.ipynb +++ b/examples/notebooks/Retiarii_example_multi-trial_NAS.ipynb @@ -109,7 +109,9 @@ "source": [ "import torch.nn.functional as F\n", "import nni.retiarii.nn.pytorch as nn\n", + "from nni.retiarii import model_wrapper\n", "\n", + "@model_wrapper\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", @@ -949,4 +951,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/examples/notebooks/tabular_data_classification_in_AML.ipynb b/examples/notebooks/tabular_data_classification_in_AML.ipynb index a84b45784a..5526639134 100644 --- a/examples/notebooks/tabular_data_classification_in_AML.ipynb +++ b/examples/notebooks/tabular_data_classification_in_AML.ipynb @@ -127,7 +127,9 @@ "source": [ "import nni.retiarii.nn.pytorch as nn\n", "import torch.nn.functional as F\n", + "from nni.retiarii import model_wrapper\n", "\n", + "@model_wrapper\n", "class Net(nn.Module):\n", "\n", " def __init__(self, input_size):\n", @@ -1069,4 +1071,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/test/retiarii_test/darts/darts_model.py b/test/retiarii_test/darts/darts_model.py index 1354bdc14b..9a01e78b40 100644 --- a/test/retiarii_test/darts/darts_model.py +++ b/test/retiarii_test/darts/darts_model.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from nni.retiarii.serializer import model_wrapper from typing import (List, Optional) import torch @@ -7,7 +8,7 @@ import ops import nni.retiarii.nn.pytorch as nn -from nni.retiarii import basic_unit +from nni.retiarii import basic_unit, model_wrapper @basic_unit class AuxiliaryHead(nn.Module): @@ -98,6 +99,7 @@ def forward(self, s0, s1): output = torch.cat(new_tensors, dim=1) return output +@model_wrapper class CNN(nn.Module): def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,