diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index 6619318f1..945b808b2 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -348,7 +348,7 @@ def import_from_string(dotted_path): raise ImportError(msg) -def community_detection(embeddings, threshold=0.75, min_community_size=10, batch_size=1024): +def community_detection(embeddings, threshold=0.75, min_community_size=10, batch_size=1024, show_progress_bar=False): """ Function for Fast Community Detection Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold). @@ -366,7 +366,7 @@ def community_detection(embeddings, threshold=0.75, min_community_size=10, batch min_community_size = min(min_community_size, len(embeddings)) sort_max_size = min(max(2 * min_community_size, 50), len(embeddings)) - for start_idx in range(0, len(embeddings), batch_size): + for start_idx in tqdm(range(0, len(embeddings), batch_size), desc="Finding clusters", disable=not show_progress_bar): # Compute cosine similarity scores cos_scores = cos_sim(embeddings[start_idx:start_idx + batch_size], embeddings)