Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of the MNASNet family of models #829

Merged
merged 23 commits into from
Jun 24, 2019
Merged

Implementation of the MNASNet family of models #829

merged 23 commits into from
Jun 24, 2019

Conversation

1e100
Copy link
Contributor

@1e100 1e100 commented Apr 2, 2019

As described in https://arxiv.org/pdf/1807.11626.pdf

Training program and the detailed training notes are posted separately, as requested by @fmassa. Imagenet results for depth multiplier 1.0 are as follows:

  • Top-1 = 73.512
  • Top-5 = 91.544

Paper top 1 is 74%. The best result was achieved at epoch 198 out of 200, so it seems likely that the model would benefit from adding more epochs.

Trainer program, checkpoint, training log, and tensorboard files can be found here: https://github.com/1e100/mnasnet_trainer/tree/master

@1e100 1e100 marked this pull request as ready for review April 2, 2019 08:04
@1e100
Copy link
Contributor Author

1e100 commented Apr 5, 2019

MNASNet 0.5 result:

Top1: 67.006% (paper number 67.8%)
Top5: 86.960%

Hyperparameters:

250 epochs (then 50 more)
Peak LR = 0.8
Warmup 5 epochs
No weight decay
SGD+Nesterov
Momentum 0.9

I can post the checkpoint, training log, and tensorboard dump.

@depthwise
Copy link

Seems like linter failed on existing models, not on the new files.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Thanks for the PR!

I have left a couple of comments.
Also, did you try training any of those models yet?

"""

def __init__(self, num_classes, alpha, dropout=0.2):
super().__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is Python3-only.
Can you replace instead by super(MNASNet, self).__init__()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

m.bias.data.zero_()


class MNASNet0_5(MNASNet):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about having mnasnet0_5 etc as functions, instead of classes, in the same way as what we have for the other models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@codecov-io
Copy link

codecov-io commented Apr 13, 2019

Codecov Report

Merging #829 into master will increase coverage by 0.42%.
The diff coverage is 82.92%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #829      +/-   ##
==========================================
+ Coverage   60.26%   60.69%   +0.42%     
==========================================
  Files          63       64       +1     
  Lines        5001     5083      +82     
  Branches      745      758      +13     
==========================================
+ Hits         3014     3085      +71     
- Misses       1784     1791       +7     
- Partials      203      207       +4
Impacted Files Coverage Δ
torchvision/models/__init__.py 100% <100%> (ø) ⬆️
torchvision/models/mnasnet.py 82.71% <82.71%> (ø)
torchvision/transforms/transforms.py 82.54% <0%> (+0.64%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 76b2667...c34df87. Read the comment docs.

@1e100
Copy link
Contributor Author

1e100 commented Apr 13, 2019

Most of my attention was focused on 0.5 and 1.0 depth multipliers. The former is a good "fast" configuration, the latter is a good "precise" configuration.

I have also trained the 0.75 once (current best top-1 is 70.31, achieved at first try, paper number is 71.5), and 1.3 is still training with 65 more epochs to go, as it takes a lot longer to train.

Based on my experience with 0.5 and 1.0, I think I just need to find better training hyperparameters and/or train for longer. I'm fundamentally limited by having only 8 consumer-grade GPUs at my disposal.

@rwightman
Copy link
Contributor

This model looks very promising, excited to try it out. Thanks for the impl. A few comments..

  1. The layout of the model changes based on whether dropout is 0. or > 0. This is problematic, not just because people often expect to be able to just swap/reference the classifier by name, as a Linear or 1x1 Conv2d layer, but also because it will prevent models from loading correctly with different settings for dropout. Also, the drop rate isn't passed through to the actual Dropout module.

  2. Less important, and more a matter of style and preference. When working with models it's definitely nice to leverage the advantages of PyTorch and build them with class modules as building blocks. Pushing more towards the ol Torch all Sequence style can make code compact but definitely has some drawbacks when debugging, modifying, working with the weights, etc.

Compare a snippet of the weight names of this model:

layers.11.1.layers.4.weight
layers.11.1.layers.4.bias
layers.11.1.layers.6.weight
layers.11.1.layers.7.weight
layers.11.1.layers.7.bias
layers.12.0.layers.0.weight
layers.12.0.layers.1.weight
layers.12.0.layers.1.bias
layers.12.0.layers.3.weight
layers.12.0.layers.4.weight

to torchvision ResNet:

layer1.2.bn3.weight
layer1.2.bn3.bias
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.conv3.weight
layer2.0.bn3.weight
layer2.0.bn3.bias
layer2.0.downsample.0.weight
layer2.0.downsample.1.weight
layer2.0.downsample.1.bias
layer2.1.conv1.weight
layer2.1.bn1.weight
layer2.1.bn1.bias

@1e100
Copy link
Contributor Author

1e100 commented Apr 14, 2019

I don't believe that whether or not dropout is there will prevent the model from loading, as dropout doesn't actually store anything in the checkpoint. In PyTorch dropout probability is defined programmatically. If I take a look at the model dict in the checkpoint I can't see a parameter for dropout in there even when dropout is present in the model.

The "classifier" module stores the following in the checkpoint:

  • 'classifier.1.weight'
  • 'classifier.1.bias'

Even though dropout has non-zero probability.

@rwightman
Copy link
Contributor

I don't believe that whether or not dropout is there will prevent the model from loading, as dropout doesn't actually store anything in the checkpoint. In PyTorch dropout probability is defined programmatically. If I take a look at the model dict in the checkpoint I can't see a parameter for dropout in there even when dropout is present in the model.

The "classifier" module stores the following in the checkpoint:

  • 'classifier.1.weight'
  • 'classifier.1.bias'

Even though dropout has non-zero probability.

If you set dropout to 0. and try to load weights created with dropout > 0. you'll get an exception. The extra Sequence container creates another level in the parameter key.

RuntimeError: Error(s) in loading state_dict for MNASNet:
Missing key(s) in state_dict: "classifier.weight", "classifier.bias".
Unexpected key(s) in state_dict: "classifier.1.weight", "classifier.1.bias".

@1e100
Copy link
Contributor Author

1e100 commented Apr 14, 2019

I see what you mean, but the model would not work correctly in this case even if loaded, because the checkpoint would be loaded into a model that does not match the checkpoint. I think I'd prefer the model to barf on load in such a case.

So I'm kind of reluctant to address this. @fmassa, what do you think?

@1e100
Copy link
Contributor Author

1e100 commented Apr 14, 2019

Looks like PyTorch scales the dropout during training so it can be just omitted during inference, so it would, in fact, work correctly even if loaded with a different dropout probability.

To simplify things I just removed the condition for the dropout. It'll be identity if p is zero anyway.

@rwightman
Copy link
Contributor

rwightman commented Apr 14, 2019

I'd still suggest leaving the classifier as nn.Linear and pull the dropout out as an optional functional call. It's the most common approach and it's nice not to be astonished. Makes the model easier to work with too. Current checkpoint can be easily remapped with a load/key-map/save.

@1e100
Copy link
Contributor Author

1e100 commented Apr 15, 2019

@rwightman PTAL. This is more consistent with other models in torchvision, thanks for the suggestion. I still need to update my published checkpoint, didn't have the time to do it today.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few more comments, in particular the similarities with MobileNet V2.

torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
torchvision/models/mnasnet.py Show resolved Hide resolved
torchvision/models/mnasnet.py Show resolved Hide resolved
torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
@1e100
Copy link
Contributor Author

1e100 commented Apr 18, 2019

Trained up another MNASNet 0.5 with the changes (and with more epochs), with substantially better results:

loss=1.367, metrics = prec1=67.484, prec5=87.482

Paper number for top1 is 67.8. Will upload the checkpoint and training log to the trainer repo. MNASNet 1.0 is still training.

@1e100
Copy link
Contributor Author

1e100 commented Apr 29, 2019

@fmassa is there anything else I need to change? I think I resolved everything, not sure you've seen it yet. PTAL. If there's anything you're not satisfied with, I have the time this week to address it.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Sorry for the delay in reviewing it.

I have a few comments, let me know what you think

torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
torchvision/models/__init__.py Show resolved Hide resolved
@1e100
Copy link
Contributor Author

1e100 commented May 6, 2019

OK, @fmassa, I think I have addressed all of your suggestions other than initialization, and uploaded a better MNASNet 0.5 checkpoint.

Regarding initialization, these low-compute nets are more sensitive to normalization, and I'm not actually sure that using the more off-the-shelf stuff will result in the same accuracies. Just to test that out, I'm kicking off a training run on 0.5, with only initializations different. It'll take about 2.5 days.

Let's get the rest of this ironed out in the meanwhile.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is looking very good, thanks!

Just a couple more comments and then this is good to go

torchvision/models/mnasnet.py Outdated Show resolved Hide resolved
@1e100
Copy link
Contributor Author

1e100 commented May 12, 2019

Hi @fmassa. I've addressed your feedback. PTAL

@1e100
Copy link
Contributor Author

1e100 commented May 21, 2019

Merged master. Good time to merge. :-)

@1e100
Copy link
Contributor Author

1e100 commented May 23, 2019

@fmassa, is there anything I need to change here?

@1e100
Copy link
Contributor Author

1e100 commented Jun 20, 2019

@fmassa, this has been sitting in the queue long enough that I'm considering withdrawing the PR. If this is not going to be merged, please let me know. If I don't hear back anything within a week, I'll close the PR.

@fmassa
Copy link
Member

fmassa commented Jun 24, 2019

I'm very sorry for the delay in replying, I was pretty busy with the 0.3 release of torchvision, and some follow-up tasks.

I'm reviewing it today

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an update of the url and I also included the progress option to the datasets.

Also, do you have an idea on how much the current LR scheduler that you used is important to match your results?

I'm merging this as is instead of pushing the commits to your branch because it's in your master branch and I might need to apply some extra changes which could affect your branch, so in order to be safe I'm merging this as is.

Thanks a lot, and sorry for the delay in merging it!

@@ -111,6 +114,7 @@ ShuffleNet V2 30.64 11.68
MobileNet V2 28.12 9.71
ResNeXt-50-32x4d 22.38 6.30
ResNeXt-101-32x8d 20.69 5.47
MNASNet 1.0 26.49 8.456
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question:

I've tried testing the models with the references/train.py file, and got

 * Acc@1 73.456 Acc@5 91.510

which corresponds to 26.54 and 8.49.

I wonder if there is a difference in our data, or if the model that I downloaded is not the same?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll verify and report back (and train a better model if needed). Could be something as dumb as the wrong version of Pillow. To aid with investigation, please answer the following questions:

  • Are you using Pillow-SIMD?
  • Are you compiling it from source?
  • Are you using libjpeg-turbo8?

Copy link

@depthwise depthwise Jun 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also, did you change the number of epochs when training? I think that particular model actually used more epochs. I'm 1e100 just under another account

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I trained this 'b1' variant of mnasnet to 74.66 and the 'a1' SE variant to 75.45. They both took over 400 epochs, both using RMSprop with Google like h-params. Using EMA of the model weights was necessary to match/surpass the papers...

Models only: https://github.com/rwightman/gen-efficientnet-pytorch#pretrained

Training code w/ models: https://github.com/rwightman/pytorch-image-models

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@depthwise I don't think that Pillow / the different image processing is the cause.
I've done extensive experiments in the past with multiple models, and they all were fairly insensitive to using PIL / OpenCV. If that is indeed the case, then I'd be surprised, and that could indicate that the model is very fragile to small perturbations.

For completeness, I used Pillow 5.4.1, from pip, using libjpeg-turbo8

@rwightman interesting, and what happens if you do not apply EMA, do you recall what accuracies you get?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to investigate this I wrote a simple eval script, which I pushed to https://github.com/1e100/mnasnet_trainer/blob/master/eval.py.

The results with Pillow-SIMD/libjpeg-turbo-8 are as follows:

Dev/mnasnet_trainer % ./eval.py
Evaluating pretrained mnasnet1_0
1.0769143749256522 [('prec1', 73.490265), ('prec5', 91.53294)]
Evaluating pretrained mnasnet0_5
1.3720299355229553 [('prec1', 67.59815), ('prec5', 87.51842)]

Neither of which matches the published numbers exactly. MNASNet 1.0 is slightly worse than in the doc says. MNASNet 0.5 is slightly better than the checkpoint name would imply (67.598% top1 vs 67.592).

The results with "plain" pillow 6.0.0 from PyPl are as follows:

% ./eval.py
Evaluating pretrained mnasnet1_0
1.0772113243536072 [('prec1', 73.46037), ('prec5', 91.52099)]
Evaluating pretrained mnasnet0_5
1.372453243756781 [('prec1', 67.606026), ('prec5', 87.50845)]

So the top1 for 1.0 gets a bit worse, and for 0.5 it gets a bit better. I've observed such sensitivity with other "efficient" models in the past. In particular, the resize algorithm (which is different in Pillow SIMD) seems to make a noticeable difference. The effect of this on smaller models is easily measurable. Something as mundane as a different JPEG decoder affects them, and so do the software versions: CUDA/cuDNN, PyTorch, etc - a number of these are different between then and now.

Just to be on the safe side, though, I kicked off another run for 1.0. I should have the results sometime over the weekend, one way or the other.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@1e100 if you are running the training again, would you mind using the code from references/classification, and maybe only changing the lr step?

For MobileNet V2, I used the following commands, on 8 GPUs

--model mobilenet_v2 --epochs 300 --lr 0.045 --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98

The most important thing we are trying to do here is to have a simple path for reproducible research, so having a few % worse accuracy (~0.3-0.5 e.g.,) but with a reproducible script available in torchvision would be preferable I'd say.

This way, we can indeed compare apples with apples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa what version of Pillow should I use for this? I use custom-compiled PillowSIMD for my own training runs, but if we want things to be more repeatable, I could use the slower, default install of PillowSIMD.

Here's how I compile it on my machines:

CC="cc -march=native -mtune=native -O3" pip3 install \
    --force-reinstall --user pillow-simd

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@1e100 sorry, I missed your message.

You can use current version that you have available, don't worry about potentially small differences.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still training. First training run wasn't too good (I tweaked the training regime, and this model is sensitive to that kind of thing), so I'm using that as initialization on the second run. I'll update (and send a PR) when I get a good result.


_MODEL_URLS = {
"mnasnet0_5":
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a follow-up commit that updates this

@fmassa fmassa merged commit 69b2857 into pytorch:master Jun 24, 2019
@scb-vs5
Copy link

scb-vs5 commented Jan 10, 2020

Being a freshman to cv, I just can not find the codes about reinforce learning what the paper described(the rewards or the way to update).
Could you give some tips about which part of the codes excute the RNN controller?
Thank you so much!

@1e100
Copy link
Contributor Author

1e100 commented Jan 10, 2020

@scb-vs5 There's no neural architecture search here, only the model. And only one kind of model (the "B" variant), as well. Unless you have access to immense amounts of spare GPU/TPU capacity like Google does, neural architecture search is not really feasible for you as an individual anyway. In my "back of the envelope" estimation it'd cost no less than $100K to just run the search for the model architecture, and that excludes the time and resources it'd take you to build the infrastructure and get it running reliably. So while I could implement this, it'd be kind of pointless for me to do so, other than to educate. TorchVision, in my view, is more practically-oriented.

@scb-vs5
Copy link

scb-vs5 commented Jan 11, 2020

Thank you for your reply, Then whether it's well recommended for individual to start “NAS” with weights sharing or gradient policy ,which could speed up the train and practically-oriented?Like ENas or darts?

@fmassa
Copy link
Member

fmassa commented Jan 13, 2020

@scb-vs5 I would recommend asking such questions in the PyTorch forums http://discuss.pytorch.org/

@azamatkhid
Copy link

Hello! Recently, I tried to train the mnasnet1_0 from scratch on cifar10 dataset (with some small augmentations), but the loss behavior was very weird, as well the training accuracy does not exceed 10%.

@fmassa
Copy link
Member

fmassa commented Apr 30, 2020

@azamatkhid I would recommend checking how the model was trained on ImageNet following the instructions from https://github.com/1e100/mnasnet_trainer

MNasNet might have a few training tricks that are not really standard.

@1e100
Copy link
Contributor Author

1e100 commented Apr 30, 2020

@azamatkhid Training procedure there is tuned for large datasets like ImageNet. You may need to e.g. reduce momentum for batchnorm, or even leave it at its default value for a small dataset like CIFAR. FWIW, I've been using this model quite a lot lately for object detection on smaller datasets with great success. I use BN momentum of 0.01 on such datasets (or 0.99 in TensorFlow terms). Try this, see what happens.

Beyond this, please share your training code on GitHub. It's difficult to see what's wrong with so little information. Could be a million different things.

@azamatkhid
Copy link

azamatkhid commented May 3, 2020

@fmassa and @1e100 thanks for your comments.
This is the repository with the code:
https://github.com/azamatkhid/mnasnet-pytorch
I will try to re-train with BN momentum 0.01, and let you know guys if it is successful.
In addition, I was wondering why there is no mnasnet-a1 official implementation on pytorch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants