From ff7813f3495c9d676f2633b38af1945f13d721fd Mon Sep 17 00:00:00 2001 From: abthuy Date: Tue, 12 Sep 2023 19:59:54 +0200 Subject: [PATCH 1/3] fix: no pool preds for random heuristic --- baal/active/active_loop.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/baal/active/active_loop.py b/baal/active/active_loop.py index f07db47d..277b7f7c 100644 --- a/baal/active/active_loop.py +++ b/baal/active/active_loop.py @@ -79,7 +79,10 @@ def step(self, pool=None) -> bool: indices = None if len(pool) > 0: - probs = self.get_probabilities(pool, **self.kwargs) + if self.heuristic.__class__.__name__ == "Random": + probs = np.random.uniform(low=0, high=1, size=(len(pool), 1)) + else: + probs = self.get_probabilities(pool, **self.kwargs) if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0): to_label, uncertainty = self.heuristic.get_ranks(probs) if indices is not None: From cf35bc79b34490e2afd5d04136875db9fdd7fd4b Mon Sep 17 00:00:00 2001 From: abthuy Date: Wed, 13 Sep 2023 08:47:00 +0200 Subject: [PATCH 2/3] use isinstance() instead of __name__ attribute --- baal/active/active_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baal/active/active_loop.py b/baal/active/active_loop.py index 277b7f7c..cc59dcf2 100644 --- a/baal/active/active_loop.py +++ b/baal/active/active_loop.py @@ -79,7 +79,7 @@ def step(self, pool=None) -> bool: indices = None if len(pool) > 0: - if self.heuristic.__class__.__name__ == "Random": + if isinstance(self.heuristic, heuristics.Random): probs = np.random.uniform(low=0, high=1, size=(len(pool), 1)) else: probs = self.get_probabilities(pool, **self.kwargs) From f71312b7e39244e707554326784e21d1606b6310 Mon Sep 17 00:00:00 2001 From: abthuy Date: Wed, 13 Sep 2023 10:56:44 +0200 Subject: [PATCH 3/3] test: add test for get_probabilities with Random --- tests/active/active_loop_test.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/active/active_loop_test.py b/tests/active/active_loop_test.py index 527e030c..3a652339 100644 --- a/tests/active/active_loop_test.py +++ b/tests/active/active_loop_test.py @@ -1,6 +1,7 @@ import os import pickle import warnings +from unittest.mock import patch import numpy as np import pytest @@ -140,5 +141,24 @@ def test_deprecation(): assert issubclass(w[-1].category, DeprecationWarning) assert "ndata_to_label" in str(w[-1].message) + +@pytest.mark.parametrize('heur,num_get_probs', [(heuristics.Random(), 0), + (heuristics.BALD(), 1), + (heuristics.Entropy(), 1), + (heuristics.Variance(reduction='sum'), 1) + ]) +def test_get_probs(heur, num_get_probs): + dataset = ActiveLearningDataset(MyDataset(), make_unlabelled=lambda x: -1) + active_loop = ActiveLearningLoop(dataset, + get_probs_iter, + heur, + query_size=5, + dummy_param=1) + dataset.label_randomly(10) + with patch.object(active_loop, "get_probabilities") as mock_probs: + active_loop.step() + assert mock_probs.call_count == num_get_probs + + if __name__ == '__main__': pytest.main()