Skip to content

Commit

Permalink
fix(transformers): support trust_remote_code and added unit tests (#…
Browse files Browse the repository at this point in the history
…4271)

* transformers import model tests

* accomodate for trust_remote_code

* address PR comments

* remove irrelevant logs

* add reference to issue opened in accelerate

* Update src/bentoml/_internal/frameworks/transformers.py

Co-authored-by: Aaron Pham <[email protected]>

* Update transformers.py

---------

Co-authored-by: Aaron Pham <[email protected]>
  • Loading branch information
MingLiangDai and aarnphm authored Nov 28, 2023
1 parent e9e4475 commit 2929789
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 23 deletions.
86 changes: 63 additions & 23 deletions src/bentoml/_internal/frameworks/transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
import logging
import os
import platform
Expand Down Expand Up @@ -112,7 +113,6 @@ def from_pretrained(
API_VERSION = "v2"
PIPELINE_PICKLE_NAME = f"pipeline.{API_VERSION}.pkl"
PRETRAINED_PROTOCOL_NAME = f"pretrained.{API_VERSION}.pkl"
CONFIG_JSON_FILE_NAME = "config.json"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -528,7 +528,7 @@ def load_model(bento_model: str | Tag | Model, *args: t.Any, **kwargs: t.Any) ->
raise


def make_default_signatures(pretrained: t.Any) -> ModelSignaturesType:
def make_default_signatures(pretrained_cls: t.Any) -> ModelSignaturesType:
default_config = ModelSignature(batchable=False)
infer_fn = ("__call__",)

Expand All @@ -540,14 +540,14 @@ def make_default_signatures(pretrained: t.Any) -> ModelSignaturesType:
)
return {}

if transformers.processing_utils.ProcessorMixin in pretrained.__class__.__bases__:
if transformers.processing_utils.ProcessorMixin in pretrained_cls.__bases__:
logger.info(
"Given '%s' extends the 'transformers.ProcessorMixin'. Make sure to specify the signatures manually if it has additional functions.",
pretrained.__class__.__name__,
pretrained_cls.__name__,
)
return {k: default_config for k in ("__call__", "batch_decode", "decode")}

if isinstance(pretrained, transformers.PreTrainedTokenizerBase):
if issubclass(pretrained_cls, transformers.PreTrainedTokenizerBase):
infer_fn = (
"__call__",
"tokenize",
Expand All @@ -566,7 +566,7 @@ def make_default_signatures(pretrained: t.Any) -> ModelSignaturesType:
"clean_up_tokenization",
"prepare_seq2seq_batch",
)
elif isinstance(pretrained, transformers.PreTrainedModel):
elif issubclass(pretrained_cls, transformers.PreTrainedModel):
infer_fn = (
"__call__",
"forward",
Expand All @@ -579,7 +579,7 @@ def make_default_signatures(pretrained: t.Any) -> ModelSignaturesType:
"group_beam_search",
"constrained_beam_search",
)
elif isinstance(pretrained, transformers.TFPreTrainedModel):
elif issubclass(pretrained_cls, transformers.TFPreTrainedModel):
infer_fn = (
"__call__",
"predict",
Expand All @@ -591,16 +591,18 @@ def make_default_signatures(pretrained: t.Any) -> ModelSignaturesType:
"beam_search",
"contrastive_search",
)
elif isinstance(pretrained, transformers.FlaxPreTrainedModel):
elif issubclass(pretrained_cls, transformers.FlaxPreTrainedModel):
infer_fn = ("__call__", "generate")
elif isinstance(pretrained, transformers.image_processing_utils.BaseImageProcessor):
elif issubclass(
pretrained_cls, transformers.image_processing_utils.BaseImageProcessor
):
infer_fn = ("__call__", "preprocess")
elif isinstance(pretrained, transformers.SequenceFeatureExtractor):
elif issubclass(pretrained_cls, transformers.SequenceFeatureExtractor):
infer_fn = ("pad",)
elif not isinstance(pretrained, transformers.Pipeline):
elif not issubclass(pretrained_cls, transformers.Pipeline):
logger.warning(
"Unable to infer default signatures for '%s'. Make sure to specify it manually.",
pretrained,
pretrained_cls,
)
return {}

Expand All @@ -611,11 +613,11 @@ def import_model(
name: Tag | str,
model_name_or_path: str | os.PathLike[str],
*,
pretrained_model_class: t.Type[BaseAutoModelClass] | None = None,
proxies: dict[str, str] | None = None,
revision: str = "main",
force_download: bool = False,
resume_download: bool = False,
trust_remote_code: bool = False,
clone_repository: bool = False,
sync_with_hub_version: bool = False,
signatures: ModelSignaturesType | None = None,
Expand All @@ -639,8 +641,6 @@ def import_model(
`CompVis/ldm-text2im-large-256`.
- A path to a *directory* containing weights saved using
[`~transformers.AutoModel.save_pretrained`], e.g., `./my_pretrained_directory/`.
pretrained_model_class:
The pretrained class/architecture to load the model weights. This determines what LM heads are available to the model.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
Expand All @@ -652,6 +652,10 @@ def import_model(
Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
resume_download (`boolean`, *optional*, defaults to False):
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
trust_remote_code (`boolean`, *optional*, defaults to False):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
clone_repository: (`boolean`, *optional*, defaults to False):
Download all files from the huggingface repository of the given model
sync_with_hub_version (`bool`, default to False):
Expand Down Expand Up @@ -742,18 +746,38 @@ def import_model(
framework_name="transformers", framework_versions=framework_versions
)

if trust_remote_code:
logger.warning(
"trust_remote_code is set to True. Bentoml will load the specified model into memory by default. To avoid loading the model to memory, try setting the keyword argument clone_repository=True."
)

config = transformers.AutoConfig.from_pretrained(
model_name_or_path,
proxies=proxies,
revision=revision,
force_download=force_download,
resume_download=resume_download,
trust_remote_code=trust_remote_code,
**extra_hf_hub_kwargs,
)

is_auto_class = False
pretrained_model_class = None
if getattr(config, "architectures", None):
try:
pretrained_model_class = getattr(transformers, config.architectures[0])
except AttributeError:
pass
if getattr(config, "auto_map", None):
for auto_class in config.auto_map:
if "AutoModel" in auto_class:
pretrained_model_class = getattr(transformers, auto_class)
is_auto_class = True
break

if pretrained_model_class is None:
pretrained_model_class = getattr(transformers, config.architectures[0])
logger.info(
f"pretrained_model_class is not provided, bentoml will create a model with the following pretrained model class {config.architectures[0]}. Available pretrained classes for this model: {config.architectures}."
raise BentoMLException(
"BentoML cannot automatically determine the pretrained model class/architecture for the given model."
)

model = None
Expand All @@ -768,13 +792,19 @@ def import_model(
if clone_repository:
from huggingface_hub import snapshot_download

# filter out kwargs as snapshot_download my not accept some kwargs
params = inspect.signature(snapshot_download).parameters.values()
param_names = {param.name for param in params}
input_params = dict(
filter(lambda x: x[0] in param_names, extra_hf_hub_kwargs.items())
)
src_dir = snapshot_download(
model_name_or_path,
proxies=proxies,
revision=revision,
force_download=force_download,
resume_download=resume_download,
**extra_hf_hub_kwargs,
**input_params,
)
else:
with init_empty_weights():
Expand All @@ -784,11 +814,12 @@ def import_model(
revision=revision,
force_download=force_download,
resume_download=resume_download,
trust_remote_code=trust_remote_code,
**extra_hf_hub_kwargs,
)
path_to_config = transformers.utils.cached_file(
model_name_or_path,
CONFIG_JSON_FILE_NAME,
transformers.CONFIG_NAME,
proxies=proxies,
revision=revision,
force_download=force_download,
Expand All @@ -812,7 +843,16 @@ def import_model(

if model is None:
with init_empty_weights():
model = pretrained_model_class(config=config)
if is_auto_class:
# NOTE: Under `init_empty_weights`, `.from_config` won't load the model weights. Whereas for `.from_pretrained`, transformers needs to find out what layers to load with the architecture, thereby loading the models into memory, then unload afterwards.
# See https://github.com/huggingface/accelerate/issues/2163 for more information.
model = pretrained_model_class.from_config(
trust_remote_code=trust_remote_code,
config=config,
**extra_hf_hub_kwargs,
)
else:
model = pretrained_model_class(config=config)

pretrained = t.cast("PreTrainedProtocol", model)
assert all(
Expand All @@ -822,7 +862,7 @@ def import_model(
metadata = {}

if signatures is None:
signatures = make_default_signatures(pretrained)
signatures = make_default_signatures(pretrained.__class__)
# NOTE: ``make_default_signatures`` can return an empty dict, hence we will only
# log when signatures are available.
if signatures:
Expand Down Expand Up @@ -995,7 +1035,7 @@ def save_model(
)

if signatures is None:
signatures = make_default_signatures(pretrained_or_pipeline)
signatures = make_default_signatures(pretrained_or_pipeline.__class__)
# NOTE: ``make_default_signatures`` can return an empty dict, hence we will only
# log when signatures are available.
if signatures:
Expand Down
80 changes: 80 additions & 0 deletions tests/integration/frameworks/test_transformers_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,83 @@ def test_custom_pipeline(pair_classification_pipeline: PairClassificationPipelin
) == pair_classification_pipeline("I hate you", second_text="I love you")

runner.destroy()


def test_import_model_with_synced_version():
revision = "3956d303d3cddf0708ff20660c1ea5f6ec30e434"
bento_model = bentoml.transformers.import_model(
"tiny-bert",
"hf-internal-testing/tiny-random-BertModel",
sync_with_hub_version=True,
revision=revision,
)

assert bento_model.tag.version == revision
bentoml.models.delete("tiny-bert:3956d303d3cddf0708ff20660c1ea5f6ec30e434")

revision = "3956d303d3cddf0708ff20660c1ea5f6ec30e434"
bento_model = bentoml.transformers.import_model(
"tiny-bert:asdf",
"hf-internal-testing/tiny-random-BertModel",
sync_with_hub_version=True,
revision=revision,
)

assert bento_model.tag.version == revision
bentoml.models.delete("tiny-bert:3956d303d3cddf0708ff20660c1ea5f6ec30e434")

revision = "3956d303d3cddf0708ff20660c1ea5f6ec30e434"
bento_model = bentoml.transformers.import_model(
"tiny-bert:asdf",
"hf-internal-testing/tiny-random-BertModel",
revision=revision,
)

assert bento_model.tag.version == "asdf"
bentoml.models.delete("tiny-bert:asdf")


def test_import_model_has_required_files():
revision = "3956d303d3cddf0708ff20660c1ea5f6ec30e434"
labels = {"framework": "transformers"}
bento_model = bentoml.transformers.import_model(
"tiny-bert",
"hf-internal-testing/tiny-random-BertModel",
sync_with_hub_version=True,
revision=revision,
labels=labels,
)

assert os.path.exists(bento_model.path_of("config.json"))
assert os.path.exists(bento_model.path_of("pytorch_model.bin"))
assert os.path.exists(bento_model.path_of("pretrained.v2.pkl"))
assert os.path.exists(bento_model.path_of("model.yaml"))
assert bento_model.info.labels == labels
bentoml.models.delete("tiny-bert:3956d303d3cddf0708ff20660c1ea5f6ec30e434")


def test_import_model_with_pretrained_class():
revision = "3956d303d3cddf0708ff20660c1ea5f6ec30e434"
pretrained_class = transformers.BertForMaskedLM
bento_model = bentoml.transformers.import_model(
"tiny-bert",
"hf-internal-testing/tiny-random-BertModel",
sync_with_hub_version=True,
pretrained_model_class=pretrained_class,
revision=revision,
)

assert bento_model.info.metadata["_pretrained_class"] == pretrained_class.__name__
bentoml.models.delete("tiny-bert:3956d303d3cddf0708ff20660c1ea5f6ec30e434")

pretrained_class = transformers.BertForNextSentencePrediction
bento_model = bentoml.transformers.import_model(
"tiny-bert",
"hf-internal-testing/tiny-random-BertModel",
sync_with_hub_version=True,
pretrained_model_class=pretrained_class,
revision=revision,
)

assert bento_model.info.metadata["_pretrained_class"] == pretrained_class.__name__
bentoml.models.delete("tiny-bert:3956d303d3cddf0708ff20660c1ea5f6ec30e434")

0 comments on commit 2929789

Please sign in to comment.