From 5a8da2409bc721f8fdffff40969722eacd29cca0 Mon Sep 17 00:00:00 2001 From: cedrickchee Date: Sat, 4 Nov 2017 18:03:56 +0800 Subject: [PATCH] Initial commit --- .gitignore | 107 +++++++++++++++++++++++++ CHANGELOG.md | 12 +++ LICENSE | 34 ++++++++ README.md | 62 +++++++++++++++ capsule_layer.py | 134 ++++++++++++++++++++++++++++++++ conv_layer.py | 29 +++++++ main.py | 198 +++++++++++++++++++++++++++++++++++++++++++++++ model.py | 86 ++++++++++++++++++++ utils.py | 57 ++++++++++++++ 9 files changed, 719 insertions(+) create mode 100644 .gitignore create mode 100644 CHANGELOG.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 capsule_layer.py create mode 100644 conv_layer.py create mode 100644 main.py create mode 100644 model.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4dbc50d --- /dev/null +++ b/.gitignore @@ -0,0 +1,107 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +.static_storage/ +.media/ +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# IDE settings for Visual Studio Code +.vscode diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..ab45347 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,12 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) +and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). + +## 0.0.1 - 2017-11-04 +### Added +- Initial release. The first beta version. API is stable. The code runs. So, I think it's safe to use for development but not ready for general production usage. + +[Unreleased]: https://github.com/olivierlacan/keep-a-changelog/compare/v1.0.0...HEAD +[0.0.2]: https://github.com/cedrickchee/keep-a-changelog/compare/v0.0.1...v0.0.2 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e06c514 --- /dev/null +++ b/LICENSE @@ -0,0 +1,34 @@ +COPYRIGHT + +All contributions by Cedric Chee: +Copyright (c) 2017, Cedric Chee. +All rights reserved. + +All other contributions: +Copyright (c) 2017, the respective contributors. +All rights reserved. + +Each contributor holds copyright over their respective contributions. +The project versioning (Git) records all such contribution source information. + +LICENSE + +The MIT License (MIT) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..edc7759 --- /dev/null +++ b/README.md @@ -0,0 +1,62 @@ +# PyTorch CapsNet: Capsule Network for PyTorch + +[![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/cedrickchee/capsule-net-pytorch/blob/master/LICENSE) +![completion](https://img.shields.io/badge/completion%20state-90%25-green.svg?style=plastic) + +A PyTorch implementation of CapsNet (Capsule Network) based on this paper: +[Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017](https://arxiv.org/abs/1710.09829) + +Codes comes with ample comments and Python docstring. + +**Status and Latest Updates:** + +See the [CHANGELOG](CHANGELOG.md) + +**Datasets** + +The model was trained on the standard [MNIST](http://yann.lecun.com/exdb/mnist/) data. + +*Note: you don't have to manually download and process the MNIST dataset as PyTorch will take care of this step for you.* + +## Requirements +- Python +- [PyTorch](http://pytorch.org/) + +## Usage + +### Training and Evaluation +**Step 1.** +Clone this repository with ``git``. + +``` +$ git clone https://github.com/cedrickchee/capsule-net-pytorch.git +$ cd capsule-net-pytorch +``` + +**Step 2.** +Start the training and evaluation: +``` +$ python main.py +``` + +## Results +Coming soon! + +- training loss +![total_loss](internal/img/training/training_loss.png) + +![margin_loss](internal/img/training/margin_loss.png) +![reconstruction_loss](internal/img/training/reconstruction_loss.png) + +- evaluation accuracy +![test_img1](internal/img/evaluation/test_000.png) + +**TODO** +- [WIP] Publish results. +- [WIP] More testing. +- Separate training and evaluation into independent command. +- Jupyter Notebook version +- Create a sample to show how we can apply CapsNet to real-world application. +- Experiment with CapsNet: + * Try using another dataset + * Come out a more creative model structure diff --git a/capsule_layer.py b/capsule_layer.py new file mode 100644 index 0000000..1c4ff6d --- /dev/null +++ b/capsule_layer.py @@ -0,0 +1,134 @@ +"""Capsule layer + +PyTorch implementation of CapsNet in Sabour, Hinton et al.'s paper +Dynamic Routing Between Capsules. NIPS 2017. +https://arxiv.org/abs/1710.09829 + +Author: Cedric Chee +""" + +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F + + +class CapsuleLayer(nn.Module): + """ + The core implementation of the idea of capsules + """ + def __init__(self, in_unit, in_channel, num_unit, unit_size, use_routing): + super(CapsuleLayer, self).__init__() + + self.in_unit = in_unit + self.in_channel = in_channel + self.num_unit = num_unit + self.use_routing = use_routing + + if self.use_routing: + """ + Based on the paper, DigitCaps which is capsule layer(s) with + capsule inputs use a routing algorithm that uses this weight matrix, Wij + """ + self.W = nn.Parameter(torch.randn( + 1, in_channel, num_unit, unit_size, in_unit)) + else: + """ + According to the CapsNet architecture section in the paper, + we have routing only between two consecutive capsule layers (e.g. PrimaryCapsules and DigitCaps). + No routing is used between Conv1 and PrimaryCapsules. + + This means PrimaryCapsules is composed of several convolutional units. + So, implementation-wise, it uses normal convolutional layer with a nonlinearity (squash). + """ + def create_conv_unit(idx): + unit = nn.Conv2d(in_channels=in_channel, + out_channels=32, + kernel_size=9, + stride=2) + self.add_module("conv_unit" + str(idx), unit) + return unit + + self.conv_units = [create_conv_unit(u) for u in range(self.num_unit)] + + @staticmethod + def squash(sj): + """ + Non-linear 'squashing' function. + This implement equation 1 from the paper. + """ + sj_mag_sq = torch.sum(sj**2, dim=2, keepdim=True) + # ||sj || + sj_mag = torch.sqrt(sj_mag_sq) + v_j = (sj_mag_sq / (1.0 + sj_mag_sq)) * (sj / sj_mag) + return v_j + + def forward(self, x): + if self.use_routing: + return self.routing(x) + else: + return self.no_routing(x) + + def routing(self, x): + """ + Routing algorithm for capsule. + + :return: vector output of capsule j + """ + batch_size = x.size(0) + + x = x.transpose(1, 2) + x = torch.stack([x] * self.num_unit, dim=2).unsqueeze(4) + W = torch.cat([self.W] * batch_size, dim=0) + + # Transform inputs by weight matrix. + u_hat = torch.matmul(W, x) + + # All the routing logits (b_ij in the paper) are initialized to zero. + b_ij = Variable(torch.zeros( + 1, self.in_channel, self.num_unit, 1)).cuda() + + # From the paper in the "Capsules on MNIST" section, + # the sample MNIST test reconstructions of a CapsNet with 3 routing iterations. + num_iterations = 3 + + for iteration in range(num_iterations): + # Routing algorithm + + # Calculate routing or also known as coupling coefficients (c_ij). + c_ij = F.softmax(b_ij) # Convert routing logits (b_ij) to softmax. + c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4) + + # Implement equation 2 in the paper. + # u_hat is weighted inputs + s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) + + v_j = CapsuleLayer.squash(s_j) + + v_j1 = torch.cat([v_j] * self.in_channel, dim=1) + + u_vj1 = torch.matmul(u_hat.transpose(3, 4), v_j1).squeeze( + 4).mean(dim=0, keepdim=True) + + # Update routing (b_ij) + b_ij = b_ij + u_vj1 + + return v_j.squeeze(1) + + def no_routing(self, x): + """ + Get output for each unit. + A unit has batch, channels, height, width. + + :return: vector output of capsule j + """ + unit = [self.conv_units[i](x) for i in range(self.num_unit)] + + # Stack all unit outputs. + unit = torch.stack(unit, dim=1) + + # Flatten + unit = unit.view(x.size(0), self.num_unit, -1) + + # Return squashed outputs. + return CapsuleLayer.squash(unit) diff --git a/conv_layer.py b/conv_layer.py new file mode 100644 index 0000000..702e1dd --- /dev/null +++ b/conv_layer.py @@ -0,0 +1,29 @@ +"""Convolutional layer + +PyTorch implementation of CapsNet in Sabour, Hinton et al.'s paper +Dynamic Routing Between Capsules. NIPS 2017. +https://arxiv.org/abs/1710.09829 + +Author: Cedric Chee +""" + +import torch +import torch.nn as nn + + +class ConvLayer(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size): + super(ConvLayer, self).__init__() + + self.conv0 = nn.Conv2d(in_channels=in_channel, + out_channels=out_channel, + kernel_size=kernel_size, + stride=1) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass""" + out_conv0 = self.conv0(x) + out_relu = self.relu(out_conv0) + return out_relu diff --git a/main.py b/main.py new file mode 100644 index 0000000..52c3cf8 --- /dev/null +++ b/main.py @@ -0,0 +1,198 @@ +""" +PyTorch implementation of CapsNet in Sabour, Hinton et al.'s paper +Dynamic Routing Between Capsules. NIPS 2017. +https://arxiv.org/abs/1710.09829 + +Usage: + python main.py + python main.py --epochs 50 + python main.py --epochs 50 --loss-threshold 0.0001 + +Author: Cedric Chee +""" + +from __future__ import print_function +import argparse +import sys +import time + +import torch +import torch.optim as optim +from torch.autograd import Variable + +import utils +from model import Net + + +def train(model, data_loader, optimizer, epoch): + """Train CapsuleNet model on training set + :param model: The CapsuleNet model + :param data_loader: An interator over the dataset. It combines a dataset and a sampler + :optimizer: Optimization algorithm + :epoch: Current epoch + :return: Loss + """ + print('===> Training mode') + + last_loss = None + + # Switch to train mode + model.train() + + for batch_idx, (data, target) in enumerate(data_loader): + target_one_hot = utils.one_hot_encode( + target, length=model.digits.num_units) + + data, target = Variable(data), Variable(target_one_hot) + + if args.cuda: + data = data.cuda() + target = target.cuda() + + optimizer.zero_grad() + output = model(data) + loss = model.loss(output, target) + loss.backward() + last_loss = loss.data[0] + optimizer.step() + + if batch_idx % args.log_interval == 0: + mesg = '{}\tEpoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\n'.format( + time.ctime(), + epoch, + batch_idx * len(data), + len(data_loader.dataset), + 100. * batch_idx / len(data_loader), + loss.data[0]) + + print(mesg) + + if last_loss < args.loss_threshold: + # Stop training early + break + + return last_loss + + +def test(model, data_loader): + """Evaluate model on validation set + """ + print('===> Evaluate mode') + + # Switch to evaluate mode + model.eval() + + test_loss = 0 + correct = 0 + for data, target in data_loader: + target_indices = target + target_one_hot = utils.one_hot_encode( + target_indices, length=model.digits.num_units) + + data, target = Variable(data, volatile=True), Variable(target_one_hot) + + if args.cuda: + data = data.cuda() + target = target.cuda() + + output = model(data) + + # sum up batch loss + test_loss += model.loss(output, target, size_average=False).data[0] + + # evaluate + v_magnitud = torch.sqrt((output**2).sum(dim=2, keepdim=True)) + pred = v_magnitud.data.max(1, keepdim=True)[1].cpu() + correct += pred.eq(target_indices.view_as(pred)).sum() + + test_loss /= len(data_loader.dataset) + + mesg = 'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, + correct, + len(data_loader.dataset), + 100. * correct / len(data_loader.dataset)) + print(mesg) + + +def main(): + """The main function + Entry point. + """ + global args + + # Setting the hyper parameters + parser = argparse.ArgumentParser(description='Example of Capsule Network') + parser.add_argument('--epochs', type=int, default=10, + help='number of training epochs. default=10') + parser.add_argument('--lr', type=float, default=0.01, + help='learning rate. default=0.01') + parser.add_argument('--batch-size', type=int, default=128, + help='training batch size. default=128') + parser.add_argument('--test-batch-size', type=int, + default=128, help='testing batch size. default=128') + parser.add_argument('--loss-threshold', type=float, default=0.0001, + help='stop training if loss goes below this threshold. default=0.0001') + parser.add_argument("--log-interval", type=int, default=1, + help='number of images after which the training loss is logged, default is 1') + parser.add_argument('--cuda', action='store_true', + help='set it to 1 for running on GPU, 0 for CPU') + parser.add_argument('--threads', type=int, default=4, + help='number of threads for data loader to use') + parser.add_argument('--seed', type=int, default=42, + help='random seed for training. default=42') + parser.add_argument('--num-conv-channel', type=int, default=256, + help='number of convolutional channel. default=256') + parser.add_argument('--num-primary-unit', type=int, default=8, + help='number of primary unit. default=8') + parser.add_argument('--primary-unit-size', type=int, + default=1152, help='primary unit size. default=1152') + parser.add_argument('--output-unit-size', type=int, + default=16, help='output unit size. default=16') + + args = parser.parse_args() + + print(args) + + # Check GPU or CUDA is available + cuda = args.cuda + if cuda and not torch.cuda.is_available(): + print( + "ERROR: No GPU/cuda is not available. Try running on CPU or run without --cuda") + sys.exit(1) + + torch.manual_seed(args.seed) + if cuda: + torch.cuda.manual_seed(args.seed) + + # Load data + train_loader, test_loader = utils.load_mnist(args) + + # Build Capsule Network + print('===> Building model') + model = Net(num_conv_channel=args.num_conv_channel, + num_primary_unit=args.num_primary_unit, + primary_unit_size=args.primary_unit_size, + output_unit_size=args.output_unit_size) + + if cuda: + model = model.cuda() + + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + # Train and test + for epoch in range(1, args.epochs + 1): + previous_loss = train(model, train_loader, optimizer, epoch) + test(model, test_loader) + utils.checkpoint({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict() + }, epoch) + + if previous_loss < args.loss_threshold: + break + + +if __name__ == "__main__": + main() diff --git a/model.py b/model.py new file mode 100644 index 0000000..ed4cea4 --- /dev/null +++ b/model.py @@ -0,0 +1,86 @@ +"""CapsNet Architecture + +PyTorch implementation of CapsNet in Sabour, Hinton et al.'s paper +Dynamic Routing Between Capsules. NIPS 2017. +https://arxiv.org/abs/1710.09829 + +Author: Cedric Chee +""" + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from conv_layer import ConvLayer +from capsule_layer import CapsuleLayer + + +class Net(nn.Module): + """ + A simple CapsNet with 3 layers + """ + def __init__(self, num_conv_channel, num_primary_unit, primary_unit_size, output_unit_size): + """ + In the constructor we instantiate one ConvLayer module and two CapsuleLayer modules + and assign them as member variables. + """ + super(Net, self).__init__() + + self.conv1 = ConvLayer(in_channel=1, + out_channel=num_conv_channel, + kernel_size=9) + + # PrimaryCaps + self.primary = CapsuleLayer(in_unit=0, + in_channel=num_conv_channel, + num_unit=num_primary_unit, + unit_size=primary_unit_size, + use_routing=False) + + # DigitCaps + self.digits = CapsuleLayer(in_unit=num_primary_unit, + in_channel=primary_unit_size, + num_unit=10, + unit_size=output_unit_size, + use_routing=True) + + def forward(self, x): + """ + Defines the computation performed at every forward pass. + """ + out_conv1 = self.conv1(x) + out_primary_caps = self.primary(out_conv1) + out_digit_caps = self.digits(out_primary_caps) + return out_digit_caps + + def loss(self, input, target, size_average=True): + """Custom loss function""" + m_loss = self.margin_loss(input, target, size_average) + return m_loss + + def margin_loss(self, input, target, size_average=True): + """Margin loss for digit existence + """ + batch_size = input.size(0) + + # Implement equation 4 in the paper. + + # ||vc|| + v_c = torch.sqrt((input**2).sum(dim=2, keepdim=True)) + + # Calculate left and right max() terms. + zero = Variable(torch.zeros(1)).cuda() + m_plus = 0.9 + m_minus = 0.1 + loss_lambda = 0.5 + max_left = torch.max(m_plus - v_c, zero).view(batch_size, -1) + max_right = torch.max(v_c - m_minus, zero).view(batch_size, -1) + t_c = target + # Lc is margin loss for each digit of class c + l_c = t_c * max_left + loss_lambda * (1.0 - t_c) * max_right + l_c = l_c.sum(dim=1) + + if size_average: + l_c = l_c.mean() + + return l_c diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..a4c27b2 --- /dev/null +++ b/utils.py @@ -0,0 +1,57 @@ +"""Utilities + +PyTorch implementation of CapsNet in Sabour, Hinton et al.'s paper +Dynamic Routing Between Capsules. NIPS 2017. +https://arxiv.org/abs/1710.09829 + +Author: Cedric Chee +""" + +import torch +from torchvision import transforms, datasets +from torch.utils.data import DataLoader + + +def normalize_dataset(): + """Normalize MNIST dataset.""" + return transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + +def one_hot_encode(target, length): + """Converts batches of class indices to classes of one-hot vectors.""" + batch_s = target.size(0) + one_hot_vec = torch.zeros(batch_s, length) + + for i in range(batch_s): + one_hot_vec[i, target[i]] = 1.0 + + return one_hot_vec + + +def checkpoint(state, epoch): + """Save checkpoint""" + model_out_path = 'model_epoch_{}.pth'.format(epoch) + torch.save(state, model_out_path) + print('Checkpoint saved to {}'.format(model_out_path)) + + +def load_mnist(args): + """Load MNIST dataset. + The data is split and normalized between train and test sets. + """ + print('===> Loading training datasets') + training_set = datasets.MNIST( + '../data', train=True, download=True, transform=normalize_dataset) + training_data_loader = DataLoader( + training_set, num_workers=args.threads, batch_size=args.batch_size, shuffle=True) + + print('===> Loading testing datasets') + testing_set = datasets.MNIST( + '../data', train=False, download=True, transform=normalize_dataset) + testing_data_loader = DataLoader( + testing_set, num_workers=args.threads, batch_size=args.test_batch_size, shuffle=True) + + return training_data_loader, testing_data_loader