diff --git a/.github/workflows/docs-pr-close.yml b/.github/workflows/docs-pr-close.yml index 71f4e5ff9..61008bcee 100644 --- a/.github/workflows/docs-pr-close.yml +++ b/.github/workflows/docs-pr-close.yml @@ -8,6 +8,10 @@ concurrency: group: distilabel-docs cancel-in-progress: false +permissions: + contents: write + pull-requests: write + jobs: cleanup: runs-on: ubuntu-latest diff --git a/.github/workflows/docs-pr.yml b/.github/workflows/docs-pr.yml index 48c7236a5..ec963ccf9 100644 --- a/.github/workflows/docs-pr.yml +++ b/.github/workflows/docs-pr.yml @@ -10,6 +10,10 @@ concurrency: group: distilabel-docs cancel-in-progress: false +permissions: + contents: write + pull-requests: write + jobs: publish: runs-on: ubuntu-latest diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index dd59a5129..93a17408e 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -12,6 +12,10 @@ concurrency: group: distilabel-docs cancel-in-progress: false +permissions: + contents: write + pull-requests: write + jobs: publish: runs-on: ubuntu-latest diff --git a/src/distilabel/__init__.py b/src/distilabel/__init__.py index f6ca72cd1..11a837825 100644 --- a/src/distilabel/__init__.py +++ b/src/distilabel/__init__.py @@ -14,6 +14,6 @@ from rich import traceback as rich_traceback -__version__ = "1.4.1" +__version__ = "1.4.2" rich_traceback.install(show_locals=True) diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 27ab00e5b..a0582b415 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -174,7 +174,7 @@ def prepare_input(self, input: "StandardInput") -> str: Returns: The prompt to send to the LLM. """ - if self._pipeline.tokenizer.chat_template: # type: ignore + if self._pipeline.tokenizer.chat_template is None: # type: ignore return input[0]["content"] prompt: str = ( diff --git a/tests/unit/llms/huggingface/test_transformers.py b/tests/unit/llms/huggingface/test_transformers.py index 97214ef5f..79d6089f7 100644 --- a/tests/unit/llms/huggingface/test_transformers.py +++ b/tests/unit/llms/huggingface/test_transformers.py @@ -40,6 +40,21 @@ def test_model_name(self, transformers_llm: TransformersLLM) -> None: == "distilabel-internal-testing/tiny-random-mistral" ) + def test_prepare_input(self, transformers_llm: TransformersLLM) -> None: + assert ( + transformers_llm.prepare_input([{"role": "user", "content": "Hello"}]) + == " [INST] Hello [/INST]" + ) + + def test_prepare_input_no_chat_template( + self, transformers_llm: TransformersLLM + ) -> None: + transformers_llm._pipeline.tokenizer.chat_template = None + assert ( + transformers_llm.prepare_input([{"role": "user", "content": "Hello"}]) + == "Hello" + ) + def test_generate(self, transformers_llm: TransformersLLM) -> None: responses = transformers_llm.generate( inputs=[ diff --git a/tests/unit/steps/argilla/test_preference.py b/tests/unit/steps/argilla/test_preference.py index ab63ee541..1c99f2f5c 100644 --- a/tests/unit/steps/argilla/test_preference.py +++ b/tests/unit/steps/argilla/test_preference.py @@ -83,13 +83,23 @@ def test_process(self, mock_dataset) -> None: ) with patch.object(PreferenceToArgilla, "load"): step.load() + step._instruction = "instruction" step._generations = "generations" + step._ratings = "ratings" + step._rationales = "rationales" step._dataset = mock_dataset # type: ignore step._dataset.records.log = lambda x: x # type: ignore assert list( - step.process([{"instruction": "test", "generations": ["test", "test"]}]) + step.process( + [ + { + "instruction": "test", + "generations": ["test", "test"], + } + ] + ) ) == [[{"instruction": "test", "generations": ["test", "test"]}]] assert step._dataset.records # type: ignore diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index d2be053aa..ecdfd0424 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -101,6 +101,9 @@ class DummyUserTest(BaseModel): } +@pytest.mark.skip( + reason="won't work until we update our code to work with `outlines>0.1.0`" +) class TestOutlinesIntegration: @pytest.mark.parametrize( "format, schema, prompt",