Skip to content

Commit

Permalink
Merge pull request #1104 from argilla-io/develop
Browse files Browse the repository at this point in the history
`1.5.1`
  • Loading branch information
gabrielmbmb authored Jan 17, 2025
2 parents b261b23 + 34e84e3 commit 1c6a854
Show file tree
Hide file tree
Showing 14 changed files with 73 additions and 91 deletions.
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"orjson >= 3.10.0",
"universal_pathlib >= 0.2.2",
"portalocker >= 2.8.2",
"setuptools",
]
dynamic = ["version"]

Expand Down Expand Up @@ -90,9 +91,7 @@ ray = ["ray[default] >= 2.31.0"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
vllm = [
"vllm >= 0.5.3",
"filelock >= 3.13.4",
# `setuptools` is needed to be installed if installed with `uv pip install distilabel[vllm]`
"setuptools",
"filelock >= 3.13.4"
]
sentence-transformers = ["sentence-transformers >= 3.0.0"]
faiss-cpu = ["faiss-cpu >= 1.8.0"]
Expand All @@ -102,7 +101,7 @@ text-clustering = [
"scikit-learn >= 1.4.1",
"matplotlib >= 3.8.3", # For the figure (even though it's optional)
]
mlx = ["mlx >= 0.21.0", "mlx-lm"]
mlx = ["mlx >= 0.21.0", "mlx-lm >= 0.21.0, < 0.22.0"]
vision = ["Pillow >= 10.3.0"] # To work with images.

# minhash
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

from rich import traceback as rich_traceback

__version__ = "1.5.0"
__version__ = "1.6.0"

rich_traceback.install(show_locals=True)
2 changes: 2 additions & 0 deletions src/distilabel/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from distilabel.models.llms.litellm import LiteLLM
from distilabel.models.llms.llamacpp import LlamaCppLLM
from distilabel.models.llms.mistral import MistralLLM
from distilabel.models.llms.mlx import MlxLLM
from distilabel.models.llms.moa import MixtureOfAgentsLLM
from distilabel.models.llms.ollama import OllamaLLM
from distilabel.models.llms.openai import OpenAILLM
Expand All @@ -59,6 +60,7 @@
"LlamaCppLLM",
"MistralLLM",
"MixtureOfAgentsLLM",
"MlxLLM",
"OllamaLLM",
"OpenAILLM",
"TogetherLLM",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
InferenceEndpointsBaseClient,
)
from distilabel.models.image_generation.base import AsyncImageGenerationModel
from distilabel.models.image_generation.utils import image_to_str

if TYPE_CHECKING:
from PIL.Image import Image
Expand Down Expand Up @@ -60,10 +59,14 @@ class InferenceEndpointsImageGeneration( # type: ignore
"""

def load(self) -> None:
from distilabel.models.image_generation.utils import image_to_str

# Sets the logger and calls the load method of the BaseClient
AsyncImageGenerationModel.load(self)
InferenceEndpointsBaseClient.load(self)

self._image_to_str = image_to_str

@validate_call
async def agenerate( # type: ignore
self,
Expand Down Expand Up @@ -101,6 +104,6 @@ async def agenerate( # type: ignore
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
img_str = image_to_str(image, image_format="JPEG")
img_str = self._image_to_str(image, image_format="JPEG")

return [{"images": [img_str]}]
4 changes: 2 additions & 2 deletions src/distilabel/models/image_generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from PIL import Image


def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str:
def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str:
"""Converts a PIL Image to a base64 encoded string."""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
return base64.b64encode(buffered.getvalue()).decode("utf-8")


def image_from_str(image_str: str) -> Image.Image:
def image_from_str(image_str: str) -> "Image.Image":
"""Converts a base64 encoded string to a PIL Image."""
image_bytes = base64.b64decode(image_str)
return Image.open(io.BytesIO(image_bytes))
85 changes: 43 additions & 42 deletions src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
Dict,
List,
Optional,
Union,
)

from pydantic import (
Field,
PrivateAttr,
validate_call,
)
Expand All @@ -42,7 +44,7 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
Attributes:
path_or_hf_repo: the path to the model or the Hugging Face Hub repo id.
tokenizer_config: the tokenizer configuration.
model_config: the model configuration.
mlx_model_config: the MLX model configuration.
adapter_path: the path to the adapter.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
Expand All @@ -60,7 +62,7 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
```python
from distilabel.models.llms import MlxLLM
llm = MlxLLM(model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")
llm = MlxLLM(path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")
llm.load()
Expand All @@ -70,20 +72,22 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
"""

path_or_hf_repo: str
tokenizer_config: Dict[str, Any] = {}
model_config: Dict[str, Any] = {}
tokenizer_config: Dict[str, Any] = Field(default_factory=dict)
mlx_model_config: Dict[str, Any] = Field(default_factory=dict)
adapter_path: Optional[str] = None

_mlx_generate: Optional[Callable] = PrivateAttr(default=None)
_model: Optional["nn.Module"] = PrivateAttr(...)
_tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(...)
_model: Optional["nn.Module"] = PrivateAttr(None)
_tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(None)
_mlx_generate: Optional[Callable] = PrivateAttr(None)
_make_sampler: Optional[Callable] = PrivateAttr(None)

def load(self) -> None:
"""Loads the model and tokenizer and creates the text generation pipeline. In addition,
it will configure the tokenizer chat template."""
try:
import mlx # noqa
from mlx_lm import generate, load
from mlx_lm.utils import generate, load
from mlx_lm.sample_utils import make_sampler
except ImportError as ie:
raise ImportError(
"MLX is not installed. Please install it using `pip install 'distilabel[mlx]'`."
Expand All @@ -92,23 +96,23 @@ def load(self) -> None:
self._model, self._tokenizer = load(
self.path_or_hf_repo,
tokenizer_config=self.tokenizer_config,
model_config=self.model_config,
model_config=self.mlx_model_config,
adapter_path=self.adapter_path,
)

if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token

self._mlx_generate = generate

self._make_sampler = make_sampler
super().load()

@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.path_or_hf_repo

def prepare_input(self, input: "StandardInput") -> str:
def prepare_input(self, input: Union["StandardInput", str]) -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.
Expand All @@ -118,11 +122,11 @@ def prepare_input(self, input: "StandardInput") -> str:
Returns:
The prompt to send to the LLM.
"""
if self._tokenizer.chat_template is None:
return input[0]["content"]
if isinstance(input, str):
return input

prompt: str = (
self._tokenizer.apply_chat_template(
self._tokenizer.apply_chat_template( # type: ignore
input,
tokenize=False,
add_generation_prompt=True,
Expand All @@ -133,12 +137,11 @@ def prepare_input(self, input: "StandardInput") -> str:
return super().apply_magpie_pre_query_template(prompt, input)

@validate_call
def generate(
def generate( # type: ignore
self,
inputs: List[StandardInput],
inputs: List[Union[StandardInput, str]],
num_generations: int = 1,
max_tokens: int = 256,
sampler: Optional[Callable] = None,
logits_processors: Optional[List[Callable]] = None,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
Expand All @@ -147,12 +150,11 @@ def generate(
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
temp: float = 0.0,
top_p: float = 0.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
top_k: int = -1,
) -> List[GenerateOutput]:
"""Generates `num_generations` responses for each input using the text generation
pipeline.
Expand All @@ -163,7 +165,6 @@ def generate(
`1`.
max_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
sampler: the sampler to use for the generation. Defaults to `None`.
logits_processors: the logits processors to use for the generation. Defaults to
`None`.
max_kv_size: the maximum size of the key-value cache. Defaults to `None`.
Expand All @@ -174,18 +175,24 @@ def generate(
quantized_kv_start: the start of the quantized key-value cache. Defaults to `0`.
prompt_progress_callback: the callback to use for the generation. Defaults to
`None`.
temp: the temperature to use for the generation. Defaults to `None`.
repetition_penalty: the repetition penalty to use for the generation. Defaults to
`None`.
repetition_context_size: the context size for the repetition penalty. Defaults to
`None`.
top_p: the top-p value to use for the generation. Defaults to `None`.
min_p: the minimum p value to use for the generation. Defaults to `None`.
min_tokens_to_keep: the minimum number of tokens to keep. Defaults to `None`.
temp: The temperature for text generation. Defaults to `0.0`.
top_p: The top-p value used for the generation. Defaults to `0.0`.
min_p: The min-p value used for the generation. Defaults to `0.0`.
min_tokens_to_keep: Minimum number of tokens to keep for sampling after
filtering. Must be at least 1. Defaults to `1`.
top_k: The top-k value used for the generation. Defaults to `-1`.
Returns:
A list of lists of strings containing the generated responses for each input.
"""

sampler = self._make_sampler( # type: ignore
temp=temp,
top_p=top_p,
min_p=min_p,
min_tokens_to_keep=min_tokens_to_keep,
top_k=top_k,
)
structured_output = None
result = []
for input in inputs:
Expand All @@ -197,7 +204,7 @@ def generate(
if structured_output: # will raise a NotImplementedError
self._prepare_structured_output(structured_output)
prompt = self.prepare_input(input)
generation = self._mlx_generate(
generation = self._mlx_generate( # type: ignore
prompt=prompt,
model=self._model,
tokenizer=self._tokenizer,
Expand All @@ -211,24 +218,18 @@ def generate(
kv_group_size=kv_group_size,
quantized_kv_start=quantized_kv_start,
prompt_progress_callback=prompt_progress_callback,
temp=temp,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
top_p=top_p,
min_p=min_p,
min_tokens_to_keep=min_tokens_to_keep,
)

output.append(generation)

result.append(
prepare_output(
output,
input_tokens=[compute_tokens(input, self._tokenizer.encode)],
generations=output,
input_tokens=[compute_tokens(input, self._tokenizer.encode)], # type: ignore
output_tokens=[
compute_tokens(
text_or_messages=generation,
tokenizer=self._tokenizer.encode,
tokenizer=self._tokenizer.encode, # type: ignore
)
for generation in output
],
Expand Down
3 changes: 1 addition & 2 deletions src/distilabel/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from distilabel.steps.clustering.umap import UMAP
from distilabel.steps.columns.combine import CombineOutputs
from distilabel.steps.columns.expand import ExpandColumns
from distilabel.steps.columns.group import CombineColumns, GroupColumns
from distilabel.steps.columns.group import GroupColumns
from distilabel.steps.columns.keep import KeepColumns
from distilabel.steps.columns.merge import MergeColumns
from distilabel.steps.decorator import step
Expand Down Expand Up @@ -60,7 +60,6 @@
__all__ = [
"DBSCAN",
"UMAP",
"CombineColumns",
"CombineOutputs",
"ConversationTemplate",
"DataSampler",
Expand Down
15 changes: 1 addition & 14 deletions src/distilabel/steps/columns/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, List, Optional

from typing_extensions import override

Expand Down Expand Up @@ -125,15 +124,3 @@ def process(self, *inputs: StepInput) -> "StepOutput":
group_columns=self.inputs,
output_group_columns=self.outputs,
)


class CombineColumns(GroupColumns):
"""`CombineColumns` is deprecated and will be removed in version 1.5.0, use `GroupColumns` instead."""

def __init__(self, **data: Any) -> None:
warnings.warn(
"`CombineColumns` is deprecated and will be removed in version 1.5.0, use `GroupColumns` instead.",
DeprecationWarning,
stacklevel=2,
)
return super().__init__(**data)
10 changes: 8 additions & 2 deletions src/distilabel/steps/tasks/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import hashlib
from typing import TYPE_CHECKING

from distilabel.models.image_generation.utils import image_from_str
from distilabel.steps.base import StepInput
from distilabel.steps.tasks.base import ImageTask

Expand Down Expand Up @@ -117,6 +116,13 @@ class ImageGeneration(ImageTask):
save_artifacts: bool = False
image_format: str = "JPEG"

def load(self) -> None:
from distilabel.models.image_generation.utils import image_from_str

super().load()

self._image_from_str = image_from_str

@property
def inputs(self) -> "StepColumns":
return ["prompt"]
Expand Down Expand Up @@ -166,7 +172,7 @@ def process(self, inputs: StepInput) -> "StepOutput":
# use prompt as filename
prompt_hash = hashlib.md5(input["prompt"].encode()).hexdigest()
# Build PIL image to save it
image = image_from_str(image)
image = self._image_from_str(image)

self.save_artifact(
name="images",
Expand Down
Loading

0 comments on commit 1c6a854

Please sign in to comment.