Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized Community Detection using Binary Search #1857

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,22 @@ def community_detection(embeddings, threshold=0.75, min_community_size=10, batch
sort_max_size = min(2 * sort_max_size, len(embeddings))
top_val_large, top_idx_large = cos_scores[i].topk(k=sort_max_size, largest=True)

for idx, val in zip(top_idx_large.tolist(), top_val_large):
if val < threshold:
# Binary search for the index of the last element greater than threshold
low, high = 0, len(top_val_large) - 1
while low <= high:
mid = (low + high) // 2

# If we reached the end of the list or the next element is smaller than or equal to threshold
if top_val_large[mid] > threshold and ((mid+1) >= len(top_val_large)-1 or (top_val_large[mid+1] <= threshold)):
break

new_cluster.append(idx)
elif top_val_large[mid] > threshold: # Next element is bigger than threshold
low = mid + 1
else: # Next element is smaller than threshold
high = mid - 1

# Last element bigger than threshold
threshold_idx = mid
new_cluster = top_idx_large.tolist()[:threshold_idx+1]

extracted_communities.append(new_cluster)

Expand Down