diff --git a/argilla/src/argilla/_api/_records.py b/argilla/src/argilla/_api/_records.py index 92b42a1617..4ed5bbeed4 100644 --- a/argilla/src/argilla/_api/_records.py +++ b/argilla/src/argilla/_api/_records.py @@ -30,6 +30,7 @@ class RecordsAPI(ResourceAPI[RecordModel]): MAX_RECORDS_PER_CREATE_BULK = 500 MAX_RECORDS_PER_UPSERT_BULK = 500 + MAX_RECORDS_PER_DELETE_BULK = 100 http_client: httpx.Client diff --git a/argilla/src/argilla/records/_dataset_records.py b/argilla/src/argilla/records/_dataset_records.py index b914b98fcc..d91a0459c1 100644 --- a/argilla/src/argilla/records/_dataset_records.py +++ b/argilla/src/argilla/records/_dataset_records.py @@ -138,6 +138,7 @@ class DatasetRecords(Iterable[Record], LoggingMixin): _api: RecordsAPI DEFAULT_BATCH_SIZE = 256 + DEFAULT_DELETE_BATCH_SIZE = 64 def __init__(self, client: "Argilla", dataset: "Dataset"): """Initializes a DatasetRecords object with a client and a dataset. @@ -258,12 +259,14 @@ def log( def delete( self, records: List[Record], + batch_size: int = DEFAULT_DELETE_BATCH_SIZE, ) -> List[Record]: """Delete records in a dataset on the server using the provided records and matching based on the id. Parameters: records: A list of `Record` objects representing the records to be deleted. + batch_size: The number of records to send in each batch. The default is 64. Returns: A list of Record objects representing the deleted records. @@ -271,17 +274,31 @@ def delete( """ mapping = None user_id = self.__client.me.id - record_models = self._ingest_records(records=records, mapping=mapping, user_id=user_id) + batch_size = self._normalize_batch_size( + batch_size=batch_size, + records_length=len(record_models), + max_value=self._api.MAX_RECORDS_PER_DELETE_BULK, + ) - self._api.delete_many(dataset_id=self.__dataset.id, records=record_models) + records_deleted = 0 + for batch in tqdm( + iterable=range(0, len(records), batch_size), + desc="Sending records...", + total=len(records) // batch_size, + unit="batch", + ): + self._log_message(message=f"Sending records from {batch} to {batch + batch_size}.") + batch_records = record_models[batch : batch + batch_size] + self._api.delete_many(dataset_id=self.__dataset.id, records=batch_records) + records_deleted += len(batch_records) self._log_message( message=f"Deleted {len(record_models)} records from dataset {self.__dataset.name}", level="info", ) - return record_models + return records def to_dict(self, flatten: bool = False, orient: str = "names") -> Dict[str, Any]: """ diff --git a/argilla/tests/integration/test_delete_records.py b/argilla/tests/integration/test_delete_records.py index 768c67e53c..97b4bb1805 100644 --- a/argilla/tests/integration/test_delete_records.py +++ b/argilla/tests/integration/test_delete_records.py @@ -13,6 +13,7 @@ # limitations under the License. import uuid + import pytest import argilla as rg @@ -100,3 +101,13 @@ def test_delete_single_record(client: rg.Argilla, dataset: rg.Dataset): assert dataset_records[0].id == str(mock_data[0]["id"]) assert dataset_records[1].id == str(mock_data[2]["id"]) assert mock_data[1]["id"] not in [record.id for record in dataset_records] + + +def test_delete_records_with_batch_support(client: rg.Argilla, dataset: rg.Dataset): + records = [rg.Record(id=uuid.uuid4(), fields={"text": f"Field for record {i}"}) for i in range(0, 1000)] + + dataset.records.log(records) + all_records = list(dataset.records) + dataset.records.delete(all_records) + + assert len(list(dataset.records)) == 0