From 8dd2790d8240b12e508e52e05f3a04c7a8e48a9a Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 14 Aug 2024 16:36:22 +0200 Subject: [PATCH] Refactor TruncateRow to TruncateTextColumn --- src/distilabel/steps/__init__.py | 4 ++-- src/distilabel/steps/truncate.py | 14 +++++++------- tests/unit/steps/test_truncate.py | 6 ++++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index 3627162e5..420334bca 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -47,7 +47,7 @@ from distilabel.steps.generators.utils import make_generator_step from distilabel.steps.globals.huggingface import PushToHub from distilabel.steps.reward_model import RewardModelScore -from distilabel.steps.truncate import TruncateRow +from distilabel.steps.truncate import TruncateTextColumn from distilabel.steps.typing import GeneratorStepOutput, StepOutput __all__ = [ @@ -78,7 +78,7 @@ "Step", "StepInput", "RewardModelScore", - "TruncateRow", + "TruncateTextColumn", "GeneratorStepOutput", "StepOutput", "step", diff --git a/src/distilabel/steps/truncate.py b/src/distilabel/steps/truncate.py index 007f6b34d..bb1785b8a 100644 --- a/src/distilabel/steps/truncate.py +++ b/src/distilabel/steps/truncate.py @@ -23,13 +23,13 @@ from distilabel.steps.typing import StepOutput -class TruncateRow(Step): +class TruncateTextColumn(Step): """Truncate a row using a tokenizer or the number of characters. - `TruncateRow` is a `Step` that truncates a row according to the max length. If + `TruncateTextColumn` is a `Step` that truncates a row according to the max length. If the `tokenizer` is provided, then the row will be truncated using the tokenizer, and the `max_length` will be used as the maximum number of tokens, otherwise it will - be used as the maximum number of characters. The `TruncateRow` step is useful when one + be used as the maximum number of characters. The `TruncateTextColumn` step is useful when one wants to truncate a row to a certain length, to avoid posterior errors in the model due to the length. @@ -55,9 +55,9 @@ class TruncateRow(Step): Truncating a row to a given number of tokens: ```python - from distilabel.steps import TruncateRow + from distilabel.steps import TruncateTextColumn - trunc = TruncateRow( + trunc = TruncateTextColumn( tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct", max_length=4, column="text" @@ -79,9 +79,9 @@ class TruncateRow(Step): Truncating a row to a given number of characters: ```python - from distilabel.steps import TruncateRow + from distilabel.steps import TruncateTextColumn - trunc = TruncateRow(max_length=10) + trunc = TruncateTextColumn(max_length=10) trunc.load() diff --git a/tests/unit/steps/test_truncate.py b/tests/unit/steps/test_truncate.py index 9512c1966..b26808033 100644 --- a/tests/unit/steps/test_truncate.py +++ b/tests/unit/steps/test_truncate.py @@ -15,7 +15,7 @@ from typing import Optional import pytest -from distilabel.steps.truncate import TruncateRow +from distilabel.steps.truncate import TruncateTextColumn @pytest.mark.parametrize( @@ -38,7 +38,9 @@ def test_truncate_row( max_length: int, text: str, tokenizer: Optional[str], expected: str ) -> None: - trunc = TruncateRow(column="text", max_length=max_length, tokenizer=tokenizer) + trunc = TruncateTextColumn( + column="text", max_length=max_length, tokenizer=tokenizer + ) trunc.load() assert next(trunc.process([{"text": text}])) == [{"text": expected}]