Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CombineOutputs step #939

Merged
merged 10 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/api/pipeline/utils.md

This file was deleted.

1 change: 1 addition & 0 deletions docs/api/step_gallery/columns.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ This section contains the existing steps intended to be used for common column o
::: distilabel.steps.columns.keep
::: distilabel.steps.columns.merge
::: distilabel.steps.columns.group
::: distilabel.steps.columns.utils
5 changes: 2 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ plugins:
# Members
inherited_members: false # allow looking up inherited methods
members_order: source # order methods according to their order of definition in the source code, not alphabetical order
show_labels : true
show_labels: true
# Docstring
docstring_style: google # more info: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
docstring_style: google # more info: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
show_if_no_docstring: false
# Signature
separate_signature: false
Expand Down Expand Up @@ -240,7 +240,6 @@ nav:
- Routing Batch Function: "api/pipeline/routing_batch_function.md"
- Typing: "api/pipeline/typing.md"
- Step Wrapper: "api/pipeline/step_wrapper.md"
- Utils: "api/pipeline/utils.md"
- Mixins:
- RuntimeParametersMixin: "api/mixins/runtime_parameters.md"
- RequirementsMixin: "api/mixins/requirements.md"
Expand Down
22 changes: 12 additions & 10 deletions src/distilabel/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
StepInput,
StepResources,
)
from distilabel.steps.columns.combine import CombineOutputs
from distilabel.steps.columns.expand import ExpandColumns
from distilabel.steps.columns.group import CombineColumns, GroupColumns
from distilabel.steps.columns.keep import KeepColumns
Expand Down Expand Up @@ -54,22 +55,26 @@
__all__ = [
"PreferenceToArgilla",
"TextGenerationToArgilla",
"GeneratorStep",
"GlobalStep",
"Step",
"StepInput",
"StepResources",
"CombineOutputs",
"ExpandColumns",
"CombineColumns",
"GroupColumns",
"KeepColumns",
"MergeColumns",
"CombineColumns",
"ConversationTemplate",
"step",
"DeitaFiltering",
"EmbeddingGeneration",
"FaissNearestNeighbour",
"ExpandColumns",
"ConversationTemplate",
"FormatChatGenerationDPO",
"FormatChatGenerationSFT",
"FormatTextGenerationDPO",
"FormatChatGenerationSFT",
"FormatTextGenerationSFT",
"GeneratorStep",
"GlobalStep",
"KeepColumns",
"LoadDataFromDicts",
"LoadDataFromDisk",
"LoadDataFromFileSystem",
Expand All @@ -78,11 +83,8 @@
"MinHashLSH",
"make_generator_step",
"PushToHub",
"Step",
"StepInput",
"RewardModelScore",
"TruncateTextColumn",
"GeneratorStepOutput",
"StepOutput",
"step",
]
99 changes: 99 additions & 0 deletions src/distilabel/steps/columns/combine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from distilabel.constants import DISTILABEL_METADATA_KEY
from distilabel.steps.base import Step, StepInput
from distilabel.steps.columns.utils import merge_distilabel_metadata

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput


class CombineOutputs(Step):
"""Combine the outputs of several upstream steps.

`CombineOutputs` is a `Step` that takes the outputs of several upstream steps and combines
them to generate a new dictionary with all keys/columns of the upstream steps outputs.

Input columns:
- dynamic (based on the upstream `Step`s): All the columns of the upstream steps outputs.

Output columns:
- dynamic (based on the upstream `Step`s): All the columns of the upstream steps outputs.

Categories:
- columns

Examples:

Combine dictionaries of a dataset:

```python
from distilabel.steps import CombineOutputs

combine_outputs = CombineOutputs()
combine_outputs.load()

result = next(
combine_outputs.process(
[{"a": 1, "b": 2}, {"a": 3, "b": 4}],
[{"c": 5, "d": 6}, {"c": 7, "d": 8}],
)
)
# [
# {"a": 1, "b": 2, "c": 5, "d": 6},
# {"a": 3, "b": 4, "c": 7, "d": 8},
# ]
```

Combine upstream steps outputs in a pipeline:

```python
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs

with Pipeline() as pipeline:
step_1 = ...
step_2 = ...
step_3 = ...
combine = CombineOutputs()

[step_1, step_2, step_3] >> combine
```
"""

def process(self, *inputs: StepInput) -> "StepOutput":
combined_outputs = []
for output_dicts in zip(*inputs):
combined_dict = {}
for output_dict in output_dicts:
combined_dict.update(
{
k: v
for k, v in output_dict.items()
if k != DISTILABEL_METADATA_KEY
}
)

if any(
DISTILABEL_METADATA_KEY in output_dict for output_dict in output_dicts
):
combined_dict[DISTILABEL_METADATA_KEY] = merge_distilabel_metadata(
*output_dicts
)
combined_outputs.append(combined_dict)

yield combined_outputs
3 changes: 3 additions & 0 deletions src/distilabel/steps/columns/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class ExpandColumns(Step):
Output columns:
- dynamic (determined by `columns` attribute): The expanded columns.

Categories:
- columns

Examples:
Expand the selected columns into multiple rows:

Expand Down
8 changes: 6 additions & 2 deletions src/distilabel/steps/columns/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from typing_extensions import override

from distilabel.pipeline.utils import group_columns
from distilabel.steps.base import Step, StepInput
from distilabel.steps.columns.utils import group_columns

if TYPE_CHECKING:
from distilabel.steps.typing import StepColumns, StepOutput
Expand All @@ -43,8 +43,12 @@ class GroupColumns(Step):
- dynamic (determined by `columns` and `output_columns` attributes): The columns
that were grouped.

Categories:
- columns

Examples:
Combine columns of a dataset:

Group columns of a dataset:

```python
from distilabel.steps import GroupColumns
Expand Down
3 changes: 3 additions & 0 deletions src/distilabel/steps/columns/keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class KeepColumns(Step):
Output columns:
- dynamic (determined by `columns` attribute): The columns that were kept.

Categories:
- columns

Examples:
Select the columns to keep:

Expand Down
5 changes: 4 additions & 1 deletion src/distilabel/steps/columns/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from typing_extensions import override

from distilabel.pipeline.utils import merge_columns
from distilabel.steps.base import Step, StepInput
from distilabel.steps.columns.utils import merge_columns

if TYPE_CHECKING:
from distilabel.steps.typing import StepColumns, StepOutput
Expand Down Expand Up @@ -47,6 +47,9 @@ class MergeColumns(Step):
- dynamic (determined by `columns` and `output_column` attributes): The columns
that were merged.

Categories:
- columns

Examples:
Combine columns in rows of a dataset:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from distilabel.steps.base import StepInput
from distilabel.constants import DISTILABEL_METADATA_KEY

if TYPE_CHECKING:
from distilabel.steps.base import StepInput


def merge_distilabel_metadata(*output_dicts: Dict[str, Any]) -> Dict[str, Any]:
"""
Merge the `DISTILABEL_METADATA_KEY` from multiple output dictionaries.

Args:
*output_dicts: Variable number of dictionaries containing distilabel metadata.

Returns:
A merged dictionary containing all the distilabel metadata from the input dictionaries.
"""
merged_metadata = defaultdict(list)

for output_dict in output_dicts:
metadata = output_dict.get(DISTILABEL_METADATA_KEY, {})
for key, value in metadata.items():
merged_metadata[key].append(value)

final_metadata = {}
for key, value_list in merged_metadata.items():
if len(value_list) == 1:
final_metadata[key] = value_list[0]
else:
final_metadata[key] = value_list

return final_metadata


def group_columns(
*inputs: StepInput,
*inputs: "StepInput",
group_columns: List[str],
output_group_columns: Optional[List[str]] = None,
) -> StepInput:
) -> "StepInput":
"""Groups multiple list of dictionaries into a single list of dictionaries on the
specified `group_columns`. If `group_columns` are provided, then it will also rename
`group_columns`.
Expand Down Expand Up @@ -49,16 +80,30 @@ def group_columns(
# Use zip to iterate over lists based on their index
for dicts_at_index in zip(*inputs):
combined_dict = {}
metadata_dicts = []
# Iterate over dicts at the same index
for d in dicts_at_index:
# Extract metadata for merging
if DISTILABEL_METADATA_KEY in d:
metadata_dicts.append(
{DISTILABEL_METADATA_KEY: d[DISTILABEL_METADATA_KEY]}
)
# Iterate over key-value pairs in each dict
for key, value in d.items():
if key == DISTILABEL_METADATA_KEY:
continue
# If the key is in the merge_keys, append the value to the existing list
if key in group_columns_dict.keys():
combined_dict.setdefault(group_columns_dict[key], []).append(value)
# If the key is not in the merge_keys, create a new key-value pair
else:
combined_dict[key] = value

if metadata_dicts:
combined_dict[DISTILABEL_METADATA_KEY] = merge_distilabel_metadata(
*metadata_dicts
)

result.append(combined_dict)
return result

Expand Down
3 changes: 3 additions & 0 deletions src/distilabel/steps/embeddings/embedding_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class EmbeddingGeneration(Step):
Output columns:
- embedding (`List[Union[float, int]]`): the generated sentence embedding.

Categories:
- embedding

Examples:
Generate sentence embeddings with Sentence Transformers:

Expand Down
1 change: 1 addition & 0 deletions src/distilabel/utils/mkdocs/components_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"scorer": ":octicons-number-16:",
"text-generation": ":material-text-box-edit:",
"text-manipulation": ":material-receipt-text-edit:",
"columns": ":material-table-column:",
}


Expand Down
14 changes: 14 additions & 0 deletions tests/unit/steps/columns/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Loading
Loading