Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mlx support #1089

Merged
merged 16 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ In addition, the following extras are available:
- `vertexai`: for using [Google Vertex AI](https://cloud.google.com/vertex-ai) proprietary models via the `VertexAILLM` integration.
- `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration.
- `sentence-transformers`: for generating sentence embeddings using [sentence-transformers](https://github.com/UKPLab/sentence-transformers).
- `mlx`: for using [MLX](https://github.com/ml-explore/mlx) models via the `MlxLLM` integration.

### Structured generation

Expand Down
2 changes: 2 additions & 0 deletions docs/sections/getting_started/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ Additionally, as part of `distilabel` some extra dependencies are available, mai

- `sentence-transformers`: for generating sentence embeddings using [sentence-transformers](https://github.com/UKPLab/sentence-transformers).

- `mlx`: for using [MLX](https://github.com/ml-explore/mlx) models via the `MlxLLM` integration.

### Data processing

- `ray`: for scaling and distributing a pipeline with [Ray](https://github.com/ray-project/ray).
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ text-clustering = [
"scikit-learn >= 1.4.1",
"matplotlib >= 3.8.3", # For the figure (even though it's optional)
]
mlx = ["mlx >= 0.21.0", "mlx-lm"]

# minhash
minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"]
Expand Down
4 changes: 4 additions & 0 deletions src/distilabel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from distilabel.models.embeddings.base import Embeddings
from distilabel.models.embeddings.llamacpp import LlamaCppEmbeddings
from distilabel.models.embeddings.sentence_transformers import (
SentenceTransformerEmbeddings,
)
Expand All @@ -28,6 +29,7 @@
from distilabel.models.llms.litellm import LiteLLM
from distilabel.models.llms.llamacpp import LlamaCppLLM
from distilabel.models.llms.mistral import MistralLLM
from distilabel.models.llms.mlx import MlxLLM
from distilabel.models.llms.moa import MixtureOfAgentsLLM
from distilabel.models.llms.ollama import OllamaLLM
from distilabel.models.llms.openai import OpenAILLM
Expand All @@ -52,9 +54,11 @@
"HiddenState",
"InferenceEndpointsLLM",
"LiteLLM",
"LlamaCppEmbeddings",
"LlamaCppLLM",
"MistralLLM",
"MixtureOfAgentsLLM",
"MlxLLM",
"OllamaLLM",
"OpenAILLM",
"SentenceTransformerEmbeddings",
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/models/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

__all__ = [
"Embeddings",
"LlamaCppEmbeddings",
"SentenceTransformerEmbeddings",
"vLLMEmbeddings",
"LlamaCppEmbeddings",
]
2 changes: 2 additions & 0 deletions src/distilabel/models/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from distilabel.models.llms.litellm import LiteLLM
from distilabel.models.llms.llamacpp import LlamaCppLLM
from distilabel.models.llms.mistral import MistralLLM
from distilabel.models.llms.mlx import MlxLLM
from distilabel.models.llms.moa import MixtureOfAgentsLLM
from distilabel.models.llms.ollama import OllamaLLM
from distilabel.models.llms.openai import OpenAILLM
Expand All @@ -48,6 +49,7 @@
"LlamaCppLLM",
"MistralLLM",
"MixtureOfAgentsLLM",
"MlxLLM",
"OllamaLLM",
"OpenAILLM",
"TogetherLLM",
Expand Down
288 changes: 288 additions & 0 deletions src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Union,
)

from pydantic import (
Field,
PrivateAttr,
validate_call,
)

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.models.llms.typing import GenerateOutput
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.typing import (
OutlinesStructuredOutputType,
StandardInput,
)

if TYPE_CHECKING:
import mlx.nn as nn
from mlx_lm.tokenizer_utils import TokenizerWrapper


class MlxLLM(LLM, MagpieChatTemplateMixin):
"""Apple MLX LLM implementation.

Attributes:
path_or_hf_repo: the path to the model or the Hugging Face Hub repo id.
tokenizer_config: the tokenizer configuration.
model_config: the model configuration.
adapter_path: the path to the adapter.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.

Icon:
`:apple:`

Examples:
Generate text:

```python
from distilabel.models.llms import MlxLLM

llm = MlxLLM(model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
"""

path_or_hf_repo: str
tokenizer_config: Dict[str, Any] = {}
model_config: Dict[str, Any] = {}
adapter_path: Optional[str] = None
structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
default=None,
description="The structured output format to use across all the generations.",
)

_mlx_generate: Optional[Callable] = PrivateAttr(default=None)
_model: Optional["nn.Module"] = PrivateAttr(...)
_tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(...)
_structured_output_logits_processor: Union[Callable, None] = PrivateAttr(
default=None
)

def load(self) -> None:
"""Loads the model and tokenizer and creates the text generation pipeline. In addition,
it will configure the tokenizer chat template."""
try:
import mlx # noqa
from mlx_lm import generate, load
except ImportError as ie:
raise ImportError(
"MLX is not installed. Please install it using `pip install 'distilabel[mlx]'`."
) from ie

self._model, self._tokenizer = load(
self.path_or_hf_repo,
tokenizer_config=self.tokenizer_config,
model_config=self.model_config,
adapter_path=self.adapter_path,
)

if self.structured_output:
self._structured_output_logits_processor = self._prepare_structured_output(
self.structured_output
)

if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token

self._mlx_generate = generate

super().load()

@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.path_or_hf_repo

def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.

Args:
input: the input list containing chat items.

Returns:
The prompt to send to the LLM.
"""
if self._tokenizer.chat_template is None:
return input[0]["content"]

prompt: str = (
self._tokenizer.apply_chat_template(
input,
tokenize=False,
add_generation_prompt=True,
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)

@validate_call
def generate(
self,
inputs: List[StandardInput],
num_generations: int = 1,
max_tokens: int = 256,
sampler: Optional[Callable] = None,
logits_processors: Optional[List[Callable]] = None,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> List[GenerateOutput]:
"""Generates `num_generations` responses for each input using the text generation
pipeline.

Args:
inputs: the inputs to generate responses for.
num_generations: the number of generations to create per input. Defaults to
`1`.
max_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
sampler: the sampler to use for the generation. Defaults to `None`.
logits_processors: the logits processors to use for the generation. Defaults to
`None`.
max_kv_size: the maximum size of the key-value cache. Defaults to `None`.
prompt_cache: the prompt cache to use for the generation. Defaults to `None`.
prefill_step_size: the prefill step size. Defaults to `512`.
kv_bits: the number of bits to use for the key-value cache. Defaults to `None`.
kv_group_size: the group size for the key-value cache. Defaults to `64`.
quantized_kv_start: the start of the quantized key-value cache. Defaults to `0`.
prompt_progress_callback: the callback to use for the generation. Defaults to
`None`.
temp: the temperature to use for the generation. Defaults to `None`.
repetition_penalty: the repetition penalty to use for the generation. Defaults to
`None`.
repetition_context_size: the context size for the repetition penalty. Defaults to
`None`.
top_p: the top-p value to use for the generation. Defaults to `None`.
min_p: the minimum p value to use for the generation. Defaults to `None`.
min_tokens_to_keep: the minimum number of tokens to keep. Defaults to `None`.

Returns:
A list of lists of strings containing the generated responses for each input.
"""
logits_processors = []
if self._structured_output_logits_processor:
logits_processors.append(self._structured_output_logits_processor)

structured_output = None
result = []
for input in inputs:
if isinstance(input, tuple):
input, structured_output = input

output: List[str] = []
for _ in range(num_generations):
if structured_output:
additional_logits_processors = self._prepare_structured_output(
structured_output
)
logits_processors.append(additional_logits_processors)
prompt = self.prepare_input(input)

generation = self._mlx_generate(
prompt=prompt,
model=self._model,
tokenizer=self._tokenizer,
logits_processors=logits_processors,
max_tokens=max_tokens,
sampler=sampler,
max_kv_size=max_kv_size,
prompt_cache=prompt_cache,
prefill_step_size=prefill_step_size,
kv_bits=kv_bits,
kv_group_size=kv_group_size,
quantized_kv_start=quantized_kv_start,
prompt_progress_callback=prompt_progress_callback,
temp=temp,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
top_p=top_p,
min_p=min_p,
min_tokens_to_keep=min_tokens_to_keep,
)

output.append(generation)

result.append(
prepare_output(
output,
input_tokens=[compute_tokens(input, self._tokenizer.encode)],
output_tokens=[
compute_tokens(
text_or_messages=generation,
tokenizer=self._tokenizer.encode,
)
for generation in output
],
)
)
return result

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, None]:
"""Creates the appropriate function to filter tokens to generate structured outputs.

Args:
structured_output: the configuration dict to prepare the structured output.

Returns:
The callable that will be used to guide the generation of the model.
"""
from distilabel.steps.tasks.structured_outputs.outlines import (
prepare_guided_output,
)

result = prepare_guided_output(
structured_output, "transformers", self._pipeline
)
if schema := result.get("schema"):
self.structured_output["schema"] = schema
return result["processor"]
1 change: 1 addition & 0 deletions src/distilabel/models/llms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def prepare_output(
generations: The outputs from an LLM.
input_tokens: The number of tokens of the inputs. Defaults to `None`.
output_tokens: The number of tokens of the LLM response. Defaults to `None`.
logprobs: The logprobs of the LLM response. Defaults to `None`.

Returns:
Output generation from an LLM.
Expand Down
Loading
Loading