Skip to content

Commit

Permalink
Adding rule that checks if input data has been correctly normalized (a…
Browse files Browse the repository at this point in the history
…ws#293)

* add test and rule for checking input tensors
  • Loading branch information
NRauschmayr authored and Vikas-kum committed Oct 27, 2019
1 parent a4d7c8b commit 6837964
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 1 deletion.
22 changes: 21 additions & 1 deletion tests/analysis/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
$CODEBUILD_SRC_DIR/examples/mxnet/scripts/mnist_gluon_all_zero_demo.py,
mnist_gluon_vg_demo.py: &mnist_gluon_vg_demo
$CODEBUILD_SRC_DIR/examples/mxnet/scripts/mnist_gluon_vg_demo.py,
test_check_input_images.py: &test_check_input_images
$CODEBUILD_SRC_DIR/tests/analysis/rules/test_check_input_images.py,
test_dead_relu.py: &test_dead_relu
$CODEBUILD_SRC_DIR/tests/analysis/rules/test_dead_relu.py,
invoker.py: &invoker
Expand Down Expand Up @@ -180,6 +182,24 @@
*invoker,
--rule_name allzero --flag True --end_step 3 --collection_names ReluActivation
]
-
- check_input_images/mxnet/true
- mxnet
- *Enable
- [*test_check_input_images,
--flag True,
*invoker,
--rule_name checkinputimages --flag True --end_step 7 --collection_names input
]
-
- check_input_images/mxnet/false
- mxnet
- *Enable
- [*test_check_input_images,
--flag False,
*invoker,
--rule_name checkinputimages --flag False --end_step 936 --collection_names input
]
-
- dead_relu/mxnet/true
- mxnet
Expand All @@ -199,7 +219,7 @@
--rule_name deadrelu --flag False --end_step 4 --collection_names ReluActivation
]

# test cases for pytorch
# test cases for pytorch
-
- exploding_tensor/pytorch/false
- pytorch
Expand Down
104 changes: 104 additions & 0 deletions tests/analysis/rules/test_check_input_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import argparse
import mxnet as mx
import numpy as np
from mxnet import gluon, init, autograd
from mxnet.gluon import nn
from mxnet.gluon.data.vision import datasets, transforms
from tornasole.mxnet import TornasoleHook, SaveConfig, modes
import tornasole.mxnet as tm


def parse_args():
parser = argparse.ArgumentParser(description="Train a mxnet gluon model")
parser.add_argument("--batch-size", type=int, default=256, help="Batch size")
parser.add_argument(
"--output-s3-uri",
type=str,
default="s3://tornasole-testing/saveall-mxnet-hook",
help="S3 URI of the bucket where tensor data will be stored.",
)
parser.add_argument(
"--tornasole_path",
type=str,
default=None,
help="S3 URI of the bucket where tensor data will be stored.",
)
parser.add_argument("--random_seed", type=bool, default=True)
parser.add_argument(
"--flag",
type=bool,
default=True,
help="Bool variable that indicates whether parameters will be intialized to zero",
)
opt = parser.parse_args()
return opt


def create_gluon_model():
net = nn.HybridSequential()
net.add(
nn.Conv2D(channels=6, kernel_size=5, activation="relu"),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Conv2D(channels=16, kernel_size=3, activation="relu"),
nn.MaxPool2D(pool_size=2, strides=2),
nn.Flatten(),
nn.Dense(120, activation="relu"),
nn.Dense(84, activation="relu"),
nn.Dense(10),
)
net.initialize(init=init.Uniform(1), ctx=mx.cpu())
return net


def train_model(batch_size, net, train_data, lr, hook):
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": lr})
for epoch in range(1):
for data, label in train_data:
data = data.as_in_context(mx.cpu(0))
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size)


def create_tornasole_hook(output_s3_uri):
save_config = SaveConfig(save_interval=1)
custom_collect = tm.get_collection("inputData")
custom_collect.set_save_config(save_config)
custom_collect.include([".*hybridsequential0_input_0"])
hook = TornasoleHook(
out_dir=output_s3_uri, save_config=save_config, include_collections=["inputData"]
)
return hook


def prepare_data(batch_size, flag):
mnist_train = datasets.FashionMNIST(train=True)
if flag:
transformer = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(0.286, 0.352)]
)
else:
transformer = transforms.Compose([transforms.ToTensor()])
mnist_train = mnist_train.transform_first(transformer)
train_data = gluon.data.DataLoader(
mnist_train, batch_size=batch_size, shuffle=True, num_workers=4
)

return train_data


def main():
opt = parse_args()
net = create_gluon_model()
output_s3_uri = opt.tornasole_path if opt.tornasole_path is not None else opt.output_s3_uri
hook = create_tornasole_hook(output_s3_uri)
hook.register_hook(net)
train_data = prepare_data(64, opt.flag)
train_model(64, net, train_data, 0.1, hook)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions tornasole/rules/rule_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def get_rule(rule_name):
return UnchangedTensor
elif rule_name_lower == "lossnotdecreasing":
return LossNotDecreasing
elif rule_name_lower == "checkinputimages":
return CheckInputImages
elif rule_name_lower == "deadrelu":
return DeadRelu
elif rule_name_lower == "confusion":
Expand Down

0 comments on commit 6837964

Please sign in to comment.