diff --git a/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts b/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts index e0e30adfd3..40ce2645eb 100644 --- a/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts +++ b/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts @@ -42,10 +42,8 @@ export class RecordRepository { constructor(private readonly axios: NuxtAxiosInstance) {} getRecords(criteria: RecordCriteria): Promise { - if (criteria.isFilteringByAdvanceSearch) - return this.getRecordsByAdvanceSearch(criteria); - - return this.getRecordsByDatasetId(criteria); + return this.getRecordsByAdvanceSearch(criteria); + // return this.getRecordsByDatasetId(criteria); } async getRecord(recordId: string): Promise { @@ -264,6 +262,30 @@ export class RecordRepository { }; } + body.filters = { + and: [ + { + type: "terms", + scope: { + entity: "response", + property: "status", + }, + values: [status], + }, + ], + }; + + if (status === "pending") { + body.filters.and.push({ + type: "terms", + scope: { + entity: "record", + property: "status", + }, + values: ["pending"], + }); + } + if ( isFilteringByMetadata || isFilteringByResponse || diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index de84587e41..827037a2c3 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -16,12 +16,17 @@ These are the section headers that we use: ## [Unreleased]() -## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) +### Added + +- Added support to specify `distribution` attribute when creating a dataset. ([#5013](https://github.com/argilla-io/argilla/pull/5013)) +- Added support to change `distribution` attribute when updating a dataset. ([#5028](https://github.com/argilla-io/argilla/pull/5028)) ### Changed - Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) +## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) + ### Removed - Removed all API v0 endpoints. ([#4852](https://github.com/argilla-io/argilla/pull/4852)) diff --git a/argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py b/argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py new file mode 100644 index 0000000000..767b277573 --- /dev/null +++ b/argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py @@ -0,0 +1,60 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""add status column to records table + +Revision ID: 237f7c674d74 +Revises: 45a12f74448b +Create Date: 2024-06-18 17:59:36.992165 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "237f7c674d74" +down_revision = "45a12f74448b" +branch_labels = None +depends_on = None + + +record_status_enum = sa.Enum("pending", "completed", name="record_status_enum") + + +def upgrade() -> None: + record_status_enum.create(op.get_bind()) + + op.add_column("records", sa.Column("status", record_status_enum, server_default="pending", nullable=False)) + op.create_index(op.f("ix_records_status"), "records", ["status"], unique=False) + + # NOTE: Updating existent records to have "completed" status when they have + # at least one response with "submitted" status. + op.execute(""" + UPDATE records + SET status = 'completed' + WHERE id IN ( + SELECT DISTINCT record_id + FROM responses + WHERE status = 'submitted' + ); + """) + + +def downgrade() -> None: + op.drop_index(op.f("ix_records_status"), table_name="records") + op.drop_column("records", "status") + + record_status_enum.drop(op.get_bind()) diff --git a/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py b/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py similarity index 82% rename from argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py rename to argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py index 7ac80ad895..b5949f5364 100644 --- a/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py +++ b/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""add record metadata column +"""add metadata column to records table Revision ID: 3ff6484f8b37 Revises: ae5522b4c674 @@ -31,12 +31,8 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.add_column("records", sa.Column("metadata", sa.JSON(), nullable=True)) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.drop_column("records", "metadata") - # ### end Alembic commands ### diff --git a/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py b/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py new file mode 100644 index 0000000000..791da07439 --- /dev/null +++ b/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py @@ -0,0 +1,45 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""add distribution column to datasets table + +Revision ID: 45a12f74448b +Revises: d00f819ccc67 +Create Date: 2024-06-13 11:23:43.395093 + +""" + +import json + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "45a12f74448b" +down_revision = "d00f819ccc67" +branch_labels = None +depends_on = None + +DISTRIBUTION_VALUE = json.dumps({"strategy": "overlap", "min_submitted": 1}) + + +def upgrade() -> None: + op.add_column("datasets", sa.Column("distribution", sa.JSON(), nullable=True)) + op.execute(f"UPDATE datasets SET distribution = '{DISTRIBUTION_VALUE}'") + with op.batch_alter_table("datasets") as batch_op: + batch_op.alter_column("distribution", nullable=False) + + +def downgrade() -> None: + op.drop_column("datasets", "distribution") diff --git a/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py b/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py similarity index 81% rename from argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py rename to argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py index 8b23340448..f8fa87536e 100644 --- a/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py +++ b/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""add allow_extra_metadata column to dataset table +"""add allow_extra_metadata column to datasets table Revision ID: b8458008b60e Revises: 7cbcccf8b57a @@ -31,14 +31,10 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.add_column( "datasets", sa.Column("allow_extra_metadata", sa.Boolean(), server_default=sa.text("true"), nullable=False) ) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.drop_column("datasets", "allow_extra_metadata") - # ### end Alembic commands ### diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py index 63f95391e1..0590b41bb4 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py @@ -189,7 +189,7 @@ async def create_dataset( ): await authorize(current_user, DatasetPolicy.create(dataset_create.workspace_id)) - return await datasets.create_dataset(db, dataset_create) + return await datasets.create_dataset(db, dataset_create.dict()) @router.post("/datasets/{dataset_id}/fields", status_code=status.HTTP_201_CREATED, response_model=Field) @@ -302,4 +302,4 @@ async def update_dataset( await authorize(current_user, DatasetPolicy.update(dataset)) - return await datasets.update_dataset(db, dataset, dataset_update) + return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True)) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/responses.py b/argilla-server/src/argilla_server/api/handlers/v1/responses.py index 56cb695c95..ddc389563a 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/responses.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/responses.py @@ -64,7 +64,9 @@ async def update_response( response = await Response.get_or_raise( db, response_id, - options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)], + options=[ + selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions), + ], ) await authorize(current_user, ResponsePolicy.update(response)) @@ -83,7 +85,9 @@ async def delete_response( response = await Response.get_or_raise( db, response_id, - options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)], + options=[ + selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions), + ], ) await authorize(current_user, ResponsePolicy.delete(response)) diff --git a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py index 5cac33bdb7..1e1b69d836 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py @@ -13,11 +13,11 @@ # limitations under the License. from datetime import datetime -from typing import List, Optional +from typing import List, Literal, Optional, Union from uuid import UUID from argilla_server.api.schemas.v1.commons import UpdateSchema -from argilla_server.enums import DatasetStatus +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus from argilla_server.pydantic_v1 import BaseModel, Field, constr try: @@ -44,6 +44,32 @@ ] +class DatasetOverlapDistribution(BaseModel): + strategy: Literal[DatasetDistributionStrategy.overlap] + min_submitted: int + + +DatasetDistribution = DatasetOverlapDistribution + + +class DatasetOverlapDistributionCreate(BaseModel): + strategy: Literal[DatasetDistributionStrategy.overlap] + min_submitted: int = Field( + ge=1, + description="Minimum number of submitted responses to consider a record as completed", + ) + + +DatasetDistributionCreate = DatasetOverlapDistributionCreate + + +class DatasetOverlapDistributionUpdate(DatasetDistributionCreate): + pass + + +DatasetDistributionUpdate = DatasetOverlapDistributionUpdate + + class RecordMetrics(BaseModel): count: int @@ -74,6 +100,7 @@ class Dataset(BaseModel): guidelines: Optional[str] allow_extra_metadata: bool status: DatasetStatus + distribution: DatasetDistribution workspace_id: UUID last_activity_at: datetime inserted_at: datetime @@ -91,6 +118,10 @@ class DatasetCreate(BaseModel): name: DatasetName guidelines: Optional[DatasetGuidelines] allow_extra_metadata: bool = True + distribution: DatasetDistributionCreate = DatasetOverlapDistributionCreate( + strategy=DatasetDistributionStrategy.overlap, + min_submitted=1, + ) workspace_id: UUID @@ -98,5 +129,6 @@ class DatasetUpdate(UpdateSchema): name: Optional[DatasetName] guidelines: Optional[DatasetGuidelines] allow_extra_metadata: Optional[bool] + distribution: Optional[DatasetDistributionUpdate] - __non_nullable_fields__ = {"name", "allow_extra_metadata"} + __non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"} diff --git a/argilla-server/src/argilla_server/api/schemas/v1/records.py b/argilla-server/src/argilla_server/api/schemas/v1/records.py index 13f37c3ae0..0cf215954a 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/records.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/records.py @@ -23,7 +23,7 @@ from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName from argilla_server.api.schemas.v1.responses import Response, ResponseFilterScope, UserResponseCreate from argilla_server.api.schemas.v1.suggestions import Suggestion, SuggestionCreate, SuggestionFilterScope -from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder +from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder, RecordStatus from argilla_server.pydantic_v1 import BaseModel, Field, StrictStr, root_validator, validator from argilla_server.pydantic_v1.utils import GetterDict from argilla_server.search_engine import TextQuery @@ -66,6 +66,7 @@ def get(self, key: str, default: Any) -> Any: class Record(BaseModel): id: UUID + status: RecordStatus fields: Dict[str, Any] metadata: Optional[Dict[str, Any]] external_id: Optional[str] @@ -196,7 +197,7 @@ def _has_relationships(self): class RecordFilterScope(BaseModel): entity: Literal["record"] - property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at]] + property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at], Literal["status"]] class Records(BaseModel): diff --git a/argilla-server/src/argilla_server/bulk/records_bulk.py b/argilla-server/src/argilla_server/bulk/records_bulk.py index 0e3d372be5..6acbc30031 100644 --- a/argilla-server/src/argilla_server/bulk/records_bulk.py +++ b/argilla-server/src/argilla_server/bulk/records_bulk.py @@ -29,6 +29,7 @@ ) from argilla_server.api.schemas.v1.responses import UserResponseCreate from argilla_server.api.schemas.v1.suggestions import SuggestionCreate +from argilla_server.contexts import distribution from argilla_server.contexts.accounts import fetch_users_by_ids_as_dict from argilla_server.contexts.records import ( fetch_records_by_external_ids_as_dict, @@ -67,6 +68,7 @@ async def create_records_bulk(self, dataset: Dataset, bulk_create: RecordsBulkCr await self._upsert_records_relationships(records, bulk_create.items) await _preload_records_relationships_before_index(self._db, records) + await distribution.update_records_status(self._db, records) await self._search_engine.index_records(dataset, records) await self._db.commit() @@ -207,6 +209,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp await self._upsert_records_relationships(records, bulk_upsert.items) await _preload_records_relationships_before_index(self._db, records) + await distribution.update_records_status(self._db, records) await self._search_engine.index_records(dataset, records) await self._db.commit() @@ -237,6 +240,7 @@ async def _preload_records_relationships_before_index(db: "AsyncSession", record .filter(Record.id.in_([record.id for record in records])) .options( selectinload(Record.responses).selectinload(Response.user), + selectinload(Record.responses_submitted), selectinload(Record.suggestions).selectinload(Suggestion.question), selectinload(Record.vectors), ) diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 34468c2b18..1dbf52fc53 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -37,10 +37,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload, selectinload -from argilla_server.api.schemas.v1.datasets import ( - DatasetCreate, - DatasetProgress, -) +from argilla_server.api.schemas.v1.datasets import DatasetProgress from argilla_server.api.schemas.v1.fields import FieldCreate from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyCreate, MetadataPropertyUpdate from argilla_server.api.schemas.v1.records import ( @@ -63,7 +60,7 @@ VectorSettingsCreate, ) from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema -from argilla_server.contexts import accounts +from argilla_server.contexts import accounts, distribution from argilla_server.enums import DatasetStatus, RecordInclude, UserRole from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError from argilla_server.models import ( @@ -82,6 +79,7 @@ ) from argilla_server.models.suggestions import SuggestionCreateWithRecordId from argilla_server.search_engine import SearchEngine +from argilla_server.validators.datasets import DatasetCreateValidator, DatasetUpdateValidator from argilla_server.validators.responses import ( ResponseCreateValidator, ResponseUpdateValidator, @@ -122,22 +120,18 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) -> return result.scalars().all() -async def create_dataset(db: AsyncSession, dataset_create: DatasetCreate): - if await Workspace.get(db, dataset_create.workspace_id) is None: - raise UnprocessableEntityError(f"Workspace with id `{dataset_create.workspace_id}` not found") +async def create_dataset(db: AsyncSession, dataset_attrs: dict): + dataset = Dataset( + name=dataset_attrs["name"], + guidelines=dataset_attrs["guidelines"], + allow_extra_metadata=dataset_attrs["allow_extra_metadata"], + distribution=dataset_attrs["distribution"], + workspace_id=dataset_attrs["workspace_id"], + ) - if await Dataset.get_by(db, name=dataset_create.name, workspace_id=dataset_create.workspace_id): - raise NotUniqueError( - f"Dataset with name `{dataset_create.name}` already exists for workspace with id `{dataset_create.workspace_id}`" - ) + await DatasetCreateValidator.validate(db, dataset) - return await Dataset.create( - db, - name=dataset_create.name, - guidelines=dataset_create.guidelines, - allow_extra_metadata=dataset_create.allow_extra_metadata, - workspace_id=dataset_create.workspace_id, - ) + return await dataset.save(db) async def _count_required_fields_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int: @@ -176,6 +170,12 @@ async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset return dataset +async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> Dataset: + await DatasetUpdateValidator.validate(db, dataset, dataset_attrs) + + return await dataset.update(db, **dataset_attrs) + + async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset: async with db.begin_nested(): dataset = await dataset.delete(db, autocommit=False) @@ -186,11 +186,6 @@ async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: return dataset -async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_update: "DatasetUpdate") -> Dataset: - params = dataset_update.dict(exclude_unset=True) - return await dataset.update(db, **params) - - async def create_field(db: AsyncSession, dataset: Dataset, field_create: FieldCreate) -> Field: if dataset.is_ready: raise UnprocessableEntityError("Field cannot be created for a published dataset") @@ -945,6 +940,9 @@ async def create_response( await db.flush([response]) await _touch_dataset_last_activity_at(db, record.dataset) await search_engine.update_record_response(response) + await db.refresh(record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, record) + await search_engine.partial_record_update(record, status=record.status) await db.commit() @@ -968,6 +966,9 @@ async def update_response( await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) await search_engine.update_record_response(response) + await db.refresh(response.record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, response.record) + await search_engine.partial_record_update(response.record, status=response.record.status) await db.commit() @@ -997,6 +998,9 @@ async def upsert_response( await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) await search_engine.update_record_response(response) + await db.refresh(record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, record) + await search_engine.partial_record_update(record, status=record.status) await db.commit() @@ -1006,9 +1010,13 @@ async def upsert_response( async def delete_response(db: AsyncSession, search_engine: SearchEngine, response: Response) -> Response: async with db.begin_nested(): response = await response.delete(db, autocommit=False) + await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) await search_engine.delete_record_response(response) + await db.refresh(response.record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, response.record) + await search_engine.partial_record_update(record=response.record, status=response.record.status) await db.commit() diff --git a/argilla-server/src/argilla_server/contexts/distribution.py b/argilla-server/src/argilla_server/contexts/distribution.py new file mode 100644 index 0000000000..92973801ce --- /dev/null +++ b/argilla-server/src/argilla_server/contexts/distribution.py @@ -0,0 +1,42 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.enums import DatasetDistributionStrategy, RecordStatus +from argilla_server.models import Record + + +# TODO: Do this with one single update statement for all records if possible to avoid too many queries. +async def update_records_status(db: AsyncSession, records: List[Record]): + for record in records: + await update_record_status(db, record) + + +async def update_record_status(db: AsyncSession, record: Record) -> Record: + if record.dataset.distribution_strategy == DatasetDistributionStrategy.overlap: + return await _update_record_status_with_overlap_strategy(db, record) + + raise NotImplementedError(f"unsupported distribution strategy `{record.dataset.distribution_strategy}`") + + +async def _update_record_status_with_overlap_strategy(db: AsyncSession, record: Record) -> Record: + if len(record.responses_submitted) >= record.dataset.distribution["min_submitted"]: + record.status = RecordStatus.completed + else: + record.status = RecordStatus.pending + + return await record.save(db, autocommit=False) diff --git a/argilla-server/src/argilla_server/enums.py b/argilla-server/src/argilla_server/enums.py index 13b4843280..fcf0b3142f 100644 --- a/argilla-server/src/argilla_server/enums.py +++ b/argilla-server/src/argilla_server/enums.py @@ -43,12 +43,21 @@ class DatasetStatus(str, Enum): ready = "ready" +class DatasetDistributionStrategy(str, Enum): + overlap = "overlap" + + class UserRole(str, Enum): owner = "owner" admin = "admin" annotator = "annotator" +class RecordStatus(str, Enum): + pending = "pending" + completed = "completed" + + class RecordInclude(str, Enum): responses = "responses" suggestions = "suggestions" diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 468b682467..37bd7730c9 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -29,9 +29,12 @@ DatasetStatus, MetadataPropertyType, QuestionType, + RecordStatus, ResponseStatus, SuggestionType, UserRole, + DatasetDistributionStrategy, + RecordStatus, ) from argilla_server.models.base import DatabaseModel from argilla_server.models.metadata_properties import MetadataPropertySettings @@ -180,11 +183,17 @@ def __repr__(self) -> str: ) +RecordStatusEnum = SAEnum(RecordStatus, name="record_status_enum") + + class Record(DatabaseModel): __tablename__ = "records" fields: Mapped[dict] = mapped_column(JSON, default={}) metadata_: Mapped[Optional[dict]] = mapped_column("metadata", MutableDict.as_mutable(JSON), nullable=True) + status: Mapped[RecordStatus] = mapped_column( + RecordStatusEnum, default=RecordStatus.pending, server_default=RecordStatus.pending, index=True + ) external_id: Mapped[Optional[str]] = mapped_column(index=True) dataset_id: Mapped[UUID] = mapped_column(ForeignKey("datasets.id", ondelete="CASCADE"), index=True) @@ -195,6 +204,13 @@ class Record(DatabaseModel): passive_deletes=True, order_by=Response.inserted_at.asc(), ) + responses_submitted: Mapped[List["Response"]] = relationship( + back_populates="record", + cascade="all, delete-orphan", + passive_deletes=True, + primaryjoin=f"and_(Record.id==Response.record_id, Response.status=='{ResponseStatus.submitted}')", + order_by=Response.inserted_at.asc(), + ) suggestions: Mapped[List["Suggestion"]] = relationship( back_populates="record", cascade="all, delete-orphan", @@ -210,17 +226,17 @@ class Record(DatabaseModel): __table_args__ = (UniqueConstraint("external_id", "dataset_id", name="record_external_id_dataset_id_uq"),) + def vector_value_by_vector_settings(self, vector_settings: "VectorSettings") -> Union[List[float], None]: + for vector in self.vectors: + if vector.vector_settings_id == vector_settings.id: + return vector.value + def __repr__(self): return ( f"Record(id={str(self.id)!r}, external_id={self.external_id!r}, dataset_id={str(self.dataset_id)!r}, " f"inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})" ) - def vector_value_by_vector_settings(self, vector_settings: "VectorSettings") -> Union[List[float], None]: - for vector in self.vectors: - if vector.vector_settings_id == vector_settings.id: - return vector.value - class Question(DatabaseModel): __tablename__ = "questions" @@ -304,6 +320,7 @@ class Dataset(DatabaseModel): guidelines: Mapped[Optional[str]] = mapped_column(Text) allow_extra_metadata: Mapped[bool] = mapped_column(default=True, server_default=sql.true()) status: Mapped[DatasetStatus] = mapped_column(DatasetStatusEnum, default=DatasetStatus.draft, index=True) + distribution: Mapped[dict] = mapped_column(MutableDict.as_mutable(JSON)) workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id", ondelete="CASCADE"), index=True) inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) updated_at: Mapped[datetime] = mapped_column(default=inserted_at_current_value, onupdate=datetime.utcnow) @@ -353,6 +370,10 @@ def is_draft(self): def is_ready(self): return self.status == DatasetStatus.ready + @property + def distribution_strategy(self) -> DatasetDistributionStrategy: + return DatasetDistributionStrategy(self.distribution["strategy"]) + def metadata_property_by_name(self, name: str) -> Union["MetadataProperty", None]: for metadata_property in self.metadata_properties: if metadata_property.name == name: diff --git a/argilla-server/src/argilla_server/search_engine/base.py b/argilla-server/src/argilla_server/search_engine/base.py index 7c9146cafe..ee1dbcc386 100644 --- a/argilla-server/src/argilla_server/search_engine/base.py +++ b/argilla-server/src/argilla_server/search_engine/base.py @@ -317,6 +317,10 @@ async def configure_metadata_property(self, dataset: Dataset, metadata_property: async def index_records(self, dataset: Dataset, records: Iterable[Record]): pass + @abstractmethod + async def partial_record_update(self, record: Record, **update): + pass + @abstractmethod async def delete_records(self, dataset: Dataset, records: Iterable[Record]): pass diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index 2030b59ae5..b328224f19 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -346,6 +346,10 @@ async def index_records(self, dataset: Dataset, records: Iterable[Record]): await self._bulk_op_request(bulk_actions) + async def partial_record_update(self, record: Record, **update): + index_name = await self._get_dataset_index(record.dataset) + await self._update_document_request(index_name=index_name, id=str(record.id), body={"doc": update}) + async def delete_records(self, dataset: Dataset, records: Iterable[Record]): index_name = await self._get_dataset_index(dataset) @@ -552,6 +556,7 @@ def _map_record_to_es_document(self, record: Record) -> Dict[str, Any]: document = { "id": str(record.id), "fields": record.fields, + "status": record.status, "inserted_at": record.inserted_at, "updated_at": record.updated_at, } @@ -712,6 +717,7 @@ def _configure_index_mappings(self, dataset: Dataset) -> dict: "properties": { # See https://www.elastic.co/guide/en/elasticsearch/reference/current/explicit-mapping.html "id": {"type": "keyword"}, + "status": {"type": "keyword"}, RecordSortField.inserted_at.value: {"type": "date_nanos"}, RecordSortField.updated_at.value: {"type": "date_nanos"}, "responses": {"dynamic": True, "type": "object"}, diff --git a/argilla-server/src/argilla_server/validators/datasets.py b/argilla-server/src/argilla_server/validators/datasets.py new file mode 100644 index 0000000000..aae2a5fc83 --- /dev/null +++ b/argilla-server/src/argilla_server/validators/datasets.py @@ -0,0 +1,48 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError +from argilla_server.models import Dataset, Workspace + + +class DatasetCreateValidator: + @classmethod + async def validate(cls, db: AsyncSession, dataset: Dataset) -> None: + await cls._validate_workspace_is_present(db, dataset.workspace_id) + await cls._validate_name_is_not_duplicated(db, dataset.name, dataset.workspace_id) + + @classmethod + async def _validate_workspace_is_present(cls, db: AsyncSession, workspace_id: UUID) -> None: + if await Workspace.get(db, workspace_id) is None: + raise UnprocessableEntityError(f"Workspace with id `{workspace_id}` not found") + + @classmethod + async def _validate_name_is_not_duplicated(cls, db: AsyncSession, name: str, workspace_id: UUID) -> None: + if await Dataset.get_by(db, name=name, workspace_id=workspace_id): + raise NotUniqueError(f"Dataset with name `{name}` already exists for workspace with id `{workspace_id}`") + + +class DatasetUpdateValidator: + @classmethod + async def validate(cls, db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> None: + cls._validate_distribution(dataset, dataset_attrs) + + @classmethod + def _validate_distribution(cls, dataset: Dataset, dataset_attrs: dict) -> None: + if dataset.is_ready and dataset_attrs.get("distribution") is not None: + raise UnprocessableEntityError(f"Distribution settings cannot be modified for a published dataset") diff --git a/argilla-server/tests/factories.py b/argilla-server/tests/factories.py index 5c77b9a0f5..c429fed9af 100644 --- a/argilla-server/tests/factories.py +++ b/argilla-server/tests/factories.py @@ -16,7 +16,7 @@ import random import factory -from argilla_server.enums import FieldType, MetadataPropertyType, OptionsOrder +from argilla_server.enums import DatasetDistributionStrategy, FieldType, MetadataPropertyType, OptionsOrder from argilla_server.models import ( Dataset, Field, @@ -203,6 +203,7 @@ class Meta: model = Dataset name = factory.Sequence(lambda n: f"dataset-{n}") + distribution = {"strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1} workspace = factory.SubFactory(WorkspaceFactory) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py index d7e95520d5..3d1f0bf6da 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py @@ -15,7 +15,7 @@ from uuid import UUID import pytest -from argilla_server.enums import DatasetStatus +from argilla_server.enums import DatasetStatus, RecordStatus from argilla_server.models import Dataset, Record from httpx import AsyncClient from sqlalchemy import func, select @@ -87,6 +87,7 @@ async def test_create_dataset_records_bulk( "items": [ { "id": str(record.id), + "status": RecordStatus.pending, "dataset_id": str(dataset.id), "external_id": record.external_id, "fields": record.fields, diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py new file mode 100644 index 0000000000..4261145d0c --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py @@ -0,0 +1,139 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus +from argilla_server.models import Dataset +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import WorkspaceFactory + + +@pytest.mark.asyncio +class TestCreateDataset: + def url(self) -> str: + return "/api/v1/datasets" + + async def test_create_dataset_with_default_distribution( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "workspace_id": str(workspace.id), + }, + ) + + dataset = (await db.execute(select(Dataset))).scalar_one() + + assert response.status_code == 201 + assert response.json() == { + "id": str(dataset.id), + "name": "Dataset Name", + "guidelines": None, + "allow_extra_metadata": True, + "status": DatasetStatus.draft, + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + "workspace_id": str(workspace.id), + "last_activity_at": dataset.last_activity_at.isoformat(), + "inserted_at": dataset.inserted_at.isoformat(), + "updated_at": dataset.updated_at.isoformat(), + } + + async def test_create_dataset_with_overlap_distribution( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + "workspace_id": str(workspace.id), + }, + ) + + dataset = (await db.execute(select(Dataset))).scalar_one() + + assert response.status_code == 201 + assert response.json() == { + "id": str(dataset.id), + "name": "Dataset Name", + "guidelines": None, + "allow_extra_metadata": True, + "status": DatasetStatus.draft, + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + "workspace_id": str(workspace.id), + "last_activity_at": dataset.last_activity_at.isoformat(), + "inserted_at": dataset.inserted_at.isoformat(), + "updated_at": dataset.updated_at.isoformat(), + } + + async def test_create_dataset_with_overlap_distribution_using_invalid_min_submitted_value( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset name", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 0, + }, + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0 + + async def test_create_dataset_with_invalid_distribution_strategy( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "distribution": { + "strategy": "invalid_strategy", + }, + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0 diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py index e70072d814..8d4981e828 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py @@ -16,7 +16,7 @@ import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import UserRole +from argilla_server.enums import UserRole, RecordStatus from argilla_server.search_engine import SearchEngine, SearchResponseItem, SearchResponses from httpx import AsyncClient @@ -71,6 +71,7 @@ async def test_search_with_filtered_metadata( { "record": { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": record.metadata_, "external_id": record.external_id, @@ -122,6 +123,7 @@ async def test_search_with_filtered_metadata_as_annotator( { "record": { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": {"annotator_meta": "value"}, "external_id": record.external_id, @@ -173,6 +175,7 @@ async def test_search_with_filtered_metadata_as_admin( { "record": { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": {"admin_meta": "value", "annotator_meta": "value", "extra": "value"}, "external_id": record.external_id, diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 3d22527c3b..73077c4381 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -17,7 +17,7 @@ import pytest from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_LE from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import RecordInclude, SortOrder +from argilla_server.enums import RecordInclude, SortOrder, RecordStatus from argilla_server.search_engine import ( AndFilter, Order, @@ -118,6 +118,7 @@ async def test_with_include_responses( { "record": { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": { "sentiment": "neutral", "text": "This is a text", @@ -153,6 +154,7 @@ async def test_with_include_responses( { "record": { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": { "sentiment": "neutral", "text": "This is a text", diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py new file mode 100644 index 0000000000..cdb9b06ea2 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py @@ -0,0 +1,178 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +import pytest +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus +from httpx import AsyncClient + +from tests.factories import DatasetFactory + + +@pytest.mark.asyncio +class TestUpdateDataset: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}" + + async def test_update_dataset_distribution(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + }, + ) + + assert response.status_code == 200 + assert response.json()["distribution"] == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + } + + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + } + + async def test_update_dataset_without_distribution(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"name": "Dataset updated name"}, + ) + + assert response.status_code == 200 + assert response.json()["distribution"] == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + assert dataset.name == "Dataset updated name" + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_without_distribution_for_published_dataset( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"name": "Dataset updated name"}, + ) + + assert response.status_code == 200 + assert response.json()["distribution"] == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + assert dataset.name == "Dataset updated name" + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_for_published_dataset( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + }, + ) + + assert response.status_code == 422 + assert response.json() == {"detail": "Distribution settings cannot be modified for a published dataset"} + + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_with_invalid_strategy( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": "invalid_strategy", + }, + }, + ) + + assert response.status_code == 422 + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_with_invalid_min_submitted_value( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 0, + }, + }, + ) + + assert response.status_code == 422 + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_as_none(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"distribution": None}, + ) + + assert response.status_code == 422 + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py new file mode 100644 index 0000000000..1aae133535 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py @@ -0,0 +1,145 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.models import User, Record +from argilla_server.enums import DatasetDistributionStrategy, RecordStatus, ResponseStatus, DatasetStatus + +from tests.factories import AnnotatorFactory, DatasetFactory, TextFieldFactory, TextQuestionFactory + + +@pytest.mark.asyncio +class TestCreateDatasetRecordsBulk: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}/records/bulk" + + async def test_create_dataset_records_bulk_updates_records_status( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + status=DatasetStatus.ready, + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + }, + ) + + user = await AnnotatorFactory.create(workspaces=[dataset.workspace]) + + await TextFieldFactory.create(name="prompt", dataset=dataset) + await TextFieldFactory.create(name="response", dataset=dataset) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + }, + ], + }, + ) + + assert response.status_code == 201 + + response_items = response.json()["items"] + assert response_items[0]["status"] == RecordStatus.completed + assert response_items[1]["status"] == RecordStatus.pending + assert response_items[2]["status"] == RecordStatus.pending + assert response_items[3]["status"] == RecordStatus.pending + + assert (await Record.get(db, UUID(response_items[0]["id"]))).status == RecordStatus.completed + assert (await Record.get(db, UUID(response_items[1]["id"]))).status == RecordStatus.pending + assert (await Record.get(db, UUID(response_items[2]["id"]))).status == RecordStatus.pending + assert (await Record.get(db, UUID(response_items[3]["id"]))).status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py b/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py index 98b3a864b9..ce433d036d 100644 --- a/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py @@ -16,13 +16,15 @@ from uuid import UUID import pytest -from argilla_server.enums import ResponseStatusFilter -from argilla_server.models import Response, User + from httpx import AsyncClient from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession -from tests.factories import DatasetFactory, RecordFactory, SpanQuestionFactory +from argilla_server.enums import ResponseStatus, RecordStatus, DatasetDistributionStrategy +from argilla_server.models import Response, User + +from tests.factories import DatasetFactory, RecordFactory, SpanQuestionFactory, TextQuestionFactory @pytest.mark.asyncio @@ -52,7 +54,7 @@ async def test_create_record_response_for_span_question( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -72,7 +74,7 @@ async def test_create_record_response_for_span_question( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_json["inserted_at"]).isoformat(), @@ -101,7 +103,7 @@ async def test_create_record_response_for_span_question_with_additional_value_at ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -121,7 +123,7 @@ async def test_create_record_response_for_span_question_with_additional_value_at ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_json["inserted_at"]).isoformat(), @@ -146,7 +148,7 @@ async def test_create_record_response_for_span_question_with_empty_value( "value": [], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -162,7 +164,7 @@ async def test_create_record_response_for_span_question_with_empty_value( "value": [], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_json["inserted_at"]).isoformat(), @@ -189,7 +191,7 @@ async def test_create_record_response_for_span_question_with_record_not_providin ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -219,7 +221,7 @@ async def test_create_record_response_for_span_question_with_invalid_value( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -244,7 +246,7 @@ async def test_create_record_response_for_span_question_with_start_greater_than_ "value": [{"label": "label-a", "start": 5, "end": 6}], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -273,7 +275,7 @@ async def test_create_record_response_for_span_question_with_end_greater_than_ex "value": [{"label": "label-a", "start": 4, "end": 6}], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -304,7 +306,7 @@ async def test_create_record_response_for_span_question_with_invalid_start( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -331,7 +333,7 @@ async def test_create_record_response_for_span_question_with_invalid_end( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -358,7 +360,7 @@ async def test_create_record_response_for_span_question_with_equal_start_and_end ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -385,7 +387,7 @@ async def test_create_record_response_for_span_question_with_end_smaller_than_st ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -412,7 +414,7 @@ async def test_create_record_response_for_span_question_with_non_existent_label( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -446,7 +448,7 @@ async def test_create_record_response_for_span_question_with_overlapped_values( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -454,3 +456,63 @@ async def test_create_record_response_for_span_question_with_overlapped_values( assert response.json() == {"detail": "overlapping values found between spans at index idx=0 and idx=2"} assert (await db.execute(select(func.count(Response.id)))).scalar() == 0 + + async def test_create_record_response_updates_record_status_to_completed( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + response = await async_client.post( + self.url(record.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert response.status_code == 201 + assert record.status == RecordStatus.completed + + async def test_create_record_response_does_not_updates_record_status_to_completed( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + } + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + response = await async_client.post( + self.url(record.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert response.status_code == 201 + assert record.status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py new file mode 100644 index 0000000000..82b035a58a --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py @@ -0,0 +1,153 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient + +from argilla_server.models import User +from argilla_server.enums import DatasetDistributionStrategy, ResponseStatus, DatasetStatus, RecordStatus + +from tests.factories import DatasetFactory, RecordFactory, TextQuestionFactory, ResponseFactory, AnnotatorFactory + + +@pytest.mark.asyncio +class TestUpsertDatasetRecordsBulk: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}/records/bulk" + + async def test_upsert_dataset_records_bulk_updates_records_status( + self, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + status=DatasetStatus.ready, + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + }, + ) + + user = await AnnotatorFactory.create(workspaces=[dataset.workspace]) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record_a = await RecordFactory.create(dataset=dataset) + assert record_a.status == RecordStatus.pending + + await ResponseFactory.create( + user=owner, + record=record_a, + status=ResponseStatus.submitted, + values={ + "text-question": { + "value": "text question response", + }, + }, + ) + + record_b = await RecordFactory.create(dataset=dataset) + assert record_b.status == RecordStatus.pending + + record_c = await RecordFactory.create(dataset=dataset) + assert record_c.status == RecordStatus.pending + + record_d = await RecordFactory.create(dataset=dataset) + assert record_d.status == RecordStatus.pending + + response = await async_client.put( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "id": str(record_a.id), + "responses": [ + { + "user_id": str(user.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "id": str(record_b.id), + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "id": str(record_c.id), + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "id": str(record_d.id), + "responses": [], + }, + ], + }, + ) + + assert response.status_code == 200 + + respose_items = response.json()["items"] + assert respose_items[0]["status"] == RecordStatus.completed + assert respose_items[1]["status"] == RecordStatus.pending + assert respose_items[2]["status"] == RecordStatus.pending + assert respose_items[3]["status"] == RecordStatus.pending + + assert record_a.status == RecordStatus.completed + assert record_b.status == RecordStatus.pending + assert record_c.status == RecordStatus.pending + assert record_d.status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py index 009cec7d2e..07b4bf0199 100644 --- a/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py @@ -18,7 +18,7 @@ import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import ResponseStatus +from argilla_server.enums import ResponseStatus, RecordStatus from argilla_server.models import Response, User from argilla_server.search_engine import SearchEngine from argilla_server.use_cases.responses.upsert_responses_in_bulk import UpsertResponsesInBulkUseCase @@ -111,7 +111,7 @@ async def test_multiple_responses( "item": { "id": str(response_to_create_id), "values": {"prompt-quality": {"value": 5}}, - "status": ResponseStatus.submitted.value, + "status": ResponseStatus.submitted, "record_id": str(records[0].id), "user_id": str(annotator.id), "inserted_at": datetime.fromisoformat(resp_json["items"][0]["item"]["inserted_at"]).isoformat(), @@ -123,7 +123,7 @@ async def test_multiple_responses( "item": { "id": str(response_to_update.id), "values": {"prompt-quality": {"value": 10}}, - "status": ResponseStatus.submitted.value, + "status": ResponseStatus.submitted, "record_id": str(records[1].id), "user_id": str(annotator.id), "inserted_at": datetime.fromisoformat(resp_json["items"][1]["item"]["inserted_at"]).isoformat(), @@ -146,6 +146,10 @@ async def test_multiple_responses( ], } + assert records[0].status == RecordStatus.completed + assert records[1].status == RecordStatus.completed + assert records[2].status == RecordStatus.pending + assert (await db.execute(select(func.count(Response.id)))).scalar() == 2 response_to_create = (await db.execute(select(Response).filter_by(id=response_to_create_id))).scalar_one() diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py new file mode 100644 index 0000000000..6b9d4ec749 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py @@ -0,0 +1,66 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +import pytest + +from httpx import AsyncClient + +from argilla_server.models import User +from argilla_server.enums import DatasetDistributionStrategy, RecordStatus, ResponseStatus + +from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, TextQuestionFactory + + +@pytest.mark.asyncio +class TestDeleteResponse: + def url(self, response_id: UUID) -> str: + return f"/api/v1/responses/{response_id}" + + async def test_delete_response_updates_record_status_to_pending( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + ) + + record = await RecordFactory.create(status=RecordStatus.completed, dataset=dataset) + response = await ResponseFactory.create(record=record) + + resp = await async_client.delete(self.url(response.id), headers=owner_auth_header) + + assert resp.status_code == 200 + assert record.status == RecordStatus.pending + + async def test_delete_response_does_not_updates_record_status_to_pending( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + } + ) + + record = await RecordFactory.create(status=RecordStatus.completed, dataset=dataset) + responses = await ResponseFactory.create_batch(3, record=record) + + resp = await async_client.delete(self.url(responses[0].id), headers=owner_auth_header) + + assert resp.status_code == 200 + assert record.status == RecordStatus.completed diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py index f5ffab7b31..d5097f8c7b 100644 --- a/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py @@ -16,13 +16,15 @@ from uuid import UUID import pytest -from argilla_server.enums import ResponseStatus -from argilla_server.models import Response, User from httpx import AsyncClient + from sqlalchemy import select from sqlalchemy.ext.asyncio.session import AsyncSession -from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, SpanQuestionFactory +from argilla_server.enums import ResponseStatus, DatasetDistributionStrategy, RecordStatus +from argilla_server.models import Response, User + +from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, SpanQuestionFactory, TextQuestionFactory @pytest.mark.asyncio @@ -560,3 +562,66 @@ async def test_update_response_for_span_question_with_non_existent_label( } assert (await db.execute(select(Response).filter_by(id=response.id))).scalar_one().values == response_values + + async def test_update_response_updates_record_status_to_completed( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + response = await ResponseFactory.create(record=record, status=ResponseStatus.draft) + + resp = await async_client.put( + self.url(response.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question updated response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert resp.status_code == 200 + assert record.status == RecordStatus.completed + + async def test_update_response_updates_record_status_to_pending( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset, status=RecordStatus.completed) + response = await ResponseFactory.create( + values={ + "text-question": { + "value": "text question response", + }, + }, + record=record, + status=ResponseStatus.submitted, + ) + + resp = await async_client.put( + self.url(response.id), + headers=owner_auth_header, + json={"status": ResponseStatus.draft}, + ) + + assert resp.status_code == 200 + assert record.status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 650e9f3808..e0c9fe4d5e 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -34,11 +34,13 @@ ) from argilla_server.constants import API_KEY_HEADER_NAME from argilla_server.enums import ( + DatasetDistributionStrategy, DatasetStatus, OptionsOrder, RecordInclude, ResponseStatusFilter, SimilarityOrder, + RecordStatus, ) from argilla_server.models import ( Dataset, @@ -116,6 +118,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": None, "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_a.workspace_id), "last_activity_at": dataset_a.last_activity_at.isoformat(), "inserted_at": dataset_a.inserted_at.isoformat(), @@ -127,6 +133,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": "guidelines", "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_b.workspace_id), "last_activity_at": dataset_b.last_activity_at.isoformat(), "inserted_at": dataset_b.inserted_at.isoformat(), @@ -138,6 +148,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": None, "allow_extra_metadata": True, "status": "ready", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_c.workspace_id), "last_activity_at": dataset_c.last_activity_at.isoformat(), "inserted_at": dataset_c.inserted_at.isoformat(), @@ -653,8 +667,6 @@ async def test_list_dataset_vectors_settings_without_authentication(self, async_ assert response.status_code == 401 - # Helper function to create records with responses - async def test_get_dataset(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create(name="dataset") @@ -667,6 +679,10 @@ async def test_get_dataset(self, async_client: "AsyncClient", owner_auth_header: "guidelines": None, "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset.workspace_id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), @@ -839,13 +855,16 @@ async def test_create_dataset(self, async_client: "AsyncClient", db: "AsyncSessi await db.refresh(workspace) response_body = response.json() - assert (await db.execute(select(func.count(Dataset.id)))).scalar() == 1 assert response_body == { "id": str(UUID(response_body["id"])), "name": "name", "guidelines": "guidelines", "allow_extra_metadata": False, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(workspace.id), "last_activity_at": datetime.fromisoformat(response_body["last_activity_at"]).isoformat(), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), @@ -3644,6 +3663,7 @@ async def test_search_current_user_dataset_records( { "record": { "id": str(records[0].id), + "status": RecordStatus.pending, "fields": {"input": "input_a", "output": "output_a"}, "metadata": None, "external_id": records[0].external_id, @@ -3656,6 +3676,7 @@ async def test_search_current_user_dataset_records( { "record": { "id": str(records[1].id), + "status": RecordStatus.pending, "fields": {"input": "input_b", "output": "output_b"}, "metadata": {"unit": "test"}, "external_id": records[1].external_id, @@ -3997,6 +4018,7 @@ async def test_search_current_user_dataset_records_with_include( { "record": { "id": str(records[0].id), + "status": RecordStatus.pending, "fields": { "input": "input_a", "output": "output_a", @@ -4012,6 +4034,7 @@ async def test_search_current_user_dataset_records_with_include( { "record": { "id": str(records[1].id), + "status": RecordStatus.pending, "fields": { "input": "input_b", "output": "output_b", @@ -4151,6 +4174,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( { "record": { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -4167,6 +4191,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( { "record": { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, @@ -4182,6 +4207,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( { "record": { "id": str(record_c.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, @@ -4245,6 +4271,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors { "record": { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -4261,6 +4288,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors { "record": { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, @@ -4276,6 +4304,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors { "record": { "id": str(record_c.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, @@ -4752,6 +4781,10 @@ async def test_update_dataset(self, async_client: "AsyncClient", db: "AsyncSessi "guidelines": guidelines, "allow_extra_metadata": allow_extra_metadata, "status": "ready", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset.workspace_id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index f088cfcda9..8f78940df3 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -18,7 +18,7 @@ import pytest from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import RecordInclude, RecordSortField, ResponseStatus, UserRole +from argilla_server.enums import RecordInclude, RecordSortField, ResponseStatus, UserRole, RecordStatus from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace from argilla_server.search_engine import ( FloatMetadataFilter, @@ -821,6 +821,7 @@ async def test_list_current_user_dataset_records( "items": [ { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": {"input": "input_a", "output": "output_a"}, "metadata": None, "dataset_id": str(dataset.id), @@ -830,6 +831,7 @@ async def test_list_current_user_dataset_records( }, { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": {"input": "input_b", "output": "output_b"}, "metadata": {"unit": "test"}, "dataset_id": str(dataset.id), @@ -839,6 +841,7 @@ async def test_list_current_user_dataset_records( }, { "id": str(record_c.id), + "status": RecordStatus.pending, "fields": {"input": "input_c", "output": "output_c"}, "metadata": None, "dataset_id": str(dataset.id), @@ -898,6 +901,7 @@ async def test_list_current_user_dataset_records_with_filtered_metadata_as_annot "items": [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"input": "input_b", "output": "output_b"}, "metadata": {"key1": "value1"}, "dataset_id": str(dataset.id), diff --git a/argilla-server/tests/unit/api/handlers/v1/test_records.py b/argilla-server/tests/unit/api/handlers/v1/test_records.py index ed7d9f8cc2..3c361b1666 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_records.py @@ -19,7 +19,7 @@ import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import ResponseStatus +from argilla_server.enums import RecordStatus, ResponseStatus from argilla_server.models import Dataset, Record, Response, Suggestion, User, UserRole from argilla_server.search_engine import SearchEngine from sqlalchemy import func, select @@ -92,6 +92,7 @@ async def test_get_record(self, async_client: "AsyncClient", role: UserRole): assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -188,6 +189,7 @@ async def test_update_record(self, async_client: "AsyncClient", mock_search_engi assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": { "terms-metadata-property": "c", @@ -228,6 +230,7 @@ async def test_update_record(self, async_client: "AsyncClient", mock_search_engi "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), } + mock_search_engine.index_records.assert_called_once_with(dataset, [record]) async def test_update_record_with_null_metadata( @@ -251,6 +254,7 @@ async def test_update_record_with_null_metadata( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -278,6 +282,7 @@ async def test_update_record_with_no_metadata( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -310,6 +315,7 @@ async def test_update_record_with_list_terms_metadata( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": { "terms-metadata-property": ["a", "b", "c"], @@ -339,6 +345,7 @@ async def test_update_record_with_no_suggestions( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -1413,6 +1420,7 @@ async def test_delete_record( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": None, "external_id": record.external_id, diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index ecba3232a6..c4376ca686 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -16,7 +16,7 @@ import pytest import pytest_asyncio -from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder +from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder, RecordStatus from argilla_server.models import Dataset, Question, Record, User, VectorSettings from argilla_server.search_engine import ( FloatMetadataFilter, @@ -263,6 +263,7 @@ async def refresh_records(records: List[Record]): for record in records: await record.awaitable_attrs.suggestions await record.awaitable_attrs.responses + await record.awaitable_attrs.responses_submitted await record.awaitable_attrs.vectors @@ -314,6 +315,7 @@ async def test_create_index_for_dataset( ], "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -356,6 +358,7 @@ async def test_create_index_for_dataset_with_fields( ], "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -428,6 +431,7 @@ async def test_create_index_for_dataset_with_metadata_properties( ], "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -475,6 +479,7 @@ async def test_create_index_for_dataset_with_questions( "dynamic": "strict", "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -879,6 +884,7 @@ async def test_index_records(self, search_engine: BaseElasticAndOpenSearchEngine assert es_docs == [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), @@ -937,6 +943,7 @@ async def test_index_records_with_suggestions( assert es_docs == [ { "id": str(records[0].id), + "status": RecordStatus.pending, "fields": records[0].fields, "inserted_at": records[0].inserted_at.isoformat(), "updated_at": records[0].updated_at.isoformat(), @@ -944,6 +951,7 @@ async def test_index_records_with_suggestions( }, { "id": str(records[1].id), + "status": RecordStatus.pending, "fields": records[1].fields, "inserted_at": records[1].inserted_at.isoformat(), "updated_at": records[1].updated_at.isoformat(), @@ -978,6 +986,7 @@ async def test_index_records_with_metadata( assert es_docs == [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), @@ -1017,6 +1026,7 @@ async def test_index_records_with_vectors( assert es_docs == [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), diff --git a/argilla/src/argilla/_models/_search.py b/argilla/src/argilla/_models/_search.py index f62dbff0b7..3c256805a0 100644 --- a/argilla/src/argilla/_models/_search.py +++ b/argilla/src/argilla/_models/_search.py @@ -17,6 +17,11 @@ from pydantic import BaseModel, Field +class RecordFilterScopeModel(BaseModel): + entity: Literal["record"] = "record" + property: Literal["status"] = "status" + + class ResponseFilterScopeModel(BaseModel): """Filter scope for filtering on a response entity.""" @@ -42,6 +47,7 @@ class MetadataFilterScopeModel(BaseModel): ScopeModel = Annotated[ Union[ + RecordFilterScopeModel, ResponseFilterScopeModel, SuggestionFilterScopeModel, MetadataFilterScopeModel, diff --git a/argilla/src/argilla/records/_search.py b/argilla/src/argilla/records/_search.py index adc56b5750..6ccdcee33a 100644 --- a/argilla/src/argilla/records/_search.py +++ b/argilla/src/argilla/records/_search.py @@ -26,6 +26,7 @@ FilterModel, AndFilterModel, QueryModel, + RecordFilterScopeModel, ) @@ -54,8 +55,9 @@ def model(self) -> FilterModel: @staticmethod def _extract_filter_scope(field: str) -> ScopeModel: field = field.strip() - if field == "status": + return RecordFilterScopeModel(property="status") + elif field == "responses.status": return ResponseFilterScopeModel(property="status") elif "metadata" in field: _, md_property = field.split(".")