Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix default structured output #892

Merged
merged 5 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Union

Expand Down Expand Up @@ -63,7 +64,7 @@ class _Task(_Step, ABC):
num_generations: RuntimeParameter[int] = Field(
default=1, description="The number of generations to be produced per input."
)
use_default_structured_output: bool = True
use_default_structured_output: bool = False

def load(self) -> None:
"""Loads the LLM via the `LLM.load()` method."""
Expand Down Expand Up @@ -173,14 +174,24 @@ def _set_default_structured_output(self) -> None:
from distilabel.llms import InferenceEndpointsLLM
from distilabel.llms.base import AsyncLLM

def check_dependency(module_name: str) -> None:
if not importlib.util.find_spec(module_name):
raise ImportError(
f"`{module_name}` is not installed and is needed for the structured generation with this LLM."
f" Please install it using `pip install {module_name}`."
)

dependency = "outlines"
structured_output = {"schema": schema}
# To determine instructor or outlines format
if not (
isinstance(self.llm, AsyncLLM)
and not isinstance(self.llm, InferenceEndpointsLLM)
):
dependency = "instructor"
structured_output.update({"format": "json"})

check_dependency(dependency)
self.llm.structured_output = structured_output

def get_structured_output(self) -> Union[Dict[str, Any], None]:
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/steps/tasks/evol_instruct/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"input_batch_size": task.input_batch_size,
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": task.llm.__module__,
"name": task.llm.__class__.__name__,
Expand All @@ -152,6 +153,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"INCREASED_REASONING_STEPS": "I want you act as a Prompt Rewriter.\n\nYour objective is to rewrite a given prompt into a more complex version to make those famous AI systems (e.g., chatgpt and GPT4) a bit harder to handle.\n\nBut the rewritten prompt must be reasonable and must be understood and responded by humans.\n\nYour rewriting cannot omit the non-text parts such as the table and code in #The Given Prompt#:. Also, please do not omit the input in #The Given Prompt#.\n\nYou SHOULD complicate the given prompt using the following method: \nIf #The Given Prompt# can be solved with just a few simple thinking processes, you can rewrite it to explicitly request multiple-step reasoning.\n\nYou should try your best not to make the #Rewritten Prompt# become verbose, #Rewritten Prompt# can only add 10 to 20 words into #The Given Prompt#.\n\n'#The Given Prompt#', '#Rewritten Prompt#', 'given prompt' and 'rewritten prompt' are not allowed to appear in #Rewritten Prompt#\n\n#The Given Prompt#:\n<PROMPT>\n#Rewritten Prompt#:\n\n",
"BREADTH": "I want you act as a Prompt Creator.\n\nYour goal is to draw inspiration from the #Given Prompt# to create a brand new prompt.\n\nThis new prompt should belong to the same domain as the #Given Prompt# but be even more rare.\n\nThe LENGTH and complexity of the #Created Prompt# should be similar to that of the #Given Prompt#.\n\nThe #Created Prompt# must be reasonable and must be understood and responded by humans.\n\n'#Given Prompt#', '#Created Prompt#', 'given prompt' and 'created prompt' are not allowed to appear in #Created Prompt#\n\n#Given Prompt#:\n<PROMPT>\n#Created Prompt#:\n\n",
},
"use_default_structured_output": False,
"seed": task.seed,
"runtime_parameters_info": [
{
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/steps/tasks/evol_instruct/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "task",
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": task.llm.__class__.__module__,
"name": task.llm.__class__.__name__,
Expand Down Expand Up @@ -148,6 +149,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"min_length": task.min_length,
"max_length": task.max_length,
"seed": task.seed,
"use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "resources",
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/steps/tasks/evol_quality/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"input_batch_size": task.input_batch_size,
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": task.llm.__module__,
"name": task.llm.__class__.__name__,
Expand All @@ -117,6 +118,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"group_generations": task.group_generations,
"include_original_response": task.include_original_response,
"seed": task.seed,
"use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "resources",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/steps/tasks/magpie/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def test_serialization(self) -> None:
"group_generations": False,
"add_raw_output": True,
"num_generations": 1,
"use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "llm",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/steps/tasks/magpie/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_serialization(self) -> None:
"add_raw_output": True,
"num_generations": 1,
"num_rows": None,
"use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "llm",
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/steps/tasks/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def test_serialization(self) -> None:
"input_batch_size": 50,
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": "tests.unit.conftest",
"name": "DummyLLM",
Expand Down Expand Up @@ -389,6 +390,7 @@ def test_serialization(self) -> None:
"module": "tests.unit.steps.tasks.test_base",
"name": "DummyTask",
},
"use_default_structured_output": False,
}

with Pipeline(name="unit-test-pipeline") as pipeline:
Expand Down
Loading