Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update argilla integration to use argilla_sdk v2 #705

Merged
merged 22 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9662146
Update `_Argilla` base and `TextGenerationToArgilla`
alvarobartt Jun 6, 2024
ef189f8
Fix `_dataset.records.log` and rename to `ArgillaBase`
alvarobartt Jun 6, 2024
c83de4c
Update `TextGenerationToArgilla` subclass inheritance
alvarobartt Jun 6, 2024
2353768
Remove unused `logger.info` message
alvarobartt Jun 6, 2024
055a9de
Update `PreferenceToArgilla`
alvarobartt Jun 6, 2024
18761fb
Update `argilla` extra to install `argilla_sdk`
alvarobartt Jun 6, 2024
7d0f07d
Add `ArgillaBase` and subclasses unit tests
alvarobartt Jun 7, 2024
a97e310
Merge branch 'develop' into argilla-2.0
alvarobartt Jun 7, 2024
d77dd11
Install `argilla_sdk` from source and add `ipython`
alvarobartt Jun 10, 2024
d6f7131
Merge branch 'develop' into argilla-2.0
alvarobartt Jun 12, 2024
7d55576
upgrade argilla dep to latest rc
frascuchon Jul 17, 2024
78ca5f7
udate code with latest changes
frascuchon Jul 17, 2024
c9fc2a5
chore: remove unnecessary workspace definition
frascuchon Jul 17, 2024
06a3610
fix: wrong argilla module import
frascuchon Jul 17, 2024
58a2e8c
Merge branch 'develop' into argilla-2.0
gabrielmbmb Jul 30, 2024
5c1ce95
Update docstrings
gabrielmbmb Jul 30, 2024
1e16e38
Fix lint
gabrielmbmb Jul 30, 2024
20b92ab
Add check for `api_url` and `api_key`
gabrielmbmb Jul 30, 2024
ba13431
Fix unit tests
gabrielmbmb Jul 30, 2024
d088510
Fix unit tests
gabrielmbmb Jul 30, 2024
6e13f3a
Merge branch 'argilla-2.0' of https://github.com/argilla-io/rlxf into…
gabrielmbmb Jul 30, 2024
b0a6b71
Update argilla dependency version
gabrielmbmb Jul 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ tests = [

# Optional LLMs, integrations, etc
anthropic = ["anthropic >= 0.20.0"]
argilla = ["argilla >= 1.29.0"]
argilla = [
# TODO(alvarobartt): update before the `argilla_sdk` or `argilla` release
"argilla~=2.0.0rc",
"ipython",
]
cohere = ["cohere >= 5.2.0"]
groq = ["groq >= 0.4.1"]
hf-inference-endpoints = ["huggingface_hub >= 0.19.0"]
Expand Down
47 changes: 25 additions & 22 deletions src/distilabel/steps/argilla/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@
from distilabel.steps.base import Step, StepInput

if TYPE_CHECKING:
from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset
from argilla import Argilla, Dataset

from distilabel.steps.typing import StepOutput


_ARGILLA_API_URL_ENV_VAR_NAME = "ARGILLA_API_URL"
_ARGILLA_API_KEY_ENV_VAR_NAME = "ARGILLA_API_KEY"


class Argilla(Step, ABC):
class ArgillaBase(Step, ABC):
"""Abstract step that provides a class to subclass from, that contains the boilerplate code
required to interact with Argilla, as well as some extra validations on top of it. It also defines
the abstract methods that need to be implemented in order to add a new dataset type as a step.
Expand Down Expand Up @@ -70,20 +71,21 @@ class Argilla(Step, ABC):
)
dataset_workspace: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The workspace where the dataset will be created in Argilla. Defaults"
description="The workspace where the dataset will be created in Argilla. Defaults "
"to `None` which means it will be created in the default workspace.",
)

api_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv("ARGILLA_API_URL"),
default_factory=lambda: os.getenv(_ARGILLA_API_URL_ENV_VAR_NAME),
description="The base URL to use for the Argilla API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(_ARGILLA_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the Argilla API.",
)

_rg_dataset: Optional["RemoteFeedbackDataset"] = PrivateAttr(...)
_client: Optional["Argilla"] = PrivateAttr(...)
_dataset: Optional["Dataset"] = PrivateAttr(...)

def model_post_init(self, __context: Any) -> None:
"""Checks that the Argilla Python SDK is installed, and then filters the Argilla warnings."""
Expand All @@ -93,32 +95,33 @@ def model_post_init(self, __context: Any) -> None:
import argilla as rg # noqa
except ImportError as ie:
raise ImportError(
"Argilla is not installed. Please install it using `pip install argilla`."
"Argilla is not installed. Please install it using `pip install argilla --upgrade`."
) from ie

warnings.filterwarnings("ignore")

def _rg_init(self) -> None:
def _client_init(self) -> None:
"""Initializes the Argilla API client with the provided `api_url` and `api_key`."""
try:
if "hf.space" in self.api_url and "HF_TOKEN" in os.environ:
headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
else:
headers = None
rg.init(
self._client = rg.Argilla( # type: ignore
api_url=self.api_url,
api_key=self.api_key.get_secret_value(),
extra_headers=headers,
) # type: ignore
api_key=self.api_key.get_secret_value(), # type: ignore
headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
if isinstance(self.api_url, str)
and "hf.space" in self.api_url
and "HF_TOKEN" in os.environ
else {},
)
except Exception as e:
raise ValueError(f"Failed to initialize the Argilla API: {e}") from e

def _rg_dataset_exists(self) -> bool:
"""Checks if the dataset already exists in Argilla."""
return self.dataset_name in [
dataset.name
for dataset in rg.FeedbackDataset.list(workspace=self.dataset_workspace) # type: ignore
]
@property
def _dataset_exists_in_workspace(self) -> bool:
"""Checks if the dataset already exists in Argilla in the provided workspace if any."""
return self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self.dataset_workspace
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
) is not None

@property
def outputs(self) -> List[str]:
Expand All @@ -133,7 +136,7 @@ def load(self) -> None:
"""
super().load()

self._rg_init()
self._client_init()

@property
@abstractmethod
Expand Down
71 changes: 37 additions & 34 deletions src/distilabel/steps/argilla/preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,16 @@
except ImportError:
pass

from distilabel.steps.argilla.base import Argilla
from distilabel.steps.argilla.base import ArgillaBase
from distilabel.steps.base import StepInput

if TYPE_CHECKING:
from argilla import (
RatingQuestion,
SuggestionSchema,
TextField,
TextQuestion,
)
from argilla import RatingQuestion, Suggestion, TextField, TextQuestion

from distilabel.steps.typing import StepOutput


class PreferenceToArgilla(Argilla):
class PreferenceToArgilla(ArgillaBase):
"""Creates a preference dataset in Argilla.

Step that creates a dataset in Argilla during the load phase, and then pushes the input
Expand Down Expand Up @@ -97,16 +92,18 @@ def load(self) -> None:
self._ratings = self.input_mappings.get("ratings", "ratings")
self._rationales = self.input_mappings.get("rationales", "rationales")

if self._rg_dataset_exists():
_rg_dataset = rg.FeedbackDataset.from_argilla( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
if self._dataset_exists_in_workspace:
_dataset = self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self.dataset_workspace # type: ignore
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
)

for field in _rg_dataset.fields:
for field in _dataset.fields:
if not isinstance(field, rg.TextField):
continue
if (
field.name
not in [self._id, self._instruction]
not in [self._id, self._instruction] # type: ignore
+ [
f"{self._generations}-{idx}"
for idx in range(self.num_generations)
Expand All @@ -116,23 +113,26 @@ def load(self) -> None:
raise ValueError(
f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists,"
f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`,"
f" nor `{self._generations}`."
f" nor `{self._generations}` (one per generation starting from 0 up to {self.num_generations - 1})."
)

self._rg_dataset = _rg_dataset
self._dataset = _dataset
else:
_rg_dataset = rg.FeedbackDataset( # type: ignore
_settings = rg.Settings( # type: ignore
fields=[
rg.TextField(name=self._id, title=self._id), # type: ignore
rg.TextField(name=self._instruction, title=self._instruction), # type: ignore
*self._generation_fields(), # type: ignore
],
questions=self._rating_rationale_pairs(), # type: ignore
)
self._rg_dataset = _rg_dataset.push_to_argilla(
name=self.dataset_name, # type: ignore
_dataset = rg.Dataset( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
settings=_settings,
client=self._client,
)
self._dataset = _dataset.create()

def _generation_fields(self) -> List["TextField"]:
"""Method to generate the fields for each of the generations."""
Expand Down Expand Up @@ -180,20 +180,23 @@ def inputs(self) -> List[str]:
provide the `ratings` and the `rationales` for the generations."""
return ["instruction", "generations"]

def _add_suggestions_if_any(
self, input: Dict[str, Any]
) -> List["SuggestionSchema"]:
"""Method to generate the suggestions for the `FeedbackRecord` based on the input."""
@property
def optional_inputs(self) -> List[str]:
"""The optional inputs for the step are the `ratings` and the `rationales` for the generations."""
return ["ratings", "rationales"]

def _add_suggestions_if_any(self, input: Dict[str, Any]) -> List["Suggestion"]:
"""Method to generate the suggestions for the `rg.Record` based on the input."""
# Since the `suggestions` i.e. answers to the `questions` are optional, will default to {}
suggestions = []
# If `ratings` is in `input`, then add those as suggestions
if self._ratings in input:
suggestions.extend(
[
{
"question_name": f"{self._generations}-{idx}-rating",
"value": rating,
}
rg.Suggestion( # type: ignore
value=rating,
question_name=f"{self._generations}-{idx}-rating",
)
for idx, rating in enumerate(input[self._ratings])
if rating is not None
and isinstance(rating, int)
Expand All @@ -204,10 +207,10 @@ def _add_suggestions_if_any(
if self._rationales in input:
suggestions.extend(
[
{
"question_name": f"{self._generations}-{idx}-rationale",
"value": rationale,
}
rg.Suggestion( # type: ignore
value=rationale,
question_name=f"{self._generations}-{idx}-rationale",
)
for idx, rationale in enumerate(input[self._rationales])
if rationale is not None and isinstance(rationale, str)
],
Expand All @@ -216,7 +219,7 @@ def _add_suggestions_if_any(

@override
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Creates and pushes the records as FeedbackRecords to the Argilla dataset.
"""Creates and pushes the records as `rg.Record`s to the Argilla dataset.

Args:
inputs: A list of Python dictionaries with the inputs of the task.
Expand All @@ -237,7 +240,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
}

records.append( # type: ignore
rg.FeedbackRecord( # type: ignore
rg.Record( # type: ignore
fields={
"id": instruction_id,
"instruction": input["instruction"], # type: ignore
Expand All @@ -246,5 +249,5 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
suggestions=self._add_suggestions_if_any(input), # type: ignore
)
)
self._rg_dataset.add_records(records) # type: ignore
self._dataset.records.log(records) # type: ignore
yield inputs
37 changes: 21 additions & 16 deletions src/distilabel/steps/argilla/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
except ImportError:
pass

from distilabel.steps.argilla.base import Argilla
from distilabel.steps.argilla.base import ArgillaBase
from distilabel.steps.base import StepInput

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput


class TextGenerationToArgilla(Argilla):
class TextGenerationToArgilla(ArgillaBase):
"""Creates a text generation dataset in Argilla.

`Step` that creates a dataset in Argilla during the load phase, and then pushes the input
Expand Down Expand Up @@ -74,26 +74,28 @@ def load(self) -> None:
self._instruction = self.input_mappings.get("instruction", "instruction")
self._generation = self.input_mappings.get("generation", "generation")

if self._rg_dataset_exists():
_rg_dataset = rg.FeedbackDataset.from_argilla( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
if self._dataset_exists_in_workspace:
_dataset = self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self.dataset_workspace # type: ignore
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
)

for field in _rg_dataset.fields:
for field in _dataset.fields:
if not isinstance(field, rg.TextField): # type: ignore
continue
if (
field.name not in [self._id, self._instruction, self._generation]
and field.required
):
raise ValueError(
f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists,"
f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`"
f", nor `{self._generation}`."
f", nor `{self._generation}`, so it cannot be reused for this dataset."
)

self._rg_dataset = _rg_dataset
self._dataset = _dataset
else:
_rg_dataset = rg.FeedbackDataset( # type: ignore
_settings = rg.Settings( # type: ignore
fields=[
rg.TextField(name=self._id, title=self._id), # type: ignore
rg.TextField(name=self._instruction, title=self._instruction), # type: ignore
Expand All @@ -103,14 +105,17 @@ def load(self) -> None:
rg.LabelQuestion( # type: ignore
name="quality",
title=f"What's the quality of the {self._generation} for the given {self._instruction}?",
labels={"bad": "👎", "good": "👍"},
labels={"bad": "👎", "good": "👍"}, # type: ignore
)
],
)
self._rg_dataset = _rg_dataset.push_to_argilla(
name=self.dataset_name, # type: ignore
_dataset = rg.Dataset( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
settings=_settings,
client=self._client,
)
self._dataset = _dataset.create()

@property
def inputs(self) -> List[str]:
Expand Down Expand Up @@ -151,13 +156,13 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
generations_set.add(generation)

records.append(
rg.FeedbackRecord( # type: ignore
rg.Record( # type: ignore
fields={
self._id: instruction_id,
self._instruction: input["instruction"],
self._generation: generation,
},
)
),
)
self._rg_dataset.add_records(records) # type: ignore
self._dataset.records.log(records) # type: ignore
yield inputs
Loading
Loading