Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
cedrickchee committed Nov 4, 2017
0 parents commit 5a8da24
Show file tree
Hide file tree
Showing 9 changed files with 719 additions and 0 deletions.
107 changes: 107 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
.static_storage/
.media/
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# IDE settings for Visual Studio Code
.vscode
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Changelog
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## 0.0.1 - 2017-11-04
### Added
- Initial release. The first beta version. API is stable. The code runs. So, I think it's safe to use for development but not ready for general production usage.

[Unreleased]: https://github.com/olivierlacan/keep-a-changelog/compare/v1.0.0...HEAD
[0.0.2]: https://github.com/cedrickchee/keep-a-changelog/compare/v0.0.1...v0.0.2
34 changes: 34 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
COPYRIGHT

All contributions by Cedric Chee:
Copyright (c) 2017, Cedric Chee.
All rights reserved.

All other contributions:
Copyright (c) 2017, the respective contributors.
All rights reserved.

Each contributor holds copyright over their respective contributions.
The project versioning (Git) records all such contribution source information.

LICENSE

The MIT License (MIT)

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
62 changes: 62 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# PyTorch CapsNet: Capsule Network for PyTorch

[![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/cedrickchee/capsule-net-pytorch/blob/master/LICENSE)
![completion](https://img.shields.io/badge/completion%20state-90%25-green.svg?style=plastic)

A PyTorch implementation of CapsNet (Capsule Network) based on this paper:
[Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017](https://arxiv.org/abs/1710.09829)

Codes comes with ample comments and Python docstring.

**Status and Latest Updates:**

See the [CHANGELOG](CHANGELOG.md)

**Datasets**

The model was trained on the standard [MNIST](http://yann.lecun.com/exdb/mnist/) data.

*Note: you don't have to manually download and process the MNIST dataset as PyTorch will take care of this step for you.*

## Requirements
- Python
- [PyTorch](http://pytorch.org/)

## Usage

### Training and Evaluation
**Step 1.**
Clone this repository with ``git``.

```
$ git clone https://github.com/cedrickchee/capsule-net-pytorch.git
$ cd capsule-net-pytorch
```

**Step 2.**
Start the training and evaluation:
```
$ python main.py
```

## Results
Coming soon!

- training loss
![total_loss](internal/img/training/training_loss.png)

![margin_loss](internal/img/training/margin_loss.png)
![reconstruction_loss](internal/img/training/reconstruction_loss.png)

- evaluation accuracy
![test_img1](internal/img/evaluation/test_000.png)

**TODO**
- [WIP] Publish results.
- [WIP] More testing.
- Separate training and evaluation into independent command.
- Jupyter Notebook version
- Create a sample to show how we can apply CapsNet to real-world application.
- Experiment with CapsNet:
* Try using another dataset
* Come out a more creative model structure
134 changes: 134 additions & 0 deletions capsule_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Capsule layer
PyTorch implementation of CapsNet in Sabour, Hinton et al.'s paper
Dynamic Routing Between Capsules. NIPS 2017.
https://arxiv.org/abs/1710.09829
Author: Cedric Chee
"""

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F


class CapsuleLayer(nn.Module):
"""
The core implementation of the idea of capsules
"""
def __init__(self, in_unit, in_channel, num_unit, unit_size, use_routing):
super(CapsuleLayer, self).__init__()

self.in_unit = in_unit
self.in_channel = in_channel
self.num_unit = num_unit
self.use_routing = use_routing

if self.use_routing:
"""
Based on the paper, DigitCaps which is capsule layer(s) with
capsule inputs use a routing algorithm that uses this weight matrix, Wij
"""
self.W = nn.Parameter(torch.randn(
1, in_channel, num_unit, unit_size, in_unit))
else:
"""
According to the CapsNet architecture section in the paper,
we have routing only between two consecutive capsule layers (e.g. PrimaryCapsules and DigitCaps).
No routing is used between Conv1 and PrimaryCapsules.
This means PrimaryCapsules is composed of several convolutional units.
So, implementation-wise, it uses normal convolutional layer with a nonlinearity (squash).
"""
def create_conv_unit(idx):
unit = nn.Conv2d(in_channels=in_channel,
out_channels=32,
kernel_size=9,
stride=2)
self.add_module("conv_unit" + str(idx), unit)
return unit

self.conv_units = [create_conv_unit(u) for u in range(self.num_unit)]

@staticmethod
def squash(sj):
"""
Non-linear 'squashing' function.
This implement equation 1 from the paper.
"""
sj_mag_sq = torch.sum(sj**2, dim=2, keepdim=True)
# ||sj ||
sj_mag = torch.sqrt(sj_mag_sq)
v_j = (sj_mag_sq / (1.0 + sj_mag_sq)) * (sj / sj_mag)
return v_j

def forward(self, x):
if self.use_routing:
return self.routing(x)
else:
return self.no_routing(x)

def routing(self, x):
"""
Routing algorithm for capsule.
:return: vector output of capsule j
"""
batch_size = x.size(0)

x = x.transpose(1, 2)
x = torch.stack([x] * self.num_unit, dim=2).unsqueeze(4)
W = torch.cat([self.W] * batch_size, dim=0)

# Transform inputs by weight matrix.
u_hat = torch.matmul(W, x)

# All the routing logits (b_ij in the paper) are initialized to zero.
b_ij = Variable(torch.zeros(
1, self.in_channel, self.num_unit, 1)).cuda()

# From the paper in the "Capsules on MNIST" section,
# the sample MNIST test reconstructions of a CapsNet with 3 routing iterations.
num_iterations = 3

for iteration in range(num_iterations):
# Routing algorithm

# Calculate routing or also known as coupling coefficients (c_ij).
c_ij = F.softmax(b_ij) # Convert routing logits (b_ij) to softmax.
c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

# Implement equation 2 in the paper.
# u_hat is weighted inputs
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)

v_j = CapsuleLayer.squash(s_j)

v_j1 = torch.cat([v_j] * self.in_channel, dim=1)

u_vj1 = torch.matmul(u_hat.transpose(3, 4), v_j1).squeeze(
4).mean(dim=0, keepdim=True)

# Update routing (b_ij)
b_ij = b_ij + u_vj1

return v_j.squeeze(1)

def no_routing(self, x):
"""
Get output for each unit.
A unit has batch, channels, height, width.
:return: vector output of capsule j
"""
unit = [self.conv_units[i](x) for i in range(self.num_unit)]

# Stack all unit outputs.
unit = torch.stack(unit, dim=1)

# Flatten
unit = unit.view(x.size(0), self.num_unit, -1)

# Return squashed outputs.
return CapsuleLayer.squash(unit)
29 changes: 29 additions & 0 deletions conv_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Convolutional layer
PyTorch implementation of CapsNet in Sabour, Hinton et al.'s paper
Dynamic Routing Between Capsules. NIPS 2017.
https://arxiv.org/abs/1710.09829
Author: Cedric Chee
"""

import torch
import torch.nn as nn


class ConvLayer(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size):
super(ConvLayer, self).__init__()

self.conv0 = nn.Conv2d(in_channels=in_channel,
out_channels=out_channel,
kernel_size=kernel_size,
stride=1)

self.relu = nn.ReLU(inplace=True)

def forward(self, x):
"""Forward pass"""
out_conv0 = self.conv0(x)
out_relu = self.relu(out_conv0)
return out_relu
Loading

0 comments on commit 5a8da24

Please sign in to comment.