This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
1,177 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import tensorflow as tf | ||
from tensorflow.data import Dataset | ||
|
||
def get_dataset(): | ||
(x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.cifar10.load_data() | ||
x_train, x_valid = x_train / 255.0, x_valid / 255.0 | ||
train_set = (x_train, y_train) | ||
valid_set = (x_valid, y_valid) | ||
return train_set, valid_set |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras import Model, Sequential | ||
from tensorflow.keras.layers import ( | ||
AveragePooling2D, | ||
BatchNormalization, | ||
Conv2D, | ||
Dense, | ||
Dropout, | ||
GlobalAveragePooling2D, | ||
MaxPool2D, | ||
ReLU, | ||
SeparableConv2D, | ||
) | ||
|
||
from nni.nas.tensorflow.mutables import InputChoice, LayerChoice, MutableScope | ||
|
||
|
||
def build_conv(filters, kernel_size, name=None): | ||
return Sequential([ | ||
Conv2D(filters, kernel_size=1, use_bias=False), | ||
BatchNormalization(trainable=False), | ||
ReLU(), | ||
Conv2D(filters, kernel_size, padding='same'), | ||
BatchNormalization(trainable=False), | ||
ReLU(), | ||
], name) | ||
|
||
def build_separable_conv(filters, kernel_size, name=None): | ||
return Sequential([ | ||
Conv2D(filters, kernel_size=1, use_bias=False), | ||
BatchNormalization(trainable=False), | ||
ReLU(), | ||
SeparableConv2D(filters, kernel_size, padding='same', use_bias=False), | ||
Conv2D(filters, kernel_size=1, use_bias=False), | ||
BatchNormalization(trainable=False), | ||
ReLU(), | ||
], name) | ||
|
||
def build_avg_pool(filters, name=None): | ||
return Sequential([ | ||
Conv2D(filters, kernel_size=1, use_bias=False), | ||
BatchNormalization(trainable=False), | ||
ReLU(), | ||
AveragePooling2D(pool_size=3, strides=1, padding='same'), | ||
BatchNormalization(trainable=False), | ||
], name) | ||
|
||
def build_max_pool(filters, name=None): | ||
return Sequential([ | ||
Conv2D(filters, kernel_size=1, use_bias=False), | ||
BatchNormalization(trainable=False), | ||
ReLU(), | ||
MaxPool2D(pool_size=3, strides=1, padding='same'), | ||
BatchNormalization(trainable=False), | ||
], name) | ||
|
||
|
||
class FactorizedReduce(Model): | ||
def __init__(self, filters): | ||
super().__init__() | ||
self.conv1 = Conv2D(filters // 2, kernel_size=1, strides=2, use_bias=False) | ||
self.conv2 = Conv2D(filters // 2, kernel_size=1, strides=2, use_bias=False) | ||
self.bn = BatchNormalization(trainable=False) | ||
|
||
def call(self, x): | ||
out1 = self.conv1(x) | ||
out2 = self.conv2(x[:, 1:, 1:, :]) | ||
out = tf.concat([out1, out2], axis=3) | ||
out = self.bn(out) | ||
return out | ||
|
||
|
||
class ENASLayer(MutableScope): | ||
def __init__(self, key, prev_labels, filters): | ||
super().__init__(key) | ||
self.mutable = LayerChoice([ | ||
build_conv(filters, 3, 'conv3'), | ||
build_separable_conv(filters, 3, 'sepconv3'), | ||
build_conv(filters, 5, 'conv5'), | ||
build_separable_conv(filters, 5, 'sepconv5'), | ||
build_avg_pool(filters, 'avgpool'), | ||
build_max_pool(filters, 'maxpool'), | ||
]) | ||
if len(prev_labels) > 0: | ||
self.skipconnect = InputChoice(choose_from=prev_labels, n_chosen=None) | ||
else: | ||
self.skipconnect = None | ||
self.batch_norm = BatchNormalization(trainable=False) | ||
|
||
def call(self, prev_layers): | ||
out = self.mutable(prev_layers[-1]) | ||
if self.skipconnect is not None: | ||
connection = self.skipconnect(prev_layers[:-1]) | ||
if connection is not None: | ||
out += connection | ||
return self.batch_norm(out) | ||
|
||
|
||
class GeneralNetwork(Model): | ||
def __init__(self, num_layers=12, filters=24, num_classes=10, dropout_rate=0.0): | ||
super().__init__() | ||
self.num_layers = num_layers | ||
|
||
self.stem = Sequential([ | ||
Conv2D(filters, kernel_size=3, padding='same', use_bias=False), | ||
BatchNormalization() | ||
]) | ||
|
||
labels = ['layer_{}'.format(i) for i in range(num_layers)] | ||
self.enas_layers = [] | ||
for i in range(num_layers): | ||
layer = ENASLayer(labels[i], labels[:i], filters) | ||
self.enas_layers.append(layer) | ||
|
||
pool_num = 2 | ||
self.pool_distance = num_layers // (pool_num + 1) | ||
self.pool_layers = [FactorizedReduce(filters) for _ in range(pool_num)] | ||
|
||
self.gap = GlobalAveragePooling2D() | ||
self.dropout = Dropout(dropout_rate) | ||
self.dense = Dense(num_classes) | ||
|
||
def call(self, x): | ||
cur = self.stem(x) | ||
prev_outputs = [cur] | ||
|
||
for i, layer in enumerate(self.enas_layers): | ||
if i > 0 and i % self.pool_distance == 0: | ||
pool = self.pool_layers[i // self.pool_distance - 1] | ||
prev_outputs = [pool(tensor) for tensor in prev_outputs] | ||
cur = prev_outputs[-1] | ||
|
||
cur = layer(prev_outputs) | ||
prev_outputs.append(cur) | ||
|
||
cur = self.gap(cur) | ||
cur = self.dropout(cur) | ||
logits = self.dense(cur) | ||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras import Model, Sequential | ||
from tensorflow.keras.layers import ( | ||
AveragePooling2D, | ||
BatchNormalization, | ||
Conv2D, | ||
Dense, | ||
Dropout, | ||
GlobalAveragePooling2D, | ||
MaxPool2D, | ||
ReLU, | ||
SeparableConv2D, | ||
) | ||
|
||
from nni.nas.tensorflow.mutables import InputChoice, LayerChoice, MutableScope | ||
|
||
|
||
def build_conv_1x1(filters, name=None): | ||
return Sequential([ | ||
Conv2D(filters, kernel_size=1, use_bias=False), | ||
BatchNormalization(trainable=False), | ||
ReLU(), | ||
], name) | ||
|
||
def build_sep_conv(filters, kernel_size, name=None): | ||
return Sequential([ | ||
ReLU(), | ||
SeparableConv2D(filters, kernel_size, padding='same'), | ||
BatchNormalization(trainable=True), | ||
], name) | ||
|
||
|
||
class FactorizedReduce(Model): | ||
def __init__(self, filters): | ||
super().__init__() | ||
self.conv1 = Conv2D(filters // 2, kernel_size=1, strides=2, use_bias=False) | ||
self.conv2 = Conv2D(filters // 2, kernel_size=1, strides=2, use_bias=False) | ||
self.bn = BatchNormalization(trainable=False) | ||
|
||
def call(self, x): | ||
out1 = self.conv1(x) | ||
out2 = self.conv2(x[:, 1:, 1:, :]) | ||
out = tf.concat([out1, out2], axis=3) | ||
out = self.bn(out) | ||
return out | ||
|
||
|
||
class ReductionLayer(Model): | ||
def __init__(self, filters): | ||
super().__init__() | ||
self.reduce0 = FactorizedReduce(filters) | ||
self.reduce1 = FactorizedReduce(filters) | ||
|
||
def call(self, prevprev, prev): | ||
return self.reduce0(prevprev), self.reduce1(prev) | ||
|
||
|
||
class Calibration(Model): | ||
def __init__(self, filters): | ||
super().__init__() | ||
self.filters = filters | ||
self.process = None | ||
|
||
def build(self, shape): | ||
assert len(shape) == 4 # batch_size, width, height, filters | ||
if shape[3] != self.filters: | ||
self.process = build_conv_1x1(self.filters) | ||
|
||
def call(self, x): | ||
if self.process is None: | ||
return x | ||
return self.process(x) | ||
|
||
|
||
class Cell(Model): | ||
def __init__(self, cell_name, prev_labels, filters): | ||
super().__init__() | ||
self.input_choice = InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True, key=cell_name + '_input') | ||
self.op_choice = LayerChoice([ | ||
build_sep_conv(filters, 3), | ||
build_sep_conv(filters, 5), | ||
AveragePooling2D(pool_size=3, strides=1, padding='same'), | ||
MaxPool2D(pool_size=3, strides=1, padding='same'), | ||
Sequential(), # Identity | ||
], key=cell_name + '_op') | ||
|
||
def call(self, prev_layers): | ||
chosen_input, chosen_mask = self.input_choice(prev_layers) | ||
cell_out = self.op_choice(chosen_input) | ||
return cell_out, chosen_mask | ||
|
||
|
||
class Node(MutableScope): | ||
def __init__(self, node_name, prev_node_names, filters): | ||
super().__init__(node_name) | ||
self.cell_x = Cell(node_name + '_x', prev_node_names, filters) | ||
self.cell_y = Cell(node_name + '_y', prev_node_names, filters) | ||
|
||
def call(self, prev_layers): | ||
out_x, mask_x = self.cell_x(prev_layers) | ||
out_y, mask_y = self.cell_y(prev_layers) | ||
return out_x + out_y, mask_x | mask_y | ||
|
||
|
||
class ENASLayer(Model): | ||
def __init__(self, num_nodes, filters, reduction): | ||
super().__init__() | ||
self.preproc0 = Calibration(filters) | ||
self.preproc1 = Calibration(filters) | ||
|
||
self.nodes = [] | ||
node_labels = [InputChoice.NO_KEY, InputChoice.NO_KEY] | ||
name_prefix = 'reduce' if reduction else 'normal' | ||
for i in range(num_nodes): | ||
node_labels.append('{}_node_{}'.format(name_prefix, i)) | ||
self.nodes.append(Node(node_labels[-1], node_labels[:-1], filters)) | ||
|
||
self.conv_ops = [Conv2D(filters, kernel_size=1, padding='same', use_bias=False) for _ in range(num_nodes + 2)] | ||
self.bn = BatchNormalization(trainable=False) | ||
|
||
def call(self, prevprev, prev): | ||
prev_nodes_out = [self.preproc0(prevprev), self.preproc1(prev)] | ||
nodes_used_mask = tf.zeros(len(self.nodes) + 2, dtype=tf.bool) | ||
for i, node in enumerate(self.nodes): | ||
node_out, mask = node(prev_nodes_out) | ||
nodes_used_mask |= tf.pad(mask, [[0, nodes_used_mask.shape[0] - mask.shape[0]]]) | ||
prev_nodes_out.append(node_out) | ||
|
||
outputs = [] | ||
for used, out, conv in zip(nodes_used_mask.numpy(), prev_nodes_out, self.conv_ops): | ||
if not used: | ||
outputs.append(conv(out)) | ||
out = tf.add_n(outputs) | ||
return prev, self.bn(out) | ||
|
||
|
||
class MicroNetwork(Model): | ||
def __init__(self, num_layers=6, num_nodes=5, out_channels=20, num_classes=10, dropout_rate=0.1): | ||
super().__init__() | ||
self.num_layers = num_layers | ||
self.stem = Sequential([ | ||
Conv2D(out_channels * 3, kernel_size=3, padding='same', use_bias=False), | ||
BatchNormalization(), | ||
]) | ||
|
||
pool_distance = num_layers // 3 | ||
pool_layer_indices = [pool_distance, 2 * pool_distance + 1] | ||
|
||
self.enas_layers = [] | ||
|
||
filters = out_channels | ||
for i in range(num_layers + 2): | ||
if i in pool_layer_indices: | ||
reduction = True | ||
filters *= 2 | ||
self.enas_layers.append(ReductionLayer(filters)) | ||
else: | ||
reduction = False | ||
self.enas_layers.append(ENASLayer(num_nodes, filters, reduction)) | ||
|
||
self.gap = GlobalAveragePooling2D() | ||
self.dropout = Dropout(dropout_rate) | ||
self.dense = Dense(num_classes) | ||
|
||
def call(self, x): | ||
prev = cur = self.stem(x) | ||
for layer in self.enas_layers: | ||
prev, cur = layer(prev, cur) | ||
cur = tf.keras.activations.relu(cur) | ||
cur = self.gap(cur) | ||
cur = self.dropout(cur) | ||
logits = self.dense(cur) | ||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
|
||
from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy | ||
from tensorflow.keras.optimizers import SGD | ||
|
||
from nni.nas.tensorflow import enas | ||
|
||
import datasets | ||
from macro import GeneralNetwork | ||
from micro import MicroNetwork | ||
from utils import accuracy, accuracy_metrics | ||
|
||
|
||
# TODO: argparse | ||
|
||
|
||
dataset_train, dataset_valid = datasets.get_dataset() | ||
#model = GeneralNetwork() | ||
model = MicroNetwork() | ||
|
||
loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE) | ||
optimizer = SGD(learning_rate=0.05, momentum=0.9) | ||
|
||
trainer = enas.EnasTrainer(model, | ||
loss=loss, | ||
metrics=accuracy_metrics, | ||
reward_function=accuracy, | ||
optimizer=optimizer, | ||
batch_size=64, | ||
num_epochs=310, | ||
dataset_train=dataset_train, | ||
dataset_valid=dataset_valid) | ||
trainer.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import tensorflow as tf | ||
|
||
|
||
def accuracy_metrics(y_true, logits): | ||
return {'enas_acc': accuracy(y_true, logits)} | ||
|
||
def accuracy(y_true, logits): | ||
# y_true: shape=(batch_size) or (batch_size,1), type=integer | ||
# logits: shape=(batch_size, num_of_classes), type=float | ||
# returns float | ||
batch_size = y_true.shape[0] | ||
y_true = tf.squeeze(y_true) | ||
y_pred = tf.math.argmax(logits, axis=1) | ||
y_pred = tf.cast(y_pred, y_true.dtype) | ||
equal = tf.cast(y_pred == y_true, tf.int32) | ||
return tf.math.reduce_sum(equal).numpy() / batch_size |
Oops, something went wrong.