-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fittable #140
base: main
Are you sure you want to change the base?
Add fittable #140
Changes from all commits
4078a3b
09f888d
2167a4e
c95dca5
b5d8bb7
3e68669
1ae4d61
1349b0c
4b83d59
9515b83
dfd865b
c4ba272
4713bfa
e8058bb
a59127e
310fbb5
b1899d1
e27f9dc
8df3aaf
839d88a
8457357
a750709
e83c54e
803565d
9052806
2f9fbf4
f1e08c3
bb54a76
69ee4ee
9962be7
ffec235
0af84fc
c829745
9ce65a1
e1169fb
f096824
ff3ebdf
b4e966a
8f65bfd
8cdb668
e96a72a
3e76083
9f1cb5a
e2d92b9
657cef0
773009f
7015341
8ab8456
e21e61f
ff75af9
87de7c4
1fb33f1
59f0076
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Inference | ||
|
||
This subpackage mainly contains helper functions for inference with trained models that have been exported to `scikit-learn` compatible pipelines. | ||
|
||
If you're looking for information on how to train a model, see [here](../train/README.md). | ||
|
||
# Usage | ||
|
||
Let's assume you're using our `potion-edu classifier`. | ||
|
||
```python | ||
from model2vec.inference import StaticModelPipeline | ||
|
||
s = StaticModelPipeline.from_pretrained("minishlab/potion-8m-edu-classifier") | ||
label = s.predict("Attitudes towards cattle in the Alps: a study in letting go.") | ||
``` | ||
|
||
This should just work. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from model2vec.utils import get_package_extras, importable | ||
|
||
_REQUIRED_EXTRA = "inference" | ||
|
||
for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA): | ||
importable(extra_dependency, _REQUIRED_EXTRA) | ||
|
||
from model2vec.inference.model import StaticModelPipeline | ||
|
||
__all__ = ["StaticModelPipeline"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from __future__ import annotations | ||
|
||
import re | ||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
|
||
import huggingface_hub | ||
import numpy as np | ||
import skops.io | ||
from sklearn.neural_network import MLPClassifier | ||
from sklearn.pipeline import Pipeline | ||
|
||
from model2vec.model import PathLike, StaticModel | ||
|
||
_DEFAULT_TRUST_PATTERN = re.compile(r"sklearn\..+") | ||
_DEFAULT_MODEL_FILENAME = "pipeline.skops" | ||
|
||
|
||
class StaticModelPipeline: | ||
def __init__(self, model: StaticModel, head: Pipeline) -> None: | ||
"""Create a pipeline with a StaticModel encoder.""" | ||
self.model = model | ||
self.head = head | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls: type[StaticModelPipeline], path: PathLike, token: str | None = None | ||
) -> StaticModelPipeline: | ||
""" | ||
Load a StaticModel from a local path or huggingface hub path. | ||
|
||
NOTE: if you load a private model from the huggingface hub, you need to pass a token. | ||
|
||
:param path: The path to the folder containing the pipeline, or a repository on the Hugging Face Hub | ||
:param token: The token to use to download the pipeline from the hub. | ||
:return: The loaded pipeline. | ||
""" | ||
model, head = _load_pipeline(path, token) | ||
model.embedding = np.nan_to_num(model.embedding) | ||
|
||
return cls(model, head) | ||
|
||
def save_pretrained(self, path: str) -> None: | ||
"""Save the model to a folder.""" | ||
save_pipeline(self, path) | ||
|
||
def push_to_hub(self, repo_id: str, token: str | None = None, private: bool = False) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a modelcard and perhaps tags or a library reference, this helps a lot with visibility, usability and findability. https://huggingface.co/docs/hub/model-cards#specifying-a-library There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This actually already happens because we push the underlying static model to the hub, which has a model card. This model card template is specified in the root of the code. |
||
""" | ||
Save a model to a folder, and then push that folder to the hf hub. | ||
|
||
:param repo_id: The id of the repository to push to. | ||
:param token: The token to use to push to the hub. | ||
:param private: Whether the repository should be private. | ||
""" | ||
from model2vec.hf_utils import push_folder_to_hub | ||
|
||
with TemporaryDirectory() as temp_dir: | ||
save_pipeline(self, temp_dir) | ||
self.model.save_pretrained(temp_dir) | ||
push_folder_to_hub(Path(temp_dir), repo_id, private, token) | ||
|
||
def _predict_and_coerce_to_2d(self, X: list[str] | str) -> np.ndarray: | ||
"""Predict the labels of the input and coerce the output to a matrix.""" | ||
encoded = self.model.encode(X) | ||
if np.ndim(encoded) == 1: | ||
encoded = encoded[None, :] | ||
|
||
return encoded | ||
|
||
def predict(self, X: list[str] | str) -> np.ndarray: | ||
"""Predict the labels of the input.""" | ||
encoded = self._predict_and_coerce_to_2d(X) | ||
|
||
return self.head.predict(encoded) | ||
|
||
def predict_proba(self, X: list[str] | str) -> np.ndarray: | ||
"""Predict the probabilities of the labels of the input.""" | ||
encoded = self._predict_and_coerce_to_2d(X) | ||
|
||
return self.head.predict_proba(encoded) | ||
stephantul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _load_pipeline( | ||
folder_or_repo_path: PathLike, token: str | None = None, trust_remote_code: bool = False | ||
) -> tuple[StaticModel, Pipeline]: | ||
""" | ||
Load a model and an sklearn pipeline. | ||
|
||
This assumes the following files are present in the repo: | ||
- `pipeline.skops`: The head of the pipeline. | ||
- `config.json`: The configuration of the model. | ||
- `model.safetensors`: The weights of the model. | ||
- `tokenizer.json`: The tokenizer of the model. | ||
|
||
:param folder_or_repo_path: The path to the folder containing the pipeline. | ||
:param token: The token to use to download the pipeline from the hub. If this is None, you will only | ||
be able to load the pipeline from a local folder, public repository, or a repository that you have access to | ||
because you are logged in. | ||
:param trust_remote_code: Whether to trust the remote code. If this is False, | ||
we will only load components coming from `sklearn`. If this is True, we will load all components. | ||
If you set this to True, you are responsible for whatever happens. | ||
:return: The encoder model and the loaded head | ||
:raises FileNotFoundError: If the pipeline file does not exist in the folder. | ||
:raises ValueError: If an untrusted type is found in the pipeline, and `trust_remote_code` is False. | ||
""" | ||
folder_or_repo_path = Path(folder_or_repo_path) | ||
model_filename = _DEFAULT_MODEL_FILENAME | ||
if folder_or_repo_path.exists(): | ||
head_pipeline_path = folder_or_repo_path / model_filename | ||
if not head_pipeline_path.exists(): | ||
raise FileNotFoundError(f"Pipeline file does not exist in {folder_or_repo_path}") | ||
else: | ||
head_pipeline_path = huggingface_hub.hf_hub_download( | ||
folder_or_repo_path.as_posix(), model_filename, token=token | ||
) | ||
|
||
model = StaticModel.from_pretrained(folder_or_repo_path) | ||
|
||
unknown_types = skops.io.get_untrusted_types(file=head_pipeline_path) | ||
# If the user does not trust remote code, we should check that the unknown types are trusted. | ||
# By default, we trust everything coming from scikit-learn. | ||
if not trust_remote_code: | ||
for t in unknown_types: | ||
if not _DEFAULT_TRUST_PATTERN.match(t): | ||
raise ValueError(f"Untrusted type {t}.") | ||
head = skops.io.load(head_pipeline_path, trusted=unknown_types) | ||
|
||
return model, head | ||
|
||
|
||
def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> None: | ||
""" | ||
Save a pipeline to a folder. | ||
|
||
:param pipeline: The pipeline to save. | ||
:param folder_path: The path to the folder to save the pipeline to. | ||
""" | ||
folder_path = Path(folder_path) | ||
folder_path.mkdir(parents=True, exist_ok=True) | ||
model_filename = _DEFAULT_MODEL_FILENAME | ||
head_pipeline_path = folder_path / model_filename | ||
skops.io.dump(pipeline.head, head_pipeline_path) | ||
pipeline.model.save_pretrained(folder_path) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Training | ||
|
||
Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html). | ||
|
||
# Installation | ||
|
||
To train, make sure you install the training extra: | ||
|
||
``` | ||
pip install model2vec[training] | ||
``` | ||
|
||
# Quickstart | ||
|
||
To train a model, simply initialize it using a `StaticModel`, or from a pre-trained model, as follows: | ||
|
||
```python | ||
from model2vec.distill import distill | ||
from model2vec.train import StaticModelForClassification | ||
|
||
# From a distilled model | ||
distilled_model = distill("baai/bge-base-en-v1.5") | ||
classifier = StaticModelForClassification.from_static_model(distilled_model) | ||
|
||
# From a pre-trained model: potion is the default | ||
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-8m") | ||
``` | ||
|
||
This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for `from_pretrained` is [potion-base-8m](https://huggingface.co/minishlab/potion-base-8M), our best model to date. This is our recommended path if you're working with general English data. | ||
|
||
Now that you have created the classifier, let's just train a model. The example below assumes you have the [`datasets`](https://github.com/huggingface/datasets) library installed. | ||
|
||
```python | ||
import numpy as np | ||
from datasets import load_dataset | ||
|
||
# Load the subj dataset | ||
ds = load_dataset("setfit/subj") | ||
train = ds["train"] | ||
test = ds["test"] | ||
|
||
s = perf_counter() | ||
classifier = classifier.fit(train["text"], train["label"]) | ||
|
||
predicted = classifier.predict(test["text"]) | ||
print(f"Training took {int(perf_counter() - s)} seconds.") | ||
# Training took 81 seconds | ||
accuracy = np.mean([x == y for x, y in zip(predicted, test["label"])]) * 100 | ||
print(f"Achieved {accuracy} test accuracy") | ||
# Achieved 91.0 test accuracy | ||
``` | ||
|
||
As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training. | ||
|
||
The training loop is handled by [`lightning`](https://pypi.org/project/lightning/). By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5. | ||
|
||
Note that this model is as fast as you're used to from us: | ||
|
||
```python | ||
from time import perf_counter | ||
|
||
s = perf_counter() | ||
classifier.predict(test["text"]) | ||
print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.") | ||
# Took 67 milliseconds for 2000 instances on CPU. | ||
``` | ||
|
||
# Persistence | ||
|
||
You can turn a classifier into a scikit-learn compatible pipeline, as follows: | ||
|
||
```python | ||
pipeline = classifier.to_pipeline() | ||
``` | ||
|
||
This pipeline object can be persisted using standard pickle-based methods, such as [joblib](https://joblib.readthedocs.io/en/stable/). This makes it easy to use your model in inferene pipelines (no installing torch!), although `joblib` and `pickle` should not be used to share models outside of your organization. | ||
|
||
If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions: | ||
|
||
```python | ||
pipeline.save_pretrained(path) | ||
pipeline.push_to_hub("my_cool/project") | ||
``` | ||
|
||
Later, you can load these as follows: | ||
|
||
```python | ||
from model2vec.inference import StaticModelPipeline | ||
|
||
pipeline = StaticModelPipeline.from_pretrained("my_cool/project") | ||
``` | ||
|
||
Loading pipelines in this way is _extremely_ fast. It takes only 30ms to load a pipeline from disk. | ||
|
||
# Results | ||
|
||
The main results are detailed in our training blogpost, but we'll do a comparison with vanilla model2vec here. In a vanilla model2vec classifier, you just put a scikit-learn `LogisticRegressionCV` on top of the model encoder. In contrast, training a `StaticModelForClassification` fine-tunes the full model, including the `StaticModel` weights. | ||
|
||
We use 14 classification datasets, using 1000 examples from the train set, and the full test set. No parameters were tuned on any validation set. All datasets were taken from the [Setfit organization on Hugging Face](https://huggingface.co/datasets/SetFit). | ||
|
||
| dataset_name | model2vec logreg | setfit | model2vec full finetune | | ||
|:---------------------------|---------------------------------------------:|-------------------------------------------------:|--------------------------------------:| | ||
| 20_newgroups | 0.545312 | 0.595426 | 0.555459 | | ||
| ade | 0.715725 | 0.788789 | 0.740307 | | ||
| ag_news | 0.860154 | 0.880142 | 0.858304 | | ||
| amazon_counterfactual | 0.637754 | 0.873249 | 0.744288 | | ||
| bbc | 0.955719 | 0.965823 | 0.965018 | | ||
| emotion | 0.516267 | 0.598852 | 0.586328 | | ||
| enron_spam | 0.951975 | 0.974498 | 0.964994 | | ||
| hatespeech_offensive | 0.543758 | 0.659873 | 0.592587 | | ||
| imdb | 0.839002 | 0.860037 | 0.846198 | | ||
| massive_scenario | 0.797779 | 0.814601 | 0.822825 | | ||
| senteval_cr | 0.743436 | 0.8526 | 0.745863 | | ||
| sst5 | 0.290249 | 0.393179 | 0.363071 | | ||
| student | 0.806069 | 0.889399 | 0.837581 | | ||
| subj | 0.878394 | 0.937955 | 0.88941 | | ||
| tweet_sentiment_extraction | 0.638664 | 0.755296 | 0.632009 | | ||
|
||
| | logreg | full finetune | | ||
|:---------------------------|-----------:|---------------:| | ||
| average | 0.714 | 0.742 | | ||
|
||
As you can see, full fine-tuning brings modest performance improvements in some cases, but very large ones in other cases, leading to a pretty large increase in average score. Our advice is to test both if you can use `potion-base-8m`, and to use full fine-tuning if you are starting from another base model. | ||
|
||
# Bring your own architecture | ||
|
||
Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer `StaticModelForClassification`, but in the future we'll also offer regression, etc. | ||
|
||
The core functionality of the `StaticModelForClassification` is contained in a couple of functions: | ||
|
||
* `construct_head`: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior. | ||
* `train_test_split`: governs the train test split before classification. | ||
* `prepare_dataset`: Selects the `torch.Dataset` that will be used in the `Dataloader` during training. | ||
* `_encode`: The encoding function used in the model. | ||
* `fit`: contains all the lightning-related fitting logic. | ||
|
||
The training of the model is done in a `lighting.LightningModule`, which can be modified but is very basic. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from model2vec.utils import get_package_extras, importable | ||
|
||
_REQUIRED_EXTRA = "train" | ||
|
||
for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA): | ||
importable(extra_dependency, _REQUIRED_EXTRA) | ||
|
||
from model2vec.train.classifier import StaticModelForClassification | ||
|
||
__all__ = ["StaticModelForClassification"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we load it from the Hub? perhaps we should align the arguments a bit with the
transformers
naming given you've also adoptedfrom_pretrained
?For example using
pretrained_model_name_or_path
. https://huggingface.co/docs/transformers/v4.48.0/en/model_doc/auto#transformers.AutoTokenizer.from_pretrainedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from_pretrained loads from the hub. The arguments mimic the ones from
StaticModel
and, although they don't match transformers exactly, we're wary of introducing breaking changes.