Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: issue with exact duplicates not being returned #35

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions semhash/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ def __init__(self, vectors: np.ndarray, items: list[DictItem], backend: Abstract
self.backend = backend
self.vectors = vectors

def items_as_sequence(self) -> list[dict[str, str]]:
"""Return all items as a single sequence."""
return [item[0] for item in self.items]

@classmethod
def from_vectors_and_items(cls, vectors: np.ndarray, items: list[DictItem], backend_type: Backend) -> Index:
"""
Expand Down
5 changes: 5 additions & 0 deletions semhash/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ def map_deduplication_result_to_strings(result: DeduplicationResult, columns: Se
)
)
return DeduplicationResult(deduplicated=deduplicated_str, duplicates=mapped, threshold=result.threshold)


def add_scores_to_records(records: list[dict[str, str]]) -> list[tuple[dict[str, str], float]]:
"""Add scores to records and return a DeduplicationResult."""
return [(record, 1.0) for record in records]
47 changes: 33 additions & 14 deletions semhash/semhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from semhash.datamodels import DeduplicationResult, DuplicateRecord, Record
from semhash.index import Index
from semhash.records import map_deduplication_result_to_strings, to_frozendict
from semhash.records import add_scores_to_records, map_deduplication_result_to_strings, to_frozendict
from semhash.utils import Encoder


Expand Down Expand Up @@ -57,8 +57,8 @@ def _remove_exact_duplicates(
cls,
records: Sequence[dict[str, str]],
columns: Sequence[str],
reference_records: list[dict[str, str]] | None = None,
) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
reference_records: list[list[dict[str, str]]] | None = None,
) -> tuple[list[dict[str, str]], list[tuple[dict[str, str], list[dict[str, str]]]]]:
"""
Remove exact duplicates based on the unpacked string representation of each record.

Expand All @@ -74,18 +74,22 @@ def _remove_exact_duplicates(

column_set = set(columns)
# Build a seen set from reference_records if provided
seen = {to_frozendict(x, column_set) for x in reference_records} if reference_records else set()
seen: defaultdict[frozendict[str, str], list[dict[str, str]]] = defaultdict(list)
if reference_records is not None:
for record_set in reference_records:
key = to_frozendict(record_set[0], column_set)
seen[key] = list(record_set)
in_one_set = reference_records is None

for record in records:
frozen_record = frozendict({k: v for k, v in record.items() if k in column_set})
if frozen_record not in seen:
if duplicated_records := seen.get(frozen_record):
duplicates.append((record, duplicated_records))
else:
deduplicated.append(record)
# Only add current documents to seen if no reference set is used
if in_one_set:
seen.add(frozen_record)
else:
duplicates.append(record)
seen[frozen_record].append(record)

return deduplicated, duplicates

Expand Down Expand Up @@ -128,9 +132,10 @@ def from_records(
# Remove exact duplicates
deduplicated_records, duplicates = cls._remove_exact_duplicates(dict_records, columns)

col_set = set(columns)
duplicate_map = defaultdict(list)
for x in duplicates:
frozen_record = to_frozendict(x, set(columns))
for x, _ in duplicates:
frozen_record = to_frozendict(x, col_set)
duplicate_map[frozen_record].append(x)

items: list[list[dict[str, str]]] = []
Expand Down Expand Up @@ -180,9 +185,13 @@ def deduplicate(

# Remove exact duplicates before embedding
dict_records, exact_duplicates = self._remove_exact_duplicates(
records=dict_records, columns=self.columns, reference_records=self.index.items_as_sequence()
records=dict_records, columns=self.columns, reference_records=self.index.items
)
duplicate_records = [DuplicateRecord(record=record, duplicates=[], exact=True) for record in exact_duplicates]
duplicate_records = []
for record, duplicates in exact_duplicates:
duplicated_with_score = add_scores_to_records(duplicates)
duplicate_record = DuplicateRecord(record=record, duplicates=duplicated_with_score, exact=True)
duplicate_records.append(duplicate_record)

# If no records are left after removing exact duplicates, return early
if not dict_records:
Expand Down Expand Up @@ -240,8 +249,18 @@ def self_deduplicate(
# So if the an item has more than one record, it is an exact duplicate
# Crucially, we should count each instance separately.
record, *duplicates = item
for record in duplicates:
duplicate_records.append(DuplicateRecord(record=record, duplicates=[], exact=True))
# We need to compare all duplicates to all _items_.
# The first item in a list of duplicate is not duplicated, because otherwise
# we would remove the whole cluster. But it is a duplicate for the other items.

# Iterate from index 1.
for index, curr_record in enumerate(duplicates, 1):
# The use of indexing is intentional here, we want to check if the object is the same
# not if they have the same values. If we did != or is we would probably ignore lots
# of items.
items_to_keep = item[:index] + item[index + 1 :]
items_with_score = add_scores_to_records(items_to_keep)
duplicate_records.append(DuplicateRecord(record=curr_record, duplicates=items_with_score, exact=True))

# If we don't see any similar_items, we know the record is not a duplicate.
# in rare cases, the item itself might not be a duplicate of itself.
Expand Down
2 changes: 1 addition & 1 deletion semhash/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ def encode(
:param **kwargs: Additional keyword arguments.
:return: The embeddings of the sentences.
"""
...
... # pragma: no cover
6 changes: 3 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading