Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dataset support to be created using distribution settings #5013

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ export class RecordRepository {
constructor(private readonly axios: NuxtAxiosInstance) {}

getRecords(criteria: RecordCriteria): Promise<BackendRecords> {
if (criteria.isFilteringByAdvanceSearch)
return this.getRecordsByAdvanceSearch(criteria);

return this.getRecordsByDatasetId(criteria);
return this.getRecordsByAdvanceSearch(criteria);
// return this.getRecordsByDatasetId(criteria);
}

async getRecord(recordId: string): Promise<BackendRecord> {
Expand Down Expand Up @@ -264,6 +262,30 @@ export class RecordRepository {
};
}

body.filters = {
and: [
{
type: "terms",
scope: {
entity: "response",
property: "status",
},
values: [status],
},
],
};

if (status === "pending") {
body.filters.and.push({
type: "terms",
scope: {
entity: "record",
property: "status",
},
values: ["pending"],
});
}

if (
isFilteringByMetadata ||
isFilteringByResponse ||
Expand Down
7 changes: 6 additions & 1 deletion argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ These are the section headers that we use:

## [Unreleased]()

## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1)
### Added

- Added support to specify `distribution` attribute when creating a dataset. ([#5013](https://github.com/argilla-io/argilla/pull/5013))
- Added support to change `distribution` attribute when updating a dataset. ([#5028](https://github.com/argilla-io/argilla/pull/5028))

### Changed

- Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126))

## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1)

### Removed

- Removed all API v0 endpoints. ([#4852](https://github.com/argilla-io/argilla/pull/4852))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add status column to records table

Revision ID: 237f7c674d74
Revises: 45a12f74448b
Create Date: 2024-06-18 17:59:36.992165

"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "237f7c674d74"
down_revision = "45a12f74448b"
branch_labels = None
depends_on = None


record_status_enum = sa.Enum("pending", "completed", name="record_status_enum")


def upgrade() -> None:
record_status_enum.create(op.get_bind())

op.add_column("records", sa.Column("status", record_status_enum, server_default="pending", nullable=False))
op.create_index(op.f("ix_records_status"), "records", ["status"], unique=False)

# NOTE: Updating existent records to have "completed" status when they have
# at least one response with "submitted" status.
op.execute("""
UPDATE records
SET status = 'completed'
WHERE id IN (
SELECT DISTINCT record_id
FROM responses
WHERE status = 'submitted'
);
""")


def downgrade() -> None:
op.drop_index(op.f("ix_records_status"), table_name="records")
op.drop_column("records", "status")

record_status_enum.drop(op.get_bind())
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""add record metadata column
"""add metadata column to records table

Revision ID: 3ff6484f8b37
Revises: ae5522b4c674
Expand All @@ -31,12 +31,8 @@


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("records", sa.Column("metadata", sa.JSON(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("records", "metadata")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add distribution column to datasets table

Revision ID: 45a12f74448b
Revises: d00f819ccc67
Create Date: 2024-06-13 11:23:43.395093

"""

import json

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "45a12f74448b"
down_revision = "d00f819ccc67"
branch_labels = None
depends_on = None

DISTRIBUTION_VALUE = json.dumps({"strategy": "overlap", "min_submitted": 1})


def upgrade() -> None:
op.add_column("datasets", sa.Column("distribution", sa.JSON(), nullable=True))
op.execute(f"UPDATE datasets SET distribution = '{DISTRIBUTION_VALUE}'")
with op.batch_alter_table("datasets") as batch_op:
batch_op.alter_column("distribution", nullable=False)


def downgrade() -> None:
op.drop_column("datasets", "distribution")
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""add allow_extra_metadata column to dataset table
"""add allow_extra_metadata column to datasets table

Revision ID: b8458008b60e
Revises: 7cbcccf8b57a
Expand All @@ -31,14 +31,10 @@


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"datasets", sa.Column("allow_extra_metadata", sa.Boolean(), server_default=sa.text("true"), nullable=False)
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("datasets", "allow_extra_metadata")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ async def create_dataset(
):
await authorize(current_user, DatasetPolicy.create(dataset_create.workspace_id))

return await datasets.create_dataset(db, dataset_create)
return await datasets.create_dataset(db, dataset_create.dict())


@router.post("/datasets/{dataset_id}/fields", status_code=status.HTTP_201_CREATED, response_model=Field)
Expand Down Expand Up @@ -302,4 +302,4 @@ async def update_dataset(

await authorize(current_user, DatasetPolicy.update(dataset))

return await datasets.update_dataset(db, dataset, dataset_update)
return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True))
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ async def update_response(
response = await Response.get_or_raise(
db,
response_id,
options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)],
options=[
selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions),
],
)

await authorize(current_user, ResponsePolicy.update(response))
Expand All @@ -83,7 +85,9 @@ async def delete_response(
response = await Response.get_or_raise(
db,
response_id,
options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)],
options=[
selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions),
],
)

await authorize(current_user, ResponsePolicy.delete(response))
Expand Down
38 changes: 35 additions & 3 deletions argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

from datetime import datetime
from typing import List, Optional
from typing import List, Literal, Optional, Union
from uuid import UUID

from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.enums import DatasetStatus
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.pydantic_v1 import BaseModel, Field, constr

try:
Expand All @@ -44,6 +44,32 @@
]


class DatasetOverlapDistribution(BaseModel):
strategy: Literal[DatasetDistributionStrategy.overlap]
min_submitted: int


DatasetDistribution = DatasetOverlapDistribution


class DatasetOverlapDistributionCreate(BaseModel):
strategy: Literal[DatasetDistributionStrategy.overlap]
min_submitted: int = Field(
ge=1,
description="Minimum number of submitted responses to consider a record as completed",
)


DatasetDistributionCreate = DatasetOverlapDistributionCreate


class DatasetOverlapDistributionUpdate(DatasetDistributionCreate):
pass


DatasetDistributionUpdate = DatasetOverlapDistributionUpdate


class RecordMetrics(BaseModel):
count: int

Expand Down Expand Up @@ -74,6 +100,7 @@ class Dataset(BaseModel):
guidelines: Optional[str]
allow_extra_metadata: bool
status: DatasetStatus
distribution: DatasetDistribution
workspace_id: UUID
last_activity_at: datetime
inserted_at: datetime
Expand All @@ -91,12 +118,17 @@ class DatasetCreate(BaseModel):
name: DatasetName
guidelines: Optional[DatasetGuidelines]
allow_extra_metadata: bool = True
distribution: DatasetDistributionCreate = DatasetOverlapDistributionCreate(
strategy=DatasetDistributionStrategy.overlap,
min_submitted=1,
)
workspace_id: UUID


class DatasetUpdate(UpdateSchema):
name: Optional[DatasetName]
guidelines: Optional[DatasetGuidelines]
allow_extra_metadata: Optional[bool]
distribution: Optional[DatasetDistributionUpdate]

__non_nullable_fields__ = {"name", "allow_extra_metadata"}
__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"}
5 changes: 3 additions & 2 deletions argilla-server/src/argilla_server/api/schemas/v1/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName
from argilla_server.api.schemas.v1.responses import Response, ResponseFilterScope, UserResponseCreate
from argilla_server.api.schemas.v1.suggestions import Suggestion, SuggestionCreate, SuggestionFilterScope
from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder
from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder, RecordStatus
from argilla_server.pydantic_v1 import BaseModel, Field, StrictStr, root_validator, validator
from argilla_server.pydantic_v1.utils import GetterDict
from argilla_server.search_engine import TextQuery
Expand Down Expand Up @@ -66,6 +66,7 @@ def get(self, key: str, default: Any) -> Any:

class Record(BaseModel):
id: UUID
status: RecordStatus
fields: Dict[str, Any]
metadata: Optional[Dict[str, Any]]
external_id: Optional[str]
Expand Down Expand Up @@ -196,7 +197,7 @@ def _has_relationships(self):

class RecordFilterScope(BaseModel):
entity: Literal["record"]
property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at]]
property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at], Literal["status"]]


class Records(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions argilla-server/src/argilla_server/bulk/records_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from argilla_server.api.schemas.v1.responses import UserResponseCreate
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate
from argilla_server.contexts import distribution
from argilla_server.contexts.accounts import fetch_users_by_ids_as_dict
from argilla_server.contexts.records import (
fetch_records_by_external_ids_as_dict,
Expand Down Expand Up @@ -67,6 +68,7 @@ async def create_records_bulk(self, dataset: Dataset, bulk_create: RecordsBulkCr

await self._upsert_records_relationships(records, bulk_create.items)
await _preload_records_relationships_before_index(self._db, records)
await distribution.update_records_status(self._db, records)
await self._search_engine.index_records(dataset, records)

await self._db.commit()
Expand Down Expand Up @@ -207,6 +209,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp

await self._upsert_records_relationships(records, bulk_upsert.items)
await _preload_records_relationships_before_index(self._db, records)
await distribution.update_records_status(self._db, records)
await self._search_engine.index_records(dataset, records)

await self._db.commit()
Expand Down Expand Up @@ -237,6 +240,7 @@ async def _preload_records_relationships_before_index(db: "AsyncSession", record
.filter(Record.id.in_([record.id for record in records]))
.options(
selectinload(Record.responses).selectinload(Response.user),
selectinload(Record.responses_submitted),
selectinload(Record.suggestions).selectinload(Suggestion.question),
selectinload(Record.vectors),
)
Expand Down
Loading
Loading