Skip to content

Commit

Permalink
debug the active learning algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
sderooij committed Oct 4, 2024
1 parent dbbede2 commit 6d87014
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions src/tensorlibrary/learning/active.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,30 @@ def select_samples(self, **kwargs):
n_samples_last_batch = self.n_samples - n_samples_per_batch * (n_batches - 1)
indices = np.arange(0, total_samples)
indices_select = []
min_selected_flag = False
for k in range(n_batches):
idx_start = k * self.batch_size
if k == n_batches - 1: # last batch
cur_indices = indices[idx_start:]
sel_indices = self.algorithm(self.data[idx_start:, :], self.outputs[idx_start:],
n_samples_per_batch,
break_at_pos=self.break_at_pos, labels=self.labels,
min_n_samples=self.min_n_samples, **kwargs)
indices_select.append(sel_indices)
n_samples_last_batch,
break_at_pos=self.break_at_pos, labels=self.labels[idx_start:],
min_n_samples=self.min_n_samples*(1-min_selected_flag), **kwargs)

else:
idx_end = (k + 1) * self.batch_size
cur_indices = indices[idx_start:idx_end]
sel_indices = self.algorithm(self.data[idx_start:idx_end, :], self.outputs[idx_start:idx_end],
n_samples_per_batch,
break_at_pos=self.break_at_pos, labels=self.labels, min_n_samples=self.min_n_samples, **kwargs)
break_at_pos=self.break_at_pos, labels=self.labels[idx_start:idx_end],
min_n_samples=self.min_n_samples*(1-min_selected_flag), **kwargs)

indices_select.append(cur_indices[sel_indices])

if not min_selected_flag:
min_selected_flag = sum(len(arr) for arr in indices_select) > self.min_n_samples
if self.break_at_pos:
# check where indices are positive
if np.any(self.labels[indices_select[-1]] == 1) and len(indices_select[-1]) > self.min_n_samples:
if np.any(self.labels[indices_select[-1]] == 1) and min_selected_flag:
break

self.indices = tl.concatenate(indices_select)
Expand Down Expand Up @@ -174,7 +177,7 @@ def uncertainty_strategy(outputs, n_samples=-1, thresh=-1, *, break_at_pos=False


def combined_strategy(x_feat, outputs, max_samples, l=0.5, m=10, sim_measure='cos', feature_map='rbf', map_param=1.0, \
approx=False, min_div_max=0., break_at_pos=False, labels=None, min_n_samples=50):
approx=False, max_min_sim=1., break_at_pos=False, labels=None, min_n_samples=50):
"""
Perform combined strategy for active learning.
Args:
Expand All @@ -186,7 +189,7 @@ def combined_strategy(x_feat, outputs, max_samples, l=0.5, m=10, sim_measure='co
feature_map: feature map to use, default is rbf kernel: 'rbf'
map_param: parameter for the feature map (or kernel function), default is 1.0
approx: whether to use the approximate feature map instead of kernel function
min_div_max: minimum value the maximum diversity measure, break if all values are below this threshold
max_min_sim: minimum value the maximum diversity measure, break if all values are below this threshold
Returns:
indices: indices of the most uncertain samples
Expand Down Expand Up @@ -219,13 +222,13 @@ def combined_strategy(x_feat, outputs, max_samples, l=0.5, m=10, sim_measure='co
# div[indices[k], k] = 0 # for max to work
# take max over the columns o
sim_max = tl.max(sim, axis=1) # results in N_feats x 1
if tl.all(sim_max[~indices[k]] < min_div_max):
indices = indices[:k]
if tl.min(np.delete(sim_max, indices[:k+1])) > max_min_sim:
indices = indices[:k+1]
break

indices[k+1] = tl.argmin(l*outputs + (1-l)*sim_max)
if break_at_pos:
if tl.any(labels[indices[k]] == 1):
if tl.any(labels[indices[:k+2]] == 1):
breaktime = True
if breaktime and k >= min_n_samples-1:
indices = indices[:k + 2]
Expand Down

0 comments on commit 6d87014

Please sign in to comment.