Skip to content

Commit

Permalink
[ENHANCEMENT - BUGFIX] argilla-sdk: Add batch support deleting reco…
Browse files Browse the repository at this point in the history
…rds (#5179)

# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR adds batch support when deleting records. Otherwise, users can
experience errors when deleting a large number of records at once.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)
- Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jul 9, 2024
1 parent c6b2065 commit fd0b012
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
1 change: 1 addition & 0 deletions argilla/src/argilla/_api/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class RecordsAPI(ResourceAPI[RecordModel]):

MAX_RECORDS_PER_CREATE_BULK = 500
MAX_RECORDS_PER_UPSERT_BULK = 500
MAX_RECORDS_PER_DELETE_BULK = 100

http_client: httpx.Client

Expand Down
23 changes: 20 additions & 3 deletions argilla/src/argilla/records/_dataset_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class DatasetRecords(Iterable[Record], LoggingMixin):
_api: RecordsAPI

DEFAULT_BATCH_SIZE = 256
DEFAULT_DELETE_BATCH_SIZE = 64

def __init__(self, client: "Argilla", dataset: "Dataset"):
"""Initializes a DatasetRecords object with a client and a dataset.
Expand Down Expand Up @@ -258,30 +259,46 @@ def log(
def delete(
self,
records: List[Record],
batch_size: int = DEFAULT_DELETE_BATCH_SIZE,
) -> List[Record]:
"""Delete records in a dataset on the server using the provided records
and matching based on the id.
Parameters:
records: A list of `Record` objects representing the records to be deleted.
batch_size: The number of records to send in each batch. The default is 64.
Returns:
A list of Record objects representing the deleted records.
"""
mapping = None
user_id = self.__client.me.id

record_models = self._ingest_records(records=records, mapping=mapping, user_id=user_id)
batch_size = self._normalize_batch_size(
batch_size=batch_size,
records_length=len(record_models),
max_value=self._api.MAX_RECORDS_PER_DELETE_BULK,
)

self._api.delete_many(dataset_id=self.__dataset.id, records=record_models)
records_deleted = 0
for batch in tqdm(
iterable=range(0, len(records), batch_size),
desc="Sending records...",
total=len(records) // batch_size,
unit="batch",
):
self._log_message(message=f"Sending records from {batch} to {batch + batch_size}.")
batch_records = record_models[batch : batch + batch_size]
self._api.delete_many(dataset_id=self.__dataset.id, records=batch_records)
records_deleted += len(batch_records)

self._log_message(
message=f"Deleted {len(record_models)} records from dataset {self.__dataset.name}",
level="info",
)

return record_models
return records

def to_dict(self, flatten: bool = False, orient: str = "names") -> Dict[str, Any]:
"""
Expand Down
11 changes: 11 additions & 0 deletions argilla/tests/integration/test_delete_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import uuid

import pytest

import argilla as rg
Expand Down Expand Up @@ -100,3 +101,13 @@ def test_delete_single_record(client: rg.Argilla, dataset: rg.Dataset):
assert dataset_records[0].id == str(mock_data[0]["id"])
assert dataset_records[1].id == str(mock_data[2]["id"])
assert mock_data[1]["id"] not in [record.id for record in dataset_records]


def test_delete_records_with_batch_support(client: rg.Argilla, dataset: rg.Dataset):
records = [rg.Record(id=uuid.uuid4(), fields={"text": f"Field for record {i}"}) for i in range(0, 1000)]

dataset.records.log(records)
all_records = list(dataset.records)
dataset.records.delete(all_records)

assert len(list(dataset.records)) == 0

0 comments on commit fd0b012

Please sign in to comment.