Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENHANCEMENT] argilla: add record status property #5184

4 changes: 2 additions & 2 deletions argilla/src/argilla/_models/_record/_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, Literal

from pydantic import Field, field_serializer, field_validator

Expand All @@ -30,12 +30,12 @@
class RecordModel(ResourceModel):
"""Schema for the records of a `Dataset`"""

status: Literal["pending", "completed"] = "pending"
fields: Optional[Dict[str, FieldValue]] = None
metadata: Optional[Union[List[MetadataModel], Dict[str, MetadataValue]]] = Field(default_factory=dict)
vectors: Optional[List[VectorModel]] = Field(default_factory=list)
responses: Optional[List[UserResponseModel]] = Field(default_factory=list)
suggestions: Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]] = Field(default_factory=tuple)

external_id: Optional[Any] = None

@field_serializer("external_id", when_used="unless-none")
Expand Down
22 changes: 17 additions & 5 deletions argilla/src/argilla/records/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(

def __repr__(self) -> str:
return (
f"Record(id={self.id},fields={self.fields},metadata={self.metadata},"
f"Record(id={self.id},status={self.status},fields={self.fields},metadata={self.metadata},"
f"suggestions={self.suggestions},responses={self.responses})"
)

Expand Down Expand Up @@ -147,6 +147,10 @@ def metadata(self) -> "RecordMetadata":
def vectors(self) -> "RecordVectors":
return self.__vectors

@property
def status(self) -> str:
return self._model.status

@property
def _server_id(self) -> Optional[UUID]:
return self._model.id
Expand All @@ -164,6 +168,7 @@ def api_model(self) -> RecordModel:
vectors=self.vectors.api_models(),
responses=self.responses.api_models(),
suggestions=self.suggestions.api_models(),
status=self.status,
)

def serialize(self) -> Dict[str, Any]:
Expand All @@ -185,6 +190,7 @@ def to_dict(self) -> Dict[str, Dict]:
"""
id = str(self.id) if self.id else None
server_id = str(self._model.id) if self._model.id else None
status = self.status
fields = self.fields.to_dict()
metadata = self.metadata.to_dict()
suggestions = self.suggestions.to_dict()
Expand All @@ -198,6 +204,7 @@ def to_dict(self) -> Dict[str, Dict]:
"suggestions": suggestions,
"responses": responses,
"vectors": vectors,
"status": status,
"_server_id": server_id,
}

Expand Down Expand Up @@ -245,7 +252,7 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record":
Returns:
A Record object.
"""
return cls(
instance = cls(
id=model.external_id,
fields=model.fields,
metadata={meta.name: meta.value for meta in model.metadata},
Expand All @@ -257,10 +264,15 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record":
for response in UserResponse.from_model(response_model, dataset=dataset)
],
suggestions=[Suggestion.from_model(model=suggestion, dataset=dataset) for suggestion in model.suggestions],
_dataset=dataset,
_server_id=model.id,
)

# set private attributes
instance._dataset = dataset
instance._model.id = model.id
instance._model.status = model.status

return instance


class RecordFields(dict):
"""This is a container class for the fields of a Record.
Expand Down Expand Up @@ -335,7 +347,7 @@ def to_dict(self) -> Dict[str, List[Dict]]:
response_dict = defaultdict(list)
for response in self.__responses:
response_dict[response.question_name].append({"value": response.value, "user_id": str(response.user_id)})
return response_dict
return dict(response_dict)

def api_models(self) -> List[UserResponseModel]:
"""Returns a list of ResponseModel objects."""
Expand Down
13 changes: 13 additions & 0 deletions argilla/tests/integration/test_list_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def test_list_records_with_start_offset(client: Argilla, dataset: Dataset):
records = list(dataset.records(start_offset=1))
assert len(records) == 1

assert [record.to_dict() for record in records] == [
{
"_server_id": str(records[0]._server_id),
"fields": {"text": "The record text field"},
"id": "2",
"status": "pending",
"metadata": {},
"responses": {},
"suggestions": {},
"vectors": {},
}
]


def test_list_records_with_responses(client: Argilla, dataset: Dataset):
dataset.records.log(
Expand Down
1 change: 1 addition & 0 deletions argilla/tests/unit/test_io/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_to_list_flatten(self):
assert records_list == [
{
"id": str(record.id),
"status": "pending",
"_server_id": None,
"field": "The field",
"key": "value",
Expand Down
1 change: 1 addition & 0 deletions argilla/tests/unit/test_io/test_hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_to_datasets_with_partial_values_in_records(self):

ds = HFDatasetsIO.to_datasets(records)
assert ds.features == {
"status": Value(dtype="string", id=None),
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
"_server_id": Value(dtype="null", id=None),
"a": Value(dtype="string", id=None),
"b": Value(dtype="string", id=None),
Expand Down
10 changes: 10 additions & 0 deletions argilla/tests/unit/test_resources/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import uuid

import pytest

from argilla import Record, Suggestion, Response
from argilla._models import MetadataModel

Expand All @@ -31,6 +33,7 @@ def test_record_repr(self):
)
assert (
record.__repr__() == f"Record(id={record_id},"
"status=pending,"
"fields={'name': 'John', 'age': '30'},"
"metadata={'key': 'value'},"
"suggestions={'question': {'value': 'answer', 'score': None, 'agent': None}},"
Expand Down Expand Up @@ -62,3 +65,10 @@ def test_update_record_vectors(self):

record.vectors["new-vector"] = [1.0, 2.0, 3.0]
assert record.vectors == {"vector": [1.0, 2.0, 3.0], "new-vector": [1.0, 2.0, 3.0]}

def test_prevent_update_record(self):
record = Record(fields={"name": "John"})
assert record.status == "pending"

with pytest.raises(AttributeError):
record.status = "completed"
Loading