From 10068614836710c4f799c7006a9ceac0b7af62e2 Mon Sep 17 00:00:00 2001 From: "fr.branchaud-charron" Date: Thu, 3 Mar 2022 17:06:51 -0500 Subject: [PATCH 1/2] #189 Add unpatch for dropout and consistent dropout --- Makefile | 2 +- baal/active/dataset/pytorch_dataset.py | 14 ++--- baal/bayesian/common.py | 22 ++++++++ baal/bayesian/consistent_dropout.py | 62 ++++++++++++++------- baal/bayesian/dropout.py | 66 ++++++++++++++--------- baal/utils/plot_utils.py | 2 +- tests/active/active_loop_test.py | 14 +++++ tests/bayesian/common_test.py | 64 ++++++++++++++++++++++ tests/bayesian/consistent_dropout_test.py | 10 ++++ tests/bayesian/dropconnect_test.py | 18 ++++++- tests/bayesian/dropout_test.py | 6 +++ 11 files changed, 224 insertions(+), 56 deletions(-) create mode 100644 baal/bayesian/common.py create mode 100644 tests/bayesian/common_test.py diff --git a/Makefile b/Makefile index 38e623e9..e4f557a1 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ lint: check-mypy-error-count .PHONY: test test: lint - poetry run pytest tests --cov=baal + poetry run pytest tests --cov=baal -n 2 .PHONY: format format: diff --git a/baal/active/dataset/pytorch_dataset.py b/baal/active/dataset/pytorch_dataset.py index b6558836..a8960d9a 100644 --- a/baal/active/dataset/pytorch_dataset.py +++ b/baal/active/dataset/pytorch_dataset.py @@ -32,13 +32,13 @@ class ActiveLearningDataset(SplittedDataset): """ def __init__( - self, - dataset: torchdata.Dataset, - labelled: Optional[np.ndarray] = None, - make_unlabelled: Callable = _identity, - random_state=None, - pool_specifics: Optional[dict] = None, - last_active_steps: int = -1, + self, + dataset: torchdata.Dataset, + labelled: Optional[np.ndarray] = None, + make_unlabelled: Callable = _identity, + random_state=None, + pool_specifics: Optional[dict] = None, + last_active_steps: int = -1, ) -> None: self._dataset = dataset diff --git a/baal/bayesian/common.py b/baal/bayesian/common.py new file mode 100644 index 00000000..15d6c67c --- /dev/null +++ b/baal/bayesian/common.py @@ -0,0 +1,22 @@ +from typing import Callable +from torch import nn + + +def replace_layers_in_module(module: nn.Module, mapping_fn: Callable) -> bool: + """ + Recursively iterate over the children of a module and replace them according to `mapping_fn`. + + Returns: + True if a layer has been changed. + """ + changed = False + for name, child in module.named_children(): + new_module = mapping_fn(child) + + if new_module is not None: + changed = True + module.add_module(name, new_module) + + # recursively apply to child + changed |= replace_layers_in_module(child, mapping_fn) + return changed diff --git a/baal/bayesian/consistent_dropout.py b/baal/bayesian/consistent_dropout.py index 2122b72c..e0869d8e 100644 --- a/baal/bayesian/consistent_dropout.py +++ b/baal/bayesian/consistent_dropout.py @@ -7,6 +7,8 @@ from torch.nn import functional as F from torch.nn.modules.dropout import _DropoutNd +from baal.bayesian.common import replace_layers_in_module + class ConsistentDropout(_DropoutNd): """ @@ -115,32 +117,50 @@ def patch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Modu """ if not inplace: module = copy.deepcopy(module) - changed = _patch_dropout_layers(module) + changed = replace_layers_in_module(module, _consistent_dropout_mapping_fn) if not changed: warnings.warn("No layer was modified by patch_module!", UserWarning) return module -def _patch_dropout_layers(module: torch.nn.Module) -> bool: - """ - Recursively iterate over the children of a module and replace them if - they are a dropout layer. This function operates in-place. +def unpatch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module: + """Replace ConsistentDropout layers in a model with Dropout layers. + + Args: + module (torch.nn.Module): + The module in which you would like to replace dropout layers. + inplace (bool, optional): + Whether to modify the module in place or return a copy of the module. + + Returns: + torch.nn.Module + The modified module, which is either the same object as you passed in + (if inplace = True) or a copy of that object. """ - changed = False - for name, child in module.named_children(): - new_module: Optional[nn.Module] = None - if isinstance(child, torch.nn.Dropout): - new_module = ConsistentDropout(p=child.p) - elif isinstance(child, torch.nn.Dropout2d): - new_module = ConsistentDropout2d(p=child.p) + if not inplace: + module = copy.deepcopy(module) + changed = replace_layers_in_module(module, _consistent_dropout_unmapping_fn) + if not changed: + warnings.warn("No layer was modified by patch_module!", UserWarning) + return module + - if new_module is not None: - changed = True - module.add_module(name, new_module) +def _consistent_dropout_mapping_fn(module: torch.nn.Module) -> Optional[nn.Module]: + new_module: Optional[nn.Module] = None + if isinstance(module, torch.nn.Dropout): + new_module = ConsistentDropout(p=module.p) + elif isinstance(module, torch.nn.Dropout2d): + new_module = ConsistentDropout2d(p=module.p) + return new_module - # recursively apply to child - changed |= _patch_dropout_layers(child) - return changed + +def _consistent_dropout_unmapping_fn(module: torch.nn.Module) -> Optional[nn.Module]: + new_module: Optional[nn.Module] = None + if isinstance(module, ConsistentDropout): + new_module = torch.nn.Dropout(p=module.p) + elif isinstance(module, ConsistentDropout2d): + new_module = torch.nn.Dropout2d(p=module.p) + return new_module class MCConsistentDropoutModule(torch.nn.Module): @@ -152,8 +172,10 @@ def __init__(self, module: torch.nn.Module): A fully specified neural network. """ super().__init__() - self.parent_module = module - _patch_dropout_layers(self.parent_module) + self.parent_module = patch_module(module) def forward(self, *args, **kwargs): return self.parent_module.forward(*args, **kwargs) + + def unpatch(self) -> torch.nn.Module: + return unpatch_module(self.parent_module) diff --git a/baal/bayesian/dropout.py b/baal/bayesian/dropout.py index 1a065dea..14993798 100644 --- a/baal/bayesian/dropout.py +++ b/baal/bayesian/dropout.py @@ -7,6 +7,8 @@ from torch.nn import functional as F from torch.nn.modules.dropout import _DropoutNd +from baal.bayesian.common import replace_layers_in_module + class Dropout(_DropoutNd): r"""Randomly zeroes some of the elements of the input @@ -85,7 +87,7 @@ def forward(self, input): def patch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module: - """Replace dropout layers in a model with MC Dropout layers. + """Replace dropout layers in a model with MCDropout layers. Args: module (torch.nn.Module): @@ -93,9 +95,6 @@ def patch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Modu inplace (bool, optional): Whether to modify the module in place or return a copy of the module. - Raises: - UserWarning if no layer is modified. - Returns: torch.nn.Module The modified module, which is either the same object as you passed in @@ -103,35 +102,50 @@ def patch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Modu """ if not inplace: module = copy.deepcopy(module) - changed = _patch_dropout_layers(module) + changed = replace_layers_in_module(module, _dropout_mapping_fn) if not changed: warnings.warn("No layer was modified by patch_module!", UserWarning) return module -def _patch_dropout_layers(module: torch.nn.Module) -> bool: - """ - Recursively iterate over the children of a module and replace them if - they are a dropout layer. This function operates in-place. +def unpatch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module: + """Replace MCDropout layers in a model with Dropout layers. + + Args: + module (torch.nn.Module): + The module in which you would like to replace dropout layers. + inplace (bool, optional): + Whether to modify the module in place or return a copy of the module. Returns: - Flag indicating if a layer was modified. + torch.nn.Module + The modified module, which is either the same object as you passed in + (if inplace = True) or a copy of that object. """ - changed = False - for name, child in module.named_children(): - new_module: Optional[nn.Module] = None - if isinstance(child, torch.nn.Dropout): - new_module = Dropout(p=child.p, inplace=child.inplace) - elif isinstance(child, torch.nn.Dropout2d): - new_module = Dropout2d(p=child.p, inplace=child.inplace) + if not inplace: + module = copy.deepcopy(module) + changed = replace_layers_in_module(module, _dropout_unmapping_fn) + if not changed: + warnings.warn("No layer was modified by patch_module!", UserWarning) + return module + - if new_module is not None: - changed = True - module.add_module(name, new_module) +def _dropout_mapping_fn(module: torch.nn.Module) -> Optional[nn.Module]: + new_module: Optional[nn.Module] = None + if isinstance(module, torch.nn.Dropout): + new_module = Dropout(p=module.p, inplace=module.inplace) + elif isinstance(module, torch.nn.Dropout2d): + new_module = Dropout2d(p=module.p, inplace=module.inplace) + return new_module - # recursively apply to child - changed |= _patch_dropout_layers(child) - return changed + +def _dropout_unmapping_fn(module: torch.nn.Module) -> Optional[nn.Module]: + new_module: Optional[nn.Module] = None + if isinstance(module, Dropout): + new_module = torch.nn.Dropout(p=module.p, inplace=module.inplace) + elif isinstance(module, Dropout2d): + new_module = torch.nn.Dropout2d(p=module.p, inplace=module.inplace) + return new_module class MCDropoutModule(torch.nn.Module): @@ -143,8 +157,10 @@ def __init__(self, module: torch.nn.Module): A fully specified neural network. """ super().__init__() - self.parent_module = module - _patch_dropout_layers(self.parent_module) + self.parent_module = patch_module(module) def forward(self, *args, **kwargs): return self.parent_module(*args, **kwargs) + + def unpatch(self) -> torch.nn.Module: + return unpatch_module(self.parent_module) diff --git a/baal/utils/plot_utils.py b/baal/utils/plot_utils.py index 869d6dc7..9a789d97 100644 --- a/baal/utils/plot_utils.py +++ b/baal/utils/plot_utils.py @@ -75,7 +75,7 @@ def make_animation_from_data( return frames -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover from sklearn.datasets import make_classification import imageio diff --git a/tests/active/active_loop_test.py b/tests/active/active_loop_test.py index b7d6bfdc..527e030c 100644 --- a/tests/active/active_loop_test.py +++ b/tests/active/active_loop_test.py @@ -1,5 +1,6 @@ import os import pickle +import warnings import numpy as np import pytest @@ -126,5 +127,18 @@ def test_file_saving(tmpdir): assert (data['dataset']['labelled'] != dataset.labelled).sum() == 10 +def test_deprecation(): + heur = heuristics.BALD() + ds = MyDataset() + dataset = ActiveLearningDataset(ds, make_unlabelled=lambda x: -1) + with warnings.catch_warnings(record=True) as w: + active_loop = ActiveLearningLoop(dataset, + get_probs_iter, + heur, + ndata_to_label=10, + dummy_param=1) + assert issubclass(w[-1].category, DeprecationWarning) + assert "ndata_to_label" in str(w[-1].message) + if __name__ == '__main__': pytest.main() diff --git a/tests/bayesian/common_test.py b/tests/bayesian/common_test.py new file mode 100644 index 00000000..0c25d801 --- /dev/null +++ b/tests/bayesian/common_test.py @@ -0,0 +1,64 @@ +import pytest +from torch import nn + +from baal.bayesian.common import replace_layers_in_module + + +@pytest.fixture +def a_model_deep(): + return nn.Sequential( + nn.Linear(32, 32), + nn.Sequential( + nn.Linear(32, 3), + nn.ReLU(), + nn.Linear(10, 3), + nn.ReLU(), + nn.Linear(3, 3) + )) + + +@pytest.fixture +def a_model(): + return nn.Sequential( + nn.Linear(32, 3), + nn.ReLU(), + nn.Linear(10, 3), + nn.ReLU(), + nn.Linear(3, 3) + ) + + +def test_replace_layers_in_module_swap_all_relu(a_model): + mapping = lambda mod: None if not isinstance(mod, nn.ReLU) else nn.Identity() + changed = replace_layers_in_module(a_model, mapping) + assert changed + assert not any(isinstance(m, nn.ReLU) for m in a_model.modules()) + assert any(isinstance(m, nn.Identity) for m in a_model.modules()) + + +def test_replace_layers_in_module_swap_all_relu_deep(a_model_deep): + mapping = lambda mod: None if not isinstance(mod, nn.ReLU) else nn.Identity() + changed = replace_layers_in_module(a_model_deep, mapping) + assert changed + assert not any(isinstance(m, nn.ReLU) for m in a_model_deep.modules()) + assert any(isinstance(m, nn.Identity) for m in a_model_deep.modules()) + + +def test_replace_layers_in_module_swap_no_relu_deep(a_model_deep): + mapping = lambda mod: None if not isinstance(mod, nn.ReLU6) else nn.Identity() + changed = replace_layers_in_module(a_model_deep, mapping) + assert not changed + assert any(isinstance(m, nn.ReLU) for m in a_model_deep.modules()) + assert not any(isinstance(m, nn.Identity) for m in a_model_deep.modules()) + +def test_replace_layers_in_module_swap_no_relu_deep(a_model): + mapping = lambda mod: None if not isinstance(mod, nn.ReLU6) else nn.Identity() + changed = replace_layers_in_module(a_model, mapping) + assert not changed + assert any(isinstance(m, nn.ReLU) for m in a_model.modules()) + assert not any(isinstance(m, nn.Identity) for m in a_model.modules()) + + + +if __name__ == '__main__': + pytest.main() diff --git a/tests/bayesian/consistent_dropout_test.py b/tests/bayesian/consistent_dropout_test.py index 024875e9..cf02eb16 100644 --- a/tests/bayesian/consistent_dropout_test.py +++ b/tests/bayesian/consistent_dropout_test.py @@ -96,6 +96,16 @@ def test_module_class_replaces_dropout_layers(a_model_with_dropout): for _ in range(10) ) + # Check that unpatch works + module = test_mc_module.unpatch() + module.eval() + with torch.no_grad(): + assert all( + torch.allclose(module(dummy_input), module(dummy_input)) + for _ in range(10) + ) + assert not any(isinstance(mod, baal.bayesian.consistent_dropout.ConsistentDropout) for mod in module.modules()) + @pytest.mark.parametrize("inplace", (True, False)) def test_patch_module_raise_warnings(inplace): diff --git a/tests/bayesian/dropconnect_test.py b/tests/bayesian/dropconnect_test.py index 2c0cb4b4..f2280083 100644 --- a/tests/bayesian/dropconnect_test.py +++ b/tests/bayesian/dropconnect_test.py @@ -3,7 +3,7 @@ import pytest import torch -from baal.bayesian.weight_drop import patch_module, WeightDropLinear +from baal.bayesian.weight_drop import patch_module, WeightDropLinear, MCDropoutConnectModule class SimpleModel(torch.nn.Module): @@ -70,7 +70,21 @@ def test_patch_module_replaces_all_dropout_layers(inplace): # objects should be the same if inplace is True and not otherwise: assert (mc_test_module is test_module) == inplace assert not any( - module.p != 0 for module in mc_test_module.modules() if isinstance(module, torch.nn.Dropout) + module.p != 0 for module in mc_test_module.modules() if isinstance(module, torch.nn.Dropout) + ) + assert any( + isinstance(module, WeightDropLinear) + for module in mc_test_module.modules() + ) + + +def test_mcdropconnect_replaces_all_dropout_layers_module(): + test_module = SimpleModel() + + mc_test_module = MCDropoutConnectModule(test_module, layers=['Conv2d', 'Linear', 'LSTM', 'GRU']) + + assert not any( + module.p != 0 for module in mc_test_module.modules() if isinstance(module, torch.nn.Dropout) ) assert any( isinstance(module, WeightDropLinear) diff --git a/tests/bayesian/dropout_test.py b/tests/bayesian/dropout_test.py index ff59d3b9..4b827456 100644 --- a/tests/bayesian/dropout_test.py +++ b/tests/bayesian/dropout_test.py @@ -106,5 +106,11 @@ def test_module_class_replaces_dropout_layers(a_model_with_dropout): ) + # Check that unpatch works + module = test_mc_module.unpatch() + assert not any(isinstance(mod, baal.bayesian.dropout.Dropout) for mod in module.modules()) + + + if __name__ == '__main__': pytest.main() From f6f933c1dabdaea812eee307212c7d4d06141652 Mon Sep 17 00:00:00 2001 From: "fr.branchaud-charron" Date: Fri, 4 Mar 2022 08:41:29 -0500 Subject: [PATCH 2/2] #189 remove -n 2 for xdist --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index e4f557a1..38e623e9 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ lint: check-mypy-error-count .PHONY: test test: lint - poetry run pytest tests --cov=baal -n 2 + poetry run pytest tests --cov=baal .PHONY: format format: