Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING - REFACTOR] argilla-server: remove user response status support #5163

Merged
Prev Previous commit
Next Next commit
[breaking] refactor: Remove metadata query param filter for records e…
…npoint
frascuchon committed Jul 4, 2024
commit 93a7b92d3476628c0227a1336945c24cf740cbc6
Original file line number Diff line number Diff line change
@@ -27,8 +27,6 @@
Filters,
FilterScope,
MetadataFilterScope,
MetadataParsedQueryParam,
MetadataQueryParams,
Order,
RangeFilter,
RecordFilterScope,
@@ -51,19 +49,15 @@
)
from argilla_server.contexts import datasets, search
from argilla_server.database import get_async_db
from argilla_server.enums import MetadataPropertyType, RecordSortField, ResponseStatusFilter, SortOrder
from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder
from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError
from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE
from argilla_server.models import Dataset, Field, MetadataProperty, Record, User, VectorSettings
from argilla_server.search_engine import (
AndFilter,
FloatMetadataFilter,
IntegerMetadataFilter,
MetadataFilter,
SearchEngine,
SearchResponses,
SortBy,
TermsMetadataFilter,
UserResponseStatusFilter,
get_search_engine,
)
@@ -106,7 +100,6 @@ async def _filter_records_using_search_engine(
db: "AsyncSession",
search_engine: "SearchEngine",
dataset: Dataset,
parsed_metadata: List[MetadataParsedQueryParam],
limit: int,
offset: int,
user: Optional[User] = None,
@@ -121,7 +114,6 @@ async def _filter_records_using_search_engine(
limit=limit,
offset=offset,
user=user,
parsed_metadata=parsed_metadata,
response_statuses=response_statuses,
sort_by_query_param=sort_by_query_param,
)
@@ -182,7 +174,6 @@ async def _get_search_responses(
db: "AsyncSession",
search_engine: "SearchEngine",
dataset: Dataset,
parsed_metadata: List[MetadataParsedQueryParam],
limit: int,
offset: int,
search_records_query: Optional[SearchRecordsQuery] = None,
@@ -228,7 +219,6 @@ async def _get_search_responses(
if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id):
raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.")

metadata_filters = await _build_metadata_filters(db, dataset, parsed_metadata)
response_status_filter = await _build_response_status_filter_for_search(response_statuses, user=user)
sort_by = await _build_sort_by(db, dataset, sort_by_query_param)

@@ -240,7 +230,6 @@ async def _get_search_responses(
"record": record,
"query": text_query,
"order": vector_query.order,
"metadata_filters": metadata_filters,
"user_response_status_filter": response_status_filter,
"max_results": limit,
}
@@ -253,7 +242,6 @@ async def _get_search_responses(
search_params = {
"dataset": dataset,
"query": text_query,
"metadata_filters": metadata_filters,
"user_response_status_filter": response_status_filter,
"offset": offset,
"limit": limit,
@@ -271,32 +259,6 @@ async def _get_search_responses(
return await search_engine.search(**search_params)


async def _build_metadata_filters(
db: "AsyncSession", dataset: Dataset, parsed_metadata: List[MetadataParsedQueryParam]
) -> List["MetadataFilter"]:
try:
metadata_filters = []
for metadata_param in parsed_metadata:
metadata_property = await MetadataProperty.get_by(db, name=metadata_param.name, dataset_id=dataset.id)
if metadata_property is None:
continue # won't fail on unknown metadata filter name

if metadata_property.type == MetadataPropertyType.terms:
metadata_filter_class = TermsMetadataFilter
elif metadata_property.type == MetadataPropertyType.integer:
metadata_filter_class = IntegerMetadataFilter
elif metadata_property.type == MetadataPropertyType.float:
metadata_filter_class = FloatMetadataFilter
else:
raise ValueError(f"Not found filter for type {metadata_property.type}")

metadata_filters.append(metadata_filter_class.from_string(metadata_property, metadata_param.value))
except (UnprocessableEntityError, ValueError) as ex:
raise UnprocessableEntityError(f"Cannot parse provided metadata filters: {ex}")

return metadata_filters


async def _build_response_status_filter_for_search(
response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None
) -> Optional[UserResponseStatusFilter]:
@@ -359,7 +321,6 @@ async def list_dataset_records(
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
dataset_id: UUID,
metadata: MetadataQueryParams = Depends(),
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
@@ -375,7 +336,6 @@ async def list_dataset_records(
db,
search_engine,
dataset=dataset,
parsed_metadata=metadata.metadata_parsed,
limit=limit,
offset=offset,
response_statuses=response_statuses,
@@ -489,7 +449,6 @@ async def search_current_user_dataset_records(
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
dataset_id: UUID,
body: SearchRecordsQuery,
metadata: MetadataQueryParams = Depends(),
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
@@ -515,7 +474,6 @@ async def search_current_user_dataset_records(
search_engine=search_engine,
dataset=dataset,
search_records_query=body,
parsed_metadata=metadata.metadata_parsed,
limit=limit,
offset=offset,
user=current_user,
@@ -563,7 +521,6 @@ async def search_dataset_records(
search_engine: SearchEngine = Depends(get_search_engine),
dataset_id: UUID,
body: SearchRecordsQuery,
metadata: MetadataQueryParams = Depends(),
sort_by_query_param: SortByQueryParamParsed,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
@@ -584,7 +541,6 @@ async def search_dataset_records(
search_records_query=body,
limit=limit,
offset=offset,
parsed_metadata=metadata.metadata_parsed,
response_statuses=response_statuses,
sort_by_query_param=sort_by_query_param,
)
12 changes: 0 additions & 12 deletions argilla-server/src/argilla_server/api/schemas/v1/records.py
Original file line number Diff line number Diff line change
@@ -13,12 +13,9 @@
# limitations under the License.

from datetime import datetime

from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from uuid import UUID

import fastapi

from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName
from argilla_server.api.schemas.v1.responses import Response, ResponseFilterScope, UserResponseCreate
@@ -223,15 +220,6 @@ def __init__(self, string: str):
self.value: str = "".join(v).strip()


class MetadataQueryParams(BaseModel):
metadata: List[str] = Field(fastapi.Query([], pattern=r"^(?=.*[a-z0-9])[a-z0-9_-]+:(.+(,(.+))*)$"))

@property
def metadata_parsed(self) -> List[MetadataParsedQueryParam]:
# TODO: Validate metadata fields names from query params
return [MetadataParsedQueryParam(q) for q in self.metadata]


class VectorQuery(BaseModel):
name: str
record_id: Optional[UUID] = None