Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to outlines010 #1092

Merged
merged 42 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
329b645
add outlines 0.1.0 support
davidberenstein1957 Jan 9, 2025
9dd4be9
update tests
davidberenstein1957 Jan 9, 2025
3ce1ff3
fix passing tokenizer to regex processor as well
davidberenstein1957 Jan 9, 2025
d8d7b35
fix test by specifically passing None as token to transformersllm
davidberenstein1957 Jan 9, 2025
2e0b42c
fix tests by increeasing the temperature to avoid exploding beam sear…
davidberenstein1957 Jan 9, 2025
5ee7dce
fix logit processor assignment during generation
davidberenstein1957 Jan 9, 2025
0d26a1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
47e38dc
add support transformers
davidberenstein1957 Jan 9, 2025
66ac934
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
61c3538
remove duplicate import
davidberenstein1957 Jan 9, 2025
a3b4f9c
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 9, 2025
0738b27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
8e6613b
remove duplicate
davidberenstein1957 Jan 9, 2025
7db1b0b
Merge branch 'feat/1081-feature-update-to-outlines010' of https://git…
davidberenstein1957 Jan 9, 2025
cb4c2ce
remove duplicate import
davidberenstein1957 Jan 9, 2025
7f20d9f
return content when nog chat template is present
davidberenstein1957 Jan 9, 2025
61aa597
refactor clean code
davidberenstein1957 Jan 9, 2025
b994f06
chore refactor
davidberenstein1957 Jan 9, 2025
a47963d
refactor logic if else statement
davidberenstein1957 Jan 9, 2025
a0f8acd
fix import when outlines is not present
davidberenstein1957 Jan 9, 2025
b41d6f0
chore pin transformers version
davidberenstein1957 Jan 9, 2025
d2fdd4c
chore add context w.r.t. logit processor
davidberenstein1957 Jan 9, 2025
2b8f634
chore bump version
davidberenstein1957 Jan 9, 2025
ed5f00f
add simplification of transformers implementation
davidberenstein1957 Jan 9, 2025
473de03
Update .gitignore to exclude .DS_Store files and remove vllm subproje…
davidberenstein1957 Jan 10, 2025
995e4d4
Refactor outlines version check and logits processor handling
davidberenstein1957 Jan 10, 2025
5960441
Refactor logits processor handling in LlamaCppLLM
davidberenstein1957 Jan 10, 2025
cfac574
Refactor outlines import and logits processor handling in Transformer…
davidberenstein1957 Jan 10, 2025
3378769
Refactor outlines version check and update function naming
davidberenstein1957 Jan 10, 2025
d56b6bc
Refactor processor handling in LlamaCppLLM and TransformersLLM based …
davidberenstein1957 Jan 10, 2025
110ecaf
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 10, 2025
4056f08
Refactor structured output return types in LlamaCppLLM, MlxLLM, and T…
davidberenstein1957 Jan 10, 2025
11a7957
Enhance MlxLLM integration and expand framework support
davidberenstein1957 Jan 10, 2025
e9fefc4
Refactor structured output handling in LlamaCppLLM and MlxLLM
davidberenstein1957 Jan 10, 2025
df24685
Refactor MlxLLM structured output handling and remove unused components
davidberenstein1957 Jan 10, 2025
65272bd
Refactor logits processor handling in TransformersLLM
davidberenstein1957 Jan 10, 2025
7fc1762
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 10, 2025
d2eda4e
Refactor type hints in outlines.py for improved clarity
davidberenstein1957 Jan 10, 2025
85494c4
Refactor type hint imports in outlines.py for improved clarity
davidberenstein1957 Jan 10, 2025
f6a50f0
Merge branch 'develop' into feat/1081-feature-update-to-outlines010
davidberenstein1957 Jan 10, 2025
01ea5f1
Refactor regex processor handling in prepare_guided_output function
davidberenstein1957 Jan 10, 2025
399154e
Update transformer dependency constraints in pyproject.toml
davidberenstein1957 Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions src/distilabel/models/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR

if TYPE_CHECKING:
from transformers import Pipeline
from transformers import LogitsProcessorList, Pipeline
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer

Expand Down Expand Up @@ -111,6 +111,7 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):

_pipeline: Optional["Pipeline"] = PrivateAttr(...)
_prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None)
_logits_processor: Optional["LogitsProcessorList"] = PrivateAttr(default=None)

def load(self) -> None:
"""Loads the model and tokenizer and creates the text generation pipeline. In addition,
Expand All @@ -119,7 +120,7 @@ def load(self) -> None:
CudaDevicePlacementMixin.load(self)

try:
from transformers import pipeline
from transformers import LogitsProcessorList, pipeline
except ImportError as ie:
raise ImportError(
"Transformers is not installed. Please install it using `pip install transformers`."
Expand Down Expand Up @@ -149,10 +150,20 @@ def load(self) -> None:
self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token # type: ignore

if self.structured_output:
self._prefix_allowed_tokens_fn = self._prepare_structured_output(
self.structured_output
from distilabel.steps.tasks.structured_outputs.outlines import (
outlines_below_0_1_0,
)

if outlines_below_0_1_0:
self._prefix_allowed_tokens_fn = self._prepare_structured_output(
self.structured_output
)
else:
logits_processor = self._prepare_structured_output(
self.structured_output
)
self._logits_processor = LogitsProcessorList([logits_processor])

super().load()

def unload(self) -> None:
Expand Down Expand Up @@ -222,7 +233,7 @@ def generate( # type: ignore
"""
prepared_inputs = [self.prepare_input(input=input) for input in inputs]

outputs: List[List[Dict[str, str]]] = self._pipeline( # type: ignore
outputs: List[List[Dict[str, str]]] = self._pipeline(
prepared_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
Expand All @@ -232,7 +243,8 @@ def generate( # type: ignore
do_sample=do_sample,
num_return_sequences=num_generations,
prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore
logits_processor=self._logits_processor,
pad_token_id=self._pipeline.tokenizer.eos_token_id,
)
llm_output = [
[generation["generated_text"] for generation in output]
Expand Down
73 changes: 53 additions & 20 deletions src/distilabel/steps/tasks/structured_outputs/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_args,
)

import pkg_resources
from pydantic import BaseModel

from distilabel.errors import DistilabelUserError
Expand All @@ -36,7 +37,11 @@
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType

Frameworks = Literal["transformers", "llamacpp", "vllm"]
"""Available frameworks for the structured output configuration. """
# Available frameworks for the structured output configuration.
_outlines_version = pkg_resources.get_distribution("outlines").version
outlines_below_0_1_0 = pkg_resources.parse_version(
_outlines_version
) < pkg_resources.parse_version("0.1.0")


def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]:
Expand All @@ -46,31 +51,56 @@ def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]:

def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:
"""Helper function to return the appropriate logits processor for the given framework."""
if framework == "transformers":
from outlines.integrations.transformers import (
JSONPrefixAllowedTokens,
RegexPrefixAllowedTokens,
if framework not in Frameworks.__args__:
davidberenstein1957 marked this conversation as resolved.
Show resolved Hide resolved
raise DistilabelUserError(
f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}",
page="sections/how_to_guides/advanced/structured_generation/",
)

return JSONPrefixAllowedTokens, RegexPrefixAllowedTokens
if outlines_below_0_1_0:
if framework == "transformers":
from outlines.integrations.transformers import (
JSONPrefixAllowedTokens,
RegexPrefixAllowedTokens,
)

if framework == "llamacpp":
from outlines.integrations.llamacpp import (
JSONLogitsProcessor,
RegexLogitsProcessor,
)
return JSONPrefixAllowedTokens, RegexPrefixAllowedTokens

return JSONLogitsProcessor, RegexLogitsProcessor
if framework == "llamacpp":
from outlines.integrations.llamacpp import (
JSONLogitsProcessor,
RegexLogitsProcessor,
)

return JSONLogitsProcessor, RegexLogitsProcessor

if framework == "vllm":
from outlines.integrations.vllm import JSONLogitsProcessor, RegexLogitsProcessor
if framework == "vllm":
from outlines.integrations.vllm import (
JSONLogitsProcessor,
RegexLogitsProcessor,
)

return JSONLogitsProcessor, RegexLogitsProcessor
else:
from outlines.processors import JSONLogitsProcessor, RegexLogitsProcessor

return JSONLogitsProcessor, RegexLogitsProcessor

raise DistilabelUserError(
f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}",
page="sections/how_to_guides/advanced/structured_generation/",
)

def _get_outlines_tokenizer_or_model(llm: Any, framework: Frameworks) -> Callable:
davidberenstein1957 marked this conversation as resolved.
Show resolved Hide resolved
if outlines_below_0_1_0:
return llm
else:
if framework == "llamacpp":
davidberenstein1957 marked this conversation as resolved.
Show resolved Hide resolved
from outlines.models.llamacpp import LlamaCppTokenizer

return LlamaCppTokenizer(llm)
elif framework == "transformers":
from outlines.models.transformers import TransformerTokenizer

return TransformerTokenizer(llm.tokenizer)
elif framework == "vllm":
return llm.get_tokenizer()


def prepare_guided_output(
Expand All @@ -97,13 +127,16 @@ def prepare_guided_output(
case of "json" will also include the schema as a dict, to simplify serialization
and deserialization.
"""

if not importlib.util.find_spec("outlines"):
davidberenstein1957 marked this conversation as resolved.
Show resolved Hide resolved
raise ImportError(
"Outlines is not installed. Please install it using `pip install outlines`."
)

json_processor, regex_processor = _get_logits_processor(framework)

tokenizer_or_model = _get_outlines_tokenizer_or_model(llm, framework)

format = structured_output.get("format")
schema = structured_output.get("schema")

Expand All @@ -120,14 +153,14 @@ def prepare_guided_output(
return {
"processor": json_processor(
schema,
llm,
tokenizer_or_model,
whitespace_pattern=structured_output.get("whitespace_pattern"),
),
"schema": schema_as_dict(schema),
}

if format == "regex":
return {"processor": regex_processor(schema, llm)}
return {"processor": regex_processor(schema, tokenizer_or_model)}

raise DistilabelUserError(
f"Invalid format '{format}'. Must be either 'json' or 'regex'.",
Expand Down
14 changes: 9 additions & 5 deletions tests/unit/steps/tasks/structured_outputs/test_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from distilabel.models.llms.huggingface.transformers import TransformersLLM
from distilabel.steps.tasks.structured_outputs.outlines import (
model_to_schema,
outlines_below_0_1_0,
)
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType

Expand Down Expand Up @@ -100,9 +101,6 @@ class DummyUserTest(BaseModel):
}


@pytest.mark.skip(
reason="won't work until we update our code to work with `outlines>0.1.0`"
)
class TestOutlinesIntegration:
@pytest.mark.parametrize(
"format, schema, prompt",
Expand Down Expand Up @@ -138,7 +136,7 @@ def test_generation(
prompt = [
[{"role": "system", "content": ""}, {"role": "user", "content": prompt}]
]
result = llm.generate(prompt, max_new_tokens=30)
result = llm.generate(prompt, max_new_tokens=30, temperature=0.7)
assert isinstance(result, list)
assert isinstance(result[0], dict)
assert "generations" in result[0] and "statistics" in result[0]
Expand Down Expand Up @@ -174,6 +172,7 @@ def test_serialization(
structured_output=OutlinesStructuredOutputType(
format=format, schema=schema
),
token=None,
)
llm.load()
assert llm.dump() == dump
Expand All @@ -182,4 +181,9 @@ def test_load_from_dict(self) -> None:
llm = TransformersLLM.from_dict(DUMP_JSON)
assert isinstance(llm, TransformersLLM)
llm.load()
assert llm._prefix_allowed_tokens_fn is not None
if outlines_below_0_1_0:
assert llm._prefix_allowed_tokens_fn is not None
assert llm._logits_processor is None
else:
assert llm._prefix_allowed_tokens_fn is None
assert llm._logits_processor is not None
Loading