Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching mechanism for MCDropout #268

Merged
merged 15 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions baal/active/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
if last_active_steps == 0 or last_active_steps < -1:
raise ValueError("last_active_steps must be > 0 or -1 when disabled.")
self.last_active_steps = last_active_steps
self._indices_cache = (-1, None)

def get_indices_for_active_step(self) -> List[int]:
"""Returns the indices required for the active step.
Expand All @@ -49,14 +50,18 @@ def get_indices_for_active_step(self) -> List[int]:
Returns:
List of the selected indices for training.
"""
if self.last_active_steps == -1:
min_labelled_step = 0
else:
min_labelled_step = max(0, self.current_al_step - self.last_active_steps)

# we need to work with lists since arrow dataset is not compatible with np.int types!
indices = [indx for indx, val in enumerate(self.labelled_map) if val > min_labelled_step]
return indices
if (curr_al_step := self.current_al_step) != self._indices_cache[0]:
if self.last_active_steps == -1:
min_labelled_step = 0
else:
min_labelled_step = max(0, curr_al_step - self.last_active_steps)

# we need to work with lists since arrow dataset is not compatible with np.int types!
indices = [
indx for indx, val in enumerate(self.labelled_map) if val > min_labelled_step
]
self._indices_cache = (curr_al_step, indices)
return self._indices_cache[1]

def is_labelled(self, idx: int) -> bool:
"""Check if a datapoint is labelled."""
Expand Down
3 changes: 2 additions & 1 deletion baal/active/dataset/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,6 @@ def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
if not isinstance(value, (list, tuple)):
value = [value]
indexes = self._pool_to_oracle_index(index)
active_step = self.current_al_step + 1
for index, val in zip_longest(indexes, value, fillvalue=None):
self.labelled_map[index] = 1
self.labelled_map[index] = active_step
56 changes: 56 additions & 0 deletions baal/bayesian/caching_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Optional

import torch
from torch import nn, Tensor

from baal.bayesian.common import BayesianModule, _patching_wrapper


class LRUCacheModule(nn.Module):
def __init__(self, module, size=1):
super().__init__()
if size != 1:
raise ValueError("We do not support LRUCache bigger than 1.")
self.module = module
self._memory_input = None
self._memory_output = None

def _is_cache_void(self, x):
return self._memory_input is None or not torch.equal(self._memory_input, x)

def __call__(self, x: Tensor):
if self.training:
return self.module(x)
if self._is_cache_void(x):
self._memory_input = x
self._memory_output = self.module(x)
return self._memory_output


def _caching_mapping_fn(module: torch.nn.Module) -> Optional[nn.Module]:
new_module: Optional[nn.Module] = None
# Could add more
if isinstance(module, (nn.Linear, nn.Conv2d)):
new_module = LRUCacheModule(module=module)
return new_module


def _caching_unmapping_fn(module: torch.nn.Module) -> Optional[nn.Module]:
new_module: Optional[nn.Module] = None

if isinstance(module, LRUCacheModule):
new_module = module.module
return new_module


def patch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module:
return _patching_wrapper(module, inplace=inplace, patching_fn=_caching_mapping_fn)


def unpatch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module:
return _patching_wrapper(module, inplace=inplace, patching_fn=_caching_unmapping_fn)


class MCCachingModule(BayesianModule):
patching_function = patch_module
unpatch_function = unpatch_module
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ nav:
- user_guide/index.md
- Cheat Sheet: user_guide/baal_cheatsheet.md
- Active data structure: notebooks/fundamentals/active-learning.ipynb
- Speeding up Monte-Carlo Inference With MCCachingModule: notebooks/mccaching_layer.ipynb
- Computing uncertainty:
- Stochastic models: notebooks/fundamentals/posteriors.ipynb
- Heuristics: user_guide/heuristics.md
Expand Down
149 changes: 149 additions & 0 deletions notebooks/mccaching_layer.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Speeding up Monte-Carlo Inference With MCCachingModule\n",
"\n",
"It is common knowledge that running MCDropout is slow and computationally expensive.\n",
"Baal proposes a new simple API called `MCCachingModule` to speedup MCDropout by more than 70%!\n",
"\n",
"**TLDR: MCCachingWrapper**\n",
"\n",
"```python\n",
">>> from baal.bayesian.caching_utils import MCCachingModule\n",
">>> # Regular code to perform MCDropout with Baal.\n",
">>> model = MCDropoutModule(original_module)\n",
">>> # To gain 70% speedup, simply do\n",
">>> model = MCCachingModule(model)\n",
"```\n",
"\n",
"Below we detail our approach in this toy example. We will use a `VGG16` model and run MCDropout for 20 iterations on the test set of CIFAR10.\n",
"\n",
"We get the following results on a GeForce 1060Ti:\n",
"\n",
"| Number of Iteration | 20 | 50 | 100 |\n",
"|---------------------|----------|----------|----------|\n",
"| Regular MC-Dropout | 2:58 | 7:27 | 13:45 |\n",
"| Ours | **0:50** | **1:46** | **3:32** |\n",
"\n",
"We are excited to see how the community uses this new feature!\n",
"\n",
"### Code!"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"[12777-MainThread] [baal.modelwrapper:predict_on_dataset_generator:239] \u001B[2m2023-07-13T21:09:33.828796Z\u001B[0m [\u001B[32m\u001B[1minfo \u001B[0m] \u001B[1mStart Predict \u001B[0m \u001B[36mdataset\u001B[0m=\u001B[35m10000\u001B[0m\n",
"100%|██████████| 313/313 [02:49<00:00, 1.85it/s]\n"
]
}
],
"source": [
"from torchvision.datasets import CIFAR10\n",
"from torchvision.models import vgg16\n",
"from torchvision.transforms import ToTensor\n",
"\n",
"from baal.bayesian.caching_utils import MCCachingModule\n",
"from baal.bayesian.dropout import MCDropoutModule\n",
"from baal.modelwrapper import ModelWrapper\n",
"\n",
"ITERATIONS = 20\n",
"\n",
"vgg = vgg16().cuda()\n",
"vgg.eval()\n",
"\n",
"ds = CIFAR10('/tmp', train=False, transform=ToTensor(), download=True)\n",
"\n",
"# Takes ~2:58 minutes.\n",
"with MCDropoutModule(vgg) as model_2:\n",
" wrapper = ModelWrapper(model_2, None, replicate_in_memory=False)\n",
" wrapper.predict_on_dataset(ds, batch_size=32, iterations=ITERATIONS, use_cuda=True)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-07-13T21:12:23.378811603Z",
"start_time": "2023-07-13T21:09:29.068365127Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Introducing MCCachingModule!\n",
"\n",
"By simply wrapping the module with `MCCachingModule` we run the same inference 70% faster!"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[12777-MainThread] [baal.modelwrapper:predict_on_dataset_generator:239] \u001B[2m2023-07-13T21:12:23.384108Z\u001B[0m [\u001B[32m\u001B[1minfo \u001B[0m] \u001B[1mStart Predict \u001B[0m \u001B[36mdataset\u001B[0m=\u001B[35m10000\u001B[0m\n",
"100%|██████████| 313/313 [00:47<00:00, 6.60it/s]\n"
]
}
],
"source": [
"# Takes ~50 seconds!.\n",
"with MCCachingModule(vgg) as model:\n",
" with MCDropoutModule(model) as model_2:\n",
" wrapper = ModelWrapper(model_2, None, replicate_in_memory=False)\n",
" wrapper.predict_on_dataset(ds, batch_size=32, iterations=ITERATIONS, use_cuda=True)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-07-13T21:13:11.076629413Z",
"start_time": "2023-07-13T21:12:23.387507076Z"
}
}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
52 changes: 52 additions & 0 deletions tests/bayesian/test_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import torch
from torch.nn import Sequential, Linear

from baal.bayesian.caching_utils import MCCachingModule


class LinearMocked(Linear):
call_count = 0

def __init__(self, in_features: int, out_features: int):
super().__init__(in_features, out_features)

def __call__(self, x):
LinearMocked.call_count += 1
return super().__call__(x)


@pytest.fixture()
def my_model():
return Sequential(
LinearMocked(10, 10),
LinearMocked(10, 10),
Sequential(
LinearMocked(10, 10),
LinearMocked(10, 10),
)
).eval()


def test_caching(my_model):
x = torch.rand(10)

# No Caching
my_model(x)
assert LinearMocked.call_count == 4
my_model(x)
assert LinearMocked.call_count == 8

with MCCachingModule(my_model) as model:
model(x)
assert LinearMocked.call_count == 12
model(x)
assert LinearMocked.call_count == 12

# No Caching
my_model(x)
assert LinearMocked.call_count == 16
my_model(x)
assert LinearMocked.call_count == 20