Skip to content

Commit

Permalink
Refactor TruncateRow to TruncateTextColumn
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Aug 14, 2024
1 parent ab8e864 commit 8dd2790
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/distilabel/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from distilabel.steps.generators.utils import make_generator_step
from distilabel.steps.globals.huggingface import PushToHub
from distilabel.steps.reward_model import RewardModelScore
from distilabel.steps.truncate import TruncateRow
from distilabel.steps.truncate import TruncateTextColumn
from distilabel.steps.typing import GeneratorStepOutput, StepOutput

__all__ = [
Expand Down Expand Up @@ -78,7 +78,7 @@
"Step",
"StepInput",
"RewardModelScore",
"TruncateRow",
"TruncateTextColumn",
"GeneratorStepOutput",
"StepOutput",
"step",
Expand Down
14 changes: 7 additions & 7 deletions src/distilabel/steps/truncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from distilabel.steps.typing import StepOutput


class TruncateRow(Step):
class TruncateTextColumn(Step):
"""Truncate a row using a tokenizer or the number of characters.
`TruncateRow` is a `Step` that truncates a row according to the max length. If
`TruncateTextColumn` is a `Step` that truncates a row according to the max length. If
the `tokenizer` is provided, then the row will be truncated using the tokenizer,
and the `max_length` will be used as the maximum number of tokens, otherwise it will
be used as the maximum number of characters. The `TruncateRow` step is useful when one
be used as the maximum number of characters. The `TruncateTextColumn` step is useful when one
wants to truncate a row to a certain length, to avoid posterior errors in the model due
to the length.
Expand All @@ -55,9 +55,9 @@ class TruncateRow(Step):
Truncating a row to a given number of tokens:
```python
from distilabel.steps import TruncateRow
from distilabel.steps import TruncateTextColumn
trunc = TruncateRow(
trunc = TruncateTextColumn(
tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct",
max_length=4,
column="text"
Expand All @@ -79,9 +79,9 @@ class TruncateRow(Step):
Truncating a row to a given number of characters:
```python
from distilabel.steps import TruncateRow
from distilabel.steps import TruncateTextColumn
trunc = TruncateRow(max_length=10)
trunc = TruncateTextColumn(max_length=10)
trunc.load()
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/steps/test_truncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Optional

import pytest
from distilabel.steps.truncate import TruncateRow
from distilabel.steps.truncate import TruncateTextColumn


@pytest.mark.parametrize(
Expand All @@ -38,7 +38,9 @@
def test_truncate_row(
max_length: int, text: str, tokenizer: Optional[str], expected: str
) -> None:
trunc = TruncateRow(column="text", max_length=max_length, tokenizer=tokenizer)
trunc = TruncateTextColumn(
column="text", max_length=max_length, tokenizer=tokenizer
)
trunc.load()

assert next(trunc.process([{"text": text}])) == [{"text": expected}]

0 comments on commit 8dd2790

Please sign in to comment.