Skip to content

Commit

Permalink
[FEAT] Make adding and accessing suggestion and response from a recor…
Browse files Browse the repository at this point in the history
…d consistent (#5056)

This PR makes adding and accessing suggestion and response from a record
consistent. It does that by:

- implementing an `add` method in record suggestions and responses
- switch record suggetions and responses to key not index and
attributeaccess

**Type of change**


- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [x] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- updated and replaced `test_update_records`
- updated all other assertions in tests
- 
**Checklist**

- [x] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [x] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Francisco Aranda <[email protected]>
Co-authored-by: Ben Burtenshaw <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jun 20, 2024
1 parent dd1d81d commit 0348f4e
Show file tree
Hide file tree
Showing 17 changed files with 298 additions and 224 deletions.
2 changes: 1 addition & 1 deletion argilla/docs/how_to_guides/record.md
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ for record in dataset.records(

# Access the responses of the record
for response in record.responses:
print(record.question_name.value)
print(record.["<question_name>"].value)
```

## Update records
Expand Down
3 changes: 2 additions & 1 deletion argilla/docs/reference/argilla/records/records.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ for record in dataset.records(with_metadata=True):
record.metadata = {"department": "toys"}
```

For changes to take effect, the user must call the `update` method on the `Dataset` object, or pass the updated records to `Dataset.records.log`.
For changes to take effect, the user must call the `update` method on the `Dataset` object, or pass the updated records to `Dataset.records.log`. All core record atttributes can be updated in this way. Check their respective documentation for more information: [Suggestions](suggestions.md), [Responses](responses.md), [Metadata](metadata.md), [Vectors](vectors/md).


---

Expand Down
13 changes: 10 additions & 3 deletions argilla/docs/reference/argilla/records/responses.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,26 @@ Responses can be accessed from a `Record` via their question name as an attribut
# iterate over the records and responses

for record in dataset.records:
for response in record.responses.label:
for response in record.responses["label"]: # (1)
print(response.value)
print(response.user_id)

# validate that the record has a response

for record in dataset.records:
if record.responses.label:
for response in record.responses.label:
if record.responses["label"]:
for response in record.responses["label"]:
print(response.value)
print(response.user_id)
else:
record.responses.add(
rg.Response("label", "positive", user_id=user.id)
) # (2)

```
1. Access the responses for the question named `label` for each record like a dictionary containing a list of `Response` objects.
2. Add a response to the record if it does not already have one.


---

Expand Down
15 changes: 14 additions & 1 deletion argilla/docs/reference/argilla/records/suggestions.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,22 @@ Just like responses, suggestions can be accessed from a `Record` via their quest

```python
for record in dataset.records(with_suggestions=True):
print(record.suggestions.label)
print(record.suggestions["label"].value)
```

We can also add suggestions to records as we iterate over them using the `add` method:

```python
for record in dataset.records(with_suggestions=True):
if not record.suggestions["label"]: # (1)
record.suggestions.add(
rg.Suggestion("positive", "label", score=0.9, agent="model_name")
) # (2)
```

1. Validate that the record has a suggestion
2. Add a suggestion to the record if it does not already have one

---

## Class Reference
Expand Down
4 changes: 3 additions & 1 deletion argilla/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ dynamic = ["version"]
dependencies = [
"httpx>=0.26.0",
"pydantic>=2.6.0, <3.0.0",
"argilla-v1[listeners]"
"argilla-v1[listeners]",
"tqdm>=4.60.0",
"rich>=10.0.0",
]

[project.optional-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions argilla/src/argilla/records/_dataset_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def log(
mapping: Optional[Dict[str, str]] = None,
user_id: Optional[UUID] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> List[Record]:
) -> "DatasetRecords":
"""Add or update records in a dataset on the server using the provided records.
If the record includes a known `id` field, the record will be updated.
If the record does not include a known `id` field, the record will be added as a new record.
Expand Down Expand Up @@ -253,7 +253,7 @@ def log(
level="info",
)

return created_or_updated
return self

def delete(
self,
Expand Down
46 changes: 31 additions & 15 deletions argilla/src/argilla/records/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,7 @@ def __init__(self, responses: List[Response], record: Record) -> None:
def __iter__(self):
return iter(self.__responses)

def __getitem__(self, index: int):
return self.__responses[index]

def __getattr__(self, name) -> List[Response]:
def __getitem__(self, name: str):
return self.__responses_by_question_name[name]

def __repr__(self) -> str:
Expand Down Expand Up @@ -352,6 +349,15 @@ def api_models(self) -> List[UserResponseModel]:
for responses in responses_by_user_id.values()
]

def add(self, response: Response) -> None:
"""Adds a response to the record and updates the record. Records can have multiple responses per question.
Args:
response: The response to add.
"""
response.record = self.record
self.__responses.append(response)
self.__responses_by_question_name[response.question_name].append(response)


class RecordSuggestions(Iterable[Suggestion]):
"""This is a container class for the suggestions of a Record.
Expand All @@ -360,17 +366,17 @@ class RecordSuggestions(Iterable[Suggestion]):

def __init__(self, suggestions: List[Suggestion], record: Record) -> None:
self.record = record

self.__suggestions = suggestions or []
for suggestion in self.__suggestions:
self._suggestion_by_question_name: Dict[str, Suggestion] = {}
suggestions = suggestions or []
for suggestion in suggestions:
suggestion.record = self.record
setattr(self, suggestion.question_name, suggestion)
self._suggestion_by_question_name[suggestion.question_name] = suggestion

def __iter__(self):
return iter(self.__suggestions)
return iter(self._suggestion_by_question_name.values())

def __getitem__(self, index: int):
return self.__suggestions[index]
def __getitem__(self, question_name: str):
return self._suggestion_by_question_name[question_name]

def __repr__(self) -> str:
return self.to_dict().__repr__()
Expand All @@ -380,14 +386,24 @@ def to_dict(self) -> Dict[str, List[str]]:
Returns:
A dictionary of suggestions.
"""
suggestion_dict: dict = {}
for suggestion in self.__suggestions:
suggestion_dict[suggestion.question_name] = {
suggestion_dict = {}
for question_name, suggestion in self._suggestion_by_question_name.items():
suggestion_dict[question_name] = {
"value": suggestion.value,
"score": suggestion.score,
"agent": suggestion.agent,
}
return suggestion_dict

def api_models(self) -> List[SuggestionModel]:
return [suggestion.api_model() for suggestion in self.__suggestions]
suggestions = self._suggestion_by_question_name.values()
return [suggestion.api_model() for suggestion in suggestions]

def add(self, suggestion: Suggestion) -> None:
"""Adds a suggestion to the record and updates the record. Records can have only one suggestion per question, so
adding a new suggestion will overwrite the previous suggestion.
Args:
suggestion: The suggestion to add.
"""
suggestion.record = self.record
self._suggestion_by_question_name[suggestion.question_name] = suggestion
96 changes: 48 additions & 48 deletions argilla/tests/integration/test_add_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_add_dict_records(client: Argilla):

for record, data in zip(ds.records(batch_size=1, with_suggestions=True), mock_data):
assert record.id == data["id"]
assert record.suggestions.label.value == data["label"]
assert record.suggestions["label"].value == data["label"]


def test_add_records_with_suggestions(client) -> None:
Expand Down Expand Up @@ -188,21 +188,21 @@ def test_add_records_with_suggestions(client) -> None:
dataset_records = list(dataset.records(with_suggestions=True))

assert dataset_records[0].id == str(mock_data[0]["id"])
assert dataset_records[0].suggestions.comment.value == "I'm doing great, thank you!"
assert dataset_records[0].suggestions.comment.score is None
assert dataset_records[0].suggestions.topics.value == ["topic1", "topic2"]
assert dataset_records[0].suggestions.topics.score == [0.9, 0.8]
assert dataset_records[0].suggestions["comment"].value == "I'm doing great, thank you!"
assert dataset_records[0].suggestions["comment"].score is None
assert dataset_records[0].suggestions["topics"].value == ["topic1", "topic2"]
assert dataset_records[0].suggestions["topics"].score == [0.9, 0.8]

assert dataset_records[1].fields["text"] == mock_data[1]["text"]
assert dataset_records[1].suggestions.comment.value == "I'm doing great, thank you!"
assert dataset_records[1].suggestions.comment.score is None
assert dataset_records[1].suggestions.topics.value == ["topic3"]
assert dataset_records[1].suggestions.topics.score == [0.9]
assert dataset_records[1].suggestions["comment"].value == "I'm doing great, thank you!"
assert dataset_records[1].suggestions["comment"].score is None
assert dataset_records[1].suggestions["topics"].value == ["topic3"]
assert dataset_records[1].suggestions["topics"].score == [0.9]

assert dataset_records[2].suggestions.comment.value == "I'm doing great, thank you!"
assert dataset_records[2].suggestions.comment.score is None
assert dataset_records[2].suggestions.topics.value == ["topic1", "topic2", "topic3"]
assert dataset_records[2].suggestions.topics.score == [0.9, 0.8, 0.7]
assert dataset_records[2].suggestions["comment"].value == "I'm doing great, thank you!"
assert dataset_records[2].suggestions["comment"].score is None
assert dataset_records[2].suggestions["topics"].value == ["topic1", "topic2", "topic3"]
assert dataset_records[2].suggestions["topics"].score == [0.9, 0.8, 0.7]


def test_add_records_with_responses(client) -> None:
Expand Down Expand Up @@ -259,8 +259,8 @@ def test_add_records_with_responses(client) -> None:
for record, mock_record in zip(dataset_records, mock_data):
assert record.id == str(mock_record["id"])
assert record.fields["text"] == mock_record["text"]
assert record.responses.label[0].value == mock_record["my_label"]
assert record.responses.label[0].user_id == user.id
assert record.responses["label"][0].value == mock_record["my_label"]
assert record.responses["label"][0].user_id == user.id


def test_add_records_with_responses_and_suggestions(client) -> None:
Expand Down Expand Up @@ -320,9 +320,9 @@ def test_add_records_with_responses_and_suggestions(client) -> None:

assert dataset_records[0].id == str(mock_data[0]["id"])
assert dataset_records[1].fields["text"] == mock_data[1]["text"]
assert dataset_records[2].suggestions.label.value == "positive"
assert dataset_records[2].responses.label[0].value == "negative"
assert dataset_records[2].responses.label[0].user_id == user.id
assert dataset_records[2].suggestions["label"].value == "positive"
assert dataset_records[2].responses["label"][0].value == "negative"
assert dataset_records[2].responses["label"][0].user_id == user.id


def test_add_records_with_fields_mapped(client) -> None:
Expand Down Expand Up @@ -387,11 +387,11 @@ def test_add_records_with_fields_mapped(client) -> None:

assert dataset_records[0].id == str(mock_data[0]["id"])
assert dataset_records[1].fields["text"] == mock_data[1]["x"]
assert dataset_records[2].suggestions.label.value == "positive"
assert dataset_records[2].suggestions.label.score == 0.5
assert dataset_records[2].responses.label[0].value == "negative"
assert dataset_records[2].responses.label[0].value == "negative"
assert dataset_records[2].responses.label[0].user_id == user.id
assert dataset_records[2].suggestions["label"].value == "positive"
assert dataset_records[2].suggestions["label"].score == 0.5
assert dataset_records[2].responses["label"][0].value == "negative"
assert dataset_records[2].responses["label"][0].value == "negative"
assert dataset_records[2].responses["label"][0].user_id == user.id


def test_add_records_with_id_mapped(client) -> None:
Expand Down Expand Up @@ -448,9 +448,9 @@ def test_add_records_with_id_mapped(client) -> None:

assert dataset_records[0].id == str(mock_data[0]["uuid"])
assert dataset_records[1].fields["text"] == mock_data[1]["x"]
assert dataset_records[2].suggestions.label.value == "positive"
assert dataset_records[2].responses.label[0].value == "negative"
assert dataset_records[2].responses.label[0].user_id == user.id
assert dataset_records[2].suggestions["label"].value == "positive"
assert dataset_records[2].responses["label"][0].value == "negative"
assert dataset_records[2].responses["label"][0].user_id == user.id


def test_add_record_resources(client):
Expand Down Expand Up @@ -507,22 +507,22 @@ def test_add_record_resources(client):
assert dataset.name == mock_dataset_name

assert dataset_records[0].id == str(mock_resources[0].id)
assert dataset_records[0].suggestions.label.value == "positive"
assert dataset_records[0].suggestions.label.score == 0.9
assert dataset_records[0].suggestions.topics.value == ["topic1", "topic2"]
assert dataset_records[0].suggestions.topics.score == [0.9, 0.8]
assert dataset_records[0].suggestions["label"].value == "positive"
assert dataset_records[0].suggestions["label"].score == 0.9
assert dataset_records[0].suggestions["topics"].value == ["topic1", "topic2"]
assert dataset_records[0].suggestions["topics"].score == [0.9, 0.8]

assert dataset_records[1].id == str(mock_resources[1].id)
assert dataset_records[1].suggestions.label.value == "positive"
assert dataset_records[1].suggestions.label.score == 0.9
assert dataset_records[1].suggestions.topics.value == ["topic1", "topic2"]
assert dataset_records[1].suggestions.topics.score == [0.9, 0.8]
assert dataset_records[1].suggestions["label"].value == "positive"
assert dataset_records[1].suggestions["label"].score == 0.9
assert dataset_records[1].suggestions["topics"].value == ["topic1", "topic2"]
assert dataset_records[1].suggestions["topics"].score == [0.9, 0.8]

assert dataset_records[2].id == str(mock_resources[2].id)
assert dataset_records[2].suggestions.label.value == "positive"
assert dataset_records[2].suggestions.label.score == 0.9
assert dataset_records[2].suggestions.topics.value == ["topic1", "topic2"]
assert dataset_records[2].suggestions.topics.score == [0.9, 0.8]
assert dataset_records[2].suggestions["label"].value == "positive"
assert dataset_records[2].suggestions["label"].score == 0.9
assert dataset_records[2].suggestions["topics"].value == ["topic1", "topic2"]
assert dataset_records[2].suggestions["topics"].score == [0.9, 0.8]


def test_add_records_with_responses_and_same_schema_name(client: Argilla):
Expand Down Expand Up @@ -572,8 +572,8 @@ def test_add_records_with_responses_and_same_schema_name(client: Argilla):
dataset_records = list(dataset.records(with_responses=True))

assert dataset_records[0].fields["text"] == mock_data[1]["text"]
assert dataset_records[1].responses.label[0].value == "negative"
assert dataset_records[1].responses.label[0].user_id == user.id
assert dataset_records[1].responses["label"][0].value == "negative"
assert dataset_records[1].responses["label"][0].user_id == user.id


def test_add_records_objects_with_responses(client: Argilla):
Expand Down Expand Up @@ -631,17 +631,17 @@ def test_add_records_objects_with_responses(client: Argilla):

assert dataset.name == mock_dataset_name
assert dataset_records[0].id == records[0].id
assert dataset_records[0].responses.label[0].value == "negative"
assert dataset_records[0].responses.label[0].status == "submitted"
assert dataset_records[0].responses["label"][0].value == "negative"
assert dataset_records[0].responses["label"][0].status == "submitted"

assert dataset_records[1].id == records[1].id
assert dataset_records[1].responses.label[0].value == "positive"
assert dataset_records[1].responses.label[0].status == "discarded"
assert dataset_records[1].responses["label"][0].value == "positive"
assert dataset_records[1].responses["label"][0].status == "discarded"

assert dataset_records[2].id == records[2].id
assert dataset_records[2].responses.comment[0].value == "The comment"
assert dataset_records[2].responses.comment[0].status == "draft"
assert dataset_records[2].responses["comment"][0].value == "The comment"
assert dataset_records[2].responses["comment"][0].status == "draft"

assert dataset_records[3].id == records[3].id
assert dataset_records[3].responses.comment[0].value == "The comment"
assert dataset_records[3].responses.comment[0].status == "draft"
assert dataset_records[3].responses["comment"][0].value == "The comment"
assert dataset_records[3].responses["comment"][0].status == "draft"
2 changes: 1 addition & 1 deletion argilla/tests/integration/test_export_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_import_dataset_from_disk(dataset: rg.Dataset, client):

for i, record in enumerate(new_dataset.records(with_suggestions=True)):
assert record.fields["text"] == mock_data[i]["text"]
assert record.suggestions.label.value == mock_data[i]["label"]
assert record.suggestions["label"].value == mock_data[i]["label"]

assert new_dataset.settings.fields[0].name == "text"
assert new_dataset.settings.questions[0].name == "label"
4 changes: 2 additions & 2 deletions argilla/tests/integration/test_export_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_export_records_from_json(dataset: rg.Dataset):

for i, record in enumerate(dataset.records(with_suggestions=True)):
assert record.fields["text"] == mock_data[i]["text"]
assert record.suggestions.label.value == mock_data[i]["label"]
assert record.suggestions["label"].value == mock_data[i]["label"]
assert record.id == str(mock_data[i]["id"])


Expand Down Expand Up @@ -329,5 +329,5 @@ def test_import_records_from_hf_dataset(dataset: rg.Dataset) -> None:

for i, record in enumerate(dataset.records(with_suggestions=True)):
assert record.fields["text"] == mock_data[i]["text"]
assert record.suggestions.label.value == mock_data[i]["label"]
assert record.suggestions["label"].value == mock_data[i]["label"]
assert record.id == str(mock_data[i]["id"])
8 changes: 4 additions & 4 deletions argilla/tests/integration/test_list_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def test_list_records_with_responses(client: Argilla, dataset: Dataset):
records = list(dataset.records(with_responses=True))
assert len(records) == 2

assert records[0].responses.comment[0].value == "The comment"
assert records[0].responses.sentiment[0].value == "positive"
assert records[0].responses["comment"][0].value == "The comment"
assert records[0].responses["sentiment"][0].value == "positive"

assert records[1].responses.comment[0].value == "The comment"
assert records[1].responses.sentiment[0].value == "negative"
assert records[1].responses["comment"][0].value == "The comment"
assert records[1].responses["sentiment"][0].value == "negative"
Loading

0 comments on commit 0348f4e

Please sign in to comment.