diff --git a/src/tensorlibrary/learning/active.py b/src/tensorlibrary/learning/active.py index 9a82993..c0999c7 100644 --- a/src/tensorlibrary/learning/active.py +++ b/src/tensorlibrary/learning/active.py @@ -82,27 +82,30 @@ def select_samples(self, **kwargs): n_samples_last_batch = self.n_samples - n_samples_per_batch * (n_batches - 1) indices = np.arange(0, total_samples) indices_select = [] + min_selected_flag = False for k in range(n_batches): idx_start = k * self.batch_size if k == n_batches - 1: # last batch cur_indices = indices[idx_start:] sel_indices = self.algorithm(self.data[idx_start:, :], self.outputs[idx_start:], - n_samples_per_batch, - break_at_pos=self.break_at_pos, labels=self.labels, - min_n_samples=self.min_n_samples, **kwargs) - indices_select.append(sel_indices) + n_samples_last_batch, + break_at_pos=self.break_at_pos, labels=self.labels[idx_start:], + min_n_samples=self.min_n_samples*(1-min_selected_flag), **kwargs) + else: idx_end = (k + 1) * self.batch_size cur_indices = indices[idx_start:idx_end] sel_indices = self.algorithm(self.data[idx_start:idx_end, :], self.outputs[idx_start:idx_end], n_samples_per_batch, - break_at_pos=self.break_at_pos, labels=self.labels, min_n_samples=self.min_n_samples, **kwargs) + break_at_pos=self.break_at_pos, labels=self.labels[idx_start:idx_end], + min_n_samples=self.min_n_samples*(1-min_selected_flag), **kwargs) indices_select.append(cur_indices[sel_indices]) - + if not min_selected_flag: + min_selected_flag = sum(len(arr) for arr in indices_select) > self.min_n_samples if self.break_at_pos: # check where indices are positive - if np.any(self.labels[indices_select[-1]] == 1) and len(indices_select[-1]) > self.min_n_samples: + if np.any(self.labels[indices_select[-1]] == 1) and min_selected_flag: break self.indices = tl.concatenate(indices_select) @@ -174,7 +177,7 @@ def uncertainty_strategy(outputs, n_samples=-1, thresh=-1, *, break_at_pos=False def combined_strategy(x_feat, outputs, max_samples, l=0.5, m=10, sim_measure='cos', feature_map='rbf', map_param=1.0, \ - approx=False, min_div_max=0., break_at_pos=False, labels=None, min_n_samples=50): + approx=False, max_min_sim=1., break_at_pos=False, labels=None, min_n_samples=50): """ Perform combined strategy for active learning. Args: @@ -186,7 +189,7 @@ def combined_strategy(x_feat, outputs, max_samples, l=0.5, m=10, sim_measure='co feature_map: feature map to use, default is rbf kernel: 'rbf' map_param: parameter for the feature map (or kernel function), default is 1.0 approx: whether to use the approximate feature map instead of kernel function - min_div_max: minimum value the maximum diversity measure, break if all values are below this threshold + max_min_sim: minimum value the maximum diversity measure, break if all values are below this threshold Returns: indices: indices of the most uncertain samples @@ -219,13 +222,13 @@ def combined_strategy(x_feat, outputs, max_samples, l=0.5, m=10, sim_measure='co # div[indices[k], k] = 0 # for max to work # take max over the columns o sim_max = tl.max(sim, axis=1) # results in N_feats x 1 - if tl.all(sim_max[~indices[k]] < min_div_max): - indices = indices[:k] + if tl.min(np.delete(sim_max, indices[:k+1])) > max_min_sim: + indices = indices[:k+1] break indices[k+1] = tl.argmin(l*outputs + (1-l)*sim_max) if break_at_pos: - if tl.any(labels[indices[k]] == 1): + if tl.any(labels[indices[:k+2]] == 1): breaktime = True if breaktime and k >= min_n_samples-1: indices = indices[:k + 2]