-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Co-authored-by: Parmida Atighehchian <[email protected]>
- Loading branch information
1 parent
8c975ea
commit ea4c31d
Showing
5 changed files
with
236 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import types | ||
|
||
import numpy as np | ||
import structlog | ||
from scipy.special import softmax | ||
from scipy.stats import rankdata | ||
|
||
from baal.active.heuristics import AbstractHeuristic, Sequence | ||
|
||
log = structlog.get_logger(__name__) | ||
EPSILON = 1e-8 | ||
|
||
|
||
class StochasticHeuristic(AbstractHeuristic): | ||
def __init__(self, base_heuristic: AbstractHeuristic, query_size): | ||
"""Heuristic that is stochastic to improve diversity. | ||
Common acquisition functions are heavily impacted by duplicates. | ||
When using a `top-k` approache where the most | ||
uncertain examples are selected, the acquisition function can select many duplicates. | ||
Techniques such as BADGE (Ash et al, 2019) or BatchBALD (Kirsh et al. 2019) | ||
are common solutions to this problem, but they are quite expensive. | ||
Stochastic acquisitions are cheap to compute and get similar performances. | ||
References: | ||
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022) | ||
https://arxiv.org/abs/2106.12059 | ||
Args: | ||
base_heuristic: Heuristic to get uncertainty from before sampling. | ||
query_size: These heuristics will return `query_size` items. | ||
""" | ||
# TODO handle reverse | ||
super().__init__(reverse=False) | ||
self._bh = base_heuristic | ||
self.query_size = query_size | ||
|
||
def get_ranks(self, predictions): | ||
# Get the raw uncertainty from the base heuristic. | ||
scores = self.get_scores(predictions) | ||
# Create the distribution to sample from. | ||
distributions = self._make_distribution(scores) | ||
# Force normalization for np.random.choice | ||
distributions = np.clip(distributions, 0) | ||
distributions /= distributions.sum() | ||
|
||
# TODO Seed? | ||
if (distributions > 0).sum() < self.query_size: | ||
log.warnings("Not enough values, return random") | ||
distributions = np.ones_like(distributions) / len(distributions) | ||
return ( | ||
np.random.choice(len(distributions), self.query_size, replace=False, p=distributions), | ||
distributions, | ||
) | ||
|
||
def get_scores(self, predictions): | ||
if isinstance(predictions, types.GeneratorType): | ||
scores = self._bh.get_uncertainties_generator(predictions) | ||
else: | ||
scores = self._bh.get_uncertainties(predictions) | ||
if isinstance(scores, Sequence): | ||
scores = np.concatenate(scores) | ||
return scores | ||
|
||
def _make_distribution(self, scores: np.ndarray) -> np.ndarray: | ||
raise NotImplementedError | ||
|
||
|
||
class PowerSampling(StochasticHeuristic): | ||
def __init__(self, base_heuristic: AbstractHeuristic, query_size, temperature=1.0): | ||
"""Samples from the uncertainty distribution without modification beside | ||
temperature scaling and normalization. | ||
Stochastic heuristic that assumes that the uncertainty distribution | ||
is positive and that items with near-zero uncertainty are uninformative. | ||
Empirically worked the best in the paper. | ||
References: | ||
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022) | ||
https://arxiv.org/abs/2106.12059 | ||
Args: | ||
base_heuristic: Heuristic to get uncertainty from before sampling. | ||
query_size: These heuristics will return `query_size` items. | ||
temperature: Value to temper the uncertainty distribution before sampling. | ||
""" | ||
super().__init__(base_heuristic=base_heuristic, query_size=query_size) | ||
self.temperature = temperature | ||
|
||
def _make_distribution(self, scores: np.ndarray) -> np.ndarray: | ||
scores = scores ** (1 / self.temperature) | ||
scores = scores / scores.sum() | ||
return scores | ||
|
||
|
||
class GibbsSampling(StochasticHeuristic): | ||
def __init__(self, base_heuristic: AbstractHeuristic, query_size, temperature=1.0): | ||
"""Samples from the uncertainty distribution after applying softmax. | ||
References: | ||
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022) | ||
https://arxiv.org/abs/2106.12059 | ||
Args: | ||
base_heuristic: Heuristic to get uncertainty from before sampling. | ||
query_size: These heuristics will return `query_size` items. | ||
temperature: Value to temper the uncertainty distribution before sampling. | ||
""" | ||
super().__init__(base_heuristic=base_heuristic, query_size=query_size) | ||
self.temperature = temperature | ||
|
||
def _make_distribution(self, scores: np.ndarray) -> np.ndarray: | ||
scores /= self.temperature | ||
# scores dimensions is [N] | ||
scores = softmax(scores) | ||
return scores | ||
|
||
|
||
class RankBasedSampling(StochasticHeuristic): | ||
def __init__(self, base_heuristic: AbstractHeuristic, query_size, temperature=1.0): | ||
"""Samples from the ranks of the uncertainty distribution. | ||
References: | ||
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022) | ||
https://arxiv.org/abs/2106.12059 | ||
Args: | ||
base_heuristic: Heuristic to get uncertainty from before sampling. | ||
query_size: These heuristics will return `query_size` items. | ||
temperature: Value to temper the uncertainty distribution before sampling. | ||
""" | ||
super().__init__(base_heuristic=base_heuristic, query_size=query_size) | ||
self.temperature = temperature | ||
|
||
def _make_distribution(self, scores: np.ndarray) -> np.ndarray: | ||
rank = rankdata(-scores) | ||
weights = rank ** (-1 / self.temperature) | ||
normalized_weights: np.ndarray = weights / weights.sum() | ||
return normalized_weights |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import numpy as np | ||
import pytest | ||
from scipy.stats import entropy | ||
|
||
from baal.active.heuristics import BALD, Entropy | ||
from baal.active.heuristics.stochastics import GibbsSampling, RankBasedSampling, PowerSampling | ||
|
||
NUM_CLASSES = 10 | ||
NUM_ITERATIONS = 20 | ||
BATCH_SIZE = 32 | ||
|
||
|
||
@pytest.fixture | ||
def sampled_predictions(): | ||
predictions = np.stack( | ||
[np.histogram(np.random.rand(5), bins=np.linspace(-.5, .5, NUM_CLASSES + 1))[0] for _ in | ||
range(BATCH_SIZE * NUM_ITERATIONS)]).reshape( | ||
[BATCH_SIZE, NUM_ITERATIONS, NUM_CLASSES]) | ||
return np.rollaxis(predictions, -1, 1) | ||
|
||
|
||
@pytest.mark.parametrize("stochastic_heuristic", [GibbsSampling, RankBasedSampling, PowerSampling]) | ||
@pytest.mark.parametrize("base_heuristic", [BALD, Entropy]) | ||
def test_stochastic_heuristic(stochastic_heuristic, base_heuristic, sampled_predictions): | ||
heur_temp_1 = stochastic_heuristic(base_heuristic(), query_size=100, temperature=1.0) | ||
heur_temp_10 = stochastic_heuristic(base_heuristic(), query_size=100, temperature=10.0) | ||
heur_temp_05 = stochastic_heuristic(base_heuristic(), query_size=100, temperature=0.01) | ||
|
||
scores = heur_temp_1.get_scores(sampled_predictions) | ||
|
||
dist_temp_1, dist_temp_10, dist_temp_05 = (heur_temp_1._make_distribution(scores), | ||
heur_temp_10._make_distribution(scores), | ||
heur_temp_05._make_distribution(scores)) | ||
|
||
assert entropy(dist_temp_1) < entropy(dist_temp_10) | ||
# NOTE: it is possible that this fails, as temp_1 can already have minimal entropy. This is unlikely. | ||
assert entropy(dist_temp_1) > entropy(dist_temp_05) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
|
||
N_ITERATIONS = 50 | ||
IMG_SIZE = 3 | ||
|
||
|
||
def make_3d_fake_dist(means, stds, dims=10): | ||
d = np.stack( | ||
[make_fake_dist(means, stds, dims=dims) for _ in range(N_ITERATIONS)] | ||
) # 50 iterations | ||
d = np.rollaxis(d, 0, 3) | ||
# [n_sample, n_class, n_iter] | ||
return d | ||
|
||
|
||
def make_5d_fake_dist(means, stds, dims=10): | ||
d = np.stack( | ||
[make_3d_fake_dist(means, stds, dims=dims) for _ in range(IMG_SIZE ** 2)], -1 | ||
) # 3x3 image | ||
b, c, i, hw = d.shape | ||
d = np.reshape(d, [b, c, i, IMG_SIZE, IMG_SIZE]) | ||
d = np.rollaxis(d, 2, 5) | ||
# [n_sample, n_class, H, W, iter] | ||
return d | ||
|
||
|
||
def make_fake_dist(means, stds, dims=10): | ||
""" | ||
Create some fake discrete distributions | ||
Args: | ||
means: List of means | ||
stds: List of standard deviations | ||
dims: Dimensions of the distributions | ||
Returns: | ||
List of distributions | ||
""" | ||
n_trials = 100 | ||
distributions = [] | ||
for m, std in zip(means, stds): | ||
dist = np.zeros([dims]) | ||
for i in range(n_trials): | ||
dist[ | ||
np.round(np.clip(np.random.normal(m, std, 1), 0, dims - 1)).astype(int).item() | ||
] += 1 | ||
distributions.append(dist / n_trials) | ||
return np.array(distributions) |