diff --git a/docs/api/pipeline/utils.md b/docs/api/pipeline/utils.md deleted file mode 100644 index c8ad6f2e5..000000000 --- a/docs/api/pipeline/utils.md +++ /dev/null @@ -1,3 +0,0 @@ -# Pipeline Utils - -::: distilabel.pipeline.utils diff --git a/docs/api/step_gallery/columns.md b/docs/api/step_gallery/columns.md index 9e5392d85..7b75053e6 100644 --- a/docs/api/step_gallery/columns.md +++ b/docs/api/step_gallery/columns.md @@ -6,3 +6,4 @@ This section contains the existing steps intended to be used for common column o ::: distilabel.steps.columns.keep ::: distilabel.steps.columns.merge ::: distilabel.steps.columns.group +::: distilabel.steps.columns.utils diff --git a/mkdocs.yml b/mkdocs.yml index b247cfa13..72c81ff8b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -152,9 +152,9 @@ plugins: # Members inherited_members: false # allow looking up inherited methods members_order: source # order methods according to their order of definition in the source code, not alphabetical order - show_labels : true + show_labels: true # Docstring - docstring_style: google # more info: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html + docstring_style: google # more info: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html show_if_no_docstring: false # Signature separate_signature: false @@ -240,7 +240,6 @@ nav: - Routing Batch Function: "api/pipeline/routing_batch_function.md" - Typing: "api/pipeline/typing.md" - Step Wrapper: "api/pipeline/step_wrapper.md" - - Utils: "api/pipeline/utils.md" - Mixins: - RuntimeParametersMixin: "api/mixins/runtime_parameters.md" - RequirementsMixin: "api/mixins/requirements.md" diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index efa909532..6c0ad8bcf 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -21,6 +21,7 @@ StepInput, StepResources, ) +from distilabel.steps.columns.combine import CombineOutputs from distilabel.steps.columns.expand import ExpandColumns from distilabel.steps.columns.group import CombineColumns, GroupColumns from distilabel.steps.columns.keep import KeepColumns @@ -54,22 +55,26 @@ __all__ = [ "PreferenceToArgilla", "TextGenerationToArgilla", + "GeneratorStep", + "GlobalStep", + "Step", + "StepInput", "StepResources", + "CombineOutputs", + "ExpandColumns", + "CombineColumns", "GroupColumns", + "KeepColumns", "MergeColumns", - "CombineColumns", - "ConversationTemplate", + "step", "DeitaFiltering", "EmbeddingGeneration", "FaissNearestNeighbour", - "ExpandColumns", + "ConversationTemplate", "FormatChatGenerationDPO", - "FormatChatGenerationSFT", "FormatTextGenerationDPO", + "FormatChatGenerationSFT", "FormatTextGenerationSFT", - "GeneratorStep", - "GlobalStep", - "KeepColumns", "LoadDataFromDicts", "LoadDataFromDisk", "LoadDataFromFileSystem", @@ -78,11 +83,8 @@ "MinHashLSH", "make_generator_step", "PushToHub", - "Step", - "StepInput", "RewardModelScore", "TruncateTextColumn", "GeneratorStepOutput", "StepOutput", - "step", ] diff --git a/src/distilabel/steps/columns/combine.py b/src/distilabel/steps/columns/combine.py new file mode 100644 index 000000000..784beffe4 --- /dev/null +++ b/src/distilabel/steps/columns/combine.py @@ -0,0 +1,99 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from distilabel.constants import DISTILABEL_METADATA_KEY +from distilabel.steps.base import Step, StepInput +from distilabel.steps.columns.utils import merge_distilabel_metadata + +if TYPE_CHECKING: + from distilabel.steps.typing import StepOutput + + +class CombineOutputs(Step): + """Combine the outputs of several upstream steps. + + `CombineOutputs` is a `Step` that takes the outputs of several upstream steps and combines + them to generate a new dictionary with all keys/columns of the upstream steps outputs. + + Input columns: + - dynamic (based on the upstream `Step`s): All the columns of the upstream steps outputs. + + Output columns: + - dynamic (based on the upstream `Step`s): All the columns of the upstream steps outputs. + + Categories: + - columns + + Examples: + + Combine dictionaries of a dataset: + + ```python + from distilabel.steps import CombineOutputs + + combine_outputs = CombineOutputs() + combine_outputs.load() + + result = next( + combine_outputs.process( + [{"a": 1, "b": 2}, {"a": 3, "b": 4}], + [{"c": 5, "d": 6}, {"c": 7, "d": 8}], + ) + ) + # [ + # {"a": 1, "b": 2, "c": 5, "d": 6}, + # {"a": 3, "b": 4, "c": 7, "d": 8}, + # ] + ``` + + Combine upstream steps outputs in a pipeline: + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps import CombineOutputs + + with Pipeline() as pipeline: + step_1 = ... + step_2 = ... + step_3 = ... + combine = CombineOutputs() + + [step_1, step_2, step_3] >> combine + ``` + """ + + def process(self, *inputs: StepInput) -> "StepOutput": + combined_outputs = [] + for output_dicts in zip(*inputs): + combined_dict = {} + for output_dict in output_dicts: + combined_dict.update( + { + k: v + for k, v in output_dict.items() + if k != DISTILABEL_METADATA_KEY + } + ) + + if any( + DISTILABEL_METADATA_KEY in output_dict for output_dict in output_dicts + ): + combined_dict[DISTILABEL_METADATA_KEY] = merge_distilabel_metadata( + *output_dicts + ) + combined_outputs.append(combined_dict) + + yield combined_outputs diff --git a/src/distilabel/steps/columns/expand.py b/src/distilabel/steps/columns/expand.py index bb7d2fe9f..709ca4bc6 100644 --- a/src/distilabel/steps/columns/expand.py +++ b/src/distilabel/steps/columns/expand.py @@ -42,6 +42,9 @@ class ExpandColumns(Step): Output columns: - dynamic (determined by `columns` attribute): The expanded columns. + Categories: + - columns + Examples: Expand the selected columns into multiple rows: diff --git a/src/distilabel/steps/columns/group.py b/src/distilabel/steps/columns/group.py index 852b1c520..876af1f0a 100644 --- a/src/distilabel/steps/columns/group.py +++ b/src/distilabel/steps/columns/group.py @@ -17,8 +17,8 @@ from typing_extensions import override -from distilabel.pipeline.utils import group_columns from distilabel.steps.base import Step, StepInput +from distilabel.steps.columns.utils import group_columns if TYPE_CHECKING: from distilabel.steps.typing import StepColumns, StepOutput @@ -43,8 +43,12 @@ class GroupColumns(Step): - dynamic (determined by `columns` and `output_columns` attributes): The columns that were grouped. + Categories: + - columns + Examples: - Combine columns of a dataset: + + Group columns of a dataset: ```python from distilabel.steps import GroupColumns diff --git a/src/distilabel/steps/columns/keep.py b/src/distilabel/steps/columns/keep.py index 88ae7d540..c12dfdd61 100644 --- a/src/distilabel/steps/columns/keep.py +++ b/src/distilabel/steps/columns/keep.py @@ -44,6 +44,9 @@ class KeepColumns(Step): Output columns: - dynamic (determined by `columns` attribute): The columns that were kept. + Categories: + - columns + Examples: Select the columns to keep: diff --git a/src/distilabel/steps/columns/merge.py b/src/distilabel/steps/columns/merge.py index 802b17a7d..54ab3e3c7 100644 --- a/src/distilabel/steps/columns/merge.py +++ b/src/distilabel/steps/columns/merge.py @@ -16,8 +16,8 @@ from typing_extensions import override -from distilabel.pipeline.utils import merge_columns from distilabel.steps.base import Step, StepInput +from distilabel.steps.columns.utils import merge_columns if TYPE_CHECKING: from distilabel.steps.typing import StepColumns, StepOutput @@ -47,6 +47,9 @@ class MergeColumns(Step): - dynamic (determined by `columns` and `output_column` attributes): The columns that were merged. + Categories: + - columns + Examples: Combine columns in rows of a dataset: diff --git a/src/distilabel/pipeline/utils.py b/src/distilabel/steps/columns/utils.py similarity index 66% rename from src/distilabel/pipeline/utils.py rename to src/distilabel/steps/columns/utils.py index b5bd9e1b0..7b3efe226 100644 --- a/src/distilabel/pipeline/utils.py +++ b/src/distilabel/steps/columns/utils.py @@ -12,16 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional -from distilabel.steps.base import StepInput +from distilabel.constants import DISTILABEL_METADATA_KEY + +if TYPE_CHECKING: + from distilabel.steps.base import StepInput + + +def merge_distilabel_metadata(*output_dicts: Dict[str, Any]) -> Dict[str, Any]: + """ + Merge the `DISTILABEL_METADATA_KEY` from multiple output dictionaries. + + Args: + *output_dicts: Variable number of dictionaries containing distilabel metadata. + + Returns: + A merged dictionary containing all the distilabel metadata from the input dictionaries. + """ + merged_metadata = defaultdict(list) + + for output_dict in output_dicts: + metadata = output_dict.get(DISTILABEL_METADATA_KEY, {}) + for key, value in metadata.items(): + merged_metadata[key].append(value) + + final_metadata = {} + for key, value_list in merged_metadata.items(): + if len(value_list) == 1: + final_metadata[key] = value_list[0] + else: + final_metadata[key] = value_list + + return final_metadata def group_columns( - *inputs: StepInput, + *inputs: "StepInput", group_columns: List[str], output_group_columns: Optional[List[str]] = None, -) -> StepInput: +) -> "StepInput": """Groups multiple list of dictionaries into a single list of dictionaries on the specified `group_columns`. If `group_columns` are provided, then it will also rename `group_columns`. @@ -49,16 +80,30 @@ def group_columns( # Use zip to iterate over lists based on their index for dicts_at_index in zip(*inputs): combined_dict = {} + metadata_dicts = [] # Iterate over dicts at the same index for d in dicts_at_index: + # Extract metadata for merging + if DISTILABEL_METADATA_KEY in d: + metadata_dicts.append( + {DISTILABEL_METADATA_KEY: d[DISTILABEL_METADATA_KEY]} + ) # Iterate over key-value pairs in each dict for key, value in d.items(): + if key == DISTILABEL_METADATA_KEY: + continue # If the key is in the merge_keys, append the value to the existing list if key in group_columns_dict.keys(): combined_dict.setdefault(group_columns_dict[key], []).append(value) # If the key is not in the merge_keys, create a new key-value pair else: combined_dict[key] = value + + if metadata_dicts: + combined_dict[DISTILABEL_METADATA_KEY] = merge_distilabel_metadata( + *metadata_dicts + ) + result.append(combined_dict) return result diff --git a/src/distilabel/steps/embeddings/embedding_generation.py b/src/distilabel/steps/embeddings/embedding_generation.py index 30dff63ef..8db3bee2e 100644 --- a/src/distilabel/steps/embeddings/embedding_generation.py +++ b/src/distilabel/steps/embeddings/embedding_generation.py @@ -36,6 +36,9 @@ class EmbeddingGeneration(Step): Output columns: - embedding (`List[Union[float, int]]`): the generated sentence embedding. + Categories: + - embedding + Examples: Generate sentence embeddings with Sentence Transformers: diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index d1637d87c..a7dba7e7d 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -86,6 +86,7 @@ "scorer": ":octicons-number-16:", "text-generation": ":material-text-box-edit:", "text-manipulation": ":material-receipt-text-edit:", + "columns": ":material-table-column:", } diff --git a/tests/unit/steps/columns/__init__.py b/tests/unit/steps/columns/__init__.py new file mode 100644 index 000000000..20ce00bda --- /dev/null +++ b/tests/unit/steps/columns/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/steps/columns/test_combine.py b/tests/unit/steps/columns/test_combine.py new file mode 100644 index 000000000..817d89e90 --- /dev/null +++ b/tests/unit/steps/columns/test_combine.py @@ -0,0 +1,54 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.constants import DISTILABEL_METADATA_KEY +from distilabel.steps.columns.combine import CombineOutputs + + +class TestCombineOutputs: + def test_process(self) -> None: + combine = CombineOutputs() + + output = next( + combine.process( + [ + { + "a": 1, + "b": 2, + DISTILABEL_METADATA_KEY: {"model": "model-1", "a": 1}, + } + ], + [ + { + "c": 3, + "d": 4, + DISTILABEL_METADATA_KEY: {"model": "model-2", "b": 1}, + } + ], + ) + ) + + assert output == [ + { + "a": 1, + "b": 2, + "c": 3, + "d": 4, + DISTILABEL_METADATA_KEY: { + "model": ["model-1", "model-2"], + "a": 1, + "b": 1, + }, + } + ] diff --git a/tests/unit/steps/columns/test_group.py b/tests/unit/steps/columns/test_group.py index 258029d7b..57f9f114d 100644 --- a/tests/unit/steps/columns/test_group.py +++ b/tests/unit/steps/columns/test_group.py @@ -15,6 +15,7 @@ import pytest +from distilabel.constants import DISTILABEL_METADATA_KEY from distilabel.pipeline.local import Pipeline from distilabel.steps.columns.group import CombineColumns, GroupColumns @@ -44,8 +45,19 @@ def test_process(self) -> None: columns=["a", "b"], pipeline=Pipeline(name="unit-test-pipeline"), ) - output = next(group.process([{"a": 1, "b": 2}], [{"a": 3, "b": 4}])) - assert output == [{"grouped_a": [1, 3], "grouped_b": [2, 4]}] + output = next( + group.process( + [{"a": 1, "b": 2, DISTILABEL_METADATA_KEY: {"model": "model-1"}}], + [{"a": 3, "b": 4, DISTILABEL_METADATA_KEY: {"model": "model-2"}}], + ) + ) + assert output == [ + { + "grouped_a": [1, 3], + "grouped_b": [2, 4], + DISTILABEL_METADATA_KEY: {"model": ["model-1", "model-2"]}, + } + ] def test_CombineColumns_deprecation_warning():