diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 811497a370..3b7a833ff9 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -39,6 +39,13 @@ These are the section headers that we use: - [breaking] Remove deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) - [breaking] Remove deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) +### Removed + +- Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) +- [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) +- [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) +- [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166)) + ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) ### Changed diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 8cc5ee2538..0fca256da4 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -19,7 +19,6 @@ from fastapi import APIRouter, Depends, Query, Security, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from typing_extensions import Annotated import argilla_server.search_engine as search_engine from argilla_server.api.policies.v1 import DatasetPolicy, RecordPolicy, authorize, is_authorized @@ -27,8 +26,6 @@ Filters, FilterScope, MetadataFilterScope, - MetadataParsedQueryParam, - MetadataQueryParams, Order, RangeFilter, RecordFilterScope, @@ -49,19 +46,14 @@ ) from argilla_server.contexts import datasets, search from argilla_server.database import get_async_db -from argilla_server.enums import MetadataPropertyType, RecordSortField, ResponseStatusFilter, SortOrder +from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE -from argilla_server.models import Dataset, Field, MetadataProperty, Record, User, VectorSettings +from argilla_server.models import Dataset, Field, Record, User, VectorSettings from argilla_server.search_engine import ( AndFilter, - FloatMetadataFilter, - IntegerMetadataFilter, - MetadataFilter, SearchEngine, SearchResponses, - SortBy, - TermsMetadataFilter, UserResponseStatusFilter, get_search_engine, ) @@ -74,25 +66,6 @@ LIST_DATASET_RECORDS_DEFAULT_SORT_BY = {RecordSortField.inserted_at.value: "asc"} DELETE_DATASET_RECORDS_LIMIT = 100 -_RECORD_SORT_FIELD_VALUES = tuple(field.value for field in RecordSortField) -_VALID_SORT_VALUES = tuple(sort.value for sort in SortOrder) -_METADATA_PROPERTY_SORT_BY_REGEX = re.compile(r"^metadata\.(?P(?=.*[a-z0-9])[a-z0-9_-]+)$") - -SortByQueryParamParsed = Annotated[ - Dict[str, str], - Depends( - parse_query_param( - name="sort_by", - description=( - "The field used to sort the records. Expected format is `field` or `field:{asc,desc}`, where `field`" - " can be 'inserted_at', 'updated_at' or the name of a metadata property" - ), - max_values_per_key=1, - group_keys_without_values=False, - ) - ), -] - parse_record_include_param = parse_query_param( name="include", help="Relationships to include in the response", model=RecordIncludeParam ) @@ -104,13 +77,10 @@ async def _filter_records_using_search_engine( db: "AsyncSession", search_engine: "SearchEngine", dataset: Dataset, - parsed_metadata: List[MetadataParsedQueryParam], limit: int, offset: int, user: Optional[User] = None, - response_statuses: Optional[List[ResponseStatusFilter]] = None, include: Optional[RecordIncludeParam] = None, - sort_by_query_param: Optional[Dict[str, str]] = None, ) -> Tuple[List[Record], int]: search_responses = await _get_search_responses( db=db, @@ -119,9 +89,6 @@ async def _filter_records_using_search_engine( limit=limit, offset=offset, user=user, - parsed_metadata=parsed_metadata, - response_statuses=response_statuses, - sort_by_query_param=sort_by_query_param, ) record_ids = [response.record_id for response in search_responses.items] @@ -180,13 +147,10 @@ async def _get_search_responses( db: "AsyncSession", search_engine: "SearchEngine", dataset: Dataset, - parsed_metadata: List[MetadataParsedQueryParam], limit: int, offset: int, search_records_query: Optional[SearchRecordsQuery] = None, user: Optional[User] = None, - response_statuses: Optional[List[ResponseStatusFilter]] = None, - sort_by_query_param: Optional[Dict[str, str]] = None, ) -> "SearchResponses": search_records_query = search_records_query or SearchRecordsQuery() @@ -226,10 +190,6 @@ async def _get_search_responses( if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id): raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.") - metadata_filters = await _build_metadata_filters(db, dataset, parsed_metadata) - response_status_filter = await _build_response_status_filter_for_search(response_statuses, user=user) - sort_by = await _build_sort_by(db, dataset, sort_by_query_param) - if vector_query and vector_settings: similarity_search_params = { "dataset": dataset, @@ -238,8 +198,6 @@ async def _get_search_responses( "record": record, "query": text_query, "order": vector_query.order, - "metadata_filters": metadata_filters, - "user_response_status_filter": response_status_filter, "max_results": limit, } @@ -251,11 +209,8 @@ async def _get_search_responses( search_params = { "dataset": dataset, "query": text_query, - "metadata_filters": metadata_filters, - "user_response_status_filter": response_status_filter, "offset": offset, "limit": limit, - "sort_by": sort_by, } if user is not None: @@ -269,32 +224,6 @@ async def _get_search_responses( return await search_engine.search(**search_params) -async def _build_metadata_filters( - db: "AsyncSession", dataset: Dataset, parsed_metadata: List[MetadataParsedQueryParam] -) -> List["MetadataFilter"]: - try: - metadata_filters = [] - for metadata_param in parsed_metadata: - metadata_property = await MetadataProperty.get_by(db, name=metadata_param.name, dataset_id=dataset.id) - if metadata_property is None: - continue # won't fail on unknown metadata filter name - - if metadata_property.type == MetadataPropertyType.terms: - metadata_filter_class = TermsMetadataFilter - elif metadata_property.type == MetadataPropertyType.integer: - metadata_filter_class = IntegerMetadataFilter - elif metadata_property.type == MetadataPropertyType.float: - metadata_filter_class = FloatMetadataFilter - else: - raise ValueError(f"Not found filter for type {metadata_property.type}") - - metadata_filters.append(metadata_filter_class.from_string(metadata_property, metadata_param.value)) - except (UnprocessableEntityError, ValueError) as ex: - raise UnprocessableEntityError(f"Cannot parse provided metadata filters: {ex}") - - return metadata_filters - - async def _build_response_status_filter_for_search( response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None ) -> Optional[UserResponseStatusFilter]: @@ -307,43 +236,6 @@ async def _build_response_status_filter_for_search( return user_response_status_filter -async def _build_sort_by( - db: "AsyncSession", dataset: Dataset, sort_by_query_param: Optional[Dict[str, str]] = None -) -> Union[List[SortBy], None]: - if sort_by_query_param is None: - return None - - sorts_by = [] - for sort_field, sort_order in sort_by_query_param.items(): - if sort_field in _RECORD_SORT_FIELD_VALUES: - field = sort_field - elif (match := _METADATA_PROPERTY_SORT_BY_REGEX.match(sort_field)) is not None: - metadata_property_name = match.group("name") - metadata_property = await MetadataProperty.get_by(db, name=metadata_property_name, dataset_id=dataset.id) - if not metadata_property: - raise UnprocessableEntityError( - f"Provided metadata property in 'sort_by' query param '{metadata_property_name}' not found in " - f"dataset with '{dataset.id}'." - ) - - field = metadata_property - else: - valid_sort_fields = ", ".join(f"'{sort_field}'" for sort_field in _RECORD_SORT_FIELD_VALUES) - raise UnprocessableEntityError( - f"Provided sort field in 'sort_by' query param '{sort_field}' is not valid. It must be either" - f" {valid_sort_fields} or `metadata.metadata-property-name`" - ) - - if sort_order is not None and sort_order not in _VALID_SORT_VALUES: - raise UnprocessableEntityError( - f"Provided sort order in 'sort_by' query param '{sort_order}' for field '{sort_field}' is not valid.", - ) - - sorts_by.append(SortBy(field=field, order=sort_order or SortOrder.asc.value)) - - return sorts_by - - async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID): try: await search.validate_search_records_query(db, query, dataset_id) @@ -351,54 +243,13 @@ async def _validate_search_records_query(db: "AsyncSession", query: SearchRecord raise UnprocessableEntityError(str(e)) -@router.get("/me/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) -async def list_current_user_dataset_records( - *, - db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), - dataset_id: UUID, - metadata: MetadataQueryParams = Depends(), - sort_by_query_param: SortByQueryParamParsed, - include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), - offset: int = 0, - limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), - current_user: User = Security(auth.get_current_user), -): - dataset = await Dataset.get_or_raise(db, dataset_id, options=[selectinload(Dataset.metadata_properties)]) - - await authorize(current_user, DatasetPolicy.get(dataset)) - - records, total = await _filter_records_using_search_engine( - db, - search_engine, - dataset=dataset, - parsed_metadata=metadata.metadata_parsed, - limit=limit, - offset=offset, - user=current_user, - response_statuses=response_statuses, - include=include, - sort_by_query_param=sort_by_query_param, - ) - - for record in records: - record.dataset = dataset - record.metadata_ = await _filter_record_metadata_for_user(record, current_user) - - return Records(items=records, total=total) - - @router.get("/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) async def list_dataset_records( *, db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, - metadata: MetadataQueryParams = Depends(), - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), @@ -411,12 +262,9 @@ async def list_dataset_records( db, search_engine, dataset=dataset, - parsed_metadata=metadata.metadata_parsed, limit=limit, offset=offset, - response_statuses=response_statuses, include=include, - sort_by_query_param=sort_by_query_param or LIST_DATASET_RECORDS_DEFAULT_SORT_BY, ) return Records(items=records, total=total) @@ -460,10 +308,7 @@ async def search_current_user_dataset_records( telemetry_client: TelemetryClient = Depends(get_telemetry_client), dataset_id: UUID, body: SearchRecordsQuery, - metadata: MetadataQueryParams = Depends(), - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), @@ -486,12 +331,9 @@ async def search_current_user_dataset_records( search_engine=search_engine, dataset=dataset, search_records_query=body, - parsed_metadata=metadata.metadata_parsed, limit=limit, offset=offset, user=current_user, - response_statuses=response_statuses, - sort_by_query_param=sort_by_query_param, ) record_id_score_map: Dict[UUID, Dict[str, Union[float, SearchRecord, None]]] = { @@ -534,10 +376,7 @@ async def search_dataset_records( search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, body: SearchRecordsQuery, - metadata: MetadataQueryParams = Depends(), - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), @@ -555,9 +394,6 @@ async def search_dataset_records( search_records_query=body, limit=limit, offset=offset, - parsed_metadata=metadata.metadata_parsed, - response_statuses=response_statuses, - sort_by_query_param=sort_by_query_param, ) record_id_score_map = { diff --git a/argilla-server/src/argilla_server/api/schemas/v1/records.py b/argilla-server/src/argilla_server/api/schemas/v1/records.py index 0cf215954a..b5ff7c3f4c 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/records.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/records.py @@ -13,12 +13,9 @@ # limitations under the License. from datetime import datetime - from typing import Annotated, Any, Dict, List, Literal, Optional, Union from uuid import UUID -import fastapi - from argilla_server.api.schemas.v1.commons import UpdateSchema from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName from argilla_server.api.schemas.v1.responses import Response, ResponseFilterScope, UserResponseCreate @@ -223,15 +220,6 @@ def __init__(self, string: str): self.value: str = "".join(v).strip() -class MetadataQueryParams(BaseModel): - metadata: List[str] = Field(fastapi.Query([], pattern=r"^(?=.*[a-z0-9])[a-z0-9_-]+:(.+(,(.+))*)$")) - - @property - def metadata_parsed(self) -> List[MetadataParsedQueryParam]: - # TODO: Validate metadata fields names from query params - return [MetadataParsedQueryParam(q) for q in self.metadata] - - class VectorQuery(BaseModel): name: str record_id: Optional[UUID] = None diff --git a/argilla-server/src/argilla_server/search_engine/base.py b/argilla-server/src/argilla_server/search_engine/base.py index ee1dbcc386..db5bc87e2a 100644 --- a/argilla-server/src/argilla_server/search_engine/base.py +++ b/argilla-server/src/argilla_server/search_engine/base.py @@ -15,17 +15,13 @@ from abc import ABCMeta, abstractmethod from contextlib import asynccontextmanager from typing import ( - Any, AsyncGenerator, - ClassVar, - Dict, Generic, Iterable, List, Optional, - Type, - TypeVar, Union, + TypeVar, ) from uuid import UUID @@ -38,16 +34,12 @@ SortOrder, ) from argilla_server.models import Dataset, MetadataProperty, Record, Response, Suggestion, User, Vector, VectorSettings -from argilla_server.pydantic_v1 import BaseModel, Field, root_validator +from argilla_server.pydantic_v1 import BaseModel, Field from argilla_server.pydantic_v1.generics import GenericModel __all__ = [ "SearchEngine", "TextQuery", - "MetadataFilter", - "TermsMetadataFilter", - "IntegerMetadataFilter", - "FloatMetadataFilter", "UserResponseStatusFilter", "SearchResponseItem", "SearchResponses", @@ -147,67 +139,6 @@ def has_pending_status(self) -> bool: return ResponseStatusFilter.pending in self.statuses or ResponseStatusFilter.missing in self.statuses -class MetadataFilter(BaseModel): - metadata_property: MetadataProperty - - class Config: - arbitrary_types_allowed = True - - @classmethod - @abstractmethod - def from_string(cls, metadata_property: MetadataProperty, string: str) -> "MetadataFilter": - pass - - -class TermsMetadataFilter(MetadataFilter): - values: List[str] - - @classmethod - def from_string(cls, metadata_property: MetadataProperty, string: str) -> "MetadataFilter": - return cls(metadata_property=metadata_property, values=string.split(",")) - - -NT = TypeVar("NT", int, float) - - -class _RangeModel(GenericModel, Generic[NT]): - ge: Optional[NT] - le: Optional[NT] - - -class NumericMetadataFilter(GenericModel, Generic[NT], MetadataFilter): - ge: Optional[NT] = None - le: Optional[NT] = None - - _json_model: ClassVar[Type[_RangeModel]] - - @root_validator(skip_on_failure=True) - def check_bounds(cls, values: Dict[str, Any]) -> Dict[str, Any]: - ge = values.get("ge") - le = values.get("le") - - if ge is None and le is None: - raise ValueError("One of 'ge' or 'le' values must be specified") - - if ge is not None and le is not None and ge > le: - raise ValueError(f"'ge' ({ge}) must be lower or equal than 'le' ({le})") - - return values - - @classmethod - def from_string(cls, metadata_property: MetadataProperty, string: str) -> "NumericMetadataFilter": - model = cls._json_model.parse_raw(string) - return cls(metadata_property=metadata_property, ge=model.ge, le=model.le) - - -class IntegerMetadataFilter(NumericMetadataFilter[int]): - _json_model = _RangeModel[int] - - -class FloatMetadataFilter(NumericMetadataFilter[float]): - _json_model = _RangeModel[float] - - class SearchResponseItem(BaseModel): record_id: UUID score: Optional[float] @@ -236,6 +167,9 @@ class TermCount(BaseModel): values: List[TermCount] = Field(default_factory=list) +NT = TypeVar("NT", int, float) + + class NumericMetadataMetrics(GenericModel, Generic[NT]): min: Optional[NT] max: Optional[NT] @@ -348,11 +282,6 @@ async def search( query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, - # TODO: remove them and keep filter and order - user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, - sort_by: Optional[List[SortBy]] = None, - # END TODO offset: int = 0, limit: int = 100, ) -> SearchResponses: @@ -378,10 +307,6 @@ async def similarity_search( record: Optional[Record] = None, query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, - # TODO: remove them and keep filter - user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, - # END TODO max_results: int = 100, order: SimilarityOrder = SimilarityOrder.most_similar, threshold: Optional[float] = None, diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index a081105d16..e6541309a0 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -38,11 +38,8 @@ AndFilter, Filter, FilterScope, - FloatMetadataFilter, FloatMetadataMetrics, - IntegerMetadataFilter, IntegerMetadataMetrics, - MetadataFilter, MetadataFilterScope, MetadataMetrics, Order, @@ -55,7 +52,6 @@ SortBy, SuggestionFilterScope, TermsFilter, - TermsMetadataFilter, TermsMetadataMetrics, TextQuery, UserResponseStatusFilter, @@ -97,9 +93,6 @@ def es_bool_query( if must_not: bool_query["must_not"] = must_not - if not bool_query: - raise ValueError("Cannot build a boolean query without any clause") - if minimum_should_match: bool_query["minimum_should_match"] = minimum_should_match @@ -210,55 +203,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str: return str(vector_settings.id) -# This function will be moved once the `metadata_filters` argument is removed from search and similarity_search methods -def _unify_metadata_filters_with_filter(metadata_filters: List[MetadataFilter], filter: Optional[Filter]) -> Filter: - filters = [] - if filter: - filters.append(filter) - - for metadata_filter in metadata_filters: - metadata_scope = MetadataFilterScope(metadata_property=metadata_filter.metadata_property.name) - if isinstance(metadata_filter, TermsMetadataFilter): - new_filter = TermsFilter(scope=metadata_scope, values=metadata_filter.values) - elif isinstance(metadata_filter, (IntegerMetadataFilter, FloatMetadataFilter)): - new_filter = RangeFilter(scope=metadata_scope, ge=metadata_filter.ge, le=metadata_filter.le) - else: - raise ValueError(f"Cannot process request for metadata filter {metadata_filter}") - filters.append(new_filter) - - return AndFilter(filters=filters) - - -# This function will be moved once the response status filter is removed from search and similarity_search methods -def _unify_user_response_status_filter_with_filter( - user_response_status_filter: UserResponseStatusFilter, filter: Optional[Filter] = None -) -> Filter: - scope = ResponseFilterScope(user=user_response_status_filter.user, property="status") - response_filter = TermsFilter(scope=scope, values=[status.value for status in user_response_status_filter.statuses]) - - if filter: - return AndFilter(filters=[filter, response_filter]) - else: - return response_filter - - -# This function will be moved once the `sort_by` argument is removed from search and similarity_search methods -def _unify_sort_by_with_order(sort_by: List[SortBy], order: List[Order]) -> List[Order]: - if order: - return order - - new_order = [] - for sort in sort_by: - if isinstance(sort.field, MetadataProperty): - scope = MetadataFilterScope(metadata_property=sort.field.name) - else: - scope = RecordFilterScope(property=sort.field) - - new_order.append(Order(scope=scope, order=sort.order)) - - return new_order - - def is_response_status_scope(scope: FilterScope) -> bool: return isinstance(scope, ResponseFilterScope) and scope.property == "status" and scope.question is None @@ -370,14 +314,14 @@ async def update_record_response(self, response: Response): es_responses = self._map_record_responses_to_es([response]) - await self._update_document_request(index_name, id=record.id, body={"doc": {"responses": es_responses}}) + await self._update_document_request(index_name, id=str(record.id), body={"doc": {"responses": es_responses}}) async def delete_record_response(self, response: Response): record = response.record index_name = await self._get_dataset_index(record.dataset) await self._update_document_request( - index_name, id=record.id, body={"script": es_script_for_delete_user_response(response.user)} + index_name, id=str(record.id), body={"script": es_script_for_delete_user_response(response.user)} ) async def update_record_suggestion(self, suggestion: Suggestion): @@ -387,7 +331,7 @@ async def update_record_suggestion(self, suggestion: Suggestion): await self._update_document_request( index_name, - id=suggestion.record_id, + id=str(suggestion.record_id), body={"doc": {"suggestions": es_suggestions}}, ) @@ -396,7 +340,7 @@ async def delete_record_suggestion(self, suggestion: Suggestion): await self._update_document_request( index_name, - id=suggestion.record_id, + id=str(suggestion.record_id), body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'}, ) @@ -423,21 +367,10 @@ async def similarity_search( record: Optional[Record] = None, query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, - # TODO: remove them and keep filter - user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, - # END TODO max_results: int = 100, order: SimilarityOrder = SimilarityOrder.most_similar, threshold: Optional[float] = None, ) -> SearchResponses: - # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if metadata_filters: - filter = _unify_metadata_filters_with_filter(metadata_filters, filter) - if user_response_status_filter and user_response_status_filter.statuses: - filter = _unify_user_response_status_filter_with_filter(user_response_status_filter, filter) - # END TODO - if bool(value) == bool(record): raise ValueError("Must provide either vector value or record to compute the similarity search") @@ -629,26 +562,11 @@ async def search( query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, - # TODO: Remove these arguments - user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, - sort_by: Optional[List[SortBy]] = None, - # END TODO offset: int = 0, limit: int = 100, user_id: Optional[str] = None, ) -> SearchResponses: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html - - # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if metadata_filters: - filter = _unify_metadata_filters_with_filter(metadata_filters, filter) - if user_response_status_filter and user_response_status_filter.statuses: - filter = _unify_user_response_status_filter_with_filter(user_response_status_filter, filter) - - if sort_by: - sort = _unify_sort_by_with_order(sort_by, sort) - # END TODO index = await self._get_dataset_index(dataset) text_query = self._build_text_query(dataset, text=query) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 73077c4381..5e3c6653de 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -316,12 +316,9 @@ async def test_with_filter( RangeFilter(scope=SuggestionFilterScope(question=question.name, property="score"), ge=0.5), ] ), - metadata_filters=[], offset=0, limit=50, query=None, - sort_by=None, - user_response_status_filter=None, ) async def test_with_sort( @@ -367,12 +364,9 @@ async def test_with_sort( Order(scope=ResponseFilterScope(question=question.name), order=SortOrder.asc), Order(scope=SuggestionFilterScope(question=question.name, property="score"), order=SortOrder.desc), ], - metadata_filters=[], offset=0, limit=50, query=None, - sort_by=None, - user_response_status_filter=None, ) async def test_with_invalid_filter(self, async_client: AsyncClient, owner_auth_header: dict): diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 557cb4de70..a259baa773 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -14,11 +14,13 @@ import math import uuid from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type from unittest.mock import ANY, MagicMock from uuid import UUID, uuid4 import pytest +from sqlalchemy import func, inspect, select + from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla_server.api.schemas.v1.datasets import DATASET_GUIDELINES_MAX_LENGTH, DATASET_NAME_MAX_LENGTH from argilla_server.api.schemas.v1.fields import FIELD_CREATE_NAME_MAX_LENGTH, FIELD_CREATE_TITLE_MAX_LENGTH @@ -41,6 +43,7 @@ ResponseStatusFilter, SimilarityOrder, RecordStatus, + SortOrder, ) from argilla_server.models import ( Dataset, @@ -57,19 +60,18 @@ VectorSettings, ) from argilla_server.search_engine import ( - FloatMetadataFilter, - IntegerMetadataFilter, - MetadataFilter, SearchEngine, SearchResponseItem, SearchResponses, - SortBy, - TermsMetadataFilter, TextQuery, - UserResponseStatusFilter, + AndFilter, + TermsFilter, + MetadataFilterScope, + RangeFilter, + ResponseFilterScope, + Order, + RecordFilterScope, ) -from sqlalchemy import func, inspect, select - from tests.factories import ( AdminFactory, AnnotatorFactory, @@ -80,7 +82,6 @@ LabelSelectionQuestionFactory, MetadataPropertyFactory, MultiLabelSelectionQuestionFactory, - OwnerFactory, QuestionFactory, RatingQuestionFactory, RecordFactory, @@ -3592,11 +3593,8 @@ async def test_search_current_user_dataset_records( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 @@ -3633,55 +3631,85 @@ async def test_search_current_user_dataset_records( } @pytest.mark.parametrize( - ("property_config", "param_value", "expected_filter_class", "expected_filter_args"), + ("property_config", "metadata_filter", "expected_filter"), [ ( {"name": "terms_prop", "settings": {"type": "terms"}}, - "value", - TermsMetadataFilter, - dict(values=["value"]), + { + "type": "terms", + "values": ["value"], + "scope": {"entity": "metadata", "metadata_property": "terms_prop"}, + }, + TermsFilter(scope=MetadataFilterScope(metadata_property="terms_prop"), values=["value"]), ), ( {"name": "terms_prop", "settings": {"type": "terms"}}, - "value1,value2", - TermsMetadataFilter, - dict(values=["value1", "value2"]), + { + "type": "terms", + "values": ["value1", "value2"], + "scope": {"entity": "metadata", "metadata_property": "terms_prop"}, + }, + TermsFilter(scope=MetadataFilterScope(metadata_property="terms_prop"), values=["value1", "value2"]), ), ( {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 10, "le": 20}', - IntegerMetadataFilter, - dict(ge=10, le=20), + { + "type": "range", + "ge": 10, + "le": 20, + "scope": {"entity": "metadata", "metadata_property": "integer_prop"}, + }, + RangeFilter( + scope=MetadataFilterScope(metadata_property="integer_prop"), + ge=10, + le=20, + ), ), ( {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 20}', - IntegerMetadataFilter, - dict(ge=20, high=None), + {"type": "range", "ge": 20, "scope": {"entity": "metadata", "metadata_property": "integer_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="integer_prop"), + ge=20, + ), ), ( {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"le": 20}', - IntegerMetadataFilter, - dict(low=None, le=20), + {"type": "range", "le": 20, "scope": {"entity": "metadata", "metadata_property": "integer_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="integer_prop"), + le=20, + ), ), ( {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": -1.30, "le": 23.23}', - FloatMetadataFilter, - dict(ge=-1.30, le=23.23), + { + "type": "range", + "ge": -1.30, + "le": 23.23, + "scope": {"entity": "metadata", "metadata_property": "float_prop"}, + }, + RangeFilter( + scope=MetadataFilterScope(metadata_property="float_prop"), + ge=-1.30, + le=23.23, + ), ), ( {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": 23.23}', - FloatMetadataFilter, - dict(ge=23.23, high=None), + {"type": "range", "ge": 23.23, "scope": {"entity": "metadata", "metadata_property": "float_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="float_prop"), + ge=23.23, + ), ), ( {"name": "float_prop", "settings": {"type": "float"}}, - '{"le": 11.32}', - FloatMetadataFilter, - dict(low=None, le=11.32), + {"type": "range", "le": 11.32, "scope": {"entity": "metadata", "metadata_property": "float_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="float_prop"), + le=11.32, + ), ), ], ) @@ -3691,15 +3719,14 @@ async def test_search_current_user_dataset_records_with_metadata_filter( mock_search_engine: SearchEngine, owner: User, owner_auth_header: dict, - property_config: dict, - param_value: str, - expected_filter_class: Type[MetadataFilter], - expected_filter_args: dict, + property_config, + metadata_filter: dict, + expected_filter: Any, ): workspace = await WorkspaceFactory.create() dataset, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - metadata_property = await MetadataPropertyFactory.create( + await MetadataPropertyFactory.create( name=property_config["name"], settings=property_config["settings"], dataset=dataset, @@ -3713,12 +3740,9 @@ async def test_search_current_user_dataset_records_with_metadata_filter( ], ) - params = {"metadata": [f"{metadata_property.name}:{param_value}"]} - - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}, "filters": {"and": [metadata_filter]}} response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params=params, headers=owner_auth_header, json=query_json, ) @@ -3727,91 +3751,45 @@ async def test_search_current_user_dataset_records_with_metadata_filter( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[expected_filter_class(metadata_property=metadata_property, **expected_filter_args)], - user_response_status_filter=None, + filter=AndFilter(filters=[expected_filter]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) @pytest.mark.parametrize( - ("property_config", "wrong_value"), + "sort,expected_sort", [ - ({"name": "terms_prop", "settings": {"type": "terms"}}, None), - ({"name": "terms_prop", "settings": {"type": "terms"}}, "terms_prop"), - ({"name": "terms_prop", "settings": {"type": "terms"}}, "terms_prop:"), - ({"name": "terms_prop", "settings": {"type": "terms"}}, "wrong-value"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, None), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "integer_prop"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "integer_prop:"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "integer_prop:{}"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "wrong-value"), - ({"name": "float_prop", "settings": {"type": "float"}}, None), - ({"name": "float_prop", "settings": {"type": "float"}}, "float_prop"), - ({"name": "float_prop", "settings": {"type": "float"}}, "float_prop:"), - ({"name": "float_prop", "settings": {"type": "float"}}, "float_prop:{}"), - ({"name": "float_prop", "settings": {"type": "float"}}, "wrong-value"), - ], - ) - async def test_search_current_user_dataset_records_with_wrong_metadata_filter_values( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: User, - owner_auth_header: dict, - property_config: dict, - wrong_value: str, - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - - await MetadataPropertyFactory.create( - name=property_config["name"], - settings=property_config["settings"], - dataset=dataset, - ) - - mock_search_engine.search.return_value = SearchResponses( - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - total=2, - ) - - params = {"metadata": [wrong_value]} - - query_json = {"query": {"text": {"q": "Hello"}}} - response = await async_client.post( - f"/api/v1/me/datasets/{dataset.id}/records/search", - params=params, - headers=owner_auth_header, - json=query_json, - ) - assert response.status_code == 422, response.json() - - @pytest.mark.parametrize( - "sorts", - [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], + ( + [{"scope": {"entity": "record", "property": "inserted_at"}, "order": "asc"}], + [Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.asc)], + ), + ( + [{"scope": {"entity": "record", "property": "inserted_at"}, "order": "desc"}], + [Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.desc)], + ), + ( + [{"scope": {"entity": "record", "property": "updated_at"}, "order": "asc"}], + [Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.asc)], + ), + ( + [{"scope": {"entity": "record", "property": "updated_at"}, "order": "desc"}], + [Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc)], + ), + ( + [{"scope": {"entity": "metadata", "metadata_property": "terms-metadata-property"}, "order": "asc"}], + [Order(scope=MetadataFilterScope(metadata_property="terms-metadata-property"), order=SortOrder.asc)], + ), + ( + [ + {"scope": {"entity": "record", "property": "updated_at"}, "order": "desc"}, + {"scope": {"entity": "metadata", "metadata_property": "terms-metadata-property"}, "order": "desc"}, + ], + [ + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc), + Order(scope=MetadataFilterScope(metadata_property="terms-metadata-property"), order=SortOrder.desc), + ], + ), ], ) async def test_search_current_user_dataset_records_with_sort_by( @@ -3820,16 +3798,15 @@ async def test_search_current_user_dataset_records_with_sort_by( mock_search_engine: SearchEngine, owner: "User", owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], + sort: List[dict], + expected_sort: List[Order], ): workspace = await WorkspaceFactory.create() dataset, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) + for order in expected_sort: + if isinstance(order.scope, MetadataFilterScope): + await TermsMetadataPropertyFactory.create(name=order.scope.metadata_property, dataset=dataset) mock_search_engine.search.return_value = SearchResponses( total=2, @@ -3839,15 +3816,13 @@ async def test_search_current_user_dataset_records_with_sort_by( ], ) - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": sort, } - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} - response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params=query_params, headers=owner_auth_header, json=query_json, ) @@ -3857,11 +3832,9 @@ async def test_search_current_user_dataset_records_with_sort_by( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, + sort=expected_sort, user_id=owner.id, ) @@ -3871,18 +3844,17 @@ async def test_search_current_user_dataset_records_with_sort_by_with_wrong_sort_ workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [{"scope": {"entity": "record", "property": "wrong_property"}, "order": "asc"}], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "inserted_at:wrong"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } async def test_search_current_user_dataset_records_with_sort_by_with_non_existent_metadata_property( self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict @@ -3890,17 +3862,19 @@ async def test_search_current_user_dataset_records_with_sort_by_with_non_existen workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [{"scope": {"entity": "metadata", "metadata_property": "missing"}, "order": "asc"}], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "metadata.i-do-not-exist:asc"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." + "detail": f"MetadataProperty not found filtering by name=missing, dataset_id={dataset.id}" } async def test_search_current_user_dataset_records_with_sort_by_with_invalid_field( @@ -3909,19 +3883,19 @@ async def test_search_current_user_dataset_records_with_sort_by_with_invalid_fie workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [ + {"scope": {"entity": "wrong", "property": "wrong"}, "order": "asc"}, + ], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "not-valid"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. " - "It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } @pytest.mark.parametrize( "includes", @@ -4063,9 +4037,6 @@ async def test_search_current_user_dataset_records_with_include( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], - sort_by=None, - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, user_id=owner.id, @@ -4268,22 +4239,37 @@ async def test_search_current_user_dataset_records_with_response_status_filter( dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) mock_search_engine.search.return_value = SearchResponses(items=[]) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "filters": { + "and": [ + { + "type": "terms", + "scope": {"entity": "response", "property": "status"}, + "values": [ResponseStatus.submitted], + } + ] + }, + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", headers=owner_auth_header, json=query_json, - params={"response_status": ResponseStatus.submitted.value}, ) mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], - user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatusFilter.submitted]), + filter=AndFilter( + filters=[ + TermsFilter( + scope=ResponseFilterScope(property="status", user=owner), + values=[ResponseStatusFilter.submitted], + ) + ] + ), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 @@ -4326,8 +4312,6 @@ async def test_search_current_user_dataset_records_with_record_vector( query=None, order=SimilarityOrder.most_similar, max_results=5, - metadata_filters=[], - user_response_status_filter=None, ) async def test_search_current_user_dataset_records_with_vector_value( @@ -4370,8 +4354,6 @@ async def test_search_current_user_dataset_records_with_vector_value( query=None, order=SimilarityOrder.most_similar, max_results=10, - metadata_filters=[], - user_response_status_filter=None, ) async def test_search_current_user_dataset_records_with_vector_value_and_query( @@ -4419,8 +4401,6 @@ async def test_search_current_user_dataset_records_with_vector_value_and_query( query=TextQuery(q="Test query"), order=SimilarityOrder.most_similar, max_results=10, - metadata_filters=[], - user_response_status_filter=None, ) async def test_search_current_user_dataset_records_with_wrong_vector( @@ -4512,11 +4492,8 @@ async def test_search_current_user_dataset_records_with_offset_and_limit( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], - user_response_status_filter=None, offset=0, limit=5, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index 8f78940df3..4f989e5399 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -12,43 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Type, Union -from uuid import uuid4 +from typing import List, Optional, Tuple, Union import pytest -from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT -from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import RecordInclude, RecordSortField, ResponseStatus, UserRole, RecordStatus -from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace -from argilla_server.search_engine import ( - FloatMetadataFilter, - IntegerMetadataFilter, - MetadataFilter, - SearchEngine, - SearchResponseItem, - SearchResponses, - SortBy, - TermsMetadataFilter, -) from httpx import AsyncClient +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import RecordInclude, ResponseStatus +from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace from tests.factories import ( AdminFactory, AnnotatorFactory, DatasetFactory, LabelSelectionQuestionFactory, - MetadataPropertyFactory, RecordFactory, ResponseFactory, SuggestionFactory, - TermsMetadataPropertyFactory, TextFieldFactory, TextQuestionFactory, - UserFactory, VectorFactory, VectorSettingsFactory, WorkspaceFactory, - WorkspaceUserFactory, ) @@ -398,108 +382,6 @@ async def create_records_with_response( for record in await RecordFactory.create_batch(size=num_records, dataset=dataset): await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status) - @pytest.mark.parametrize( - ("property_config", "param_value", "expected_filter_class", "expected_filter_args"), - [ - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value", - TermsMetadataFilter, - dict(values=["value"]), - ), - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value1,value2", - TermsMetadataFilter, - dict(values=["value1", "value2"]), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 10, "le": 20}', - IntegerMetadataFilter, - dict(ge=10, le=20), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 20}', - IntegerMetadataFilter, - dict(ge=20, high=None), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"le": 20}', - IntegerMetadataFilter, - dict(ge=None, le=20), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": -1.30, "le": 23.23}', - FloatMetadataFilter, - dict(ge=-1.30, le=23.23), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": 23.23}', - FloatMetadataFilter, - dict(ge=23.23, high=None), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"le": 11.32}', - FloatMetadataFilter, - dict(ge=None, le=11.32), - ), - ], - ) - async def test_list_dataset_records_with_metadata_filter( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: User, - owner_auth_header: dict, - property_config: dict, - param_value: str, - expected_filter_class: Type[MetadataFilter], - expected_filter_args: dict, - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - metadata_property = await MetadataPropertyFactory.create( - name=property_config["name"], - settings=property_config["settings"], - dataset=dataset, - ) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = {"metadata": [f"{metadata_property.name}:{param_value}"]} - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - - response_json = response.json() - assert response_json["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[expected_filter_class(metadata_property=metadata_property, **expected_filter_args)], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=[SortBy(field=RecordSortField.inserted_at)], - ) - @pytest.mark.skip(reason="Factory integration with search engine") @pytest.mark.parametrize( "response_status_filter", ["missing", "pending", "discarded", "submitted", "draft", ["submitted", "draft"]] @@ -563,121 +445,6 @@ async def test_list_dataset_records_with_response_status_filter( ] ) - @pytest.mark.parametrize( - "sorts", - [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - ], - ) - async def test_list_dataset_records_with_sort_by( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: "User", - owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] - } - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - assert response.json()["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, - ) - - async def test_list_dataset_records_with_sort_by_with_wrong_sort_order_value( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", params={"sort_by": "inserted_at:wrong"}, headers=owner_auth_header - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } - - async def test_list_dataset_records_with_sort_by_with_non_existent_metadata_property( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params={"sort_by": "metadata.i-do-not-exist:asc"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." - } - - async def test_list_dataset_records_with_sort_by_with_invalid_field( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params={"sort_by": "not-valid"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. " - "It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } - async def test_list_dataset_records_without_authentication(self, async_client: "AsyncClient"): dataset = await DatasetFactory.create() @@ -793,753 +560,3 @@ async def create_dataset_with_user_responses( ] return dataset, questions, records, responses, suggestions - - async def test_list_current_user_dataset_records( - self, async_client: "AsyncClient", mock_search_engine: SearchEngine, owner: User, owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - record_a, record_b, record_c = records - - mock_search_engine.search.return_value = SearchResponses( - total=3, - items=[ - SearchResponseItem(record_id=record_a.id, score=14.2), - SearchResponseItem(record_id=record_b.id, score=12.2), - SearchResponseItem(record_id=record_c.id, score=10.2), - ], - ) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header) - - assert response.status_code == 200 - assert response.json() == { - "total": 3, - "items": [ - { - "id": str(record_a.id), - "status": RecordStatus.pending, - "fields": {"input": "input_a", "output": "output_a"}, - "metadata": None, - "dataset_id": str(dataset.id), - "external_id": record_a.external_id, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "status": RecordStatus.pending, - "fields": {"input": "input_b", "output": "output_b"}, - "metadata": {"unit": "test"}, - "dataset_id": str(dataset.id), - "external_id": record_b.external_id, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "status": RecordStatus.pending, - "fields": {"input": "input_c", "output": "output_c"}, - "metadata": None, - "dataset_id": str(dataset.id), - "external_id": record_c.external_id, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - } - - async def test_list_current_user_dataset_records_with_filtered_metadata_as_annotator( - self, async_client: "AsyncClient", mock_search_engine: SearchEngine, owner: User - ): - workspace = await WorkspaceFactory.create() - user = await AnnotatorFactory.create() - await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=user.id) - - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - await TermsMetadataPropertyFactory.create( - name="key1", - dataset=dataset, - allowed_roles=[UserRole.admin, UserRole.annotator], - ) - await TermsMetadataPropertyFactory.create( - name="key2", - dataset=dataset, - allowed_roles=[UserRole.admin], - ) - await TermsMetadataPropertyFactory.create( - name="key3", - dataset=dataset, - allowed_roles=[UserRole.admin], - ) - - record = await RecordFactory.create( - dataset=dataset, - fields={"input": "input_b", "output": "output_b"}, - metadata_={"key1": "value1", "key2": "value2", "key3": "value3", "extra": "extra"}, - ) - - mock_search_engine.search.return_value = SearchResponses( - total=1, - items=[SearchResponseItem(record_id=record.id, score=14.2)], - ) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: user.api_key} - ) - - assert response.status_code == 200 - assert response.json() == { - "total": 1, - "items": [ - { - "id": str(record.id), - "status": RecordStatus.pending, - "fields": {"input": "input_b", "output": "output_b"}, - "metadata": {"key1": "value1"}, - "dataset_id": str(dataset.id), - "external_id": record.external_id, - "inserted_at": record.inserted_at.isoformat(), - "updated_at": record.updated_at.isoformat(), - } - ], - } - - @pytest.mark.skip(reason="Factory integration with search engine") - @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin, UserRole.owner]) - @pytest.mark.parametrize( - "includes", - [[RecordInclude.responses], [RecordInclude.suggestions], [RecordInclude.responses, RecordInclude.suggestions]], - ) - async def test_list_current_user_dataset_records_with_include( - self, async_client: "AsyncClient", role: UserRole, includes: List[RecordInclude] - ): - workspace = await WorkspaceFactory.create() - user = await UserFactory.create(workspaces=[workspace], role=role) - dataset, questions, records, responses, suggestions = await self.create_dataset_with_user_responses( - user, workspace - ) - record_a, record_b, record_c = records - response_a_user, response_b_user = responses[1], responses[3] - suggestion_a, suggestion_b = suggestions - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - params = [("include", include.value) for include in includes] - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", params=params, headers={API_KEY_HEADER_NAME: user.api_key} - ) - - expected = { - "total": 3, - "items": [ - { - "id": str(record_a.id), - "fields": {"input": "input_a", "output": "output_a"}, - "metadata": None, - "external_id": record_a.external_id, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "fields": {"input": "input_b", "output": "output_b"}, - "metadata": {"unit": "test"}, - "external_id": record_b.external_id, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "fields": {"input": "input_c", "output": "output_c"}, - "metadata": None, - "external_id": record_c.external_id, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - } - - if RecordInclude.responses in includes: - expected["items"][0]["responses"] = [ - { - "id": str(response_a_user.id), - "values": None, - "status": "discarded", - "user_id": str(user.id), - "inserted_at": response_a_user.inserted_at.isoformat(), - "updated_at": response_a_user.updated_at.isoformat(), - } - ] - expected["items"][1]["responses"] = [ - { - "id": str(response_b_user.id), - "values": { - "input_ok": {"value": "no"}, - "output_ok": {"value": "no"}, - }, - "status": "submitted", - "user_id": str(user.id), - "inserted_at": response_b_user.inserted_at.isoformat(), - "updated_at": response_b_user.updated_at.isoformat(), - }, - ] - expected["items"][2]["responses"] = [] - - if RecordInclude.suggestions in includes: - expected["items"][0]["suggestions"] = [ - { - "id": str(suggestion_a.id), - "value": "option-1", - "score": None, - "agent": None, - "type": None, - "question_id": str(questions[0].id), - } - ] - expected["items"][1]["suggestions"] = [ - { - "id": str(suggestion_b.id), - "value": "option-2", - "score": 0.75, - "agent": "unit-test-agent", - "type": "model", - "question_id": str(questions[0].id), - } - ] - expected["items"][2]["suggestions"] = [] - - assert response.status_code == 200 - assert response.json() == expected - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_include_vectors( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(dataset=dataset) - record_b = await RecordFactory.create(dataset=dataset) - record_c = await RecordFactory.create(dataset=dataset) - vector_settings_a = await VectorSettingsFactory.create(name="vector-a", dimensions=3, dataset=dataset) - vector_settings_b = await VectorSettingsFactory.create(name="vector-b", dimensions=2, dataset=dataset) - - await VectorFactory.create(value=[1.0, 2.0, 3.0], vector_settings=vector_settings_a, record=record_a) - await VectorFactory.create(value=[4.0, 5.0], vector_settings=vector_settings_b, record=record_a) - await VectorFactory.create(value=[1.0, 2.0], vector_settings=vector_settings_b, record=record_b) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"include": RecordInclude.vectors.value}, - headers=owner_auth_header, - ) - - assert response.status_code == 200 - assert response.json() == { - "items": [ - { - "id": str(record_a.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_a.external_id, - "vectors": { - "vector-a": [1.0, 2.0, 3.0], - "vector-b": [4.0, 5.0], - }, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_b.external_id, - "vectors": { - "vector-b": [1.0, 2.0], - }, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_c.external_id, - "vectors": {}, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - "total": 3, - } - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_include_specific_vectors( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(dataset=dataset) - record_b = await RecordFactory.create(dataset=dataset) - record_c = await RecordFactory.create(dataset=dataset) - vector_settings_a = await VectorSettingsFactory.create(name="vector-a", dimensions=3, dataset=dataset) - vector_settings_b = await VectorSettingsFactory.create(name="vector-b", dimensions=2, dataset=dataset) - vector_settings_c = await VectorSettingsFactory.create(name="vector-c", dimensions=4, dataset=dataset) - - await VectorFactory.create(value=[1.0, 2.0, 3.0], vector_settings=vector_settings_a, record=record_a) - await VectorFactory.create(value=[4.0, 5.0], vector_settings=vector_settings_b, record=record_a) - await VectorFactory.create(value=[6.0, 7.0, 8.0, 9.0], vector_settings=vector_settings_c, record=record_a) - await VectorFactory.create(value=[1.0, 2.0], vector_settings=vector_settings_b, record=record_b) - await VectorFactory.create(value=[10.0, 11.0, 12.0, 13.0], vector_settings=vector_settings_c, record=record_b) - await VectorFactory.create(value=[14.0, 15.0, 16.0, 17.0], vector_settings=vector_settings_c, record=record_c) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"include": f"{RecordInclude.vectors.value}:{vector_settings_a.name},{vector_settings_b.name}"}, - headers=owner_auth_header, - ) - - assert response.status_code == 200 - assert response.json() == { - "items": [ - { - "id": str(record_a.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_a.external_id, - "vectors": { - "vector-a": [1.0, 2.0, 3.0], - "vector-b": [4.0, 5.0], - }, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_b.external_id, - "vectors": { - "vector-b": [1.0, 2.0], - }, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_c.external_id, - "vectors": {}, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - "total": 3, - } - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_offset( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - record_c = await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params={"offset": 2} - ) - - assert response.status_code == 200 - - response_body = response.json() - assert [item["id"] for item in response_body["items"]] == [str(record_c.id)] - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_limit( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params={"limit": 1} - ) - - assert response.status_code == 200 - - response_body = response.json() - assert [item["id"] for item in response_body["items"]] == [str(record_a.id)] - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_offset_and_limit( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - record_c = await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params={"offset": 1, "limit": 1} - ) - - assert response.status_code == 200 - - response_body = response.json() - assert [item["id"] for item in response_body["items"]] == [str(record_c.id)] - - @pytest.mark.parametrize( - ("property_config", "param_value", "expected_filter_class", "expected_filter_args"), - [ - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value", - TermsMetadataFilter, - dict(values=["value"]), - ), - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value1,value2", - TermsMetadataFilter, - dict(values=["value1", "value2"]), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 10, "le": 20}', - IntegerMetadataFilter, - dict(ge=10, le=20), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 20}', - IntegerMetadataFilter, - dict(ge=20, le=None), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"le": 20}', - IntegerMetadataFilter, - dict(ge=None, le=20), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": -1.30, "le": 23.23}', - FloatMetadataFilter, - dict(ge=-1.30, le=23.23), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": 23.23}', - FloatMetadataFilter, - dict(ge=23.23, le=None), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"le": 11.32}', - FloatMetadataFilter, - dict(ge=None, le=11.32), - ), - ], - ) - async def test_list_current_user_dataset_records_with_metadata_filter( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: User, - owner_auth_header: dict, - property_config: dict, - param_value: str, - expected_filter_class: Type[MetadataFilter], - expected_filter_args: dict, - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - metadata_property = await MetadataPropertyFactory.create( - name=property_config["name"], - settings=property_config["settings"], - dataset=dataset, - ) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = {"metadata": [f"{metadata_property.name}:{param_value}"]} - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - - response_json = response.json() - assert response_json["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[expected_filter_class(metadata_property=metadata_property, **expected_filter_args)], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, - user_id=owner.id, - ) - - @pytest.mark.skip(reason="Factory integration with search engine") - @pytest.mark.parametrize("response_status_filter", ["missing", "pending", "discarded", "submitted", "draft"]) - async def test_list_current_user_dataset_records_with_response_status_filter( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict, response_status_filter: str - ): - num_responses_per_status = 10 - response_values = {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}} - - dataset = await DatasetFactory.create() - # missing responses - await RecordFactory.create_batch(size=num_responses_per_status, dataset=dataset) - # discarded responses - await self.create_records_with_response(num_responses_per_status, dataset, owner, ResponseStatus.discarded) - # submitted responses - await self.create_records_with_response( - num_responses_per_status, dataset, owner, ResponseStatus.submitted, response_values - ) - # drafted responses - await self.create_records_with_response( - num_responses_per_status, dataset, owner, ResponseStatus.draft, response_values - ) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records?response_status={response_status_filter}&include=responses", - headers=owner_auth_header, - ) - - assert response.status_code == 200 - response_json = response.json() - - assert len(response_json["items"]) == num_responses_per_status - - if response_status_filter in ["missing", "pending"]: - assert all([len(record["responses"]) == 0 for record in response_json["items"]]) - else: - assert all( - [record["responses"][0]["status"] == response_status_filter for record in response_json["items"]] - ) - - @pytest.mark.parametrize( - "sorts", - [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - ], - ) - async def test_list_current_user_dataset_records_with_sort_by( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: "User", - owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] - } - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - assert response.json()["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, - user_id=owner.id, - ) - - async def test_list_current_user_dataset_records_with_sort_by_with_wrong_sort_order_value( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"sort_by": "inserted_at:wrong"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } - - async def test_list_current_user_dataset_records_with_sort_by_with_non_existent_metadata_property( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"sort_by": "metadata.i-do-not-exist:asc"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." - } - - async def test_list_current_user_dataset_records_with_sort_by_with_invalid_field( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"sort_by": "not-valid"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } - - async def test_list_current_user_dataset_records_without_authentication(self, async_client: "AsyncClient"): - dataset = await DatasetFactory.create() - - response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/records") - - assert response.status_code == 401 - - @pytest.mark.skip(reason="Factory integration with search engine") - @pytest.mark.parametrize("role", [UserRole.admin, UserRole.annotator]) - async def test_list_current_user_dataset_records_as_restricted_user( - self, async_client: "AsyncClient", role: UserRole - ): - workspace = await WorkspaceFactory.create() - user = await UserFactory.create(workspaces=[workspace], role=role) - dataset = await DatasetFactory.create(workspace=workspace) - record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - record_b = await RecordFactory.create( - fields={"record_b": "value_b"}, metadata_={"unit": "test"}, dataset=dataset - ) - record_c = await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - expected_records = [record_a, record_b, record_c] - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: user.api_key} - ) - - assert response.status_code == 200 - - response_items = response.json()["items"] - - for expected_record in expected_records: - found_items = [item for item in response_items if item["id"] == str(expected_record.id)] - assert found_items, expected_record - - assert found_items[0] == { - "id": str(expected_record.id), - "fields": expected_record.fields, - "metadata": expected_record.metadata_, - "external_id": expected_record.external_id, - "inserted_at": expected_record.inserted_at.isoformat(), - "updated_at": expected_record.updated_at.isoformat(), - } - - @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) - async def test_list_current_user_dataset_records_as_restricted_user_from_different_workspace( - self, async_client: "AsyncClient", role: UserRole - ): - dataset = await DatasetFactory.create() - workspace = await WorkspaceFactory.create() - user = await UserFactory.create(workspaces=[workspace], role=role) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: user.api_key} - ) - - assert response.status_code == 403 - - async def test_list_current_user_dataset_records_with_nonexistent_dataset_id( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset_id = uuid4() - - await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset_id}/records", - headers=owner_auth_header, - ) - - assert response.status_code == 404 - assert response.json() == {"detail": f"Dataset with id `{dataset_id}` not found"} diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index 6c927d42a2..f57c115492 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -16,18 +16,27 @@ import pytest import pytest_asyncio -from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder, RecordStatus +from argilla_server.enums import ( + MetadataPropertyType, + QuestionType, + ResponseStatusFilter, + SimilarityOrder, + RecordStatus, + SortOrder, +) from argilla_server.models import Dataset, Question, Record, User, VectorSettings from argilla_server.search_engine import ( - FloatMetadataFilter, - IntegerMetadataFilter, ResponseFilterScope, SortBy, SuggestionFilterScope, TermsFilter, - TermsMetadataFilter, TextQuery, UserResponseStatusFilter, + Filter, + MetadataFilterScope, + RangeFilter, + Order, + RecordFilterScope, ) from argilla_server.search_engine.commons import ( ALL_RESPONSES_STATUSES_FIELD, @@ -595,7 +604,7 @@ async def test_search_with_response_status_filter( result = await search_engine.search( test_banking_sentiment_dataset, query=TextQuery(q="payment"), - user_response_status_filter=UserResponseStatusFilter(user=user, statuses=statuses), + filter=TermsFilter(scope=ResponseFilterScope(property="status"), values=statuses), ) assert len(result.items) == expected_items assert result.total == expected_items @@ -669,26 +678,26 @@ async def test_search_with_response_status_filter_with_no_user( result = await search_engine.search( test_banking_sentiment_dataset, - user_response_status_filter=UserResponseStatusFilter(statuses=statuses, user=None), + filter=TermsFilter(ResponseFilterScope(property="status"), values=statuses), ) assert len(result.items) == expected_items assert result.total == expected_items @pytest.mark.parametrize( - ("metadata_filters_config", "expected_items"), + ("filter", "expected_items"), [ - ([{"name": "label", "values": ["neutral"]}], 4), - ([{"name": "label", "values": ["positive"]}], 1), - ([{"name": "label", "values": ["neutral", "positive"]}], 5), - ([{"name": "textId", "ge": 3, "le": 4}], 2), - ([{"name": "textId", "ge": 3, "le": 3}], 1), - ([{"name": "textId", "ge": 3}], 6), - ([{"name": "textId", "le": 4}], 5), - ([{"name": "seq_float", "ge": 0.0, "le": 12.03}], 3), - ([{"name": "seq_float", "ge": 0.13, "le": 0.13}], 1), - ([{"name": "seq_float", "ge": 0.0}], 7), - ([{"name": "seq_float", "le": 12.03}], 5), + (TermsFilter(scope=MetadataFilterScope(metadata_property="label"), values=["neutral"]), 4), + (TermsFilter(scope=MetadataFilterScope(metadata_property="label"), values=["positive"]), 1), + (TermsFilter(scope=MetadataFilterScope(metadata_property="label"), values=["neutral", "positive"]), 5), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), ge=3, le=4), 2), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), ge=3, le=3), 1), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), ge=3), 6), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), le=4), 5), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0, le=12.03), 3), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0.13, le=0.13), 1), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0.0), 7), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), le=12.03), 5), ], ) async def test_search_with_metadata_filter( @@ -696,24 +705,10 @@ async def test_search_with_metadata_filter( search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset, - metadata_filters_config: List[dict], + filter: Filter, expected_items: int, ): - metadata_filters = [] - for metadata_filter_config in metadata_filters_config: - name = metadata_filter_config.pop("name") - for metadata_property in test_banking_sentiment_dataset.metadata_properties: - if name == metadata_property.name: - if metadata_property.type == MetadataPropertyType.terms: - filter_cls = TermsMetadataFilter - elif metadata_property.type == MetadataPropertyType.integer: - filter_cls = IntegerMetadataFilter - else: - filter_cls = FloatMetadataFilter - metadata_filters.append(filter_cls(metadata_property=metadata_property, **metadata_filter_config)) - break - - result = await search_engine.search(test_banking_sentiment_dataset, metadata_filters=metadata_filters) + result = await search_engine.search(test_banking_sentiment_dataset, filter=filter) assert len(result.items) == expected_items assert result.total == expected_items @@ -748,7 +743,7 @@ async def test_search_with_response_status_filter_does_not_affect_the_result_sco results = await search_engine.search( test_banking_sentiment_dataset, query=TextQuery(q="payment"), - user_response_status_filter=UserResponseStatusFilter(user=user, statuses=all_statuses), + filter=TermsFilter(scope=ResponseFilterScope(property="status", user=user), values=all_statuses), ) assert len(no_filter_results.items) == len(results.items) @@ -834,12 +829,12 @@ async def test_search_with_pagination( assert all_results.items[offset : offset + limit] == results.items @pytest.mark.parametrize( - ("sort_by"), + ("sort_order"), [ - SortBy(field="inserted_at"), - SortBy(field="updated_at"), - SortBy(field="inserted_at", order="desc"), - SortBy(field="updated_at", order="desc"), + Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.asc), + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.asc), + Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.desc), + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc), ], ) async def test_search_with_sort_by( @@ -847,18 +842,15 @@ async def test_search_with_sort_by( search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset, - sort_by: SortBy, + sort_order: Order, ): def _local_sort_by(record: Record) -> Any: - if isinstance(sort_by.field, str): - return getattr(record, sort_by.field) - return record.metadata_[sort_by.field.name] + return getattr(record, sort_order.scope.property) - results = await search_engine.search(test_banking_sentiment_dataset, sort_by=[sort_by]) + results = await search_engine.search(test_banking_sentiment_dataset, sort=[sort_order]) records = test_banking_sentiment_dataset.records - if sort_by: - records = sorted(records, key=_local_sort_by, reverse=sort_by.order == "desc") + records = sorted(records, key=_local_sort_by, reverse=sort_order.order == "desc") assert [item.record_id for item in results.items] == [record.id for record in records] @@ -1348,32 +1340,34 @@ async def test_similarity_search_by_vector_value_with_order( assert responses.items[0].record_id != selected_record.id @pytest.mark.parametrize( - "user_response_status_filter", + "statuses", [ - None, - UserResponseStatusFilter(statuses=[ResponseStatusFilter.missing, ResponseStatusFilter.draft]), + [], + [ResponseStatusFilter.missing, ResponseStatusFilter.draft], ], ) - async def test_similarity_search_by_record_and_user_response_filter( + async def test_similarity_search_by_record_and_response_status_filter( self, search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset_with_vectors: Dataset, - user_response_status_filter: UserResponseStatusFilter, + statuses: List[ResponseStatusFilter], ): selected_record: Record = test_banking_sentiment_dataset_with_vectors.records[0] vector_settings: VectorSettings = test_banking_sentiment_dataset_with_vectors.vectors_settings[0] - if user_response_status_filter: + scope = ResponseFilterScope(property="status") + + if statuses: test_user = await UserFactory.create() - user_response_status_filter.user = test_user + scope.user = test_user responses = await search_engine.similarity_search( dataset=test_banking_sentiment_dataset_with_vectors, vector_settings=vector_settings, record=selected_record, max_results=1, - user_response_status_filter=user_response_status_filter, + filter=TermsFilter(scope=scope, values=statuses), ) assert responses.total == 1