From dbcb4a8d737684a8888af37d2dc1bd63f93cd647 Mon Sep 17 00:00:00 2001 From: Arpan Biswas Date: Mon, 6 Mar 2023 18:26:54 +0530 Subject: [PATCH] optimized community detection - applied binary search to get the new cluster --- sentence_transformers/util.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index 6361ec9a5..d2e7ebecf 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -381,11 +381,22 @@ def community_detection(embeddings, threshold=0.75, min_community_size=10, batch sort_max_size = min(2 * sort_max_size, len(embeddings)) top_val_large, top_idx_large = cos_scores[i].topk(k=sort_max_size, largest=True) - for idx, val in zip(top_idx_large.tolist(), top_val_large): - if val < threshold: + # Binary search for the index of the last element greater than threshold + low, high = 0, len(top_val_large) - 1 + while low <= high: + mid = (low + high) // 2 + + # If we reached the end of the list or the next element is smaller than or equal to threshold + if top_val_large[mid] > threshold and ((mid+1) >= len(top_val_large)-1 or (top_val_large[mid+1] <= threshold)): break - - new_cluster.append(idx) + elif top_val_large[mid] > threshold: # Next element is bigger than threshold + low = mid + 1 + else: # Next element is smaller than threshold + high = mid - 1 + + # Last element bigger than threshold + threshold_idx = mid + new_cluster = top_idx_large.tolist()[:threshold_idx+1] extracted_communities.append(new_cluster)