Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING - REFACTOR] argilla-server: remove user response status support #5163

Merged
1 change: 1 addition & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ These are the section headers that we use:
### Removed

- Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153))
- [breaking] Removed support for `response_status` query param. ([#5163](https://github.com/argilla-io/argilla/pull/5163))
- [breaking] Removed support for `metadata` query param. ([#5156](https://github.com/argilla-io/argilla/pull/5156))

## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ async def _filter_records_using_search_engine(
limit: int,
offset: int,
user: Optional[User] = None,
response_statuses: Optional[List[ResponseStatusFilter]] = None,
include: Optional[RecordIncludeParam] = None,
sort_by_query_param: Optional[Dict[str, str]] = None,
) -> Tuple[List[Record], int]:
Expand All @@ -114,7 +113,6 @@ async def _filter_records_using_search_engine(
limit=limit,
offset=offset,
user=user,
response_statuses=response_statuses,
sort_by_query_param=sort_by_query_param,
)

Expand Down Expand Up @@ -178,7 +176,6 @@ async def _get_search_responses(
offset: int,
search_records_query: Optional[SearchRecordsQuery] = None,
user: Optional[User] = None,
response_statuses: Optional[List[ResponseStatusFilter]] = None,
sort_by_query_param: Optional[Dict[str, str]] = None,
) -> "SearchResponses":
search_records_query = search_records_query or SearchRecordsQuery()
Expand Down Expand Up @@ -219,7 +216,6 @@ async def _get_search_responses(
if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id):
raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.")

response_status_filter = await _build_response_status_filter_for_search(response_statuses, user=user)
sort_by = await _build_sort_by(db, dataset, sort_by_query_param)

if vector_query and vector_settings:
Expand All @@ -230,7 +226,6 @@ async def _get_search_responses(
"record": record,
"query": text_query,
"order": vector_query.order,
"user_response_status_filter": response_status_filter,
"max_results": limit,
}

Expand All @@ -242,7 +237,6 @@ async def _get_search_responses(
search_params = {
"dataset": dataset,
"query": text_query,
"user_response_status_filter": response_status_filter,
"offset": offset,
"limit": limit,
"sort_by": sort_by,
Expand Down Expand Up @@ -323,7 +317,6 @@ async def list_dataset_records(
dataset_id: UUID,
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
current_user: User = Security(auth.get_current_user),
Expand All @@ -338,7 +331,6 @@ async def list_dataset_records(
dataset=dataset,
limit=limit,
offset=offset,
response_statuses=response_statuses,
include=include,
sort_by_query_param=sort_by_query_param or LIST_DATASET_RECORDS_DEFAULT_SORT_BY,
)
Expand Down Expand Up @@ -451,7 +443,6 @@ async def search_current_user_dataset_records(
body: SearchRecordsQuery,
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = Query(0, ge=0),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
current_user: User = Security(auth.get_current_user),
Expand All @@ -477,7 +468,6 @@ async def search_current_user_dataset_records(
limit=limit,
offset=offset,
user=current_user,
response_statuses=response_statuses,
sort_by_query_param=sort_by_query_param,
)

Expand Down Expand Up @@ -523,7 +513,6 @@ async def search_dataset_records(
body: SearchRecordsQuery,
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = Query(0, ge=0),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
current_user: User = Security(auth.get_current_user),
Expand All @@ -541,7 +530,6 @@ async def search_dataset_records(
search_records_query=body,
limit=limit,
offset=offset,
response_statuses=response_statuses,
sort_by_query_param=sort_by_query_param,
)

Expand Down
4 changes: 0 additions & 4 deletions argilla-server/src/argilla_server/search_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ async def search(
filter: Optional[Filter] = None,
sort: Optional[List[Order]] = None,
# TODO: remove them and keep filter and order
user_response_status_filter: Optional[UserResponseStatusFilter] = None,
sort_by: Optional[List[SortBy]] = None,
# END TODO
offset: int = 0,
Expand Down Expand Up @@ -311,9 +310,6 @@ async def similarity_search(
record: Optional[Record] = None,
query: Optional[Union[TextQuery, str]] = None,
filter: Optional[Filter] = None,
# TODO: remove them and keep filter
user_response_status_filter: Optional[UserResponseStatusFilter] = None,
# END TODO
max_results: int = 100,
order: SimilarityOrder = SimilarityOrder.most_similar,
threshold: Optional[float] = None,
Expand Down
25 changes: 0 additions & 25 deletions argilla-server/src/argilla_server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str:
return str(vector_settings.id)


# This function will be moved once the response status filter is removed from search and similarity_search methods
def _unify_user_response_status_filter_with_filter(
user_response_status_filter: UserResponseStatusFilter, filter: Optional[Filter] = None
) -> Filter:
scope = ResponseFilterScope(user=user_response_status_filter.user, property="status")
response_filter = TermsFilter(scope=scope, values=[status.value for status in user_response_status_filter.statuses])

if filter:
return AndFilter(filters=[filter, response_filter])
else:
return response_filter


# This function will be moved once the `sort_by` argument is removed from search and similarity_search methods
def _unify_sort_by_with_order(sort_by: List[SortBy], order: List[Order]) -> List[Order]:
if order:
Expand Down Expand Up @@ -393,18 +380,10 @@ async def similarity_search(
record: Optional[Record] = None,
query: Optional[Union[TextQuery, str]] = None,
filter: Optional[Filter] = None,
# TODO: remove them and keep filter
user_response_status_filter: Optional[UserResponseStatusFilter] = None,
# END TODO
max_results: int = 100,
order: SimilarityOrder = SimilarityOrder.most_similar,
threshold: Optional[float] = None,
) -> SearchResponses:
# TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept
if user_response_status_filter and user_response_status_filter.statuses:
filter = _unify_user_response_status_filter_with_filter(user_response_status_filter, filter)
# END TODO

if bool(value) == bool(record):
raise ValueError("Must provide either vector value or record to compute the similarity search")

Expand Down Expand Up @@ -598,7 +577,6 @@ async def search(
filter: Optional[Filter] = None,
sort: Optional[List[Order]] = None,
# TODO: Remove these arguments
user_response_status_filter: Optional[UserResponseStatusFilter] = None,
sort_by: Optional[List[SortBy]] = None,
# END TODO
offset: int = 0,
Expand All @@ -608,9 +586,6 @@ async def search(
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html

# TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept
if user_response_status_filter and user_response_status_filter.statuses:
filter = _unify_user_response_status_filter_with_filter(user_response_status_filter, filter)

if sort_by:
sort = _unify_sort_by_with_order(sort_by, sort)
# END TODO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ async def test_with_filter(
limit=50,
query=None,
sort_by=None,
user_response_status_filter=None,
)

async def test_with_sort(
Expand Down Expand Up @@ -370,7 +369,6 @@ async def test_with_sort(
limit=50,
query=None,
sort_by=None,
user_response_status_filter=None,
)

async def test_with_invalid_filter(self, async_client: AsyncClient, owner_auth_header: dict):
Expand Down
38 changes: 23 additions & 15 deletions argilla-server/tests/unit/api/handlers/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from uuid import UUID, uuid4

import pytest
from sqlalchemy import func, inspect, select

from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT
from argilla_server.api.schemas.v1.datasets import DATASET_GUIDELINES_MAX_LENGTH, DATASET_NAME_MAX_LENGTH
from argilla_server.api.schemas.v1.fields import FIELD_CREATE_NAME_MAX_LENGTH, FIELD_CREATE_TITLE_MAX_LENGTH
Expand Down Expand Up @@ -62,14 +64,12 @@
SearchResponses,
SortBy,
TextQuery,
UserResponseStatusFilter,
AndFilter,
TermsFilter,
MetadataFilterScope,
RangeFilter,
ResponseFilterScope,
)
from sqlalchemy import func, inspect, select

from tests.factories import (
AdminFactory,
AnnotatorFactory,
Expand All @@ -80,7 +80,6 @@
LabelSelectionQuestionFactory,
MetadataPropertyFactory,
MultiLabelSelectionQuestionFactory,
OwnerFactory,
QuestionFactory,
RatingQuestionFactory,
RecordFactory,
Expand Down Expand Up @@ -3650,7 +3649,6 @@ async def test_search_current_user_dataset_records(
mock_search_engine.search.assert_called_once_with(
dataset=dataset,
query=TextQuery(q="Hello", field="input"),
user_response_status_filter=None,
offset=0,
limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT,
sort_by=None,
Expand Down Expand Up @@ -3811,7 +3809,6 @@ async def test_search_current_user_dataset_records_with_metadata_filter(
dataset=dataset,
query=TextQuery(q="Hello", field="input"),
filter=AndFilter(filters=[expected_filter]),
user_response_status_filter=None,
offset=0,
limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT,
sort_by=None,
Expand Down Expand Up @@ -3884,7 +3881,6 @@ async def test_search_current_user_dataset_records_with_sort_by(
mock_search_engine.search.assert_called_once_with(
dataset=dataset,
query=TextQuery(q="Hello", field="input"),
user_response_status_filter=None,
offset=0,
limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT,
sort_by=expected_sorts_by,
Expand Down Expand Up @@ -4090,7 +4086,6 @@ async def test_search_current_user_dataset_records_with_include(
dataset=dataset,
query=TextQuery(q="Hello", field="input"),
sort_by=None,
user_response_status_filter=None,
offset=0,
limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT,
user_id=owner.id,
Expand Down Expand Up @@ -4293,18 +4288,35 @@ async def test_search_current_user_dataset_records_with_response_status_filter(
dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace)
mock_search_engine.search.return_value = SearchResponses(items=[])

query_json = {"query": {"text": {"q": "Hello", "field": "input"}}}
query_json = {
"query": {"text": {"q": "Hello", "field": "input"}},
"filters": {
"and": [
{
"type": "terms",
"scope": {"entity": "response", "property": "status"},
"values": [ResponseStatus.submitted.value],
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
}
]
},
}
response = await async_client.post(
f"/api/v1/me/datasets/{dataset.id}/records/search",
headers=owner_auth_header,
json=query_json,
params={"response_status": ResponseStatus.submitted.value},
)

mock_search_engine.search.assert_called_once_with(
dataset=dataset,
query=TextQuery(q="Hello", field="input"),
user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatusFilter.submitted]),
filter=AndFilter(
filters=[
TermsFilter(
scope=ResponseFilterScope(property="status", user=owner),
values=[ResponseStatusFilter.submitted],
)
]
),
offset=0,
limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT,
sort_by=None,
Expand Down Expand Up @@ -4350,7 +4362,6 @@ async def test_search_current_user_dataset_records_with_record_vector(
query=None,
order=SimilarityOrder.most_similar,
max_results=5,
user_response_status_filter=None,
)

async def test_search_current_user_dataset_records_with_vector_value(
Expand Down Expand Up @@ -4393,7 +4404,6 @@ async def test_search_current_user_dataset_records_with_vector_value(
query=None,
order=SimilarityOrder.most_similar,
max_results=10,
user_response_status_filter=None,
)

async def test_search_current_user_dataset_records_with_vector_value_and_query(
Expand Down Expand Up @@ -4441,7 +4451,6 @@ async def test_search_current_user_dataset_records_with_vector_value_and_query(
query=TextQuery(q="Test query"),
order=SimilarityOrder.most_similar,
max_results=10,
user_response_status_filter=None,
)

async def test_search_current_user_dataset_records_with_wrong_vector(
Expand Down Expand Up @@ -4533,7 +4542,6 @@ async def test_search_current_user_dataset_records_with_offset_and_limit(
mock_search_engine.search.assert_called_once_with(
dataset=dataset,
query=TextQuery(q="Hello", field="input"),
user_response_status_filter=None,
offset=0,
limit=5,
sort_by=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,6 @@ async def test_list_dataset_records_with_sort_by(
mock_search_engine.search.assert_called_once_with(
dataset=dataset,
query=None,
user_response_status_filter=None,
offset=0,
limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT,
sort_by=expected_sorts_by,
Expand Down
Loading